@@ -190,7 +190,7 @@ public open class GoogleLLMClient(
190
190
* @return A formatted GoogleAI request
191
191
*/
192
192
private fun createGoogleRequest (prompt : Prompt , model : LLModel , tools : List <ToolDescriptor >): GoogleRequest {
193
- val (systemMessages, convMessages) = prompt.messages.partition { it is Message . System }
193
+ val systemMessageParts = mutableListOf< GooglePart . Text >()
194
194
val contents = mutableListOf<GoogleContent >()
195
195
val pendingCalls = mutableListOf<GooglePart .FunctionCall >()
196
196
@@ -201,18 +201,55 @@ public open class GoogleLLMClient(
201
201
}
202
202
}
203
203
204
- convMessages.forEach { message ->
205
- if (message is Message .Tool .Call ) {
206
- pendingCalls + = GooglePart .FunctionCall (
207
- functionCall = GoogleData .FunctionCall (
208
- id = message.id,
209
- name = message.tool,
210
- args = json.decodeFromString(message.content)
204
+ for (message in prompt.messages) {
205
+ when (message) {
206
+ is Message .System -> {
207
+ systemMessageParts.add(GooglePart .Text (message.content))
208
+ }
209
+
210
+ is Message .User -> {
211
+ flushCalls()
212
+ // User messages become 'user' role content
213
+ contents.add(message.toGoogleContent(model))
214
+ }
215
+
216
+ is Message .Assistant -> {
217
+ flushCalls()
218
+ contents.add(
219
+ GoogleContent (
220
+ role = " model" ,
221
+ parts = listOf (GooglePart .Text (message.content))
222
+ )
211
223
)
212
- )
213
- } else {
214
- flushCalls()
215
- contents + = message.toGoogleContent(model) ? : return @forEach
224
+ }
225
+
226
+ is Message .Tool .Result -> {
227
+ flushCalls()
228
+ contents.add(
229
+ GoogleContent (
230
+ role = " user" ,
231
+ parts = listOf (
232
+ GooglePart .FunctionResponse (
233
+ functionResponse = GoogleData .FunctionResponse (
234
+ id = message.id,
235
+ name = message.tool,
236
+ response = buildJsonObject { put(" result" , message.content) }
237
+ )
238
+ )
239
+ )
240
+ )
241
+ )
242
+ }
243
+
244
+ is Message .Tool .Call -> {
245
+ pendingCalls + = GooglePart .FunctionCall (
246
+ functionCall = GoogleData .FunctionCall (
247
+ id = message.id,
248
+ name = message.tool,
249
+ args = json.decodeFromString(message.content)
250
+ )
251
+ )
252
+ }
216
253
}
217
254
}
218
255
flushCalls()
@@ -236,9 +273,9 @@ public open class GoogleLLMClient(
236
273
.takeIf { it.isNotEmpty() }
237
274
?.let { declarations -> listOf (GoogleTool (functionDeclarations = declarations)) }
238
275
239
- val googleSystemInstruction = systemMessages
276
+ val googleSystemInstruction = systemMessageParts
240
277
.takeIf { it.isNotEmpty() }
241
- ?.let { GoogleContent (parts = it.map { message -> GooglePart . Text (message.content) } ) }
278
+ ?.let { GoogleContent (parts = it) }
242
279
243
280
val generationConfig = GoogleGenerationConfig (
244
281
temperature = if (model.capabilities.contains(LLMCapability .Temperature )) prompt.params.temperature else null ,
@@ -269,84 +306,69 @@ public open class GoogleLLMClient(
269
306
)
270
307
}
271
308
272
- private fun Message.toGoogleContent (model : LLModel ): GoogleContent ? = when (this ) {
273
- is Message .User -> {
274
- val contentParts = buildList {
275
- if (content.isNotEmpty() || mediaContent.isEmpty()) {
276
- add(GooglePart .Text (content))
277
- }
278
- mediaContent.forEach { media ->
279
- when (media) {
280
- is MediaContent .Image -> {
281
- require(model.capabilities.contains(LLMCapability .Vision .Image )) {
282
- " Model ${model.id} does not support image"
283
- }
284
- if (media.isUrl()) {
285
- throw IllegalArgumentException (" URL images not supported for Gemini models" )
286
- }
287
- require(media.format in listOf (" png" , " jpg" , " jpeg" , " webp" , " heic" , " heif" )) {
288
- " Image format ${media.format} not supported"
289
- }
290
- add(
291
- GooglePart .InlineData (
292
- GoogleData .Blob (
293
- mimeType = media.getMimeType(),
294
- data = media.toBase64()
295
- )
309
+ private fun Message.User.toGoogleContent (model : LLModel ): GoogleContent {
310
+ val contentParts = buildList {
311
+ if (content.isNotEmpty() || mediaContent.isEmpty()) {
312
+ add(GooglePart .Text (content))
313
+ }
314
+ mediaContent.forEach { media ->
315
+ when (media) {
316
+ is MediaContent .Image -> {
317
+ require(model.capabilities.contains(LLMCapability .Vision .Image )) {
318
+ " Model ${model.id} does not support image"
319
+ }
320
+ if (media.isUrl()) {
321
+ throw IllegalArgumentException (" URL images not supported for Gemini models" )
322
+ }
323
+ require(media.format in listOf (" png" , " jpg" , " jpeg" , " webp" , " heic" , " heif" )) {
324
+ " Image format ${media.format} not supported"
325
+ }
326
+ add(
327
+ GooglePart .InlineData (
328
+ GoogleData .Blob (
329
+ mimeType = media.getMimeType(),
330
+ data = media.toBase64()
296
331
)
297
332
)
333
+ )
298
334
299
- }
335
+ }
300
336
301
- is MediaContent .Audio -> {
302
- require(model.capabilities.contains(LLMCapability .Audio )) {
303
- " Model ${model.id} does not support audio"
304
- }
305
- require(media.format in listOf (" wav" , " mp3" , " aiff" , " aac" , " ogg" , " flac" )) {
306
- " Audio format ${media.format} not supported"
307
- }
308
- add(GooglePart .InlineData (GoogleData .Blob (media.getMimeType(), media.toBase64())))
337
+ is MediaContent .Audio -> {
338
+ require(model.capabilities.contains(LLMCapability .Audio )) {
339
+ " Model ${model.id} does not support audio"
309
340
}
341
+ require(media.format in listOf (" wav" , " mp3" , " aiff" , " aac" , " ogg" , " flac" )) {
342
+ " Audio format ${media.format} not supported"
343
+ }
344
+ add(GooglePart .InlineData (GoogleData .Blob (media.getMimeType(), media.toBase64())))
345
+ }
310
346
311
- is MediaContent .File -> {
312
- if (media.isUrl()) {
313
- throw IllegalArgumentException (" URL files not supported for Gemini models" )
314
- }
315
- add(
316
- GooglePart .InlineData (
317
- GoogleData .Blob (
318
- mimeType = media.getMimeType(),
319
- data = media.toBase64()
320
- )
347
+ is MediaContent .File -> {
348
+ if (media.isUrl()) {
349
+ throw IllegalArgumentException (" URL files not supported for Gemini models" )
350
+ }
351
+ add(
352
+ GooglePart .InlineData (
353
+ GoogleData .Blob (
354
+ mimeType = media.getMimeType(),
355
+ data = media.toBase64()
321
356
)
322
357
)
323
- }
358
+ )
359
+ }
324
360
325
- is MediaContent .Video -> {
326
- require(model.capabilities.contains(LLMCapability .Vision .Video )) {
327
- " Model ${model.id} does not support video"
328
- }
329
- add(GooglePart .InlineData (GoogleData .Blob (media.getMimeType(), media.toBase64())))
361
+ is MediaContent .Video -> {
362
+ require(model.capabilities.contains(LLMCapability .Vision .Video )) {
363
+ " Model ${model.id} does not support video"
330
364
}
365
+ add(GooglePart .InlineData (GoogleData .Blob (media.getMimeType(), media.toBase64())))
331
366
}
332
367
}
333
368
}
334
- GoogleContent (role = " user" , parts = contentParts)
335
369
}
336
370
337
- is Message .Assistant -> GoogleContent (role = " model" , parts = listOf (GooglePart .Text (content)))
338
- is Message .Tool .Result -> GoogleContent (
339
- role = " user" ,
340
- parts = listOf (
341
- GooglePart .FunctionResponse (
342
- functionResponse = GoogleData .FunctionResponse (
343
- id = id, name = tool, response = buildJsonObject { put(" result" , content) })
344
- )
345
- )
346
- )
347
-
348
- is Message .Tool .Call -> null
349
- is Message .System -> null
371
+ return GoogleContent (role = " user" , parts = contentParts)
350
372
}
351
373
352
374
/* *
0 commit comments