Skip to content

Commit dc62da6

Browse files
Retry component (subgraph with retry)
1 parent a578e9f commit dc62da6

File tree

1 file changed

+163
-0
lines changed

1 file changed

+163
-0
lines changed
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
package ai.koog.agents.ext.agent
2+
3+
import ai.koog.agents.core.agent.context.AIAgentContextBase
4+
import ai.koog.agents.core.agent.entity.ToolSelectionStrategy
5+
import ai.koog.agents.core.agent.entity.createStorageKey
6+
import ai.koog.agents.core.dsl.builder.AIAgentSubgraphBuilderBase
7+
import ai.koog.agents.core.dsl.builder.AIAgentSubgraphDelegate
8+
import ai.koog.agents.core.dsl.builder.forwardTo
9+
10+
/**
11+
* Represents the result of [subgraphWithRetry].
12+
*
13+
* @param output The result of the subgraph operation.
14+
* @param success A boolean indicating whether the action was successful.
15+
* @param retryCount The number of retries attempted.
16+
*/
17+
public data class RetrySubgraphResult<Output>(
18+
val output: Output,
19+
val success: Boolean,
20+
val retryCount: Int,
21+
)
22+
23+
/**
24+
* Creates a subgraph with retry mechanism, allowing a specified action subgraph to be retried multiple
25+
* times until a given condition is met or the maximum number of retries is reached.
26+
*
27+
* @param condition A function that evaluates whether the output meets the desired condition.
28+
* @param maxRetries The maximum number of allowed retries. Must be greater than 0.
29+
* @param toolSelectionStrategy The strategy used to select a tool for executing the action.
30+
* @param name The optional name of the subgraph.
31+
* @param defineAction A lambda defining the action subgraph to perform within the retry subgraph.
32+
*/
33+
public fun <Input : Any, Output> AIAgentSubgraphBuilderBase<*, *>.subgraphWithRetry(
34+
condition: suspend (Output) -> Boolean,
35+
maxRetries: Int,
36+
toolSelectionStrategy: ToolSelectionStrategy = ToolSelectionStrategy.ALL,
37+
name: String? = null,
38+
defineAction: AIAgentSubgraphBuilderBase<Input, Output>.() -> Unit,
39+
): AIAgentSubgraphDelegate<Input, RetrySubgraphResult<Output>> {
40+
require(maxRetries > 0) { "maxRetries must be greater than 0" }
41+
42+
return subgraph(name = name) {
43+
val retriesKey = createStorageKey<Int>("${name}_retires")
44+
val initialInputKey = createStorageKey<Any>("${name}_initial_input")
45+
val initialContextKey = createStorageKey<AIAgentContextBase>("${name}_initial_context")
46+
47+
val beforeAction by node<Input, Input> { input ->
48+
val retries = storage.get(retriesKey) ?: 0
49+
50+
// Store initial input on the first run
51+
if (retries == 0) {
52+
storage.set(initialInputKey, input)
53+
} else {
54+
// return the initial context
55+
this.replace(storage.getValue(initialContextKey))
56+
}
57+
// store the initial context
58+
storage.set(initialContextKey, this.fork())
59+
60+
// Increment retries
61+
storage.set(retriesKey, retries + 1)
62+
63+
input
64+
}
65+
66+
val actionSubgraph by subgraph(
67+
name = "${name}_retryableAction",
68+
toolSelectionStrategy = toolSelectionStrategy,
69+
define = defineAction
70+
)
71+
72+
val decide by node<Output, RetrySubgraphResult<Output>> { output ->
73+
val retries = storage.getValue(retriesKey)
74+
val success = condition(output)
75+
76+
RetrySubgraphResult(
77+
output = output,
78+
success = success,
79+
retryCount = retries
80+
)
81+
}
82+
83+
val cleanup by node<RetrySubgraphResult<Output>, RetrySubgraphResult<Output>> { result ->
84+
storage.remove(retriesKey)
85+
storage.remove(initialInputKey)
86+
storage.remove(initialContextKey)
87+
result
88+
}
89+
90+
nodeStart then beforeAction then actionSubgraph then decide
91+
92+
// Repeat the action with initial input when condition is not met and the number of retries does not exceed max retries.
93+
edge(
94+
decide forwardTo beforeAction
95+
onCondition { result -> !result.success && result.retryCount < maxRetries }
96+
transformed {
97+
@Suppress("UNCHECKED_CAST")
98+
storage.getValue(initialInputKey) as Input
99+
}
100+
)
101+
102+
// Otherwise return the last iteration result.
103+
edge(
104+
decide forwardTo cleanup
105+
onCondition { result -> result.success || result.retryCount >= maxRetries }
106+
)
107+
108+
cleanup then nodeFinish
109+
}
110+
}
111+
112+
/**
113+
* Creates a subgraph that includes retry functionality based on a given condition and a maximum number of retries.
114+
* If the condition is not met after the specified retries and strict mode is enabled, an exception is thrown.
115+
* Unlike [subgraphWithRetry], this function directly returns the output value instead of a [RetrySubgraphResult].
116+
*
117+
* @param condition A suspendable function that determines whether the condition is met, based on the output.
118+
* @param maxRetries The maximum number of retries allowed if the condition is not met.
119+
* @param toolSelectionStrategy The strategy used to select tools for this subgraph.
120+
* @param strict If true, an exception is thrown if the condition is not met after the maximum retries.
121+
* @param name An optional name for the subgraph.
122+
* @param defineAction A lambda defining the actions within the subgraph.
123+
*
124+
* Example usage:
125+
* ```
126+
* val subgraphRetryCallLLM by subgraphWithRetrySimple(
127+
* condition = { it is Message.Tool.Call},
128+
* maxRetries = 2,
129+
* ) {
130+
* val nodeCallLLM by nodeLLMRequest("sendInput")
131+
* nodeStart then nodeCallLLM then nodeFinish
132+
* }
133+
* val nodeExecuteTool by nodeExecuteTool("nodeExecuteTool")
134+
* edge(subgraphRetryCallLLM forwardTo nodeExecuteTool onToolCall { true })
135+
* ```
136+
*/
137+
public fun <Input : Any, Output> AIAgentSubgraphBuilderBase<*, *>.subgraphWithRetrySimple(
138+
condition: suspend (Output) -> Boolean,
139+
maxRetries: Int,
140+
toolSelectionStrategy: ToolSelectionStrategy = ToolSelectionStrategy.ALL,
141+
strict: Boolean = true,
142+
name: String? = null,
143+
defineAction: AIAgentSubgraphBuilderBase<Input, Output>.() -> Unit,
144+
): AIAgentSubgraphDelegate<Input, Output> {
145+
return subgraph(name = name) {
146+
val retrySubgraph by subgraphWithRetry(
147+
toolSelectionStrategy = toolSelectionStrategy,
148+
condition = condition,
149+
maxRetries = maxRetries,
150+
name = name,
151+
defineAction = defineAction
152+
)
153+
154+
val extractResult by node<RetrySubgraphResult<Output>, Output> { result ->
155+
if (strict && !result.success) {
156+
throw IllegalStateException("Failed to meet condition after ${result.retryCount} retries")
157+
}
158+
result.output
159+
}
160+
161+
nodeStart then retrySubgraph then extractResult then nodeFinish
162+
}
163+
}

0 commit comments

Comments
 (0)