Skip to content

Commit 1e62ca8

Browse files
committed
Add mutable context
1 parent b2192e7 commit 1e62ca8

File tree

15 files changed

+594
-573
lines changed

15 files changed

+594
-573
lines changed

agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/AIAgent.kt

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,14 @@ package ai.koog.agents.core.agent
22

33
import ai.koog.agents.core.agent.config.AIAgentConfig
44
import ai.koog.agents.core.agent.config.AIAgentConfigBase
5-
import ai.koog.agents.core.agent.entity.*
65
import ai.koog.agents.core.agent.context.AIAgentContext
6+
import ai.koog.agents.core.agent.context.AIAgentLLMContext
7+
import ai.koog.agents.core.agent.entity.AIAgentStateManager
8+
import ai.koog.agents.core.agent.entity.AIAgentStorage
9+
import ai.koog.agents.core.agent.entity.AIAgentStrategy
10+
import ai.koog.agents.core.dsl.builder.forwardTo
11+
import ai.koog.agents.core.dsl.builder.strategy
12+
import ai.koog.agents.core.dsl.extension.*
713
import ai.koog.agents.core.environment.AIAgentEnvironment
814
import ai.koog.agents.core.environment.AIAgentEnvironmentUtils.mapToToolResult
915
import ai.koog.agents.core.environment.ReceivedToolResult
@@ -19,14 +25,6 @@ import ai.koog.agents.core.tools.*
1925
import ai.koog.agents.core.tools.annotations.InternalAgentToolsApi
2026
import ai.koog.agents.features.common.config.FeatureConfig
2127
import ai.koog.agents.utils.Closeable
22-
import ai.koog.agents.core.agent.context.AIAgentLLMContext
23-
import ai.koog.agents.core.dsl.builder.forwardTo
24-
import ai.koog.agents.core.dsl.builder.strategy
25-
import ai.koog.agents.core.dsl.extension.nodeExecuteTool
26-
import ai.koog.agents.core.dsl.extension.nodeLLMRequest
27-
import ai.koog.agents.core.dsl.extension.nodeLLMSendToolResult
28-
import ai.koog.agents.core.dsl.extension.onAssistantMessage
29-
import ai.koog.agents.core.dsl.extension.onToolCall
3028
import ai.koog.prompt.dsl.prompt
3129
import ai.koog.prompt.executor.model.PromptExecutor
3230
import ai.koog.prompt.llm.LLModel
@@ -184,9 +182,9 @@ public open class AIAgent(
184182
val preparedEnvironment = pipeline.transformEnvironment(strategy, this, this)
185183

186184
val agentContext = AIAgentContext(
187-
preparedEnvironment,
185+
environment = preparedEnvironment,
188186
agentInput = agentInput,
189-
agentConfig,
187+
config = agentConfig,
190188
llm = AIAgentLLMContext(
191189
toolRegistry.tools.map { it.descriptor },
192190
toolRegistry,
@@ -431,16 +429,16 @@ public open class AIAgent(
431429
}
432430

433431
/**
434-
* Creates a single-run strategy for an AI agent.
435-
* This strategy defines a simple execution flow where the agent processes input,
436-
* calls tools, and sends results back to the agent.
437-
* The flow consists of the following steps:
438-
* 1. Start the agent.
439-
* 2. Call the LLM with the input.
440-
* 3. Execute a tool based on the LLM's response.
441-
* 4. Send the tool result back to the LLM.
442-
* 5. Repeat until LLM indicates no further tool calls are needed or the agent finishes.
443-
*/
432+
* Creates a single-run strategy for an AI agent.
433+
* This strategy defines a simple execution flow where the agent processes input,
434+
* calls tools, and sends results back to the agent.
435+
* The flow consists of the following steps:
436+
* 1. Start the agent.
437+
* 2. Call the LLM with the input.
438+
* 3. Execute a tool based on the LLM's response.
439+
* 4. Send the tool result back to the LLM.
440+
* 5. Repeat until LLM indicates no further tool calls are needed or the agent finishes.
441+
*/
444442
public fun singleRunStrategy(): AIAgentStrategy = strategy("single_run") {
445443
val nodeCallLLM by nodeLLMRequest("sendInput")
446444
val nodeExecuteTool by nodeExecuteTool("nodeExecuteTool")

agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/context/AIAgentContext.kt

Lines changed: 67 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import ai.koog.agents.core.environment.AIAgentEnvironment
99
import ai.koog.agents.core.feature.AIAgentFeature
1010
import ai.koog.agents.core.feature.AIAgentPipeline
1111
import ai.koog.agents.core.tools.ToolDescriptor
12+
import ai.koog.agents.core.utils.RWLock
1213
import kotlin.uuid.ExperimentalUuidApi
1314
import kotlin.uuid.Uuid
1415

@@ -30,57 +31,64 @@ import kotlin.uuid.Uuid
3031
* @param pipeline The AI agent pipeline responsible for coordinating AI agent execution and processing.
3132
*/
3233
@OptIn(ExperimentalUuidApi::class)
33-
internal class AIAgentContext(
34-
environment: AIAgentEnvironment,
35-
agentInput: String,
36-
config: AIAgentConfigBase,
34+
public class AIAgentContext(
35+
override val environment: AIAgentEnvironment,
36+
override val agentInput: String,
37+
override val config: AIAgentConfigBase,
3738
llm: AIAgentLLMContext,
3839
stateManager: AIAgentStateManager,
3940
storage: AIAgentStorage,
40-
sessionUuid: Uuid,
41-
strategyId: String,
41+
override val sessionUuid: Uuid,
42+
override val strategyId: String,
4243
@OptIn(InternalAgentsApi::class)
43-
pipeline: AIAgentPipeline,
44+
override val pipeline: AIAgentPipeline,
4445
) : AIAgentContextBase {
45-
private var _environment: AIAgentEnvironment = environment
46-
private var _agentInput: String = agentInput
47-
private var _config: AIAgentConfigBase = config
48-
private var _llm: AIAgentLLMContext = llm
49-
private var _stateManager: AIAgentStateManager = stateManager
50-
private var _storage: AIAgentStorage = storage
51-
private var _sessionUuid: Uuid = sessionUuid
52-
private var _strategyId: String = strategyId
5346

54-
@OptIn(InternalAgentsApi::class)
55-
private var _pipeline: AIAgentPipeline = pipeline
56-
57-
override val environment: AIAgentEnvironment
58-
get() = _environment
59-
60-
override val agentInput: String
61-
get() = _agentInput
47+
/**
48+
* Mutable wrapper for AI agent context properties.
49+
*/
50+
internal class MutableAIAgentContext(
51+
var llm: AIAgentLLMContext,
52+
var stateManager: AIAgentStateManager,
53+
var storage: AIAgentStorage,
54+
) {
55+
private val rwLock = RWLock()
56+
57+
/**
58+
* Creates a copy of the current [MutableAIAgentContext].
59+
* @return A new instance of [MutableAIAgentContext] with copies of all mutable properties.
60+
*/
61+
suspend fun copy(): MutableAIAgentContext {
62+
return rwLock.withReadLock {
63+
MutableAIAgentContext(llm.copy(), stateManager.copy(), storage.copy())
64+
}
65+
}
66+
67+
/**
68+
* Replaces the current context with the provided context.
69+
* @param llm The LLM context to replace the current context with.
70+
* @param stateManager The state manager to replace the current context with.
71+
* @param storage The storage to replace the current context with.
72+
*/
73+
suspend fun replace(llm: AIAgentLLMContext?, stateManager: AIAgentStateManager?, storage: AIAgentStorage?) {
74+
rwLock.withWriteLock {
75+
llm?.let { this.llm = llm }
76+
stateManager?.let { this.stateManager = stateManager }
77+
storage?.let { this.storage = storage }
78+
}
79+
}
80+
}
6281

63-
override val config: AIAgentConfigBase
64-
get() = _config
82+
private val mutableAIAgentContext = MutableAIAgentContext(llm, stateManager, storage)
6583

6684
override val llm: AIAgentLLMContext
67-
get() = _llm
68-
69-
override val stateManager: AIAgentStateManager
70-
get() = _stateManager
85+
get() = mutableAIAgentContext.llm
7186

7287
override val storage: AIAgentStorage
73-
get() = _storage
74-
75-
override val sessionUuid: Uuid
76-
get() = _sessionUuid
88+
get() = mutableAIAgentContext.storage
7789

78-
override val strategyId: String
79-
get() = _strategyId
80-
81-
@OptIn(InternalAgentsApi::class)
82-
override val pipeline: AIAgentPipeline
83-
get() = _pipeline
90+
override val stateManager: AIAgentStateManager
91+
get() = mutableAIAgentContext.stateManager
8492

8593
/**
8694
* A map storing features associated with the current AI agent context.
@@ -160,21 +168,29 @@ internal class AIAgentContext(
160168
pipeline = pipeline ?: @OptIn(InternalAgentsApi::class) this.pipeline,
161169
)
162170

171+
/**
172+
* Creates a copy of the current [AIAgentContext] with deep copies of all mutable properties.
173+
*
174+
* @return A new instance of [AIAgentContext] with copies of all mutable properties.
175+
*/
176+
override suspend fun fork(): AIAgentContextBase = copy(
177+
llm = this.llm.copy(),
178+
storage = this.storage.copy(),
179+
stateManager = this.stateManager.copy(),
180+
)
181+
163182
/**
164183
* Replaces the current context with the provided context.
184+
* This method is used to update the current context with values from another context,
185+
* particularly useful in scenarios like parallel node execution where contexts need to be merged.
165186
*
166-
* @param context The context to replace the current context with.
167-
* @throws UnsupportedOperationException This method is not fully implemented due to the constraints of immutable properties.
187+
* @param context The context to replace the current context with.]]
168188
*/
169-
override fun replaceWith(context: AIAgentContextBase) {
170-
_environment = context.environment
171-
_agentInput = context.agentInput
172-
_config = context.config
173-
_llm = context.llm
174-
_stateManager = context.stateManager
175-
_storage = context.storage
176-
_sessionUuid = context.sessionUuid
177-
_strategyId = context.strategyId
178-
_pipeline = @OptIn(InternalAgentsApi::class) this.pipeline
189+
override suspend fun replace(context: AIAgentContextBase) {
190+
mutableAIAgentContext.replace(
191+
context.llm,
192+
context.stateManager,
193+
context.storage
194+
)
179195
}
180196
}

agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/context/AIAgentContextBase.kt

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ public interface AIAgentContextBase {
8080
*/
8181
public val storage: AIAgentStorage
8282

83-
// TODO: use Uuid?
8483
/**
8584
* A unique identifier for the current session associated with the AI agent context.
8685
* Used to track and differentiate sessions within the execution of the agent pipeline.
@@ -111,6 +110,7 @@ public interface AIAgentContextBase {
111110
@InternalAgentsApi
112111
public val pipeline: AIAgentPipeline
113112

113+
114114
/**
115115
* Retrieves a feature from the agent's storage using the specified key.
116116
*
@@ -181,12 +181,21 @@ public interface AIAgentContextBase {
181181
pipeline: AIAgentPipeline? = null,
182182
): AIAgentContextBase
183183

184+
/**
185+
* Creates a copy of the current [AIAgentContext] with deep copies of all mutable properties.
186+
* This method is particularly useful in scenarios like parallel node execution
187+
* where contexts need to be sent to different threads and then merged back together using [replace].
188+
*
189+
* @return A new instance of [AIAgentContext] with copies of all mutable properties.
190+
*/
191+
public suspend fun fork(): AIAgentContextBase
192+
184193
/**
185194
* Replaces the current context with the provided context.
186195
* This method is used to update the current context with values from another context,
187196
* particularly useful in scenarios like parallel node execution where contexts need to be merged.
188197
*
189198
* @param context The context to replace the current context with.
190199
*/
191-
public fun replaceWith(context: AIAgentContextBase)
200+
public suspend fun replace(context: AIAgentContextBase)
192201
}

agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/context/AIAgentLLMContext.kt

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,17 @@ public data class AIAgentLLMContext(
3636
private val clock: Clock
3737
) {
3838

39+
/**
40+
* Creates a deep copy of this LLM context.
41+
*
42+
* @return A new instance of [AIAgentLLMContext] with deep copies of mutable properties.
43+
*/
44+
public suspend fun copy(): AIAgentLLMContext {
45+
return rwLock.withReadLock {
46+
this.copy()
47+
}
48+
}
49+
3950
private val rwLock = RWLock()
4051

4152
/**

agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/entity/AIAgentState.kt

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,12 @@ internal class AIAgentState internal constructor(
1515
override fun close() {
1616
isActive = false
1717
}
18+
19+
internal fun copy(): AIAgentState {
20+
return AIAgentState(
21+
iterations = iterations
22+
)
23+
}
1824
}
1925

2026
/**
@@ -44,4 +50,10 @@ public class AIAgentStateManager internal constructor(
4450

4551
result
4652
}
53+
54+
public suspend fun copy(): AIAgentStateManager {
55+
return withStateLock {
56+
AIAgentStateManager(state.copy())
57+
}
58+
}
4759
}

agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/entity/AIAgentStorage.kt

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,17 @@ public class AIAgentStorage internal constructor() {
3131
private val mutex = Mutex()
3232
private val storage = mutableMapOf<AIAgentStorageKey<*>, Any>()
3333

34+
/**
35+
* Creates a deep copy of this storage.
36+
*
37+
* @return A new instance of [AIAgentStorage] with the same content as this one.
38+
*/
39+
internal suspend fun copy(): AIAgentStorage {
40+
val newStorage = AIAgentStorage()
41+
newStorage.putAll(this.toMap())
42+
return newStorage
43+
}
44+
3445
/**
3546
* Sets the value associated with the given key in the storage.
3647
*

agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/dsl/builder/AIAgentSubgraphBuilder.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ public class ParallelAIAgentNodeBuilder<Input, Output> internal constructor(
280280
val mapResults = supervisorScope {
281281
nodes.map { node ->
282282
async(dispatcher) {
283-
val nodeContext = initialContext.copy()
283+
val nodeContext = (initialContext as? ai.koog.agents.core.agent.context.AIAgentContext)?.fork() ?: initialContext.fork()
284284
val result = node.execute(nodeContext, input)
285285
ParallelNodeResult(node.name, input, nodeContext, result)
286286
}
@@ -302,7 +302,7 @@ public class ReduceAIAgentNodeBuilder<Input, Output> internal constructor(
302302
) : AIAgentNodeBuilder<List<ParallelNodeResult<Input, Output>>, Output>(
303303
execute = { input ->
304304
val (context, output) = execute(input)
305-
this.replaceWith(context)
305+
this.replace(context)
306306

307307
output
308308
}

agents/agents-features/agents-features-memory/src/jvmTest/kotlin/ai/koog/agents/memory/AIAgentMemoryTest.kt

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@ package ai.koog.agents.memory
33
import ai.koog.agents.core.agent.config.AIAgentConfig
44
import ai.koog.agents.core.agent.context.AIAgentLLMContext
55
import ai.koog.agents.core.agent.session.AIAgentLLMWriteSession
6+
import ai.koog.agents.core.tools.ToolRegistry
67
import ai.koog.agents.memory.config.MemoryScopeType
78
import ai.koog.agents.memory.config.MemoryScopesProfile
89
import ai.koog.agents.memory.feature.AgentMemory
910
import ai.koog.agents.memory.model.*
1011
import ai.koog.agents.memory.providers.AgentMemoryProvider
1112
import ai.koog.agents.memory.providers.NoMemory
13+
import ai.koog.agents.testing.tools.MockEnvironment
1214
import ai.koog.prompt.dsl.Prompt
1315
import ai.koog.prompt.dsl.PromptBuilder
1416
import ai.koog.prompt.dsl.prompt
@@ -121,7 +123,7 @@ class AIAgentMemoryTest {
121123
prompt = prompt("test") { },
122124
model = testModel,
123125
promptExecutor = promptExecutor,
124-
environment = MockAgentEnvironment(),
126+
environment = MockEnvironment(toolRegistry = ToolRegistry.EMPTY, promptExecutor),
125127
config = AIAgentConfig(Prompt.Empty, testModel, 100),
126128
clock = testClock
127129
)
@@ -212,7 +214,7 @@ class AIAgentMemoryTest {
212214
prompt = prompt("test") { },
213215
model = testModel,
214216
promptExecutor = promptExecutor,
215-
environment = MockAgentEnvironment(),
217+
environment = MockEnvironment(toolRegistry = ToolRegistry.EMPTY, promptExecutor),
216218
config = AIAgentConfig(Prompt.Empty, testModel, 100),
217219
clock = testClock
218220
)
@@ -335,7 +337,7 @@ class AIAgentMemoryTest {
335337
prompt = prompt("test") { },
336338
model = testModel,
337339
promptExecutor = promptExecutor,
338-
environment = MockAgentEnvironment(),
340+
environment = MockEnvironment(toolRegistry = ToolRegistry.EMPTY, promptExecutor),
339341
config = AIAgentConfig(Prompt.Empty, testModel, 100),
340342
clock = testClock
341343
)
@@ -406,7 +408,7 @@ class AIAgentMemoryTest {
406408
prompt = prompt("test") { },
407409
model = testModel,
408410
promptExecutor = promptExecutor,
409-
environment = MockAgentEnvironment(),
411+
environment = MockEnvironment(toolRegistry = ToolRegistry.EMPTY, promptExecutor),
410412
config = AIAgentConfig(Prompt.Empty, testModel, 100),
411413
clock = testClock
412414
)

0 commit comments

Comments
 (0)