Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,14 @@ package ai.koog.agents.core.agent

import ai.koog.agents.core.agent.config.AIAgentConfig
import ai.koog.agents.core.agent.config.AIAgentConfigBase
import ai.koog.agents.core.agent.entity.*
import ai.koog.agents.core.agent.context.AIAgentContext
import ai.koog.agents.core.agent.context.AIAgentLLMContext
import ai.koog.agents.core.agent.entity.AIAgentStateManager
import ai.koog.agents.core.agent.entity.AIAgentStorage
import ai.koog.agents.core.agent.entity.AIAgentStrategy
import ai.koog.agents.core.dsl.builder.forwardTo
import ai.koog.agents.core.dsl.builder.strategy
import ai.koog.agents.core.dsl.extension.*
import ai.koog.agents.core.environment.AIAgentEnvironment
import ai.koog.agents.core.environment.AIAgentEnvironmentUtils.mapToToolResult
import ai.koog.agents.core.environment.ReceivedToolResult
Expand All @@ -19,14 +25,6 @@ import ai.koog.agents.core.tools.*
import ai.koog.agents.core.tools.annotations.InternalAgentToolsApi
import ai.koog.agents.features.common.config.FeatureConfig
import ai.koog.agents.utils.Closeable
import ai.koog.agents.core.agent.context.AIAgentLLMContext
import ai.koog.agents.core.dsl.builder.forwardTo
import ai.koog.agents.core.dsl.builder.strategy
import ai.koog.agents.core.dsl.extension.nodeExecuteTool
import ai.koog.agents.core.dsl.extension.nodeLLMRequest
import ai.koog.agents.core.dsl.extension.nodeLLMSendToolResult
import ai.koog.agents.core.dsl.extension.onAssistantMessage
import ai.koog.agents.core.dsl.extension.onToolCall
import ai.koog.prompt.dsl.prompt
import ai.koog.prompt.executor.model.PromptExecutor
import ai.koog.prompt.llm.LLModel
Expand Down Expand Up @@ -185,9 +183,9 @@ public open class AIAgent(
val preparedEnvironment = pipeline.transformEnvironment(strategy, this, this)

val agentContext = AIAgentContext(
preparedEnvironment,
environment = preparedEnvironment,
agentInput = agentInput,
agentConfig,
config = agentConfig,
llm = AIAgentLLMContext(
toolRegistry.tools.map { it.descriptor },
toolRegistry,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import ai.koog.agents.core.environment.AIAgentEnvironment
import ai.koog.agents.core.feature.AIAgentFeature
import ai.koog.agents.core.feature.AIAgentPipeline
import ai.koog.agents.core.tools.ToolDescriptor
import ai.koog.agents.core.utils.RWLock
import kotlin.uuid.ExperimentalUuidApi
import kotlin.uuid.Uuid

Expand All @@ -30,18 +31,65 @@ import kotlin.uuid.Uuid
* @param pipeline The AI agent pipeline responsible for coordinating AI agent execution and processing.
*/
@OptIn(ExperimentalUuidApi::class)
internal class AIAgentContext(
public class AIAgentContext(
override val environment: AIAgentEnvironment,
override val agentInput: String,
override val config: AIAgentConfigBase,
override val llm: AIAgentLLMContext,
override val stateManager: AIAgentStateManager,
override val storage: AIAgentStorage,
llm: AIAgentLLMContext,
stateManager: AIAgentStateManager,
storage: AIAgentStorage,
override val sessionUuid: Uuid,
override val strategyId: String,
@OptIn(InternalAgentsApi::class)
override val pipeline: AIAgentPipeline,
) : AIAgentContextBase {

/**
* Mutable wrapper for AI agent context properties.
*/
internal class MutableAIAgentContext(
var llm: AIAgentLLMContext,
var stateManager: AIAgentStateManager,
var storage: AIAgentStorage,
) {
private val rwLock = RWLock()

/**
* Creates a copy of the current [MutableAIAgentContext].
* @return A new instance of [MutableAIAgentContext] with copies of all mutable properties.
*/
suspend fun copy(): MutableAIAgentContext {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method is not used

return rwLock.withReadLock {
MutableAIAgentContext(llm.copy(), stateManager.copy(), storage.copy())
}
}

/**
* Replaces the current context with the provided context.
* @param llm The LLM context to replace the current context with.
* @param stateManager The state manager to replace the current context with.
* @param storage The storage to replace the current context with.
*/
suspend fun replace(llm: AIAgentLLMContext?, stateManager: AIAgentStateManager?, storage: AIAgentStorage?) {
rwLock.withWriteLock {
llm?.let { this.llm = llm }
stateManager?.let { this.stateManager = stateManager }
storage?.let { this.storage = storage }
}
}
}

private val mutableAIAgentContext = MutableAIAgentContext(llm, stateManager, storage)

override val llm: AIAgentLLMContext
get() = mutableAIAgentContext.llm
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, for all these properties there's no RW lock. Meaning, when the replace operation is in progress, it's still possible to read from these fields, potentially getting inconsistent data (although the chances are probably very low)


override val storage: AIAgentStorage
get() = mutableAIAgentContext.storage

override val stateManager: AIAgentStateManager
get() = mutableAIAgentContext.stateManager

/**
* A map storing features associated with the current AI agent context.
* The keys represent unique identifiers for specific features, defined as [AIAgentStorageKey].
Expand Down Expand Up @@ -119,4 +167,30 @@ internal class AIAgentContext(
strategyId = strategyId ?: this.strategyId,
pipeline = pipeline ?: @OptIn(InternalAgentsApi::class) this.pipeline,
)
}

/**
* Creates a copy of the current [AIAgentContext] with deep copies of all mutable properties.
*
* @return A new instance of [AIAgentContext] with copies of all mutable properties.
*/
override suspend fun fork(): AIAgentContextBase = copy(
llm = this.llm.copy(),
storage = this.storage.copy(),
stateManager = this.stateManager.copy(),
)

/**
* Replaces the current context with the provided context.
* This method is used to update the current context with values from another context,
* particularly useful in scenarios like parallel node execution where contexts need to be merged.
*
* @param context The context to replace the current context with.]]
*/
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a potential problem with this method that might confuse someone (it confused me initially, at least). From the method signature and according to KDocs it looks like the whole context will be replaced, but in fact it's not the case and only three properties are actually replaced.

override suspend fun replace(context: AIAgentContextBase) {
mutableAIAgentContext.replace(
context.llm,
context.stateManager,
context.storage
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ public interface AIAgentContextBase {
*/
public val storage: AIAgentStorage

// TODO: use Uuid?
/**
* A unique identifier for the current session associated with the AI agent context.
* Used to track and differentiate sessions within the execution of the agent pipeline.
Expand Down Expand Up @@ -111,6 +110,7 @@ public interface AIAgentContextBase {
@InternalAgentsApi
public val pipeline: AIAgentPipeline


/**
* Retrieves a feature from the agent's storage using the specified key.
*
Expand Down Expand Up @@ -180,4 +180,22 @@ public interface AIAgentContextBase {
strategyId: String? = null,
pipeline: AIAgentPipeline? = null,
): AIAgentContextBase
}

/**
* Creates a copy of the current [AIAgentContext] with deep copies of all mutable properties.
* This method is particularly useful in scenarios like parallel node execution
* where contexts need to be sent to different threads and then merged back together using [replace].
*
* @return A new instance of [AIAgentContext] with copies of all mutable properties.
*/
public suspend fun fork(): AIAgentContextBase

/**
* Replaces the current context with the provided context.
* This method is used to update the current context with values from another context,
* particularly useful in scenarios like parallel node execution where contexts need to be merged.
*
* @param context The context to replace the current context with.
*/
public suspend fun replace(context: AIAgentContextBase)
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,26 @@ public data class AIAgentLLMContext(
private val clock: Clock
) {

/**
* Creates a deep copy of this LLM context.
*
* @return A new instance of [AIAgentLLMContext] with deep copies of mutable properties.
*/
public suspend fun copy(): AIAgentLLMContext {
return rwLock.withReadLock {
AIAgentLLMContext(
tools.toList(),
toolRegistry,
prompt.copy(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no need to unnecessary copy prompt and tools, these are already fully immutable

model.copy(),
promptExecutor,
environment,
config,
clock
)
}
}

private val rwLock = RWLock()

/**
Expand All @@ -44,7 +64,8 @@ public data class AIAgentLLMContext(
*/
@OptIn(ExperimentalStdlibApi::class)
public suspend fun <T> writeSession(block: suspend AIAgentLLMWriteSession.() -> T): T = rwLock.withWriteLock {
val session = AIAgentLLMWriteSession(environment, promptExecutor, tools, toolRegistry, prompt, model, config, clock)
val session =
AIAgentLLMWriteSession(environment, promptExecutor, tools, toolRegistry, prompt, model, config, clock)

session.use {
val result = it.block()
Expand All @@ -68,4 +89,4 @@ public data class AIAgentLLMContext(

session.use { block(it) }
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ internal class AIAgentNode<Input, Output> internal constructor(
*
* @param Input The type of input data this node processes and produces as output.
*/
public open class StartAIAgentNodeBase<Input>() : AIAgentNodeBase<Input, Input>() {
public open class AIAgentStartNodeBase<Input>() : AIAgentNodeBase<Input, Input>() {
/**
* The name of the subgraph associated with the AI agent's starting node.
*
Expand All @@ -160,7 +160,7 @@ public open class StartAIAgentNodeBase<Input>() : AIAgentNodeBase<Input, Input>(
*
* @param Output The type of data this node processes and produces.
*/
public open class FinishAIAgentNodeBase<Output>() : AIAgentNodeBase<Output, Output>() {
public open class AIAgentFinishNodeBase<Output>() : AIAgentNodeBase<Output, Output>() {
/**
* Stores the name of the subgraph associated with the node.
*
Expand Down Expand Up @@ -196,7 +196,7 @@ public open class FinishAIAgentNodeBase<Output>() : AIAgentNodeBase<Output, Outp
* This node effectively passes its input as-is to the next node in the execution
* pipeline, allowing downstream nodes to transform or handle the data further.
*/
internal class StartNode internal constructor() : StartAIAgentNodeBase<String>()
internal class StartNode internal constructor() : AIAgentStartNodeBase<String>()

/**
* A specialized implementation of [FinishNode] that finalizes the execution of an AI agent subgraph.
Expand All @@ -210,5 +210,5 @@ internal class StartNode internal constructor() : StartAIAgentNodeBase<String>()
*
* This node is critical to denote the completion of localized processing within a subgraph context.
*/
internal class FinishNode internal constructor() : FinishAIAgentNodeBase<String>()
internal class FinishNode internal constructor() : AIAgentFinishNodeBase<String>()

Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ internal class AIAgentState internal constructor(
override fun close() {
isActive = false
}

internal fun copy(): AIAgentState {
return AIAgentState(
iterations = iterations
)
}
}

/**
Expand Down Expand Up @@ -44,4 +50,10 @@ public class AIAgentStateManager internal constructor(

result
}

internal suspend fun copy(): AIAgentStateManager {
return withStateLock {
AIAgentStateManager(state.copy())
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,17 @@
private val mutex = Mutex()
private val storage = mutableMapOf<AIAgentStorageKey<*>, Any>()

/**
* Creates a deep copy of this storage.
*
* @return A new instance of [AIAgentStorage] with the same content as this one.
*/
internal suspend fun copy(): AIAgentStorage {
val newStorage = AIAgentStorage()
newStorage.putAll(this.toMap())
return newStorage
}

/**
* Sets the value associated with the given key in the storage.
*
Expand Down Expand Up @@ -61,7 +72,7 @@
* @return The value associated with the key, of type [T].
* @throws NoSuchElementException if the key does not exist in the storage.
*/
public suspend fun <T : Any> getValue(key: AIAgentStorageKey<T>): T {

Check warning on line 75 in agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/entity/AIAgentStorage.kt

View workflow job for this annotation

GitHub Actions / Qodana for JVM

Check Kotlin and Java source code coverage

Method `getValue` coverage is below the threshold 50%
return get(key) ?: throw NoSuchElementException("Key $key not found in storage")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ import ai.koog.agents.core.utils.runCatchingCancellable
@OptIn(InternalAgentsApi::class)
public class AIAgentStrategy(
override val name: String,
public val nodeStart: StartAIAgentNodeBase<String>,
public val nodeFinish: FinishAIAgentNodeBase<String>,
public val nodeStart: AIAgentStartNodeBase<String>,
public val nodeFinish: AIAgentFinishNodeBase<String>,
toolSelectionStrategy: ToolSelectionStrategy
) : AIAgentSubgraph<String, String>(
name, nodeStart, nodeFinish, toolSelectionStrategy
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ import kotlin.uuid.ExperimentalUuidApi
*/
public open class AIAgentSubgraph<Input, Output>(
override val name: String,
public val start: StartAIAgentNodeBase<Input>,
public val finish: FinishAIAgentNodeBase<Output>,
public val start: AIAgentStartNodeBase<Input>,
public val finish: AIAgentFinishNodeBase<Output>,
private val toolSelectionStrategy: ToolSelectionStrategy,
) : AIAgentNodeBase<Input, Output>() {
private companion object {
Expand Down Expand Up @@ -99,7 +99,7 @@ public open class AIAgentSubgraph<Input, Output>(
"Invalid finish node output type: ${currentInput?.let { it::class.simpleName }}"
)
}
throw IllegalStateException("${FinishAIAgentNodeBase::class.simpleName} should always return String")
throw IllegalStateException("${AIAgentFinishNodeBase::class.simpleName} should always return String")
}
}

Expand Down
Loading
Loading