Skip to content

Commit b78cc97

Browse files
committed
Add dsl inside merge that will allow lambdas like fold, selectBy, selectLast
1 parent 66ed9f9 commit b78cc97

File tree

5 files changed

+138
-47
lines changed

5 files changed

+138
-47
lines changed
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
package ai.koog.agents.core.dsl.builder
2+
3+
import ai.koog.agents.core.agent.config.AIAgentConfigBase
4+
import ai.koog.agents.core.agent.context.AIAgentContextBase
5+
import ai.koog.agents.core.agent.context.AIAgentLLMContext
6+
import ai.koog.agents.core.agent.entity.AIAgentStateManager
7+
import ai.koog.agents.core.agent.entity.AIAgentStorage
8+
import ai.koog.agents.core.agent.entity.AIAgentStorageKey
9+
import ai.koog.agents.core.annotation.InternalAgentsApi
10+
import ai.koog.agents.core.environment.AIAgentEnvironment
11+
import ai.koog.agents.core.feature.AIAgentFeature
12+
import ai.koog.agents.core.feature.AIAgentPipeline
13+
import ai.koog.agents.core.tools.ToolDescriptor
14+
import kotlin.uuid.ExperimentalUuidApi
15+
import kotlin.uuid.Uuid
16+
17+
/**
18+
* Context for merging parallel node execution results.
19+
*
20+
* This class provides DSL methods for selecting and folding results from parallel node executions.
21+
* It delegates all AIAgentContextBase methods and properties to the underlying context.
22+
*
23+
* @param Input The input type of the parallel nodes
24+
* @param Output The output type of the parallel nodes
25+
* @property underlyingContextBase The underlying context to delegate to
26+
* @property results The results of the parallel node executions
27+
*/
28+
@OptIn(ExperimentalUuidApi::class, InternalAgentsApi::class)
29+
public class AIAgentParallelNodesMergeContext<Input, Output>(
30+
private val underlyingContextBase: AIAgentContextBase,
31+
public val results: List<ParallelResult<Input, Output>>
32+
) : AIAgentContextBase {
33+
// Delegate all properties to the underlying context
34+
override val environment: AIAgentEnvironment get() = underlyingContextBase.environment
35+
override val agentInput: String get() = underlyingContextBase.agentInput
36+
override val config: AIAgentConfigBase get() = underlyingContextBase.config
37+
override val llm: AIAgentLLMContext get() = underlyingContextBase.llm
38+
override val stateManager: AIAgentStateManager get() = underlyingContextBase.stateManager
39+
override val storage: AIAgentStorage get() = underlyingContextBase.storage
40+
override val sessionUuid: Uuid get() = underlyingContextBase.sessionUuid
41+
override val strategyId: String get() = underlyingContextBase.strategyId
42+
override val pipeline: AIAgentPipeline get() = underlyingContextBase.pipeline
43+
44+
// Delegate all methods to the underlying context
45+
override fun <Feature : Any> feature(key: AIAgentStorageKey<Feature>): Feature? =
46+
underlyingContextBase.feature(key)
47+
48+
override fun <Feature : Any> feature(feature: AIAgentFeature<*, Feature>): Feature? =
49+
underlyingContextBase.feature(feature)
50+
51+
override fun <Feature : Any> featureOrThrow(feature: AIAgentFeature<*, Feature>): Feature =
52+
underlyingContextBase.featureOrThrow(feature)
53+
54+
override fun copyWithTools(tools: List<ToolDescriptor>): AIAgentContextBase =
55+
underlyingContextBase.copyWithTools(tools)
56+
57+
override fun copy(
58+
environment: AIAgentEnvironment?,
59+
agentInput: String?,
60+
config: AIAgentConfigBase?,
61+
llm: AIAgentLLMContext?,
62+
stateManager: AIAgentStateManager?,
63+
storage: AIAgentStorage?,
64+
sessionUuid: Uuid?,
65+
strategyId: String?,
66+
pipeline: AIAgentPipeline?
67+
): AIAgentContextBase = underlyingContextBase.copy(
68+
environment, agentInput, config, llm, stateManager,
69+
storage, sessionUuid, strategyId, pipeline
70+
)
71+
72+
override suspend fun fork(): AIAgentContextBase = underlyingContextBase.fork()
73+
74+
override suspend fun replace(context: AIAgentContextBase): Unit = underlyingContextBase.replace(context)
75+
76+
/**
77+
* Selects a result based on a predicate.
78+
*
79+
* @param predicate The predicate to use for selection
80+
* @return The NodeExecutionResult with the selected output and context
81+
* @throws NoSuchElementException if no result matches the predicate
82+
*/
83+
public fun selectBy(predicate: (Output) -> Boolean): NodeExecutionResult<Output> {
84+
return results.first(predicate = { predicate(it.result.output) }).result
85+
}
86+
87+
/**
88+
* Folds the result output into a single value and leaves the base context.
89+
*
90+
* @param initial The initial value for the fold operation
91+
* @param operation The operation to apply to each result
92+
* @return The NodeExecutionResult with the folded output and the context from the first result
93+
* @throws NoSuchElementException if the results list is empty
94+
*/
95+
public fun <R> fold(
96+
initial: R,
97+
operation: (acc: R, result: Output) -> R
98+
): NodeExecutionResult<R> {
99+
val folded = results.map { it.result.output }.fold(initial, operation)
100+
return NodeExecutionResult(folded, underlyingContextBase)
101+
}
102+
}

agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/dsl/builder/AIAgentSubgraphBuilder.kt

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,12 +135,11 @@ public abstract class AIAgentSubgraphBuilderBase<Input, Output> {
135135
*/
136136
public fun <Input, Output> merge(
137137
name: String? = null,
138-
execute: suspend AIAgentContextBase.(List<ParallelResult<Input, Output>>) -> NodeExecutionResult<Output>,
138+
execute: suspend AIAgentParallelNodesMergeContext<Input, Output>.() -> NodeExecutionResult<Output>,
139139
): AIAgentNodeDelegateBase<List<AsyncParallelResult<Input, Output>>, Output> {
140140
return AIAgentNodeDelegate(name, AIAgentParallelMergeNodeBuilder(execute))
141141
}
142142

143-
144143
/**
145144
* Creates an edge between nodes.
146145
* @param edgeIntermediate Intermediate edge builder
@@ -371,10 +370,12 @@ public class AIAgentParallelTransformNodeBuilder<Input, OldOutput, NewOutput> in
371370
*/
372371
@OptIn(ExperimentalUuidApi::class)
373372
public class AIAgentParallelMergeNodeBuilder<Input, Output> internal constructor(
374-
private val merge: suspend AIAgentContextBase.(List<ParallelResult<Input, Output>>) -> NodeExecutionResult<Output>,
373+
private val merge: suspend AIAgentParallelNodesMergeContext<Input, Output>.() -> NodeExecutionResult<Output>,
375374
) : AIAgentNodeBuilder<List<AsyncParallelResult<Input, Output>>, Output>(
376375
execute = { input ->
377-
val result = merge(input.map { it.await() })
376+
val parallelResults = input.map { it.await() }
377+
val mergeContext = AIAgentParallelNodesMergeContext(this, parallelResults)
378+
val result = with(mergeContext) { merge() }
378379
this.replace(result.context)
379380

380381
result.output

agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/dsl/extension/ParallelNodesTest.kt

Lines changed: 25 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@ package ai.koog.agents.core.dsl.extension
33
import ai.koog.agents.core.agent.AIAgent
44
import ai.koog.agents.core.agent.config.AIAgentConfig
55
import ai.koog.agents.core.agent.entity.AIAgentStorageKey
6+
import ai.koog.agents.core.dsl.builder.NodeExecutionResult
67
import ai.koog.agents.core.dsl.builder.forwardTo
78
import ai.koog.agents.core.dsl.builder.strategy
89
import ai.koog.agents.core.tools.ToolRegistry
9-
import ai.koog.agents.features.eventHandler.feature.EventHandler
1010
import ai.koog.agents.testing.tools.DummyTool
1111
import ai.koog.agents.testing.tools.getMockExecutor
1212
import ai.koog.agents.testing.tools.mockLLMAnswer
@@ -16,12 +16,12 @@ import kotlinx.coroutines.test.runTest
1616
import kotlin.test.Test
1717
import kotlin.test.assertEquals
1818
import kotlin.test.assertFalse
19-
import kotlin.test.assertTrue
19+
import kotlin.test.assertNotNull
2020

2121
class ParallelNodesTest {
2222

2323
@Test
24-
fun testContextSubstitution() = runTest {
24+
fun testParallelTransformMergeFold() = runTest {
2525
// Create a key to store and retrieve values from the context
2626
val testKey = AIAgentStorageKey<String>("testKey")
2727

@@ -43,33 +43,28 @@ class ParallelNodesTest {
4343
}
4444

4545
// Create a parallel node that executes all three nodes
46-
val parallelNode by parallel<Unit, String>(
46+
val parallelNode by parallel(
4747
node1, node2, node3,
4848
name = "parallelNode",
4949
)
5050

51-
val reduceNode by merge<Unit, String>(name = "reduceNode") { results ->
52-
// Use the context from the third node (node3)
53-
val nodeResult = results.find { it.nodeName == "node3" }!!
54-
nodeResult.result.context to nodeResult.result.output
55-
}
56-
5751
// Node to verify the context after parallel execution
58-
val verifyNode by node<String, String>("verifyNode") { input ->
52+
val verifyNode by transform<Unit, String, String>("verifyNode") { input ->
5953
// The context should have been replaced with node3's context
60-
val value = storage.get(testKey)
61-
"$input, context value: $value"
54+
input + " with value: " + storage.get(testKey)
55+
}
56+
57+
val reduceNode by merge<Unit, String>(name = "reduceNode") {
58+
fold("All results:\n") { acc, output -> acc + output + "\n" }
6259
}
6360

6461
// Connect the nodes
6562
edge(nodeStart forwardTo parallelNode transformed { })
66-
edge(parallelNode forwardTo reduceNode)
67-
edge(reduceNode forwardTo verifyNode)
68-
edge(verifyNode forwardTo nodeFinish)
63+
edge(parallelNode forwardTo verifyNode)
64+
edge(verifyNode forwardTo reduceNode)
65+
edge(reduceNode forwardTo nodeFinish)
6966
}
7067

71-
val results = mutableListOf<String?>()
72-
7368
val agentConfig = AIAgentConfig(
7469
prompt = prompt("test-agent") {},
7570
model = OllamaModels.Meta.LLAMA_3_2,
@@ -80,30 +75,22 @@ class ParallelNodesTest {
8075
mockLLMAnswer("Default test response").asDefaultResponse
8176
}
8277

83-
val runner = AIAgent(
78+
val agent = AIAgent(
8479
promptExecutor = testExecutor,
8580
strategy = agentStrategy,
8681
agentConfig = agentConfig,
8782
toolRegistry = ToolRegistry.Companion {
8883
tool(DummyTool())
8984
}
90-
) {
91-
install(EventHandler.Feature) {
92-
onAgentFinished = { _, result -> results += result }
93-
}
94-
}
95-
96-
runner.run("")
85+
)
9786

98-
// Verify that we have one result
99-
assertEquals(1, results.size)
87+
val result = agent.runAndGetResult("")
10088

101-
// Verify that the context was properly substituted (should contain value3)
102-
val result = results.first() ?: ""
103-
assertTrue(
104-
result.contains("context value: value3"),
105-
"Result should contain 'context value: value3', but was: $result"
106-
)
89+
assertNotNull(result)
90+
assertEquals("All results:\n" +
91+
"Result from node1 with value: value1\n" +
92+
"Result from node2 with value: value2\n" +
93+
"Result from node3 with value: value3\n", result)
10794
}
10895

10996
@Test
@@ -147,8 +134,8 @@ class ParallelNodesTest {
147134
)
148135

149136
// Create nodes to verify the context isolation during parallel execution
150-
val verifyNode by merge<Unit, String>("verifyNode") { results ->
151-
this to results.map {
137+
val verifyNode by merge<Unit, String>("verifyNode") {
138+
val output = results.map {
152139
// This node should only see the changes from node1
153140
val value1 = it.result.context.storage.get(testKey1)
154141
val value2 = it.result.context.storage.get(testKey2)
@@ -194,6 +181,8 @@ class ParallelNodesTest {
194181

195182
"Correct: Node ${it.nodeName} sees no changes from other nodes"
196183
}.joinToString("\n")
184+
185+
NodeExecutionResult(output, this)
197186
}
198187

199188
// Connect the nodes

examples/src/main/kotlin/ai/koog/agents/example/parallelexecution/BestJokeAgent.kt

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package ai.koog.agents.example.parallelexecution
33
import ai.koog.agents.core.agent.AIAgent
44
import ai.koog.agents.core.agent.config.AIAgentConfig
55
import ai.koog.agents.core.dsl.builder.NodeExecutionResult
6-
import ai.koog.agents.core.dsl.builder.forwardTo
76
import ai.koog.agents.core.dsl.builder.strategy
87
import ai.koog.agents.core.tools.ToolRegistry
98
import ai.koog.agents.core.tools.annotations.LLMDescription
@@ -82,11 +81,12 @@ fun main(args: Array<String>) = runBlocking {
8281
"My favorite joke: $joke"
8382
}
8483

85-
val nodeSelectBestJoke by merge<String, String>() { results ->
84+
val nodeSelectBestJoke by merge<String, String>() {
8685
val results = results.map { it.result }
8786
val context = results.map { it.context }
8887
val jokes = results.map { it.output }
8988

89+
// Use LLM to determine the best joke
9090
val bestJokeIndex = this.llm.writeSession {
9191
model = OpenAIModels.Chat.GPT4o
9292
updatePrompt {
@@ -104,7 +104,6 @@ fun main(args: Array<String>) = runBlocking {
104104
}
105105
}
106106

107-
108107
val response = requestLLMStructured(JsonStructuredData.createJsonStructure<JokeWinner>())
109108
val bestJoke = response.getOrNull()!!.structure
110109
bestJoke.index
@@ -116,7 +115,7 @@ fun main(args: Array<String>) = runBlocking {
116115
nodeStart then nodeGenerateJokes then nodeTransformJoke then nodeSelectBestJoke then nodeFinish
117116
}
118117

119-
// Create agent config
118+
// Create agent config
120119
val agentConfig = AIAgentConfig(
121120
prompt = prompt("best-joke-agent") {
122121
system("You are a joke generator that creates the best jokes about given topics.")
@@ -125,7 +124,7 @@ fun main(args: Array<String>) = runBlocking {
125124
maxAgentIterations = 10
126125
)
127126

128-
// Create the agent
127+
// Create the agent
129128
val agent = AIAgent(
130129
promptExecutor = MultiLLMPromptExecutor(
131130
LLMProvider.OpenAI to OpenAILLMClient(ApiKeyService.openAIApiKey),
@@ -141,7 +140,7 @@ fun main(args: Array<String>) = runBlocking {
141140
val topic = "programming"
142141
println("Generating jokes about: $topic")
143142

144-
// Run the agent
143+
// Run the agent
145144
val result = agent.run(topic)
146145
println("Final result: $result")
147146
}

prompt/prompt-model/src/commonMain/kotlin/ai/koog/prompt/dsl/Prompt.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ import kotlin.time.Duration
1515
*
1616
* @property messages The list of [Message] objects associated with the prompt.
1717
* @property id The unique identifier for the prompt.
18-
* @property params The language model pa parameters associated with the prompt. Defaults to [LLMParams].
18+
* @property params The language model parameters associated with the prompt. Defaults to [LLMParams].
1919
*/
2020
@Serializable
2121
public data class Prompt(

0 commit comments

Comments
 (0)