Skip to content

Commit 842274c

Browse files
committed
Support AWS_SESSION_TOKEN in BedrockClient
1 parent 96b1b6e commit 842274c

File tree

10 files changed

+66
-7
lines changed

10 files changed

+66
-7
lines changed

.github/workflows/heavy-tests.yml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ name: Heavy Tests
1010
on:
1111
workflow_dispatch: # Manual trigger
1212
push:
13-
branches: [ "main", "develop" ]
13+
branches: [ main, develop ]
1414

1515
env:
1616
AWS_REGION: us-west-2
@@ -66,6 +66,12 @@ jobs:
6666
}
6767
echo "Bedrock access verified!"
6868
69+
- name: Run Bedrock Credentials Smoke Test
70+
run: |
71+
echo "Running Bedrock credentials smoke test..."
72+
./gradlew :integration-tests:cleanJvmTest :integration-tests:jvmTest --tests "ai.koog.integration.tests.BedrockCredentialsSmokeTest"
73+
echo "Bedrock credentials smoke test passed!"
74+
6975
- name: JvmIntegrationTest with Gradle Wrapper
7076
env:
7177
ANTHROPIC_API_TEST_KEY: ${{ secrets.ANTHROPIC_API_TEST_KEY }}

build.gradle.kts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ subprojects {
9090
"OPEN_AI_API_TEST_KEY" to System.getenv("OPEN_AI_API_TEST_KEY"),
9191
"GEMINI_API_TEST_KEY" to System.getenv("GEMINI_API_TEST_KEY"),
9292
"OPEN_ROUTER_API_TEST_KEY" to System.getenv("OPEN_ROUTER_API_TEST_KEY"),
93-
"AWS_SECRET_KEY" to System.getenv("AWS_SECRET_KEY"),
93+
"AWS_SECRET_ACCESS_KEY" to System.getenv("AWS_SECRET_ACCESS_KEY"),
9494
"AWS_ACCESS_KEY_ID" to System.getenv("AWS_ACCESS_KEY_ID"),
9595
)
9696
)

gradle/libs.versions.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,9 @@ opentelemetry-bom = { module = "io.opentelemetry:opentelemetry-bom", version.ref
5656
opentelemetry-sdk = { module = "io.opentelemetry:opentelemetry-sdk" }
5757
opentelemetry-exporter-otlp = { module = "io.opentelemetry:opentelemetry-exporter-otlp" }
5858
opentelemetry-exporter-logging = { module = "io.opentelemetry:opentelemetry-exporter-logging" }
59+
aws-sdk-kotlin-bedrock = { module = "aws.sdk.kotlin:bedrock", version.ref = "aws-sdk-kotlin" }
5960
aws-sdk-kotlin-bedrockruntime = { module = "aws.sdk.kotlin:bedrockruntime", version.ref = "aws-sdk-kotlin" }
61+
aws-sdk-kotlin-sts = { module = "aws.sdk.kotlin:sts", version.ref = "aws-sdk-kotlin" }
6062

6163
# Spring
6264
spring-boot-bom = { module = "org.springframework.boot:spring-boot-dependencies", version.ref = "spring-boot" }

integration-tests/build.gradle.kts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ kotlin {
3131
implementation(libs.junit.jupiter.params)
3232
implementation(libs.kotlinx.coroutines.test)
3333
implementation(libs.kotlinx.serialization.json)
34+
implementation(libs.aws.sdk.kotlin.sts)
35+
implementation(libs.aws.sdk.kotlin.bedrock)
36+
implementation(libs.aws.sdk.kotlin.bedrockruntime)
3437
implementation(libs.ktor.client.content.negotiation)
3538
runtimeOnly(libs.slf4j.simple)
3639
}

integration-tests/env.template.properties

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,9 @@
77
OPEN_AI_API_TEST_KEY=
88
ANTHROPIC_API_TEST_KEY=
99
GEMINI_API_TEST_KEY=
10-
OPEN_ROUTER_API_TEST_KEY=
10+
OPEN_ROUTER_API_TEST_KEY=
11+
12+
# AWS Bedrock credentials
13+
AWS_ACCESS_KEY_ID=
14+
AWS_SECRET_ACCESS_KEY=
15+
AWS_SESSION_TOKEN=
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
package ai.koog.integration.tests
2+
3+
import aws.sdk.kotlin.services.bedrock.BedrockClient
4+
import aws.sdk.kotlin.services.bedrock.listFoundationModels
5+
import kotlinx.coroutines.runBlocking
6+
import kotlin.test.Test
7+
import kotlin.test.assertTrue
8+
9+
/**
10+
* Verifies that the credentials provided to the build can access the Bedrock control-plane
11+
* API (ListFoundationModels). If the call fails we abort early instead of waiting for a
12+
* full prompt execution to blow up.
13+
*/
14+
class BedrockCredentialsSmokeTest {
15+
@Test
16+
fun listFoundationModelsWorks() = runBlocking {
17+
val region = System.getenv("AWS_REGION") ?: "us-west-2"
18+
19+
BedrockClient { this.region = region }.use { bedrock ->
20+
val resp = bedrock.listFoundationModels { }
21+
assertTrue(
22+
(resp.modelSummaries?.isNotEmpty()) == true,
23+
"Bedrock ListFoundationModels returned no models – credentials/region might be wrong"
24+
)
25+
}
26+
}
27+
}

integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/SingleLLMPromptExecutorIntegrationTest.kt

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import ai.koog.integration.tests.utils.RetryUtils.withRetry
1616
import ai.koog.integration.tests.utils.TestUtils
1717
import ai.koog.integration.tests.utils.TestUtils.readAwsAccessKeyIdFromEnv
1818
import ai.koog.integration.tests.utils.TestUtils.readAwsSecretAccessKeyFromEnv
19+
import ai.koog.integration.tests.utils.TestUtils.readAwsSessionTokenFromEnv
1920
import ai.koog.integration.tests.utils.TestUtils.readTestAnthropicKeyFromEnv
2021
import ai.koog.integration.tests.utils.TestUtils.readTestGoogleAIKeyFromEnv
2122
import ai.koog.integration.tests.utils.TestUtils.readTestOpenAIKeyFromEnv
@@ -44,6 +45,7 @@ import kotlinx.coroutines.test.runTest
4445
import org.junit.jupiter.api.Assertions.assertNotNull
4546
import org.junit.jupiter.api.Assumptions.assumeTrue
4647
import org.junit.jupiter.api.BeforeAll
48+
import org.junit.jupiter.api.Disabled
4749
import org.junit.jupiter.params.ParameterizedTest
4850
import org.junit.jupiter.params.provider.Arguments
4951
import org.junit.jupiter.params.provider.MethodSource
@@ -79,6 +81,7 @@ class SingleLLMPromptExecutorIntegrationTest {
7981
val bedrockClientInstance = BedrockLLMClient(
8082
readAwsAccessKeyIdFromEnv(),
8183
readAwsSecretAccessKeyFromEnv(),
84+
readAwsSessionTokenFromEnv(),
8285
BedrockClientSettings()
8386
)
8487
// val openRouterClientInstance = OpenRouterLLMClient(readTestOpenRouterKeyFromEnv())
@@ -101,6 +104,7 @@ class SingleLLMPromptExecutorIntegrationTest {
101104
val bedrockClientInstance = BedrockLLMClient(
102105
readAwsAccessKeyIdFromEnv(),
103106
readAwsSecretAccessKeyFromEnv(),
107+
readAwsSessionTokenFromEnv(),
104108
BedrockClientSettings(),
105109
)
106110

@@ -1044,7 +1048,8 @@ class SingleLLMPromptExecutorIntegrationTest {
10441048
fun integration_testSimpleBedrockExecutor(model: LLModel) = runTest(timeout = 300.seconds) {
10451049
val executor = simpleBedrockExecutor(
10461050
readAwsAccessKeyIdFromEnv(),
1047-
readAwsSecretAccessKeyFromEnv()
1051+
readAwsSecretAccessKeyFromEnv(),
1052+
readAwsSessionTokenFromEnv() ?: "",
10481053
)
10491054

10501055
val prompt = Prompt.build("test-simple-bedrock-executor") {

integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/utils/TestUtils.kt

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,15 @@ object TestUtils {
3232
}
3333

3434
fun readAwsSecretAccessKeyFromEnv(): String {
35-
return System.getenv("AWS_SECRET_KEY")
36-
?: error("ERROR: environment variable `AWS_SECRET_KEY` is not set")
35+
return System.getenv("AWS_SECRET_ACCESS_KEY")
36+
?: error("ERROR: environment variable `AWS_SECRET_ACCESS_KEY` is not set")
37+
}
38+
39+
fun readAwsSessionTokenFromEnv(): String? {
40+
return System.getenv("AWS_SESSION_TOKEN")
41+
?: null.also {
42+
println("WARNING: environment variable `AWS_SESSION_TOKEN` is not set, using default session token")
43+
}
3744
}
3845

3946
@Serializable

prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/BedrockLLMClient.kt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,13 +93,15 @@ public class BedrockLLMClient(
9393
*
9494
* @param awsAccessKeyId The AWS access key ID for authentication
9595
* @param awsSecretAccessKey The AWS secret access key for authentication
96+
* @param awsSessionToken Optional session token for temporary credentials
9697
* @param settings Configuration settings for the Bedrock client, such as region and endpoint
9798
* @param clock A clock used for time-based operations
9899
* @return A configured [LLMClient] instance for Bedrock
99100
*/
100101
public constructor(
101102
awsAccessKeyId: String,
102103
awsSecretAccessKey: String,
104+
awsSessionToken: String? = null,
103105
settings: BedrockClientSettings = BedrockClientSettings(),
104106
clock: Clock = Clock.System,
105107
) : this(
@@ -108,6 +110,7 @@ public class BedrockLLMClient(
108110
this.credentialsProvider = StaticCredentialsProvider {
109111
this.accessKeyId = awsAccessKeyId
110112
this.secretAccessKey = awsSecretAccessKey
113+
awsSessionToken?.let { this.sessionToken = it }
111114
}
112115

113116
// Configure custom endpoint if provided

prompt/prompt-executor/prompt-executor-llms-all/src/jvmMain/kotlin/ai/koog/prompt/executor/llms/all/SimplePromptExecutors.jvm.kt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import ai.koog.prompt.executor.llms.SingleLLMPromptExecutor
1414
public fun simpleBedrockExecutor(
1515
awsAccessKeyId: String,
1616
awsSecretAccessKey: String,
17+
awsSessionToken: String? = null,
1718
settings: BedrockClientSettings = BedrockClientSettings()
1819
): SingleLLMPromptExecutor =
19-
SingleLLMPromptExecutor(BedrockLLMClient(awsAccessKeyId, awsSecretAccessKey, settings))
20+
SingleLLMPromptExecutor(BedrockLLMClient(awsAccessKeyId, awsSecretAccessKey, awsSessionToken, settings))

0 commit comments

Comments
 (0)