Skip to content

Commit 76bc3a6

Browse files
authored
Use Inference Profiles on AWS Bedrock Models (#506)
1 parent a38991d commit 76bc3a6

File tree

9 files changed

+436
-297
lines changed

9 files changed

+436
-297
lines changed

agents/agents-features/agents-features-memory/src/jvmMain/kotlin/ai/koog/agents/memory/storage/Aes256GCMStorageEncryptor.kt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
package ai.koog.agents.memory.storage
22

33
import java.security.SecureRandom
4-
import java.util.*
4+
import java.util.Base64
55
import javax.crypto.Cipher
66
import javax.crypto.KeyGenerator
77
import javax.crypto.SecretKey
@@ -75,8 +75,8 @@ public class Aes256GCMEncryptor(secretKey: String) : Encryption {
7575
return String(plaintext)
7676
}
7777

78-
override fun encrypt(text: String): String {
79-
val (nonce, ciphertext) = encryptImpl(text)
78+
override fun encrypt(value: String): String {
79+
val (nonce, ciphertext) = encryptImpl(value)
8080
return Base64.getEncoder().encodeToString(nonce + ciphertext)
8181
}
8282

examples/src/main/kotlin/ai/koog/agents/example/client/BedrockAgent.kt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import ai.koog.agents.example.simpleapi.Switch
88
import ai.koog.agents.example.simpleapi.SwitchTools
99
import ai.koog.prompt.executor.clients.bedrock.BedrockClientSettings
1010
import ai.koog.prompt.executor.clients.bedrock.BedrockModels
11+
import ai.koog.prompt.executor.clients.bedrock.BedrockRegions
1112
import ai.koog.prompt.executor.llms.all.simpleBedrockExecutor
1213
import kotlinx.coroutines.runBlocking
1314

@@ -20,7 +21,7 @@ fun main(): Unit = runBlocking {
2021

2122
// Create Bedrock client settings
2223
val bedrockSettings = BedrockClientSettings(
23-
region = "us-east-1", // Change this to your preferred region
24+
region = BedrockRegions.US_WEST_2.regionCode, // Change this to your preferred region
2425
maxRetries = 3
2526
)
2627

integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/MultipleLLMPromptExecutorIntegrationTest.kt

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,13 @@ import ai.koog.integration.tests.utils.TestUtils
1717
import ai.koog.integration.tests.utils.TestUtils.readTestAnthropicKeyFromEnv
1818
import ai.koog.integration.tests.utils.TestUtils.readTestGoogleAIKeyFromEnv
1919
import ai.koog.integration.tests.utils.TestUtils.readTestOpenAIKeyFromEnv
20+
import ai.koog.integration.tests.utils.annotations.Retry
2021
import ai.koog.prompt.dsl.ModerationCategory
2122
import ai.koog.prompt.dsl.prompt
2223
import ai.koog.prompt.executor.clients.anthropic.AnthropicLLMClient
24+
import ai.koog.prompt.executor.clients.anthropic.AnthropicModels
2325
import ai.koog.prompt.executor.clients.google.GoogleLLMClient
26+
import ai.koog.prompt.executor.clients.google.GoogleModels
2427
import ai.koog.prompt.executor.clients.openai.OpenAILLMClient
2528
import ai.koog.prompt.executor.clients.openai.OpenAIModels
2629
import ai.koog.prompt.executor.llms.MultiLLMPromptExecutor
@@ -35,6 +38,7 @@ import ai.koog.prompt.message.AttachmentContent
3538
import ai.koog.prompt.message.Message
3639
import ai.koog.prompt.params.LLMParams.ToolChoice
3740
import kotlinx.coroutines.flow.toList
41+
import kotlinx.coroutines.runBlocking
3842
import kotlinx.coroutines.test.runTest
3943
import org.junit.jupiter.api.Assumptions.assumeTrue
4044
import org.junit.jupiter.api.BeforeAll
@@ -105,14 +109,14 @@ class MultipleLLMPromptExecutorIntegrationTest {
105109
}
106110

107111
// API keys for testing
108-
private val geminiApiKey: String get() = readTestGoogleAIKeyFromEnv()
109112
private val openAIApiKey: String get() = readTestOpenAIKeyFromEnv()
110113
private val anthropicApiKey: String get() = readTestAnthropicKeyFromEnv()
114+
private val googleApiKey: String get() = readTestGoogleAIKeyFromEnv()
111115

112116
// LLM clients
113117
private val openAIClient get() = OpenAILLMClient(openAIApiKey)
114118
private val anthropicClient get() = AnthropicLLMClient(anthropicApiKey)
115-
private val googleClient get() = GoogleLLMClient(geminiApiKey)
119+
private val googleClient get() = GoogleLLMClient(googleApiKey)
116120
val executor = DefaultMultiLLMPromptExecutor(openAIClient, anthropicClient, googleClient)
117121

118122
private fun createCalculatorTool(): ToolDescriptor {
@@ -1144,4 +1148,44 @@ class MultipleLLMPromptExecutorIntegrationTest {
11441148
)
11451149
) { "Violence must be detected!" }
11461150
}
1151+
1152+
@Retry
1153+
@Test
1154+
fun integration_testMultipleSystemMessages() = runBlocking {
1155+
Models.assumeAvailable(LLMProvider.OpenAI)
1156+
Models.assumeAvailable(LLMProvider.Anthropic)
1157+
Models.assumeAvailable(LLMProvider.Google)
1158+
1159+
val openAIClient = OpenAILLMClient(openAIApiKey)
1160+
val anthropicClient = AnthropicLLMClient(anthropicApiKey)
1161+
val googleClient = GoogleLLMClient(googleApiKey)
1162+
1163+
val executor = MultiLLMPromptExecutor(
1164+
LLMProvider.OpenAI to openAIClient,
1165+
LLMProvider.Anthropic to anthropicClient,
1166+
LLMProvider.Google to googleClient
1167+
)
1168+
1169+
val prompt = prompt("multiple-system-messages-test") {
1170+
system("You are a helpful assistant.")
1171+
user("Hi")
1172+
system("You can handle multiple system messages.")
1173+
user("Respond with a short message.")
1174+
}
1175+
1176+
val modelOpenAI = OpenAIModels.CostOptimized.GPT4oMini
1177+
val modelAnthropic = AnthropicModels.Haiku_3_5
1178+
val modelGemini = GoogleModels.Gemini2_0Flash
1179+
1180+
val responseOpenAI = executor.execute(prompt, modelOpenAI)
1181+
val responseAnthropic = executor.execute(prompt, modelAnthropic)
1182+
val responseGemini = executor.execute(prompt, modelGemini)
1183+
1184+
assertTrue(responseOpenAI.content.isNotEmpty(), "OpenAI response should not be empty")
1185+
assertTrue(responseAnthropic.content.isNotEmpty(), "Anthropic response should not be empty")
1186+
assertTrue(responseGemini.content.isNotEmpty(), "Gemini response should not be empty")
1187+
println("OpenAI Response: ${responseOpenAI.content}")
1188+
println("Anthropic Response: ${responseAnthropic.content}")
1189+
println("Gemini Response: ${responseGemini.content}")
1190+
}
11471191
}

integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/MultipleSystemMessagesPromptIntegrationTest.kt

Lines changed: 0 additions & 69 deletions
This file was deleted.

integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/SingleLLMPromptExecutorIntegrationTest.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1093,7 +1093,7 @@ class SingleLLMPromptExecutorIntegrationTest {
10931093
val executor = simpleBedrockExecutor(
10941094
readAwsAccessKeyIdFromEnv(),
10951095
readAwsSecretAccessKeyFromEnv(),
1096-
readAwsSessionTokenFromEnv() ?: "",
1096+
readAwsSessionTokenFromEnv(),
10971097
)
10981098

10991099
val prompt = Prompt.build("test-simple-bedrock-executor") {

integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/utils/RetryUtils.kt

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package ai.koog.integration.tests.utils
22

33
import kotlinx.coroutines.delay
4+
import kotlinx.coroutines.runBlocking
45
import org.junit.jupiter.api.Assumptions
56

67
/*
@@ -36,20 +37,18 @@ object RetryUtils {
3637
}
3738
}
3839

39-
suspend fun <T> withRetry(
40+
fun <T> withRetry(
4041
times: Int = 3,
4142
delayMs: Long = 1000,
4243
testName: String = "test",
4344
action: suspend () -> T
44-
): T {
45+
): T = runBlocking {
4546
var lastException: Throwable? = null
4647

4748
for (attempt in 1..times) {
4849
try {
49-
println("[DEBUG_LOG] Test '$testName' - attempt $attempt of $times")
5050
val result = action()
51-
println("[DEBUG_LOG] Test '$testName' succeeded on attempt $attempt")
52-
return result
51+
return@runBlocking result
5352
} catch (throwable: Throwable) {
5453
lastException = throwable
5554

@@ -59,14 +58,10 @@ object RetryUtils {
5958
false,
6059
"Skipping test due to third-party service error: ${throwable.message}"
6160
)
62-
return action()
61+
return@runBlocking action()
6362
}
6463

65-
println("[DEBUG_LOG] Test '$testName' failed on attempt $attempt: ${throwable.message}")
66-
6764
if (attempt < times) {
68-
println("[DEBUG_LOG] Retrying test '$testName' (attempt ${attempt + 1} of $times)")
69-
7065
if (delayMs > 0) {
7166
delay(delayMs)
7267
}

prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/BedrockLLMClient.kt

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ import kotlinx.serialization.json.Json
6262
* @property moderationGuardrailsSettings Optional settings of the AWS bedrock Guardrails (see [AWS documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/guardrails-use-independent-api.html) ) that would be used for the [LLMClient.moderate] request
6363
*/
6464
public class BedrockClientSettings(
65-
internal val region: String = "us-east-1",
65+
internal val region: String = BedrockRegions.US_WEST_2.regionCode,
6666
internal val timeoutConfig: ConnectionTimeoutConfig = ConnectionTimeoutConfig(),
6767
internal val endpointUrl: String? = null,
6868
internal val maxRetries: Int = 3,
@@ -124,7 +124,7 @@ public class BedrockLLMClient(
124124
awsSessionToken?.let { this.sessionToken = it }
125125
}
126126

127-
// Configure custom endpoint if provided
127+
// Configure a custom endpoint if provided
128128
settings.endpointUrl?.let { url ->
129129
this.endpointUrl = Url.parse(url)
130130
}
@@ -147,10 +147,10 @@ public class BedrockLLMClient(
147147
internal fun getBedrockModelFamily(model: LLModel): BedrockModelFamilies {
148148
require(model.provider == LLMProvider.Bedrock) { "Model ${model.id} is not a Bedrock model" }
149149
return when {
150-
model.id.startsWith("anthropic.claude") -> BedrockModelFamilies.AnthropicClaude
151-
model.id.startsWith("amazon.nova") -> BedrockModelFamilies.AmazonNova
152-
model.id.startsWith("ai21.jamba") -> BedrockModelFamilies.AI21Jamba
153-
model.id.startsWith("meta.llama") -> BedrockModelFamilies.Meta
150+
model.id.contains("anthropic.claude") -> BedrockModelFamilies.AnthropicClaude
151+
model.id.contains("amazon.nova") -> BedrockModelFamilies.AmazonNova
152+
model.id.contains("ai21.jamba") -> BedrockModelFamilies.AI21Jamba
153+
model.id.contains("meta.llama") -> BedrockModelFamilies.Meta
154154
else -> throw IllegalArgumentException("Model ${model.id} is not a supported Bedrock model")
155155
}
156156
}
@@ -368,14 +368,14 @@ public class BedrockLLMClient(
368368
)
369369

370370
val inputIsHarmful = inputGuardrailResponse.action is GuardrailAction.GuardrailIntervened
371-
val outputputIsHarmful = inputGuardrailResponse.action is GuardrailAction.GuardrailIntervened
371+
val outputIsHarmful = inputGuardrailResponse.action is GuardrailAction.GuardrailIntervened
372372

373373
val categories = buildMap {
374374
fillCategoriesMap(inputGuardrailResponse)
375375
fillCategoriesMap(outputGuardrailResponse)
376376
}
377377

378-
return ModerationResult(inputIsHarmful || outputputIsHarmful, categories)
378+
return ModerationResult(inputIsHarmful || outputIsHarmful, categories)
379379
}
380380

381381
private fun MutableMap<ModerationCategory, ModerationCategoryResult>.fillCategoriesMap(

0 commit comments

Comments
 (0)