Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,32 +23,32 @@
* @property results The results of the parallel node executions
*/
@OptIn(InternalAgentsApi::class)
public class AIAgentParallelNodesMergeContext<Input, Output>(

Check warning on line 26 in agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/dsl/builder/AIAgentParallelNodesMergeContext.kt

View workflow job for this annotation

GitHub Actions / Qodana for JVM

Check Kotlin and Java source code coverage

Class `AIAgentParallelNodesMergeContext` coverage is below the threshold 50%
private val underlyingContextBase: AIAgentContextBase,
public val results: List<ParallelResult<Input, Output>>
) : AIAgentContextBase {
// Delegate all properties to the underlying context
override val environment: AIAgentEnvironment get() = underlyingContextBase.environment

Check warning on line 31 in agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/dsl/builder/AIAgentParallelNodesMergeContext.kt

View workflow job for this annotation

GitHub Actions / Qodana for JVM

Check Kotlin and Java source code coverage

Method `getEnvironment` coverage is below the threshold 50%
override val agentInput: Any? get() = underlyingContextBase.agentInput

Check warning on line 32 in agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/dsl/builder/AIAgentParallelNodesMergeContext.kt

View workflow job for this annotation

GitHub Actions / Qodana for JVM

Check Kotlin and Java source code coverage

Method `getAgentInput` coverage is below the threshold 50%
override val config: AIAgentConfigBase get() = underlyingContextBase.config

Check warning on line 33 in agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/dsl/builder/AIAgentParallelNodesMergeContext.kt

View workflow job for this annotation

GitHub Actions / Qodana for JVM

Check Kotlin and Java source code coverage

Method `getConfig` coverage is below the threshold 50%
override val llm: AIAgentLLMContext get() = underlyingContextBase.llm
override val stateManager: AIAgentStateManager get() = underlyingContextBase.stateManager
override val storage: AIAgentStorage get() = underlyingContextBase.storage
override val runId: String get() = underlyingContextBase.runId

Check warning on line 37 in agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/dsl/builder/AIAgentParallelNodesMergeContext.kt

View workflow job for this annotation

GitHub Actions / Qodana for JVM

Check Kotlin and Java source code coverage

Method `getRunId` coverage is below the threshold 50%
override val strategyName: String get() = underlyingContextBase.strategyName

Check warning on line 38 in agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/dsl/builder/AIAgentParallelNodesMergeContext.kt

View workflow job for this annotation

GitHub Actions / Qodana for JVM

Check Kotlin and Java source code coverage

Method `getStrategyName` coverage is below the threshold 50%
override val pipeline: AIAgentPipeline get() = underlyingContextBase.pipeline

Check warning on line 39 in agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/dsl/builder/AIAgentParallelNodesMergeContext.kt

View workflow job for this annotation

GitHub Actions / Qodana for JVM

Check Kotlin and Java source code coverage

Method `getPipeline` coverage is below the threshold 50%

// Delegate all methods to the underlying context
override fun <Feature : Any> feature(key: AIAgentStorageKey<Feature>): Feature? =

Check warning on line 42 in agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/dsl/builder/AIAgentParallelNodesMergeContext.kt

View workflow job for this annotation

GitHub Actions / Qodana for JVM

Check Kotlin and Java source code coverage

Method `feature` coverage is below the threshold 50%
underlyingContextBase.feature(key)

override fun <Feature : Any> feature(feature: AIAgentFeature<*, Feature>): Feature? =

Check warning on line 45 in agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/dsl/builder/AIAgentParallelNodesMergeContext.kt

View workflow job for this annotation

GitHub Actions / Qodana for JVM

Check Kotlin and Java source code coverage

Method `feature` coverage is below the threshold 50%
underlyingContextBase.feature(feature)

override fun <Feature : Any> featureOrThrow(feature: AIAgentFeature<*, Feature>): Feature =

Check warning on line 48 in agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/dsl/builder/AIAgentParallelNodesMergeContext.kt

View workflow job for this annotation

GitHub Actions / Qodana for JVM

Check Kotlin and Java source code coverage

Method `featureOrThrow` coverage is below the threshold 50%
underlyingContextBase.featureOrThrow(feature)

override fun copy(

Check warning on line 51 in agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/dsl/builder/AIAgentParallelNodesMergeContext.kt

View workflow job for this annotation

GitHub Actions / Qodana for JVM

Check Kotlin and Java source code coverage

Method `copy` coverage is below the threshold 50%
environment: AIAgentEnvironment,
agentInput: Any?,
config: AIAgentConfigBase,
Expand All @@ -70,9 +70,9 @@
pipeline = pipeline
)

override suspend fun fork(): AIAgentContextBase = underlyingContextBase.fork()

Check warning on line 73 in agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/dsl/builder/AIAgentParallelNodesMergeContext.kt

View workflow job for this annotation

GitHub Actions / Qodana for JVM

Check Kotlin and Java source code coverage

Method `fork` coverage is below the threshold 50%

override suspend fun replace(context: AIAgentContextBase): Unit = underlyingContextBase.replace(context)

Check warning on line 75 in agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/dsl/builder/AIAgentParallelNodesMergeContext.kt

View workflow job for this annotation

GitHub Actions / Qodana for JVM

Check Kotlin and Java source code coverage

Method `replace` coverage is below the threshold 50%

/**
* Selects a result based on a predicate.
Expand All @@ -82,7 +82,7 @@
* @throws NoSuchElementException if no result matches the predicate
*/
public suspend fun selectBy(predicate: suspend (Output) -> Boolean): NodeExecutionResult<Output> {
return results.first(predicate = { predicate(it.result.output) }).result
return results.first(predicate = { predicate(it.nodeResult.output) }).nodeResult
}

/**
Expand All @@ -96,8 +96,8 @@
* @throws NoSuchElementException if the results list is empty.
*/
public suspend fun <T : Comparable<T>> selectByMax(function: suspend (Output) -> T): NodeExecutionResult<Output> {
return results.maxBy { function(it.result.output) }
.let { NodeExecutionResult(it.result.output, it.result.context) }
return results.maxBy { function(it.nodeResult.output) }
.let { NodeExecutionResult(it.nodeResult.output, it.nodeResult.context) }
}

/**
Expand All @@ -108,8 +108,8 @@
* @throws IndexOutOfBoundsException if the index returned by the selectIndex function is out of bounds.
*/
public suspend fun selectByIndex(selectIndex: suspend (List<Output>) -> Int): NodeExecutionResult<Output> {
val indexOfBest = selectIndex(results.map { it.result.output })
return NodeExecutionResult(results[indexOfBest].result.output, results[indexOfBest].result.context)
val indexOfBest = selectIndex(results.map { it.nodeResult.output })
return NodeExecutionResult(results[indexOfBest].nodeResult.output, results[indexOfBest].nodeResult.context)
}

/**
Expand All @@ -124,7 +124,7 @@
initial: R,
operation: suspend (acc: R, result: Output) -> R
): NodeExecutionResult<R> {
val folded = results.map { it.result.output }.fold(initial) { r, t -> operation(r, t) }
val folded = results.map { it.nodeResult.output }.fold(initial) { r, t -> operation(r, t) }
return NodeExecutionResult(folded, underlyingContextBase)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ public abstract class AIAgentSubgraphBuilderBase<Input, Output> {
* @param Input The type of input data that this starting node processes.
*/
public abstract val nodeStart: StartNode<Input>

/**
* Represents the "finish" node in the AI agent's subgraph structure. This node indicates
* the endpoint of the subgraph and acts as a terminal stage where the workflow stops.
Expand Down Expand Up @@ -71,7 +72,12 @@ public abstract class AIAgentSubgraphBuilderBase<Input, Output> {
llmParams: LLMParams? = null,
define: AIAgentSubgraphBuilderBase<Input, Output>.() -> Unit
): AIAgentSubgraphDelegate<Input, Output> {
return AIAgentSubgraphBuilder<Input, Output>(name, toolSelectionStrategy, llmModel, llmParams).also { it.define() }.build()
return AIAgentSubgraphBuilder<Input, Output>(
name,
toolSelectionStrategy,
llmModel,
llmParams
).also { it.define() }.build()
}

/**
Expand Down Expand Up @@ -107,41 +113,15 @@ public abstract class AIAgentSubgraphBuilderBase<Input, Output> {
* @param nodes List of nodes to execute in parallel
* @param dispatcher Coroutine dispatcher to use for parallel execution
* @param name Optional node name
* @param merge A suspendable lambda that defines how the outputs from the parallel nodes should be merged
*/
public fun <Input, Output> parallel(
vararg nodes: AIAgentNodeBase<Input, Output>,
dispatcher: CoroutineDispatcher = Dispatchers.Default,
name: String? = null,
): AIAgentNodeDelegate<Input, List<AsyncParallelResult<Input, Output>>> {
return AIAgentNodeDelegate(name, AIAgentParallelNodeBuilder(nodes.asList(), dispatcher))
}

/**
* Creates a node that applies a transform function to the output of parallel node executions.
*
* @param name Optional name for the node. If not provided, the property name of the delegate will be used.
* @param dispatcher The coroutine dispatcher used for executing the transform function. Defaults to `Dispatchers.Default`.
* @param transform A suspendable function defining the transformation logic. It processes each `OldOutput` and produces a `NewOutput`.
* @return A delegate representing the node with the transformed parallel results.
*/
public fun <Input, OldOutput, NewOutput> transform(
name: String? = null,
dispatcher: CoroutineDispatcher = Dispatchers.Default,
transform: suspend AIAgentContextBase.(OldOutput) -> NewOutput,
): AIAgentNodeDelegate<List<AsyncParallelResult<Input, OldOutput>>, List<AsyncParallelResult<Input, NewOutput>>> {
return AIAgentNodeDelegate(name, AIAgentParallelTransformNodeBuilder(transform, dispatcher))
}

/**
* Creates a node that merges the results of the forked nodes.
* @param execute Function to merge the contexts and outputs after parallel execution
* @param name Optional node name
*/
public fun <Input, Output> merge(
name: String? = null,
execute: suspend AIAgentParallelNodesMergeContext<Input, Output>.() -> NodeExecutionResult<Output>,
): AIAgentNodeDelegate<List<AsyncParallelResult<Input, Output>>, Output> {
return AIAgentNodeDelegate(name, AIAgentParallelMergeNodeBuilder(execute))
merge: suspend AIAgentParallelNodesMergeContext<Input, Output>.() -> NodeExecutionResult<Output>,
): AIAgentNodeDelegate<Input, Output> {
return AIAgentNodeDelegate(name, AIAgentParallelNodeBuilder(nodes.asList(), merge, dispatcher))
}

/**
Expand Down Expand Up @@ -265,112 +245,68 @@ public open class AIAgentSubgraphDelegate<Input, Output> internal constructor(
}
}


/**
* Output and context of parallel node execution.
* Represents the result of a parallel node execution, containing both the output value and the execution context.
*
* This class is used to capture the complete state of a node's execution, including both the
* produced output value and the context in which it was executed. This allows for both the result
* and any side effects or state changes to be preserved and utilized in subsequent operations.
*
* @param Output The type of the output value produced by the node execution.
* @property output The output value produced by the node execution.
* @property context The agent context in which the node was executed, containing any state changes.
*/
public data class NodeExecutionResult<Output>(val output: Output, val context: AIAgentContextBase)

/**
* Async result of parallel node execution.
* Represents the completed result of a parallel node execution.
*
* @property nodeName Name of the node
* @property input Input to the node
* @property asyncResult Output and context of the parallel pipeline step
*/
public data class AsyncParallelResult<Input, Output>(
val nodeName: String,
val input: Input,
val asyncResult: Deferred<NodeExecutionResult<Output>>
) {
/**
* Awaits for the asynchronous execution of a parallel node and converts it into a [ParallelResult].
*
* @return A [ParallelResult] instance that contains the node's name, its input, and the result of its execution.
*/
public suspend fun await(): ParallelResult<Input, Output> {
return ParallelResult(nodeName, input, asyncResult.await())
}
}

/**
* Result of parallel node execution.
* This class encapsulates the final state of a node that was executed as part of a parallel
* execution strategy. It contains the node's name, the input that was provided to it, and the
* final execution result including both output and context.
*
* @property nodeName Name of the node
* @property input Input to the node
* @property result Output and context of the node on the parallel pipeline termination state
* @param Input The type of input that was provided to the node.
* @param Output The type of output produced by the node.
* @property nodeName The name of the node that was executed.
* @property nodeInput The input value that was provided to the node.
* @property nodeResult The final execution result containing both output and context.
*/
public data class ParallelResult<Input, Output>(
val nodeName: String,
val input: Input,
val result: NodeExecutionResult<Output>
val nodeInput: Input,
val nodeResult: NodeExecutionResult<Output>
)


/**
* Builder for a node that executes multiple nodes in parallel.
*
* @param nodes List of nodes to execute in parallel
* @param merge A suspendable lambda that defines how the outputs from the parallel nodes should be merged
* @param dispatcher Coroutine dispatcher to use for parallel execution
*/
public class AIAgentParallelNodeBuilder<Input, Output> internal constructor(
private val nodes: List<AIAgentNodeBase<Input, Output>>,
private val merge: suspend AIAgentParallelNodesMergeContext<Input, Output>.() -> NodeExecutionResult<Output>,
private val dispatcher: CoroutineDispatcher
) : AIAgentNodeBuilder<Input, List<AsyncParallelResult<Input, Output>>>(
) : AIAgentNodeBuilder<Input, Output>(
execute = { input ->
val initialContext: AIAgentContextBase = this
val mapResults = supervisorScope {

// Execute all nodes in parallel using the provided dispatcher
val nodeResults = supervisorScope {
nodes.map { node ->
val asyncResult = async(dispatcher) {
async(dispatcher) {
val nodeContext = initialContext.fork()
val result = node.execute(nodeContext, input)
NodeExecutionResult(result, nodeContext)
}
AsyncParallelResult(node.name, input, asyncResult)
}
}
mapResults
}
)


/**
* Builder for constructing a parallel outputs transformation node.
*
* @param transform A suspend function defining the transformation logic to be applied to the elements in the output list.
* @param dispatcher The [CoroutineDispatcher] used to control the parallel execution of the transformation operations.
*/
public class AIAgentParallelTransformNodeBuilder<Input, OldOutput, NewOutput> internal constructor(
transform: suspend AIAgentContextBase.(OldOutput) -> NewOutput,
private val dispatcher: CoroutineDispatcher
) : AIAgentNodeBuilder<List<AsyncParallelResult<Input, OldOutput>>, List<AsyncParallelResult<Input, NewOutput>>>(
execute = { input ->
val transformedResults = supervisorScope {
input.map {
val asyncResult = async(dispatcher) {
val result = it.asyncResult.await()
with(result.context) {
NodeExecutionResult(transform(result.output), this@with)
}
val nodeOutput = node.execute(nodeContext, input)
val executionResult = NodeExecutionResult(nodeOutput, nodeContext)
ParallelResult(node.name, input, executionResult)
}
AsyncParallelResult(it.nodeName, it.input, asyncResult)
}
}.awaitAll()
}
transformedResults
}
)

/**
* Builder for a node that merges the parallel tool results.
*
* @param merge Function to merge the contexts after parallel execution
*/
public class AIAgentParallelMergeNodeBuilder<Input, Output> internal constructor(
private val merge: suspend AIAgentParallelNodesMergeContext<Input, Output>.() -> NodeExecutionResult<Output>,
) : AIAgentNodeBuilder<List<AsyncParallelResult<Input, Output>>, Output>(
execute = { input ->
val parallelResults = input.map { it.await() }
val mergeContext = AIAgentParallelNodesMergeContext(this, parallelResults)
// Merge parallel node results
val mergeContext = AIAgentParallelNodesMergeContext(this, nodeResults)
val result = with(mergeContext) { merge() }
this.replace(result.context)

Expand Down
Loading
Loading