Skip to content

Use Inference Profiles on AWS Bedrock Models #506

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Aug 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package ai.koog.agents.memory.storage

import java.security.SecureRandom
import java.util.*
import java.util.Base64
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A good practice would be to use the Kotlin api where possible. Although in this case it doesn’t change anything, since kotlin simply wraps Java Base64. This approach will reduce potential issues if the code ever needs to be moved to common

import javax.crypto.Cipher
import javax.crypto.KeyGenerator
import javax.crypto.SecretKey
Expand Down Expand Up @@ -75,8 +75,8 @@ public class Aes256GCMEncryptor(secretKey: String) : Encryption {
return String(plaintext)
}

override fun encrypt(text: String): String {
val (nonce, ciphertext) = encryptImpl(text)
override fun encrypt(value: String): String {
val (nonce, ciphertext) = encryptImpl(value)
return Base64.getEncoder().encodeToString(nonce + ciphertext)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import ai.koog.agents.example.simpleapi.Switch
import ai.koog.agents.example.simpleapi.SwitchTools
import ai.koog.prompt.executor.clients.bedrock.BedrockClientSettings
import ai.koog.prompt.executor.clients.bedrock.BedrockModels
import ai.koog.prompt.executor.clients.bedrock.BedrockRegions
import ai.koog.prompt.executor.llms.all.simpleBedrockExecutor
import kotlinx.coroutines.runBlocking

Expand All @@ -20,7 +21,7 @@ fun main(): Unit = runBlocking {

// Create Bedrock client settings
val bedrockSettings = BedrockClientSettings(
region = "us-east-1", // Change this to your preferred region
region = BedrockRegions.US_WEST_2.regionCode, // Change this to your preferred region
maxRetries = 3
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@ import ai.koog.integration.tests.utils.TestUtils
import ai.koog.integration.tests.utils.TestUtils.readTestAnthropicKeyFromEnv
import ai.koog.integration.tests.utils.TestUtils.readTestGoogleAIKeyFromEnv
import ai.koog.integration.tests.utils.TestUtils.readTestOpenAIKeyFromEnv
import ai.koog.integration.tests.utils.annotations.Retry
import ai.koog.prompt.dsl.ModerationCategory
import ai.koog.prompt.dsl.prompt
import ai.koog.prompt.executor.clients.anthropic.AnthropicLLMClient
import ai.koog.prompt.executor.clients.anthropic.AnthropicModels
import ai.koog.prompt.executor.clients.google.GoogleLLMClient
import ai.koog.prompt.executor.clients.google.GoogleModels
import ai.koog.prompt.executor.clients.openai.OpenAILLMClient
import ai.koog.prompt.executor.clients.openai.OpenAIModels
import ai.koog.prompt.executor.llms.MultiLLMPromptExecutor
Expand All @@ -35,6 +38,7 @@ import ai.koog.prompt.message.AttachmentContent
import ai.koog.prompt.message.Message
import ai.koog.prompt.params.LLMParams.ToolChoice
import kotlinx.coroutines.flow.toList
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.test.runTest
import org.junit.jupiter.api.Assumptions.assumeTrue
import org.junit.jupiter.api.BeforeAll
Expand Down Expand Up @@ -105,14 +109,14 @@ class MultipleLLMPromptExecutorIntegrationTest {
}

// API keys for testing
private val geminiApiKey: String get() = readTestGoogleAIKeyFromEnv()
private val openAIApiKey: String get() = readTestOpenAIKeyFromEnv()
private val anthropicApiKey: String get() = readTestAnthropicKeyFromEnv()
private val googleApiKey: String get() = readTestGoogleAIKeyFromEnv()

// LLM clients
private val openAIClient get() = OpenAILLMClient(openAIApiKey)
private val anthropicClient get() = AnthropicLLMClient(anthropicApiKey)
private val googleClient get() = GoogleLLMClient(geminiApiKey)
private val googleClient get() = GoogleLLMClient(googleApiKey)
val executor = DefaultMultiLLMPromptExecutor(openAIClient, anthropicClient, googleClient)

private fun createCalculatorTool(): ToolDescriptor {
Expand Down Expand Up @@ -1144,4 +1148,44 @@ class MultipleLLMPromptExecutorIntegrationTest {
)
) { "Violence must be detected!" }
}

@Retry
@Test
fun integration_testMultipleSystemMessages() = runBlocking {
Models.assumeAvailable(LLMProvider.OpenAI)
Models.assumeAvailable(LLMProvider.Anthropic)
Models.assumeAvailable(LLMProvider.Google)

val openAIClient = OpenAILLMClient(openAIApiKey)
val anthropicClient = AnthropicLLMClient(anthropicApiKey)
val googleClient = GoogleLLMClient(googleApiKey)

val executor = MultiLLMPromptExecutor(
LLMProvider.OpenAI to openAIClient,
LLMProvider.Anthropic to anthropicClient,
LLMProvider.Google to googleClient
)

val prompt = prompt("multiple-system-messages-test") {
system("You are a helpful assistant.")
user("Hi")
system("You can handle multiple system messages.")
user("Respond with a short message.")
}

val modelOpenAI = OpenAIModels.CostOptimized.GPT4oMini
val modelAnthropic = AnthropicModels.Haiku_3_5
val modelGemini = GoogleModels.Gemini2_0Flash

val responseOpenAI = executor.execute(prompt, modelOpenAI)
val responseAnthropic = executor.execute(prompt, modelAnthropic)
val responseGemini = executor.execute(prompt, modelGemini)

assertTrue(responseOpenAI.content.isNotEmpty(), "OpenAI response should not be empty")
assertTrue(responseAnthropic.content.isNotEmpty(), "Anthropic response should not be empty")
assertTrue(responseGemini.content.isNotEmpty(), "Gemini response should not be empty")
println("OpenAI Response: ${responseOpenAI.content}")
println("Anthropic Response: ${responseAnthropic.content}")
println("Gemini Response: ${responseGemini.content}")
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -1093,7 +1093,7 @@ class SingleLLMPromptExecutorIntegrationTest {
val executor = simpleBedrockExecutor(
readAwsAccessKeyIdFromEnv(),
readAwsSecretAccessKeyFromEnv(),
readAwsSessionTokenFromEnv() ?: "",
readAwsSessionTokenFromEnv(),
)

val prompt = Prompt.build("test-simple-bedrock-executor") {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ai.koog.integration.tests.utils

import kotlinx.coroutines.delay
import kotlinx.coroutines.runBlocking
import org.junit.jupiter.api.Assumptions

/*
Expand Down Expand Up @@ -36,20 +37,18 @@ object RetryUtils {
}
}

suspend fun <T> withRetry(
fun <T> withRetry(
times: Int = 3,
delayMs: Long = 1000,
testName: String = "test",
action: suspend () -> T
): T {
): T = runBlocking {
var lastException: Throwable? = null

for (attempt in 1..times) {
try {
println("[DEBUG_LOG] Test '$testName' - attempt $attempt of $times")
val result = action()
println("[DEBUG_LOG] Test '$testName' succeeded on attempt $attempt")
return result
return@runBlocking result
} catch (throwable: Throwable) {
lastException = throwable

Expand All @@ -59,14 +58,10 @@ object RetryUtils {
false,
"Skipping test due to third-party service error: ${throwable.message}"
)
return action()
return@runBlocking action()
}

println("[DEBUG_LOG] Test '$testName' failed on attempt $attempt: ${throwable.message}")

if (attempt < times) {
println("[DEBUG_LOG] Retrying test '$testName' (attempt ${attempt + 1} of $times)")

if (delayMs > 0) {
delay(delayMs)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
* @property moderationGuardrailsSettings Optional settings of the AWS bedrock Guardrails (see [AWS documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/guardrails-use-independent-api.html) ) that would be used for the [LLMClient.moderate] request
*/
public class BedrockClientSettings(
internal val region: String = "us-east-1",
internal val region: String = BedrockRegions.US_WEST_2.regionCode,
internal val timeoutConfig: ConnectionTimeoutConfig = ConnectionTimeoutConfig(),
internal val endpointUrl: String? = null,
internal val maxRetries: Int = 3,
Expand Down Expand Up @@ -124,7 +124,7 @@
awsSessionToken?.let { this.sessionToken = it }
}

// Configure custom endpoint if provided
// Configure a custom endpoint if provided
settings.endpointUrl?.let { url ->
this.endpointUrl = Url.parse(url)
}
Expand All @@ -144,18 +144,18 @@
explicitNulls = false
}

internal fun getBedrockModelFamily(model: LLModel): BedrockModelFamilies {

Check warning on line 147 in prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/BedrockLLMClient.kt

View workflow job for this annotation

GitHub Actions / Qodana for JVM

Check Kotlin and Java source code coverage

Method `getBedrockModelFamily$prompt_executor_bedrock_client` coverage is below the threshold 50%
require(model.provider == LLMProvider.Bedrock) { "Model ${model.id} is not a Bedrock model" }
return when {
model.id.startsWith("anthropic.claude") -> BedrockModelFamilies.AnthropicClaude
model.id.startsWith("amazon.nova") -> BedrockModelFamilies.AmazonNova
model.id.startsWith("ai21.jamba") -> BedrockModelFamilies.AI21Jamba
model.id.startsWith("meta.llama") -> BedrockModelFamilies.Meta
model.id.contains("anthropic.claude") -> BedrockModelFamilies.AnthropicClaude
model.id.contains("amazon.nova") -> BedrockModelFamilies.AmazonNova
model.id.contains("ai21.jamba") -> BedrockModelFamilies.AI21Jamba
model.id.contains("meta.llama") -> BedrockModelFamilies.Meta
else -> throw IllegalArgumentException("Model ${model.id} is not a supported Bedrock model")
}
}

override suspend fun execute(

Check warning on line 158 in prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/BedrockLLMClient.kt

View workflow job for this annotation

GitHub Actions / Qodana for JVM

Check Kotlin and Java source code coverage

Method `execute` coverage is below the threshold 50%
prompt: Prompt,
model: LLModel,
tools: List<ToolDescriptor>
Expand Down Expand Up @@ -238,7 +238,7 @@
}

@OptIn(ExperimentalCoroutinesApi::class, FlowPreview::class)
override fun executeStreaming(prompt: Prompt, model: LLModel): Flow<String> {

Check warning on line 241 in prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/BedrockLLMClient.kt

View workflow job for this annotation

GitHub Actions / Qodana for JVM

Check Kotlin and Java source code coverage

Method `executeStreaming` coverage is below the threshold 50%
logger.debug { "Executing streaming prompt for model: ${model.id}" }
val modelFamily = getBedrockModelFamily(model)

Expand Down Expand Up @@ -368,14 +368,14 @@
)

val inputIsHarmful = inputGuardrailResponse.action is GuardrailAction.GuardrailIntervened
val outputputIsHarmful = inputGuardrailResponse.action is GuardrailAction.GuardrailIntervened
val outputIsHarmful = inputGuardrailResponse.action is GuardrailAction.GuardrailIntervened

val categories = buildMap {
fillCategoriesMap(inputGuardrailResponse)
fillCategoriesMap(outputGuardrailResponse)
}

return ModerationResult(inputIsHarmful || outputputIsHarmful, categories)
return ModerationResult(inputIsHarmful || outputIsHarmful, categories)
}

private fun MutableMap<ModerationCategory, ModerationCategoryResult>.fillCategoriesMap(
Expand Down
Loading
Loading