Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ import ai.koog.prompt.params.LLMParams
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.DelicateCoroutinesApi
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.test.runTest
Expand Down Expand Up @@ -78,8 +80,8 @@ internal class ReportingLLMLLMClient(
return underlyingClient.execute(prompt, model, tools)
}

override suspend fun executeStreaming(prompt: Prompt, model: LLModel): Flow<String> {
CoroutineScope(coroutineContext).launch {
override fun executeStreaming(prompt: Prompt, model: LLModel): Flow<String> = flow {
coroutineScope {
eventsChannel.send(
Event.Message(
llmClient = underlyingClient::class.simpleName ?: "null",
Expand All @@ -90,7 +92,8 @@ internal class ReportingLLMLLMClient(
)
)
}
return underlyingClient.executeStreaming(prompt, model)
underlyingClient.executeStreaming(prompt, model)
.collect(this)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,44 +142,42 @@ public open class AnthropicLLMClient(
}
}

override suspend fun executeStreaming(prompt: Prompt, model: LLModel): Flow<String> {
override fun executeStreaming(prompt: Prompt, model: LLModel): Flow<String> = flow {
logger.debug { "Executing streaming prompt: $prompt with model: $model without tools" }
require(model.capabilities.contains(LLMCapability.Completion)) {
"Model ${model.id} does not support chat completions"
}

val request = createAnthropicRequest(prompt, emptyList(), model, true)

return flow {
try {
httpClient.sse(
urlString = DEFAULT_MESSAGE_PATH,
request = {
method = HttpMethod.Post
accept(ContentType.Text.EventStream)
headers {
append(HttpHeaders.CacheControl, "no-cache")
append(HttpHeaders.Connection, "keep-alive")
}
setBody(request)
}
) {
incoming.collect { event ->
event
.takeIf { it.event == "content_block_delta" }
?.data?.trim()?.let { json.decodeFromString<AnthropicStreamResponse>(it) }
?.delta?.text?.let { emit(it) }
try {
httpClient.sse(
urlString = DEFAULT_MESSAGE_PATH,
request = {
method = HttpMethod.Post
accept(ContentType.Text.EventStream)
headers {
append(HttpHeaders.CacheControl, "no-cache")
append(HttpHeaders.Connection, "keep-alive")
}
setBody(request)
}
} catch (e: SSEClientException) {
e.response?.let { response ->
logger.error { "Error from Anthropic API: ${response.status}: ${e.message}" }
error("Error from Anthropic API: ${response.status}: ${e.message}")
) {
incoming.collect { event ->
event
.takeIf { it.event == "content_block_delta" }
?.data?.trim()?.let { json.decodeFromString<AnthropicStreamResponse>(it) }
?.delta?.text?.let { emit(it) }
}
} catch (e: Exception) {
logger.error { "Exception during streaming: $e" }
error(e.message ?: "Unknown error during streaming")
}
} catch (e: SSEClientException) {
e.response?.let { response ->
logger.error { "Error from Anthropic API: ${response.status}: ${e.message}" }
error("Error from Anthropic API: ${response.status}: ${e.message}")
}
} catch (e: Exception) {
logger.error { "Exception during streaming: $e" }
error(e.message ?: "Unknown error during streaming")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,46 +138,44 @@ public open class GoogleLLMClient(
}
}

override suspend fun executeStreaming(prompt: Prompt, model: LLModel): Flow<String> {
override fun executeStreaming(prompt: Prompt, model: LLModel): Flow<String> = flow {
logger.debug { "Executing streaming prompt: $prompt with model: $model" }
require(model.capabilities.contains(LLMCapability.Completion)) {
"Model ${model.id} does not support chat completions"
}

val request = createGoogleRequest(prompt, model, emptyList())

return flow {
try {
httpClient.sse(
urlString = "$DEFAULT_PATH/${model.id}:$DEFAULT_METHOD_STREAM_GENERATE_CONTENT",
request = {
method = HttpMethod.Post
parameter("alt", "sse")
accept(ContentType.Text.EventStream)
headers {
append(HttpHeaders.CacheControl, "no-cache")
append(HttpHeaders.Connection, "keep-alive")
}
setBody(request)
}
) {
incoming.collect { event ->
event
.takeIf { it.data != "[DONE]" }
?.data?.trim()?.let { json.decodeFromString<GoogleResponse>(it) }
?.candidates?.firstOrNull()?.content
?.parts?.forEach { part -> if (part is GooglePart.Text) emit(part.text) }
try {
httpClient.sse(
urlString = "$DEFAULT_PATH/${model.id}:$DEFAULT_METHOD_STREAM_GENERATE_CONTENT",
request = {
method = HttpMethod.Post
parameter("alt", "sse")
accept(ContentType.Text.EventStream)
headers {
append(HttpHeaders.CacheControl, "no-cache")
append(HttpHeaders.Connection, "keep-alive")
}
setBody(request)
}
} catch (e: SSEClientException) {
e.response?.let { response ->
logger.error { "Error from GoogleAI API: ${response.status}: ${e.message}" }
error("Error from GoogleAI API: ${response.status}: ${e.message}")
) {
incoming.collect { event ->
event
.takeIf { it.data != "[DONE]" }
?.data?.trim()?.let { json.decodeFromString<GoogleResponse>(it) }
?.candidates?.firstOrNull()?.content
?.parts?.forEach { part -> if (part is GooglePart.Text) emit(part.text) }
}
} catch (e: Exception) {
logger.error { "Exception during streaming: $e" }
error(e.message ?: "Unknown error during streaming")
}
} catch (e: SSEClientException) {
e.response?.let { response ->
logger.error { "Error from GoogleAI API: ${response.status}: ${e.message}" }
error("Error from GoogleAI API: ${response.status}: ${e.message}")
}
} catch (e: Exception) {
logger.error { "Exception during streaming: $e" }
error(e.message ?: "Unknown error during streaming")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ public class OllamaClient(
}
}

override suspend fun executeStreaming(
override fun executeStreaming(
prompt: Prompt, model: LLModel
): Flow<String> = flow {
require(model.provider == LLMProvider.Ollama) { "Model not supported by Ollama" }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,45 +145,43 @@ public open class OpenAILLMClient(
}
}

override suspend fun executeStreaming(prompt: Prompt, model: LLModel): Flow<String> {
override fun executeStreaming(prompt: Prompt, model: LLModel): Flow<String> = flow {
logger.debug { "Executing streaming prompt: $prompt with model: $model" }
require(model.capabilities.contains(LLMCapability.Completion)) {
"Model ${model.id} does not support chat completions"
}

val request = createOpenAIRequest(prompt, emptyList(), model, true)

return flow {
try {
httpClient.sse(
urlString = settings.chatCompletionsPath,
request = {
method = HttpMethod.Post
accept(ContentType.Text.EventStream)
headers {
append(HttpHeaders.CacheControl, "no-cache")
append(HttpHeaders.Connection, "keep-alive")
}
setBody(request)
}
) {
incoming.collect { event ->
event
.takeIf { it.data != "[DONE]" }
?.data?.trim()?.let { json.decodeFromString<OpenAIStreamResponse>(it) }
?.choices?.forEach { choice -> choice.delta.content?.let { emit(it) } }
try {
httpClient.sse(
urlString = settings.chatCompletionsPath,
request = {
method = HttpMethod.Post
accept(ContentType.Text.EventStream)
headers {
append(HttpHeaders.CacheControl, "no-cache")
append(HttpHeaders.Connection, "keep-alive")
}
setBody(request)
}
} catch (e: SSEClientException) {
e.response?.let { response ->
val body = response.readRawBytes().decodeToString()
logger.error(e) { "Error from OpenAI API: ${response.status}: ${e.message}.\nBody:\n$body" }
error("Error from OpenAI API: ${response.status}: ${e.message}")
) {
incoming.collect { event ->
event
.takeIf { it.data != "[DONE]" }
?.data?.trim()?.let { json.decodeFromString<OpenAIStreamResponse>(it) }
?.choices?.forEach { choice -> choice.delta.content?.let { emit(it) } }
}
} catch (e: Exception) {
logger.error { "Exception during streaming: $e" }
error(e.message ?: "Unknown error during streaming")
}
} catch (e: SSEClientException) {
e.response?.let { response ->
val body = response.readRawBytes().decodeToString()
logger.error(e) { "Error from OpenAI API: ${response.status}: ${e.message}.\nBody:\n$body" }
error("Error from OpenAI API: ${response.status}: ${e.message}")
}
} catch (e: Exception) {
logger.error { "Exception during streaming: $e" }
error(e.message ?: "Unknown error during streaming")
}
}

Expand Down Expand Up @@ -330,7 +328,10 @@ public open class OpenAILLMClient(
null -> null
}

val modalities = if (model.capabilities.contains(LLMCapability.Audio)) listOf(OpenAIModalities.Text, OpenAIModalities.Audio) else null
val modalities = if (model.capabilities.contains(LLMCapability.Audio)) listOf(
OpenAIModalities.Text,
OpenAIModalities.Audio
) else null
// TODO allow passing this externally and actually controlling this behavior
val audio = modalities?.let {
OpenAIAudioConfig(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,44 +123,42 @@ public class OpenRouterLLMClient(
}
}

override suspend fun executeStreaming(prompt: Prompt, model: LLModel): Flow<String> {
override fun executeStreaming(prompt: Prompt, model: LLModel): Flow<String> = flow {
logger.debug { "Executing streaming prompt: $prompt" }
require(model.capabilities.contains(LLMCapability.Completion)) {
"Model ${model.id} does not support chat completions"
}

val request = createOpenRouterRequest(prompt, model, emptyList(), true)

return flow {
try {
httpClient.sse(
urlString = DEFAULT_MESSAGE_PATH,
request = {
method = HttpMethod.Post
accept(ContentType.Text.EventStream)
headers {
append(HttpHeaders.CacheControl, "no-cache")
append(HttpHeaders.Connection, "keep-alive")
}
setBody(request)
}
) {
incoming.collect { event ->
event
.takeIf { it.data != "[DONE]" }
?.data?.trim()?.let { json.decodeFromString<OpenRouterStreamResponse>(it) }
?.choices?.forEach { choice -> choice.delta.content?.let { emit(it) } }
try {
httpClient.sse(
urlString = DEFAULT_MESSAGE_PATH,
request = {
method = HttpMethod.Post
accept(ContentType.Text.EventStream)
headers {
append(HttpHeaders.CacheControl, "no-cache")
append(HttpHeaders.Connection, "keep-alive")
}
setBody(request)
}
} catch (e: SSEClientException) {
e.response?.let { response ->
logger.error { "Error from OpenRouter API: ${response.status}: ${e.message}" }
error("Error from OpenRouter API: ${response.status}: ${e.message}")
) {
incoming.collect { event ->
event
.takeIf { it.data != "[DONE]" }
?.data?.trim()?.let { json.decodeFromString<OpenRouterStreamResponse>(it) }
?.choices?.forEach { choice -> choice.delta.content?.let { emit(it) } }
}
} catch (e: Exception) {
logger.error { "Exception during streaming: $e" }
error(e.message ?: "Unknown error during streaming")
}
} catch (e: SSEClientException) {
e.response?.let { response ->
logger.error { "Error from OpenRouter API: ${response.status}: ${e.message}" }
error("Error from OpenRouter API: ${response.status}: ${e.message}")
}
} catch (e: Exception) {
logger.error { "Exception during streaming: $e" }
error(e.message ?: "Unknown error during streaming")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public interface LLMClient {
* @param model The LLM model to use
* @return Flow of response chunks
*/
public suspend fun executeStreaming(prompt: Prompt, model: LLModel): Flow<String>
public fun executeStreaming(prompt: Prompt, model: LLModel): Flow<String>
}

public data class ConnectionTimeoutConfig(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class MultipleLLMPromptExecutorMockTest {
return listOf(Message.Assistant("OpenAI response", ResponseMetaInfo.create(mockClock)))
}

override suspend fun executeStreaming(prompt: Prompt, model: LLModel): Flow<String> {
override fun executeStreaming(prompt: Prompt, model: LLModel): Flow<String> {
return flowOf("OpenAI", " streaming", " response")
}
}
Expand All @@ -57,7 +57,7 @@ class MultipleLLMPromptExecutorMockTest {
return listOf(Message.Assistant("Anthropic response", ResponseMetaInfo.create(mockClock)))
}

override suspend fun executeStreaming(prompt: Prompt, model: LLModel): Flow<String> {
override fun executeStreaming(prompt: Prompt, model: LLModel): Flow<String> {
return flowOf("Anthropic", " streaming", " response")
}
}
Expand All @@ -72,7 +72,7 @@ class MultipleLLMPromptExecutorMockTest {
return listOf(Message.Assistant("Gemini response", ResponseMetaInfo.create(mockClock)))
}

override suspend fun executeStreaming(prompt: Prompt, model: LLModel): Flow<String> {
override fun executeStreaming(prompt: Prompt, model: LLModel): Flow<String> {
return flowOf("Gemini", " streaming", " response")
}
}
Expand Down
Loading
Loading