Skip to content

Commit 6f0532f

Browse files
authored
[prompt] Fix LLM clients after #195, make LLM request construction again more explicit in LLM clients (#229)
1 parent 8a2d21f commit 6f0532f

File tree

6 files changed

+352
-266
lines changed
  • examples/src/main/kotlin/ai/koog/agents/example/media
  • prompt/prompt-executor/prompt-executor-clients
    • prompt-executor-anthropic-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/anthropic
    • prompt-executor-google-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/google
    • prompt-executor-openai-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openai
    • prompt-executor-openrouter-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openrouter

6 files changed

+352
-266
lines changed

examples/src/main/kotlin/ai/koog/agents/example/media/InstagramPostDescriber.kt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,11 @@ fun main() {
4747
}
4848

4949
runBlocking {
50+
println("OpenAI response:")
5051
openaiExecutor.execute(prompt, OpenAIModels.Chat.GPT4_1).content.also(::println)
52+
// println("Anthropic response:")
5153
// anthropicExecutor.execute(prompt, AnthropicModels.Sonnet_4).content.also(::println)
54+
// println("Google response:")
5255
// googleExecutor.execute(prompt, GoogleModels.Gemini2_0Flash).content.also(::println)
5356
}
5457
}

prompt/prompt-executor/prompt-executor-clients/prompt-executor-anthropic-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/anthropic/AnthropicLLMClient.kt

Lines changed: 35 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -190,13 +190,21 @@ public open class AnthropicLLMClient(
190190
model: LLModel,
191191
stream: Boolean
192192
): AnthropicMessageRequest {
193-
val (systemMessages, convMessages) = prompt.messages.partition { it is Message.System }
193+
val systemMessage = mutableListOf<SystemAnthropicMessage>()
194+
val messages = mutableListOf<AnthropicMessage>()
194195

195-
val messages = convMessages.fold(mutableListOf<AnthropicMessage>()) { acc, message ->
196+
for (message in prompt.messages) {
196197
when (message) {
197-
is Message.User -> acc.add(message.toAnthropicUserMessage(model))
198+
is Message.System -> {
199+
systemMessage.add(SystemAnthropicMessage(message.content))
200+
}
201+
202+
is Message.User -> {
203+
messages.add(message.toAnthropicUserMessage(model))
204+
}
205+
198206
is Message.Assistant -> {
199-
acc.add(
207+
messages.add(
200208
AnthropicMessage(
201209
role = "assistant",
202210
content = listOf(AnthropicContent.Text(message.content))
@@ -205,31 +213,35 @@ public open class AnthropicLLMClient(
205213
}
206214

207215
is Message.Tool.Result -> {
208-
val toolResult = AnthropicContent.ToolResult(
209-
toolUseId = message.id.orEmpty(),
210-
content = message.content
216+
messages.add(
217+
AnthropicMessage(
218+
role = "user",
219+
content = listOf(
220+
AnthropicContent.ToolResult(
221+
toolUseId = message.id ?: "",
222+
content = message.content
223+
)
224+
)
225+
)
211226
)
212-
acc.lastOrNull { it.role == "user" }?.let { lastUserMessage ->
213-
acc[acc.lastIndex] = lastUserMessage.copy(content = lastUserMessage.content + toolResult)
214-
} ?: acc.add(AnthropicMessage(role = "user", content = listOf(toolResult)))
215227
}
216228

217229
is Message.Tool.Call -> {
218-
val toolUse = AnthropicContent.ToolUse(
219-
id = message.id ?: Uuid.random().toString(),
220-
name = message.tool,
221-
input = Json.parseToJsonElement(message.content).jsonObject
230+
// Create a new assistant message with the tool call
231+
messages.add(
232+
AnthropicMessage(
233+
role = "assistant",
234+
content = listOf(
235+
AnthropicContent.ToolUse(
236+
id = message.id ?: Uuid.random().toString(),
237+
name = message.tool,
238+
input = Json.parseToJsonElement(message.content).jsonObject
239+
)
240+
)
241+
)
222242
)
223-
acc.lastOrNull { it.role == "assistant" }?.let { lastAssistantMessage ->
224-
acc[acc.lastIndex] = lastAssistantMessage.copy(content = lastAssistantMessage.content + toolUse)
225-
} ?: acc.add(AnthropicMessage(role = "assistant", content = listOf(toolUse)))
226-
}
227-
228-
is Message.System -> {
229-
logger.warn { "System messages already prepares for Anthropic. Ignoring: ${message.content}" }
230243
}
231244
}
232-
acc
233245
}
234246

235247
val anthropicTools = tools.map { tool ->
@@ -269,7 +281,7 @@ public open class AnthropicLLMClient(
269281
maxTokens = 2048, // This is required by the API
270282
// TODO why 0.7 and not 0.0?
271283
temperature = prompt.params.temperature ?: 0.7, // Default temperature if not provided
272-
system = systemMessages.map { SystemAnthropicMessage(it.content) },
284+
system = systemMessage,
273285
tools = if (tools.isNotEmpty()) anthropicTools else emptyList(), // Always provide a list for tools
274286
stream = stream,
275287
toolChoice = toolChoice,

prompt/prompt-executor/prompt-executor-clients/prompt-executor-google-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/google/GoogleLLMClient.kt

Lines changed: 99 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ public open class GoogleLLMClient(
190190
* @return A formatted GoogleAI request
191191
*/
192192
private fun createGoogleRequest(prompt: Prompt, model: LLModel, tools: List<ToolDescriptor>): GoogleRequest {
193-
val (systemMessages, convMessages) = prompt.messages.partition { it is Message.System }
193+
val systemMessageParts = mutableListOf<GooglePart.Text>()
194194
val contents = mutableListOf<GoogleContent>()
195195
val pendingCalls = mutableListOf<GooglePart.FunctionCall>()
196196

@@ -201,18 +201,55 @@ public open class GoogleLLMClient(
201201
}
202202
}
203203

204-
convMessages.forEach { message ->
205-
if (message is Message.Tool.Call) {
206-
pendingCalls += GooglePart.FunctionCall(
207-
functionCall = GoogleData.FunctionCall(
208-
id = message.id,
209-
name = message.tool,
210-
args = json.decodeFromString(message.content)
204+
for (message in prompt.messages) {
205+
when (message) {
206+
is Message.System -> {
207+
systemMessageParts.add(GooglePart.Text(message.content))
208+
}
209+
210+
is Message.User -> {
211+
flushCalls()
212+
// User messages become 'user' role content
213+
contents.add(message.toGoogleContent(model))
214+
}
215+
216+
is Message.Assistant -> {
217+
flushCalls()
218+
contents.add(
219+
GoogleContent(
220+
role = "model",
221+
parts = listOf(GooglePart.Text(message.content))
222+
)
211223
)
212-
)
213-
} else {
214-
flushCalls()
215-
contents += message.toGoogleContent(model) ?: return@forEach
224+
}
225+
226+
is Message.Tool.Result -> {
227+
flushCalls()
228+
contents.add(
229+
GoogleContent(
230+
role = "user",
231+
parts = listOf(
232+
GooglePart.FunctionResponse(
233+
functionResponse = GoogleData.FunctionResponse(
234+
id = message.id,
235+
name = message.tool,
236+
response = buildJsonObject { put("result", message.content) }
237+
)
238+
)
239+
)
240+
)
241+
)
242+
}
243+
244+
is Message.Tool.Call -> {
245+
pendingCalls += GooglePart.FunctionCall(
246+
functionCall = GoogleData.FunctionCall(
247+
id = message.id,
248+
name = message.tool,
249+
args = json.decodeFromString(message.content)
250+
)
251+
)
252+
}
216253
}
217254
}
218255
flushCalls()
@@ -236,9 +273,9 @@ public open class GoogleLLMClient(
236273
.takeIf { it.isNotEmpty() }
237274
?.let { declarations -> listOf(GoogleTool(functionDeclarations = declarations)) }
238275

239-
val googleSystemInstruction = systemMessages
276+
val googleSystemInstruction = systemMessageParts
240277
.takeIf { it.isNotEmpty() }
241-
?.let { GoogleContent(parts = it.map { message -> GooglePart.Text(message.content) }) }
278+
?.let { GoogleContent(parts = it) }
242279

243280
val generationConfig = GoogleGenerationConfig(
244281
temperature = if (model.capabilities.contains(LLMCapability.Temperature)) prompt.params.temperature else null,
@@ -269,84 +306,69 @@ public open class GoogleLLMClient(
269306
)
270307
}
271308

272-
private fun Message.toGoogleContent(model: LLModel): GoogleContent? = when (this) {
273-
is Message.User -> {
274-
val contentParts = buildList {
275-
if (content.isNotEmpty() || mediaContent.isEmpty()) {
276-
add(GooglePart.Text(content))
277-
}
278-
mediaContent.forEach { media ->
279-
when (media) {
280-
is MediaContent.Image -> {
281-
require(model.capabilities.contains(LLMCapability.Vision.Image)) {
282-
"Model ${model.id} does not support image"
283-
}
284-
if (media.isUrl()) {
285-
throw IllegalArgumentException("URL images not supported for Gemini models")
286-
}
287-
require(media.format in listOf("png", "jpg", "jpeg", "webp", "heic", "heif")) {
288-
"Image format ${media.format} not supported"
289-
}
290-
add(
291-
GooglePart.InlineData(
292-
GoogleData.Blob(
293-
mimeType = media.getMimeType(),
294-
data = media.toBase64()
295-
)
309+
private fun Message.User.toGoogleContent(model: LLModel): GoogleContent {
310+
val contentParts = buildList {
311+
if (content.isNotEmpty() || mediaContent.isEmpty()) {
312+
add(GooglePart.Text(content))
313+
}
314+
mediaContent.forEach { media ->
315+
when (media) {
316+
is MediaContent.Image -> {
317+
require(model.capabilities.contains(LLMCapability.Vision.Image)) {
318+
"Model ${model.id} does not support image"
319+
}
320+
if (media.isUrl()) {
321+
throw IllegalArgumentException("URL images not supported for Gemini models")
322+
}
323+
require(media.format in listOf("png", "jpg", "jpeg", "webp", "heic", "heif")) {
324+
"Image format ${media.format} not supported"
325+
}
326+
add(
327+
GooglePart.InlineData(
328+
GoogleData.Blob(
329+
mimeType = media.getMimeType(),
330+
data = media.toBase64()
296331
)
297332
)
333+
)
298334

299-
}
335+
}
300336

301-
is MediaContent.Audio -> {
302-
require(model.capabilities.contains(LLMCapability.Audio)) {
303-
"Model ${model.id} does not support audio"
304-
}
305-
require(media.format in listOf("wav", "mp3", "aiff", "aac", "ogg", "flac")) {
306-
"Audio format ${media.format} not supported"
307-
}
308-
add(GooglePart.InlineData(GoogleData.Blob(media.getMimeType(), media.toBase64())))
337+
is MediaContent.Audio -> {
338+
require(model.capabilities.contains(LLMCapability.Audio)) {
339+
"Model ${model.id} does not support audio"
309340
}
341+
require(media.format in listOf("wav", "mp3", "aiff", "aac", "ogg", "flac")) {
342+
"Audio format ${media.format} not supported"
343+
}
344+
add(GooglePart.InlineData(GoogleData.Blob(media.getMimeType(), media.toBase64())))
345+
}
310346

311-
is MediaContent.File -> {
312-
if (media.isUrl()) {
313-
throw IllegalArgumentException("URL files not supported for Gemini models")
314-
}
315-
add(
316-
GooglePart.InlineData(
317-
GoogleData.Blob(
318-
mimeType = media.getMimeType(),
319-
data = media.toBase64()
320-
)
347+
is MediaContent.File -> {
348+
if (media.isUrl()) {
349+
throw IllegalArgumentException("URL files not supported for Gemini models")
350+
}
351+
add(
352+
GooglePart.InlineData(
353+
GoogleData.Blob(
354+
mimeType = media.getMimeType(),
355+
data = media.toBase64()
321356
)
322357
)
323-
}
358+
)
359+
}
324360

325-
is MediaContent.Video -> {
326-
require(model.capabilities.contains(LLMCapability.Vision.Video)) {
327-
"Model ${model.id} does not support video"
328-
}
329-
add(GooglePart.InlineData(GoogleData.Blob(media.getMimeType(), media.toBase64())))
361+
is MediaContent.Video -> {
362+
require(model.capabilities.contains(LLMCapability.Vision.Video)) {
363+
"Model ${model.id} does not support video"
330364
}
365+
add(GooglePart.InlineData(GoogleData.Blob(media.getMimeType(), media.toBase64())))
331366
}
332367
}
333368
}
334-
GoogleContent(role = "user", parts = contentParts)
335369
}
336370

337-
is Message.Assistant -> GoogleContent(role = "model", parts = listOf(GooglePart.Text(content)))
338-
is Message.Tool.Result -> GoogleContent(
339-
role = "user",
340-
parts = listOf(
341-
GooglePart.FunctionResponse(
342-
functionResponse = GoogleData.FunctionResponse(
343-
id = id, name = tool, response = buildJsonObject { put("result", content) })
344-
)
345-
)
346-
)
347-
348-
is Message.Tool.Call -> null
349-
is Message.System -> null
371+
return GoogleContent(role = "user", parts = contentParts)
350372
}
351373

352374
/**

prompt/prompt-executor/prompt-executor-clients/prompt-executor-openai-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openai/DataModel.kt

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,13 @@
11
package ai.koog.prompt.executor.clients.openai
22

3-
import kotlinx.serialization.InternalSerializationApi
4-
import kotlinx.serialization.KSerializer
5-
import kotlinx.serialization.SerialName
6-
import kotlinx.serialization.Serializable
7-
import kotlinx.serialization.SerializationException
3+
import kotlinx.serialization.*
84
import kotlinx.serialization.builtins.ListSerializer
95
import kotlinx.serialization.descriptors.PolymorphicKind
106
import kotlinx.serialization.descriptors.SerialDescriptor
117
import kotlinx.serialization.descriptors.buildSerialDescriptor
128
import kotlinx.serialization.encoding.Decoder
139
import kotlinx.serialization.encoding.Encoder
14-
import kotlinx.serialization.json.JsonArray
15-
import kotlinx.serialization.json.JsonDecoder
16-
import kotlinx.serialization.json.JsonEncoder
17-
import kotlinx.serialization.json.JsonNull
18-
import kotlinx.serialization.json.JsonObject
19-
import kotlinx.serialization.json.JsonPrimitive
10+
import kotlinx.serialization.json.*
2011
import kotlin.jvm.JvmInline
2112

2213
@Serializable
@@ -25,12 +16,20 @@ internal data class OpenAIRequest(
2516
val messages: List<OpenAIMessage>,
2617
val temperature: Double? = null,
2718
val tools: List<OpenAITool>? = null,
28-
val modalities: List<String>? = null,
19+
val modalities: List<OpenAIModalities>? = null,
2920
val audio: OpenAIAudioConfig? = null,
3021
val stream: Boolean = false,
3122
val toolChoice: OpenAIToolChoice? = null
3223
)
3324

25+
@Serializable
26+
internal enum class OpenAIModalities {
27+
@SerialName("text")
28+
Text,
29+
@SerialName("audio")
30+
Audio,
31+
}
32+
3433
@Serializable
3534
internal data class OpenAIMessage(
3635
val role: String,
@@ -206,10 +205,24 @@ internal sealed interface OpenAIToolChoice {
206205

207206
@Serializable
208207
internal data class OpenAIAudioConfig(
209-
val format: String = "wav",
210-
val voice: String = "alloy"
208+
val format: OpenAIAudioFormat,
209+
val voice: OpenAIAudioVoice,
211210
)
212211

212+
@Serializable
213+
internal enum class OpenAIAudioFormat {
214+
@SerialName("wav")
215+
WAV,
216+
@SerialName("pcm16")
217+
PCM16,
218+
}
219+
220+
@Serializable
221+
internal enum class OpenAIAudioVoice {
222+
@SerialName("alloy")
223+
Alloy,
224+
}
225+
213226
@Serializable
214227
internal data class OpenAIAudio(
215228
val data: String,

0 commit comments

Comments
 (0)