Skip to content

Commit a733113

Browse files
Support providing custom LLM for fact retrieval in the history (#289)
1 parent f1c1e32 commit a733113

File tree

4 files changed

+54
-2
lines changed

4 files changed

+54
-2
lines changed

agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/dsl/extension/AIAgentLLMActions.kt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,15 @@ public fun AIAgentLLMWriteSession.leaveLastNMessages(n: Int) {
2424
prompt = prompt.withMessages { it.takeLast(n) }
2525
}
2626

27+
/**
28+
* Removes the last `n` messages from the current prompt in the write session.
29+
*
30+
* @param n The number of messages to remove from the end of the current message list.
31+
*/
32+
public fun AIAgentLLMWriteSession.dropLastNMessages(n: Int) {
33+
prompt = prompt.withMessages { it.dropLast(n) }
34+
}
35+
2736
/**
2837
* Removes all messages from the current session's prompt that have a timestamp
2938
* earlier than the specified timestamp.

agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/dsl/extension/AIAgentNodes.kt

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import ai.koog.prompt.message.Message
1818
import ai.koog.prompt.structure.StructuredData
1919
import ai.koog.prompt.structure.StructuredDataDefinition
2020
import ai.koog.prompt.structure.StructuredResponse
21+
import io.github.oshai.kotlinlogging.KotlinLogging.logger
2122
import kotlinx.coroutines.flow.Flow
2223

2324
/**
@@ -262,16 +263,26 @@ public fun AIAgentSubgraphBuilderBase<*, *>.nodeLLMRequestMultiple(name: String?
262263
*
263264
* @param name Optional node name.
264265
* @param strategy Determines which messages to include in compression.
266+
* @param retrievalModel An optional [LLModel] that will be used for retrieval of the facts from memory.
267+
* By default, the same model will be used as the current one in the agent's strategy.
265268
* @param preserveMemory Specifies whether to retain message memory after compression.
266269
*/
267270
@AIAgentBuilderDslMarker
268271
public fun <T> AIAgentSubgraphBuilderBase<*, *>.nodeLLMCompressHistory(
269272
name: String? = null,
270273
strategy: HistoryCompressionStrategy = HistoryCompressionStrategy.WholeHistory,
274+
retrievalModel: LLModel? = null,
271275
preserveMemory: Boolean = true
272276
): AIAgentNodeDelegate<T, T> = node(name) { input ->
273277
llm.writeSession {
278+
val initialModel = model
279+
if (retrievalModel != null) {
280+
model = retrievalModel
281+
}
282+
274283
replaceHistoryWithTLDR(strategy, preserveMemory)
284+
285+
model = initialModel
275286
}
276287

277288
input

agents/agents-features/agents-features-memory/src/commonMain/kotlin/ai/koog/agents/memory/feature/AgentMemory.kt

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,18 +281,29 @@ public class AgentMemory(
281281
* @param concept The concept to extract facts about
282282
* @param subject The subject categorization for the facts (e.g., User, Project)
283283
* @param scope The visibility scope for the facts (e.g., Agent, Feature, Product)
284+
* @param retrievalModel LLM that will be used for fact retrieval from the history (by default, the same model as the current one will be used)
284285
*/
285286
public suspend fun saveFactsFromHistory(
286287
concept: Concept,
287288
subject: MemorySubject,
288289
scope: MemoryScope,
290+
retrievalModel: LLModel? = null
289291
) {
290292
llm.writeSession {
293+
val initialModel = model
294+
if (retrievalModel != null) {
295+
model = retrievalModel
296+
logger.info { "Using model: ${retrievalModel.id}" }
297+
}
291298
val facts = retrieveFactsFromHistory(concept)
292299

293300
// Save facts to memory
294301
agentMemory.save(facts, subject, scope)
295302
logger.info { "Saved fact for concept '${concept.keyword}' in scope $scope: $facts" }
303+
if (retrievalModel != null) {
304+
model = initialModel
305+
logger.info { "Switching back to model: ${initialModel.id}" }
306+
}
296307
}
297308
}
298309

agents/agents-features/agents-features-memory/src/commonMain/kotlin/ai/koog/agents/memory/feature/nodes/MemoryNodes.kt

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,14 @@ package ai.koog.agents.memory.feature.nodes
33
import ai.koog.agents.core.dsl.builder.AIAgentBuilderDslMarker
44
import ai.koog.agents.core.dsl.builder.AIAgentNodeDelegate
55
import ai.koog.agents.core.dsl.builder.AIAgentSubgraphBuilderBase
6+
import ai.koog.agents.core.dsl.extension.dropLastNMessages
7+
import ai.koog.agents.core.dsl.extension.leaveLastNMessages
68
import ai.koog.agents.memory.config.MemoryScopeType
79
import ai.koog.agents.memory.feature.withMemory
810
import ai.koog.agents.memory.model.*
911
import ai.koog.agents.memory.prompts.MemoryPrompts
12+
import ai.koog.prompt.llm.LLModel
13+
import io.github.oshai.kotlinlogging.KotlinLogging.logger
1014
import kotlinx.serialization.Serializable
1115
import kotlinx.serialization.json.Json
1216

@@ -93,20 +97,23 @@ public fun <T> AIAgentSubgraphBuilderBase<*, *>.nodeLoadAllFactsFromMemory(
9397
* @param subject The subject scope of the memory (USER, PROJECT, etc.)
9498
* @param scope The scope of the memory (Agent, Feature, etc.)
9599
* @param concepts List of concepts to save in memory
100+
* @param retrievalModel LLM that will be used for fact retrieval from the history (by default, the same model as the current one will be used)
96101
*/
97102
@AIAgentBuilderDslMarker
98103
public fun <T> AIAgentSubgraphBuilderBase<*, *>.nodeSaveToMemory(
99104
name: String? = null,
100105
subject: MemorySubject,
101106
scope: MemoryScopeType,
102107
concepts: List<Concept>,
108+
retrievalModel: LLModel? = null
103109
): AIAgentNodeDelegate<T, T> = node(name) { input ->
104110
withMemory {
105111
concepts.forEach { concept ->
106112
saveFactsFromHistory(
107113
concept = concept,
108114
subject = subject,
109115
scope = scopesProfile.getScope(scope) ?: return@forEach,
116+
retrievalModel = retrievalModel
110117
)
111118
}
112119
}
@@ -120,14 +127,16 @@ public fun <T> AIAgentSubgraphBuilderBase<*, *>.nodeSaveToMemory(
120127
* @param subject The subject scope of the memory (USER, PROJECT, etc.)
121128
* @param scope The scope of the memory (Agent, Feature, etc.)
122129
* @param concept The concept to save in memory
130+
* @param retrievalModel LLM that will be used for fact retrieval from the history (by default, the same model as the current one will be used)
123131
*/
124132
@AIAgentBuilderDslMarker
125133
public fun <T> AIAgentSubgraphBuilderBase<*, *>.nodeSaveToMemory(
126134
name: String? = null,
127135
concept: Concept,
128136
subject: MemorySubject,
129137
scope: MemoryScopeType,
130-
): AIAgentNodeDelegate<T, T> = nodeSaveToMemory(name, subject, scope, listOf(concept))
138+
retrievalModel: LLModel? = null
139+
): AIAgentNodeDelegate<T, T> = nodeSaveToMemory(name, subject, scope, listOf(concept), retrievalModel)
131140

132141
/**
133142
* Node that automatically detects and extracts facts from the chat history and saves them to memory.
@@ -137,14 +146,21 @@ public fun <T> AIAgentSubgraphBuilderBase<*, *>.nodeSaveToMemory(
137146
* @param scopes List of memory scopes (Agent, Feature, etc.). By default only Agent scope would be chosen
138147
* @param subjects List of subjects (user, project, organization, etc.) to look for.
139148
* By default, all subjects will be included and looked for.
149+
* @param retrievalModel LLM that will be used for fact retrieval from the history (by default, the same model as the current one will be used)
140150
*/
141151
@AIAgentBuilderDslMarker
142152
public fun <T> AIAgentSubgraphBuilderBase<*, *>.nodeSaveToMemoryAutoDetectFacts(
143153
name: String? = null,
144154
scopes: List<MemoryScopeType> = listOf(MemoryScopeType.AGENT),
145-
subjects: List<MemorySubject> = MemorySubject.registeredSubjects
155+
subjects: List<MemorySubject> = MemorySubject.registeredSubjects,
156+
retrievalModel: LLModel? = null
146157
): AIAgentNodeDelegate<T, T> = node(name) { input ->
147158
llm.writeSession {
159+
val initialModel = model
160+
val initialPrompt = prompt.copy()
161+
if (retrievalModel != null) {
162+
model = retrievalModel
163+
}
148164
updatePrompt {
149165
val prompt = MemoryPrompts.autoDetectFacts(subjects)
150166
user(prompt)
@@ -160,6 +176,11 @@ public fun <T> AIAgentSubgraphBuilderBase<*, *>.nodeSaveToMemoryAutoDetectFacts(
160176
}
161177
}
162178
}
179+
180+
rewritePrompt { initialPrompt } // Revert the prompt to the original one
181+
if (retrievalModel != null) {
182+
model = initialModel
183+
}
163184
}
164185

165186
input

0 commit comments

Comments
 (0)