Skip to content

Commit ebd08c8

Browse files
nomisRevBruno Lannoo
authored andcommitted
Removed suspend modifier from LLMClient.executeStreaming (JetBrains#240)
1 parent 7042544 commit ebd08c8

File tree

9 files changed

+121
-123
lines changed

9 files changed

+121
-123
lines changed

integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/KotlinAIAgentWithMultipleLLMIntegrationTest.kt

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ import ai.koog.prompt.params.LLMParams
3232
import kotlinx.coroutines.CoroutineScope
3333
import kotlinx.coroutines.DelicateCoroutinesApi
3434
import kotlinx.coroutines.channels.Channel
35+
import kotlinx.coroutines.coroutineScope
3536
import kotlinx.coroutines.flow.Flow
37+
import kotlinx.coroutines.flow.flow
3638
import kotlinx.coroutines.launch
3739
import kotlinx.coroutines.runBlocking
3840
import kotlinx.coroutines.test.runTest
@@ -78,8 +80,8 @@ internal class ReportingLLMLLMClient(
7880
return underlyingClient.execute(prompt, model, tools)
7981
}
8082

81-
override suspend fun executeStreaming(prompt: Prompt, model: LLModel): Flow<String> {
82-
CoroutineScope(coroutineContext).launch {
83+
override fun executeStreaming(prompt: Prompt, model: LLModel): Flow<String> = flow {
84+
coroutineScope {
8385
eventsChannel.send(
8486
Event.Message(
8587
llmClient = underlyingClient::class.simpleName ?: "null",
@@ -90,7 +92,8 @@ internal class ReportingLLMLLMClient(
9092
)
9193
)
9294
}
93-
return underlyingClient.executeStreaming(prompt, model)
95+
underlyingClient.executeStreaming(prompt, model)
96+
.collect(this)
9497
}
9598
}
9699

prompt/prompt-executor/prompt-executor-clients/prompt-executor-anthropic-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/anthropic/AnthropicLLMClient.kt

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -142,44 +142,42 @@ public open class AnthropicLLMClient(
142142
}
143143
}
144144

145-
override suspend fun executeStreaming(prompt: Prompt, model: LLModel): Flow<String> {
145+
override fun executeStreaming(prompt: Prompt, model: LLModel): Flow<String> = flow {
146146
logger.debug { "Executing streaming prompt: $prompt with model: $model without tools" }
147147
require(model.capabilities.contains(LLMCapability.Completion)) {
148148
"Model ${model.id} does not support chat completions"
149149
}
150150

151151
val request = createAnthropicRequest(prompt, emptyList(), model, true)
152152

153-
return flow {
154-
try {
155-
httpClient.sse(
156-
urlString = DEFAULT_MESSAGE_PATH,
157-
request = {
158-
method = HttpMethod.Post
159-
accept(ContentType.Text.EventStream)
160-
headers {
161-
append(HttpHeaders.CacheControl, "no-cache")
162-
append(HttpHeaders.Connection, "keep-alive")
163-
}
164-
setBody(request)
165-
}
166-
) {
167-
incoming.collect { event ->
168-
event
169-
.takeIf { it.event == "content_block_delta" }
170-
?.data?.trim()?.let { json.decodeFromString<AnthropicStreamResponse>(it) }
171-
?.delta?.text?.let { emit(it) }
153+
try {
154+
httpClient.sse(
155+
urlString = DEFAULT_MESSAGE_PATH,
156+
request = {
157+
method = HttpMethod.Post
158+
accept(ContentType.Text.EventStream)
159+
headers {
160+
append(HttpHeaders.CacheControl, "no-cache")
161+
append(HttpHeaders.Connection, "keep-alive")
172162
}
163+
setBody(request)
173164
}
174-
} catch (e: SSEClientException) {
175-
e.response?.let { response ->
176-
logger.error { "Error from Anthropic API: ${response.status}: ${e.message}" }
177-
error("Error from Anthropic API: ${response.status}: ${e.message}")
165+
) {
166+
incoming.collect { event ->
167+
event
168+
.takeIf { it.event == "content_block_delta" }
169+
?.data?.trim()?.let { json.decodeFromString<AnthropicStreamResponse>(it) }
170+
?.delta?.text?.let { emit(it) }
178171
}
179-
} catch (e: Exception) {
180-
logger.error { "Exception during streaming: $e" }
181-
error(e.message ?: "Unknown error during streaming")
182172
}
173+
} catch (e: SSEClientException) {
174+
e.response?.let { response ->
175+
logger.error { "Error from Anthropic API: ${response.status}: ${e.message}" }
176+
error("Error from Anthropic API: ${response.status}: ${e.message}")
177+
}
178+
} catch (e: Exception) {
179+
logger.error { "Exception during streaming: $e" }
180+
error(e.message ?: "Unknown error during streaming")
183181
}
184182
}
185183

prompt/prompt-executor/prompt-executor-clients/prompt-executor-google-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/google/GoogleLLMClient.kt

Lines changed: 27 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -138,46 +138,44 @@ public open class GoogleLLMClient(
138138
}
139139
}
140140

141-
override suspend fun executeStreaming(prompt: Prompt, model: LLModel): Flow<String> {
141+
override fun executeStreaming(prompt: Prompt, model: LLModel): Flow<String> = flow {
142142
logger.debug { "Executing streaming prompt: $prompt with model: $model" }
143143
require(model.capabilities.contains(LLMCapability.Completion)) {
144144
"Model ${model.id} does not support chat completions"
145145
}
146146

147147
val request = createGoogleRequest(prompt, model, emptyList())
148148

149-
return flow {
150-
try {
151-
httpClient.sse(
152-
urlString = "$DEFAULT_PATH/${model.id}:$DEFAULT_METHOD_STREAM_GENERATE_CONTENT",
153-
request = {
154-
method = HttpMethod.Post
155-
parameter("alt", "sse")
156-
accept(ContentType.Text.EventStream)
157-
headers {
158-
append(HttpHeaders.CacheControl, "no-cache")
159-
append(HttpHeaders.Connection, "keep-alive")
160-
}
161-
setBody(request)
162-
}
163-
) {
164-
incoming.collect { event ->
165-
event
166-
.takeIf { it.data != "[DONE]" }
167-
?.data?.trim()?.let { json.decodeFromString<GoogleResponse>(it) }
168-
?.candidates?.firstOrNull()?.content
169-
?.parts?.forEach { part -> if (part is GooglePart.Text) emit(part.text) }
149+
try {
150+
httpClient.sse(
151+
urlString = "$DEFAULT_PATH/${model.id}:$DEFAULT_METHOD_STREAM_GENERATE_CONTENT",
152+
request = {
153+
method = HttpMethod.Post
154+
parameter("alt", "sse")
155+
accept(ContentType.Text.EventStream)
156+
headers {
157+
append(HttpHeaders.CacheControl, "no-cache")
158+
append(HttpHeaders.Connection, "keep-alive")
170159
}
160+
setBody(request)
171161
}
172-
} catch (e: SSEClientException) {
173-
e.response?.let { response ->
174-
logger.error { "Error from GoogleAI API: ${response.status}: ${e.message}" }
175-
error("Error from GoogleAI API: ${response.status}: ${e.message}")
162+
) {
163+
incoming.collect { event ->
164+
event
165+
.takeIf { it.data != "[DONE]" }
166+
?.data?.trim()?.let { json.decodeFromString<GoogleResponse>(it) }
167+
?.candidates?.firstOrNull()?.content
168+
?.parts?.forEach { part -> if (part is GooglePart.Text) emit(part.text) }
176169
}
177-
} catch (e: Exception) {
178-
logger.error { "Exception during streaming: $e" }
179-
error(e.message ?: "Unknown error during streaming")
180170
}
171+
} catch (e: SSEClientException) {
172+
e.response?.let { response ->
173+
logger.error { "Error from GoogleAI API: ${response.status}: ${e.message}" }
174+
error("Error from GoogleAI API: ${response.status}: ${e.message}")
175+
}
176+
} catch (e: Exception) {
177+
logger.error { "Exception during streaming: $e" }
178+
error(e.message ?: "Unknown error during streaming")
181179
}
182180
}
183181

prompt/prompt-executor/prompt-executor-clients/prompt-executor-ollama-client/src/commonMain/kotlin/ai/koog/prompt/executor/ollama/client/OllamaClient.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ public class OllamaClient(
144144
}
145145
}
146146

147-
override suspend fun executeStreaming(
147+
override fun executeStreaming(
148148
prompt: Prompt, model: LLModel
149149
): Flow<String> = flow {
150150
require(model.provider == LLMProvider.Ollama) { "Model not supported by Ollama" }

prompt/prompt-executor/prompt-executor-clients/prompt-executor-openai-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openai/OpenAILLMClient.kt

Lines changed: 30 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -145,45 +145,43 @@ public open class OpenAILLMClient(
145145
}
146146
}
147147

148-
override suspend fun executeStreaming(prompt: Prompt, model: LLModel): Flow<String> {
148+
override fun executeStreaming(prompt: Prompt, model: LLModel): Flow<String> = flow {
149149
logger.debug { "Executing streaming prompt: $prompt with model: $model" }
150150
require(model.capabilities.contains(LLMCapability.Completion)) {
151151
"Model ${model.id} does not support chat completions"
152152
}
153153

154154
val request = createOpenAIRequest(prompt, emptyList(), model, true)
155155

156-
return flow {
157-
try {
158-
httpClient.sse(
159-
urlString = settings.chatCompletionsPath,
160-
request = {
161-
method = HttpMethod.Post
162-
accept(ContentType.Text.EventStream)
163-
headers {
164-
append(HttpHeaders.CacheControl, "no-cache")
165-
append(HttpHeaders.Connection, "keep-alive")
166-
}
167-
setBody(request)
168-
}
169-
) {
170-
incoming.collect { event ->
171-
event
172-
.takeIf { it.data != "[DONE]" }
173-
?.data?.trim()?.let { json.decodeFromString<OpenAIStreamResponse>(it) }
174-
?.choices?.forEach { choice -> choice.delta.content?.let { emit(it) } }
156+
try {
157+
httpClient.sse(
158+
urlString = settings.chatCompletionsPath,
159+
request = {
160+
method = HttpMethod.Post
161+
accept(ContentType.Text.EventStream)
162+
headers {
163+
append(HttpHeaders.CacheControl, "no-cache")
164+
append(HttpHeaders.Connection, "keep-alive")
175165
}
166+
setBody(request)
176167
}
177-
} catch (e: SSEClientException) {
178-
e.response?.let { response ->
179-
val body = response.readRawBytes().decodeToString()
180-
logger.error(e) { "Error from OpenAI API: ${response.status}: ${e.message}.\nBody:\n$body" }
181-
error("Error from OpenAI API: ${response.status}: ${e.message}")
168+
) {
169+
incoming.collect { event ->
170+
event
171+
.takeIf { it.data != "[DONE]" }
172+
?.data?.trim()?.let { json.decodeFromString<OpenAIStreamResponse>(it) }
173+
?.choices?.forEach { choice -> choice.delta.content?.let { emit(it) } }
182174
}
183-
} catch (e: Exception) {
184-
logger.error { "Exception during streaming: $e" }
185-
error(e.message ?: "Unknown error during streaming")
186175
}
176+
} catch (e: SSEClientException) {
177+
e.response?.let { response ->
178+
val body = response.readRawBytes().decodeToString()
179+
logger.error(e) { "Error from OpenAI API: ${response.status}: ${e.message}.\nBody:\n$body" }
180+
error("Error from OpenAI API: ${response.status}: ${e.message}")
181+
}
182+
} catch (e: Exception) {
183+
logger.error { "Exception during streaming: $e" }
184+
error(e.message ?: "Unknown error during streaming")
187185
}
188186
}
189187

@@ -330,7 +328,10 @@ public open class OpenAILLMClient(
330328
null -> null
331329
}
332330

333-
val modalities = if (model.capabilities.contains(LLMCapability.Audio)) listOf(OpenAIModalities.Text, OpenAIModalities.Audio) else null
331+
val modalities = if (model.capabilities.contains(LLMCapability.Audio)) listOf(
332+
OpenAIModalities.Text,
333+
OpenAIModalities.Audio
334+
) else null
334335
// TODO allow passing this externally and actually controlling this behavior
335336
val audio = modalities?.let {
336337
OpenAIAudioConfig(

prompt/prompt-executor/prompt-executor-clients/prompt-executor-openrouter-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openrouter/OpenRouterLLMClient.kt

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -123,44 +123,42 @@ public class OpenRouterLLMClient(
123123
}
124124
}
125125

126-
override suspend fun executeStreaming(prompt: Prompt, model: LLModel): Flow<String> {
126+
override fun executeStreaming(prompt: Prompt, model: LLModel): Flow<String> = flow {
127127
logger.debug { "Executing streaming prompt: $prompt" }
128128
require(model.capabilities.contains(LLMCapability.Completion)) {
129129
"Model ${model.id} does not support chat completions"
130130
}
131131

132132
val request = createOpenRouterRequest(prompt, model, emptyList(), true)
133133

134-
return flow {
135-
try {
136-
httpClient.sse(
137-
urlString = DEFAULT_MESSAGE_PATH,
138-
request = {
139-
method = HttpMethod.Post
140-
accept(ContentType.Text.EventStream)
141-
headers {
142-
append(HttpHeaders.CacheControl, "no-cache")
143-
append(HttpHeaders.Connection, "keep-alive")
144-
}
145-
setBody(request)
146-
}
147-
) {
148-
incoming.collect { event ->
149-
event
150-
.takeIf { it.data != "[DONE]" }
151-
?.data?.trim()?.let { json.decodeFromString<OpenRouterStreamResponse>(it) }
152-
?.choices?.forEach { choice -> choice.delta.content?.let { emit(it) } }
134+
try {
135+
httpClient.sse(
136+
urlString = DEFAULT_MESSAGE_PATH,
137+
request = {
138+
method = HttpMethod.Post
139+
accept(ContentType.Text.EventStream)
140+
headers {
141+
append(HttpHeaders.CacheControl, "no-cache")
142+
append(HttpHeaders.Connection, "keep-alive")
153143
}
144+
setBody(request)
154145
}
155-
} catch (e: SSEClientException) {
156-
e.response?.let { response ->
157-
logger.error { "Error from OpenRouter API: ${response.status}: ${e.message}" }
158-
error("Error from OpenRouter API: ${response.status}: ${e.message}")
146+
) {
147+
incoming.collect { event ->
148+
event
149+
.takeIf { it.data != "[DONE]" }
150+
?.data?.trim()?.let { json.decodeFromString<OpenRouterStreamResponse>(it) }
151+
?.choices?.forEach { choice -> choice.delta.content?.let { emit(it) } }
159152
}
160-
} catch (e: Exception) {
161-
logger.error { "Exception during streaming: $e" }
162-
error(e.message ?: "Unknown error during streaming")
163153
}
154+
} catch (e: SSEClientException) {
155+
e.response?.let { response ->
156+
logger.error { "Error from OpenRouter API: ${response.status}: ${e.message}" }
157+
error("Error from OpenRouter API: ${response.status}: ${e.message}")
158+
}
159+
} catch (e: Exception) {
160+
logger.error { "Exception during streaming: $e" }
161+
error(e.message ?: "Unknown error during streaming")
164162
}
165163
}
166164

prompt/prompt-executor/prompt-executor-clients/src/commonMain/kotlin/ai/koog/prompt/executor/clients/LLMClient.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ public interface LLMClient {
3232
* @param model The LLM model to use
3333
* @return Flow of response chunks
3434
*/
35-
public suspend fun executeStreaming(prompt: Prompt, model: LLModel): Flow<String>
35+
public fun executeStreaming(prompt: Prompt, model: LLModel): Flow<String>
3636
}
3737

3838
/**

prompt/prompt-executor/prompt-executor-llms-all/src/jvmTest/kotlin/ai/koog/prompt/executor/llms/all/MultipleLLMPromptExecutorMockTest.kt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class MultipleLLMPromptExecutorMockTest {
4242
return listOf(Message.Assistant("OpenAI response", ResponseMetaInfo.create(mockClock)))
4343
}
4444

45-
override suspend fun executeStreaming(prompt: Prompt, model: LLModel): Flow<String> {
45+
override fun executeStreaming(prompt: Prompt, model: LLModel): Flow<String> {
4646
return flowOf("OpenAI", " streaming", " response")
4747
}
4848
}
@@ -57,7 +57,7 @@ class MultipleLLMPromptExecutorMockTest {
5757
return listOf(Message.Assistant("Anthropic response", ResponseMetaInfo.create(mockClock)))
5858
}
5959

60-
override suspend fun executeStreaming(prompt: Prompt, model: LLModel): Flow<String> {
60+
override fun executeStreaming(prompt: Prompt, model: LLModel): Flow<String> {
6161
return flowOf("Anthropic", " streaming", " response")
6262
}
6363
}
@@ -72,7 +72,7 @@ class MultipleLLMPromptExecutorMockTest {
7272
return listOf(Message.Assistant("Gemini response", ResponseMetaInfo.create(mockClock)))
7373
}
7474

75-
override suspend fun executeStreaming(prompt: Prompt, model: LLModel): Flow<String> {
75+
override fun executeStreaming(prompt: Prompt, model: LLModel): Flow<String> {
7676
return flowOf("Gemini", " streaming", " response")
7777
}
7878
}

0 commit comments

Comments
 (0)