Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ public sealed class AIAgentLLMSession(
fixingModel: LLModel = OpenAIModels.Chat.GPT4o
): Result<StructuredResponse<T>> {
validateSession()
val preparedPrompt = preparePrompt(prompt, tools)
val preparedPrompt = preparePrompt(prompt, tools = emptyList())
return executor.executeStructured(preparedPrompt, model, structure, retries, fixingModel)
}

Expand All @@ -242,7 +242,7 @@ public sealed class AIAgentLLMSession(
*/
public open suspend fun <T> requestLLMStructuredOneShot(structure: StructuredData<T>): StructuredResponse<T> {
validateSession()
val preparedPrompt = preparePrompt(prompt, tools)
val preparedPrompt = preparePrompt(prompt, tools = emptyList())
return executor.executeStructuredOneShot(preparedPrompt, model, structure)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,19 +62,26 @@ public data class SafeTool<TArgs : ToolArgs, TResult : ToolResult>(
* Casts the current instance of `Result` to a `Success` type if it is a successful result.
*
* @return The current instance cast to `Success<TResult>`.
* @throws ClassCastException If the current instance is not of type `Success<TResult>`.
* @throws IllegalStateException if not [Success]
*/
public fun asSuccessful(): Success<TResult> = this as Success<TResult>
public fun asSuccessful(): Success<TResult> = when (this) {
is Success<TResult> -> this
is Failure<TResult> -> throw IllegalStateException("Result is not a success: $this")
}

/**
* Casts the current object to a `Failure` type.
*
* This function assumes that the calling instance is of type `Failure<TResult>`.
* Use it to retrieve the object as a `Failure` and access its specific properties and behaviors.
*
* @return The current instance cast to `Failure<TResult>`.
* @throws ClassCastException if the current instance is not of type `Failure<TResult>`.
* @throws IllegalStateException if not [Failure]
*/
public fun asFailure(): Failure<TResult> = this as Failure<TResult>
public fun asFailure(): Failure<TResult> = when (this) {
is Success<TResult> -> throw IllegalStateException("Result is not a failure: $this")
is Failure<TResult> -> this
}

/**
* Represents a successful result of an operation, wrapping a specific tool result and its corresponding content.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ import ai.koog.agents.core.dsl.builder.AIAgentSubgraphBuilderBase
import ai.koog.agents.core.dsl.builder.AIAgentSubgraphDelegateBase
import ai.koog.agents.core.dsl.builder.forwardTo
import ai.koog.agents.core.dsl.extension.*
import ai.koog.agents.core.environment.ReceivedToolResult
import ai.koog.agents.core.environment.SafeTool
import ai.koog.agents.core.environment.result
import ai.koog.agents.core.environment.toSafeResult
import ai.koog.agents.core.tools.*
import ai.koog.prompt.llm.LLModel
Expand Down Expand Up @@ -218,8 +220,15 @@ public fun <Input, ProvidedResult : SubgraphResult> AIAgentSubgraphBuilderBase<*
defineTask(input)
}

val finalizeTask by node<ProvidedResult, ProvidedResult> { input ->
val finalizeTask by node<ReceivedToolResult, ProvidedResult> { input ->
llm.writeSession {
// Append final tool call result to the prompt for further LLM calls to see it (otherwise they would fail)
updatePrompt {
tool {
result(input)
}
}

// Remove finish tool from tools
tools = tools - finishTool.descriptor

Expand All @@ -228,7 +237,7 @@ public fun <Input, ProvidedResult : SubgraphResult> AIAgentSubgraphBuilderBase<*
changeLLMParams(storage.getValue(origParamsKey))
}

input
input.toSafeResult<ProvidedResult>().asSuccessful().result
}

// Helper node to overcome problems of the current api and repeat less code when writing routing conditions
Expand All @@ -251,12 +260,7 @@ public fun <Input, ProvidedResult : SubgraphResult> AIAgentSubgraphBuilderBase<*
}
)

edge(
callTool forwardTo finalizeTask
onCondition { it.tool == finishTool.name }
// result should always be successful, otherwise throw
transformed { it.toSafeResult<ProvidedResult>().asSuccessful().result }
)
edge(callTool forwardTo finalizeTask onCondition { it.tool == finishTool.name })
edge(callTool forwardTo sendToolResult)

edge(sendToolResult forwardTo nodeDecide)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,6 @@ public class PromptBuilder internal constructor(
*
* @param call The tool call message to add
*/
@Deprecated("Use call(id, tool, content) instead", ReplaceWith("call(id, tool, content)"))
public fun call(call: Message.Tool.Call) {
[email protected](call)
}
Expand All @@ -250,7 +249,6 @@ public class PromptBuilder internal constructor(
*
* @param result The tool result message to add
*/
@Deprecated("Use result(id, tool, content) instead", ReplaceWith("result(id, tool, content)"))
public fun result(result: Message.Tool.Result) {
[email protected]
.indexOfLast { it is Message.Tool.Call && it.id == result.id }
Expand Down
Loading