diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-anthropic-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/anthropic/AnthropicLLMClient.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-anthropic-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/anthropic/AnthropicLLMClient.kt index 7caf8b393..9e213a88f 100644 --- a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-anthropic-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/anthropic/AnthropicLLMClient.kt +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-anthropic-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/anthropic/AnthropicLLMClient.kt @@ -282,7 +282,7 @@ public open class AnthropicLLMClient( model = settings.modelVersionsMap[model] ?: throw IllegalArgumentException("Unsupported model: $model"), messages = messages, - maxTokens = 2048, // This is required by the API + maxTokens = prompt.params.maxTokens ?: AnthropicMessageRequest.MAX_TOKENS_DEFAULT, temperature = prompt.params.temperature, system = systemMessage, tools = if (tools.isNotEmpty()) anthropicTools else emptyList(), // Always provide a list for tools diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-anthropic-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/anthropic/DataModel.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-anthropic-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/anthropic/DataModel.kt index bea04976a..f5ec4e7ab 100644 --- a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-anthropic-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/anthropic/DataModel.kt +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-anthropic-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/anthropic/DataModel.kt @@ -27,13 +27,30 @@ import kotlinx.serialization.json.JsonObject public data class AnthropicMessageRequest( val model: String, val messages: List, - val maxTokens: Int = 2048, + val maxTokens: Int = MAX_TOKENS_DEFAULT, val temperature: Double? = null, val system: List? = null, val tools: List? = null, val stream: Boolean = false, val toolChoice: AnthropicToolChoice? = null, -) +) { + init { + require(maxTokens > 0) { "maxTokens must be greater than 0, but was $maxTokens" } + if (temperature != null) { + require(temperature >= 0) { "temperature must be greater than 0, but was $temperature" } + } + } + + /** + * Companion object with default values for request + */ + public companion object { + /** + * Default max tokens + */ + public const val MAX_TOKENS_DEFAULT: Int = 2048 + } +} /** * Represents a message within the Anthropic LLM system. This data class encapsulates diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-anthropic-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/anthropic/AnthropicModelsTest.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-anthropic-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/anthropic/AnthropicModelsTest.kt index d6ad7ee08..ec8a2a08c 100644 --- a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-anthropic-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/anthropic/AnthropicModelsTest.kt +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-anthropic-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/anthropic/AnthropicModelsTest.kt @@ -3,6 +3,8 @@ package ai.koog.prompt.executor.clients.anthropic import ai.koog.prompt.executor.clients.list import ai.koog.prompt.llm.LLMProvider import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith import kotlin.test.assertSame class AnthropicModelsTest { @@ -19,4 +21,38 @@ class AnthropicModelsTest { ) } } + + @Test + fun `AnthropicMessageRequest should use custom maxTokens when provided`() { + val customMaxTokens = 4000 + val request = AnthropicMessageRequest( + model = AnthropicModels.Opus_3.id, + messages = emptyList(), + maxTokens = customMaxTokens + ) + + assertEquals(customMaxTokens, request.maxTokens) + } + + @Test + fun `AnthropicMessageRequest should use default maxTokens when not provided`() { + val request = AnthropicMessageRequest( + model = AnthropicModels.Opus_3.id, + messages = emptyList() + ) + + assertEquals(AnthropicMessageRequest.MAX_TOKENS_DEFAULT, request.maxTokens) + } + + @Test + fun `AnthropicMessageRequest should reject zero maxTokens`() { + val exception = assertFailsWith { + AnthropicMessageRequest( + model = AnthropicModels.Opus_3.id, + messages = emptyList(), + maxTokens = 0 + ) + } + assertEquals("maxTokens must be greater than 0, but was 0", exception.message) + } } diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/ai21/BedrockAI21JambaSerialization.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/ai21/BedrockAI21JambaSerialization.kt index 6e1791d33..a110c0dbb 100644 --- a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/ai21/BedrockAI21JambaSerialization.kt +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/ai21/BedrockAI21JambaSerialization.kt @@ -123,7 +123,7 @@ internal object BedrockAI21JambaSerialization { return JambaRequest( model = model.id, messages = messages, - maxTokens = 4096, + maxTokens = JambaRequest.MAX_TOKENS_DEFAULT, temperature = if (model.capabilities.contains( LLMCapability.Temperature ) diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/ai21/JambaDataModel.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/ai21/JambaDataModel.kt index ae6b1b991..aeb75ded7 100644 --- a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/ai21/JambaDataModel.kt +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/ai21/JambaDataModel.kt @@ -6,107 +6,126 @@ import kotlinx.serialization.json.JsonObject @Serializable internal data class JambaRequest( - public val model: String, - public val messages: List, - @SerialName("max_tokens") - public val maxTokens: Int? = null, - public val temperature: Double? = null, - @SerialName("top_p") - public val topP: Double? = null, - public val stop: List? = null, - public val n: Int? = null, - public val stream: Boolean? = null, - public val tools: List? = null, - @SerialName("response_format") - public val responseFormat: JambaResponseFormat? = null -) + val model: String, + val messages: List, + @SerialName("max_tokens") val maxTokens: Int? = MAX_TOKENS_DEFAULT, + val temperature: Double? = null, + @SerialName("top_p") val topP: Double? = null, + val stop: List? = null, + val n: Int? = null, + val stream: Boolean? = null, + val tools: List? = null, + @SerialName("response_format") val responseFormat: JambaResponseFormat? = null +) { + init { + if (maxTokens != null) { + require(maxTokens > 0) { "maxTokens must be greater than 0, but was $maxTokens" } + } + if (temperature != null) { + require(temperature >= 0) { "temperature must be greater than 0, but was $temperature" } + } + if (topP != null) { + require(topP in 0.0..1.0) { "topP must be between 0 and 1, but was $topP" } + } + } + + /** + * Companion object with default values for request + */ + companion object { + /** + * Default max tokens + */ + const val MAX_TOKENS_DEFAULT: Int = 4096 + } +} @Serializable internal data class JambaMessage( - public val role: String, - public val content: String? = null, + val role: String, + val content: String? = null, @SerialName("tool_calls") - public val toolCalls: List? = null, + val toolCalls: List? = null, @SerialName("tool_call_id") - public val toolCallId: String? = null + val toolCallId: String? = null ) @Serializable internal data class JambaTool( - public val type: String = "function", - public val function: JambaFunction + val type: String = "function", + val function: JambaFunction ) @Serializable internal data class JambaFunction( - public val name: String, - public val description: String, - public val parameters: JsonObject + val name: String, + val description: String, + val parameters: JsonObject ) @Serializable internal data class JambaToolCall( - public val id: String, - public val type: String = "function", - public val function: JambaFunctionCall + val id: String, + val type: String = "function", + val function: JambaFunctionCall ) @Serializable internal data class JambaFunctionCall( - public val name: String, - public val arguments: String + val name: String, + val arguments: String ) @Serializable internal data class JambaResponseFormat( - public val type: String + val type: String ) @Serializable internal data class JambaResponse( - public val id: String, - public val model: String, - public val choices: List, - public val usage: JambaUsage? = null + val id: String, + val model: String, + val choices: List, + val usage: JambaUsage? = null ) @Serializable internal data class JambaChoice( - public val index: Int, - public val message: JambaMessage, + val index: Int, + val message: JambaMessage, @SerialName("finish_reason") - public val finishReason: String? = null + val finishReason: String? = null ) @Serializable internal data class JambaUsage( @SerialName("prompt_tokens") - public val promptTokens: Int, + val promptTokens: Int, @SerialName("completion_tokens") - public val completionTokens: Int, + val completionTokens: Int, @SerialName("total_tokens") - public val totalTokens: Int + val totalTokens: Int ) @Serializable internal data class JambaStreamResponse( - public val id: String, - public val choices: List, - public val usage: JambaUsage? = null + val id: String, + val choices: List, + val usage: JambaUsage? = null ) @Serializable internal data class JambaStreamChoice( - public val index: Int, - public val delta: JambaStreamDelta, + val index: Int, + val delta: JambaStreamDelta, @SerialName("finish_reason") - public val finishReason: String? = null + val finishReason: String? = null ) @Serializable internal data class JambaStreamDelta( - public val role: String? = null, - public val content: String? = null, + val role: String? = null, + val content: String? = null, @SerialName("tool_calls") - public val toolCalls: List? = null + val toolCalls: List? = null ) diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/amazon/BedrockAmazonNovaSerialization.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/amazon/BedrockAmazonNovaSerialization.kt index 6f35d7fe9..47c1bf2be 100644 --- a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/amazon/BedrockAmazonNovaSerialization.kt +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/amazon/BedrockAmazonNovaSerialization.kt @@ -44,7 +44,7 @@ internal object BedrockAmazonNovaSerialization { } val inferenceConfig = NovaInferenceConfig( - maxTokens = 4096, + maxTokens = prompt.params.maxTokens ?: NovaInferenceConfig.MAX_TOKENS_DEFAULT, temperature = if (model.capabilities.contains(LLMCapability.Temperature)) { prompt.params.temperature } else { diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/amazon/NovaDataModels.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/amazon/NovaDataModels.kt index 2f5da19ac..ccd96c5b5 100644 --- a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/amazon/NovaDataModels.kt +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/amazon/NovaDataModels.kt @@ -45,8 +45,18 @@ internal data class NovaInferenceConfig( @SerialName("topK") val topK: Int? = null, @SerialName("maxTokens") - val maxTokens: Int? = null -) + val maxTokens: Int? = MAX_TOKENS_DEFAULT +) { + /** + * Companion object with default values for request + */ + companion object { + /** + * Default max tokens + */ + const val MAX_TOKENS_DEFAULT: Int = 4096 + } +} /** * Response data classes for Amazon Nova models diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/anthropic/BedrockAnthropicClaudeSerialization.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/anthropic/BedrockAnthropicClaudeSerialization.kt index 87e990479..12ce01b03 100644 --- a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/anthropic/BedrockAnthropicClaudeSerialization.kt +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/anthropic/BedrockAnthropicClaudeSerialization.kt @@ -170,7 +170,7 @@ internal object BedrockAnthropicClaudeSerialization { return AnthropicMessageRequest( model = model.id, messages = messages, - maxTokens = 4096, + maxTokens = prompt.params.maxTokens ?: AnthropicMessageRequest.MAX_TOKENS_DEFAULT, temperature = if (model.capabilities.contains( LLMCapability.Temperature ) diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/ai21/BedrockAI21JambaSerializationTest.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/ai21/BedrockAI21JambaSerializationTest.kt index 245ee1b11..636b44bd0 100644 --- a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/ai21/BedrockAI21JambaSerializationTest.kt +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/ai21/BedrockAI21JambaSerializationTest.kt @@ -5,6 +5,7 @@ import ai.koog.agents.core.tools.ToolParameterDescriptor import ai.koog.agents.core.tools.ToolParameterType import ai.koog.prompt.dsl.Prompt import ai.koog.prompt.executor.clients.bedrock.BedrockModels +import ai.koog.prompt.executor.clients.bedrock.modelfamilies.ai21.JambaRequest.Companion.MAX_TOKENS_DEFAULT import ai.koog.prompt.llm.LLMCapability import ai.koog.prompt.llm.LLMProvider import ai.koog.prompt.llm.LLModel @@ -43,7 +44,7 @@ class BedrockAI21JambaSerializationTest { assertNotNull(request) assertEquals(model.id, request.model) - assertEquals(4096, request.maxTokens) + assertEquals(MAX_TOKENS_DEFAULT, request.maxTokens) assertEquals(temperature, request.temperature) assertEquals(2, request.messages.size) @@ -55,6 +56,22 @@ class BedrockAI21JambaSerializationTest { assertEquals(userMessage, request.messages[1].content) } + @Test + fun `createJambaRequest with custom maxTokens`() { + val maxTokens = 1000 + + val prompt = Prompt.build("test", params = LLMParams(maxTokens = maxTokens)) { + system(systemMessage) + user(userMessage) + } + + val request = BedrockAI21JambaSerialization.createJambaRequest(prompt, model, emptyList()) + + assertNotNull(request) + assertEquals(model.id, request.model) + assertEquals(MAX_TOKENS_DEFAULT, request.maxTokens) + } + @Test fun `createJambaRequest with conversation history`() { val userNewMessage = "Hello, who are you?" diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/ai21/JambaDataModelsTest.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/ai21/JambaDataModelsTest.kt index ea61cfaf4..3fef7c697 100644 --- a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/ai21/JambaDataModelsTest.kt +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/ai21/JambaDataModelsTest.kt @@ -1,10 +1,14 @@ package ai.koog.prompt.executor.clients.bedrock.modelfamilies.ai21 +import ai.koog.prompt.executor.clients.anthropic.AnthropicMessageRequest +import ai.koog.prompt.executor.clients.anthropic.AnthropicModels +import ai.koog.prompt.executor.clients.bedrock.modelfamilies.ai21.JambaRequest.Companion.MAX_TOKENS_DEFAULT import kotlinx.serialization.json.Json import kotlinx.serialization.json.buildJsonObject import kotlinx.serialization.json.put import kotlin.test.Test import kotlin.test.assertEquals +import kotlin.test.assertFailsWith import kotlin.test.assertNotNull import kotlin.test.assertNull @@ -51,6 +55,30 @@ class JambaDataModelsTest { assert(serialized.contains("0.7")) { "Serialized JSON should contain the temperature value: $serialized" } } + @Test + fun `JambaRequest serialization with default maxTokens`() { + val request = JambaRequest( + model = "ai21.jamba-1-5-large-v1:0", + messages = listOf( + JambaMessage(role = "user", content = "Tell me about Paris") + ), + temperature = 0.7 + ) + assertEquals(MAX_TOKENS_DEFAULT, request.maxTokens) + } + + @Test + fun `JambaRequest serialization with maxTokens less than 1`() { + val exception = assertFailsWith { + AnthropicMessageRequest( + model = AnthropicModels.Opus_3.id, + messages = emptyList(), + maxTokens = 0 + ) + } + assertEquals("maxTokens must be greater than 0, but was 0", exception.message) + } + @Test fun `JambaRequest serialization with null fields`() { val request = JambaRequest( @@ -121,7 +149,7 @@ class JambaDataModelsTest { assertEquals(1, request.messages.size) assertEquals("user", request.messages[0].role) assertEquals("Tell me about Paris", request.messages[0].content) - assertNull(request.maxTokens) + assertEquals(MAX_TOKENS_DEFAULT, request.maxTokens) assertNull(request.temperature) } diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/amazon/BedrockAmazonNovaSerializationTest.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/amazon/BedrockAmazonNovaSerializationTest.kt index 1091f929f..c3d2dba42 100644 --- a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/amazon/BedrockAmazonNovaSerializationTest.kt +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/amazon/BedrockAmazonNovaSerializationTest.kt @@ -2,6 +2,7 @@ package ai.koog.prompt.executor.clients.bedrock.modelfamilies.amazon import ai.koog.prompt.dsl.Prompt import ai.koog.prompt.executor.clients.bedrock.BedrockModels +import ai.koog.prompt.executor.clients.bedrock.modelfamilies.amazon.NovaInferenceConfig.Companion.MAX_TOKENS_DEFAULT import ai.koog.prompt.llm.LLMCapability import ai.koog.prompt.llm.LLMProvider import ai.koog.prompt.llm.LLModel @@ -50,10 +51,23 @@ class BedrockAmazonNovaSerializationTest { assertEquals(userMessage, request.messages[0].content[0].text) assertNotNull(request.inferenceConfig) - assertEquals(4096, request.inferenceConfig.maxTokens) + assertEquals(MAX_TOKENS_DEFAULT, request.inferenceConfig.maxTokens) assertEquals(temperature, request.inferenceConfig.temperature) } + @Test + fun `createNovaRequest with default maxTokens`() { + val maxTokens = 1000 + + val prompt = Prompt.build("test", params = LLMParams(maxTokens = maxTokens)) { + system(systemMessage) + user(userMessage) + } + + val request = BedrockAmazonNovaSerialization.createNovaRequest(prompt, model) + assertEquals(maxTokens, request.inferenceConfig!!.maxTokens) + } + @Test fun `createNovaRequest with conversation history`() { val prompt = Prompt.build("test") { diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/anthropic/BedrockAnthropicClaudeSerializationTest.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/anthropic/BedrockAnthropicClaudeSerializationTest.kt index c17af8912..2d6a059bd 100644 --- a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/anthropic/BedrockAnthropicClaudeSerializationTest.kt +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/anthropic/BedrockAnthropicClaudeSerializationTest.kt @@ -5,6 +5,7 @@ import ai.koog.agents.core.tools.ToolParameterDescriptor import ai.koog.agents.core.tools.ToolParameterType import ai.koog.prompt.dsl.Prompt import ai.koog.prompt.executor.clients.anthropic.AnthropicContent +import ai.koog.prompt.executor.clients.anthropic.AnthropicMessageRequest import ai.koog.prompt.executor.clients.anthropic.AnthropicToolChoice import ai.koog.prompt.executor.clients.bedrock.BedrockModels import ai.koog.prompt.message.Message @@ -47,7 +48,7 @@ class BedrockAnthropicClaudeSerializationTest { assertNotNull(request) assertEquals(model.id, request.model) - assertEquals(4096, request.maxTokens) + assertEquals(AnthropicMessageRequest.MAX_TOKENS_DEFAULT, request.maxTokens) assertEquals(temperature, request.temperature) assertNotNull(request.system) diff --git a/prompt/prompt-model/src/commonMain/kotlin/ai/koog/prompt/params/LLMParams.kt b/prompt/prompt-model/src/commonMain/kotlin/ai/koog/prompt/params/LLMParams.kt index 830b1f234..cfa0dca3d 100644 --- a/prompt/prompt-model/src/commonMain/kotlin/ai/koog/prompt/params/LLMParams.kt +++ b/prompt/prompt-model/src/commonMain/kotlin/ai/koog/prompt/params/LLMParams.kt @@ -36,6 +36,7 @@ import kotlinx.serialization.json.JsonObject @Serializable public data class LLMParams( val temperature: Double? = null, + val maxTokens: Int? = null, val numberOfChoices: Int? = null, val speculation: String? = null, val schema: Schema? = null, @@ -74,6 +75,7 @@ public data class LLMParams( */ public fun default(default: LLMParams): LLMParams = copy( temperature = temperature ?: default.temperature, + maxTokens = maxTokens ?: default.maxTokens, numberOfChoices = numberOfChoices ?: default.numberOfChoices, speculation = speculation ?: default.speculation, schema = schema ?: default.schema,