Skip to content

Add maxTokens as prompt parameters #579

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@
model = settings.modelVersionsMap[model]
?: throw IllegalArgumentException("Unsupported model: $model"),
messages = messages,
maxTokens = 2048, // This is required by the API
maxTokens = prompt.params.maxTokens ?: AnthropicMessageRequest.MAX_TOKENS_DEFAULT,
temperature = prompt.params.temperature,
system = systemMessage,
tools = if (tools.isNotEmpty()) anthropicTools else emptyList(), // Always provide a list for tools
Expand Down Expand Up @@ -467,7 +467,7 @@
* @return This method does not return a value as it always throws an exception.
* @throws UnsupportedOperationException Always thrown, as moderation is not supported by the Anthropic API.
*/
public override suspend fun moderate(prompt: Prompt, model: LLModel): ModerationResult {

Check warning on line 470 in prompt/prompt-executor/prompt-executor-clients/prompt-executor-anthropic-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/anthropic/AnthropicLLMClient.kt

View workflow job for this annotation

GitHub Actions / Qodana for JVM

Check Kotlin and Java source code coverage

Method `moderate` coverage is below the threshold 50%
logger.warn { "Moderation is not supported by Anthropic API" }
throw UnsupportedOperationException("Moderation is not supported by Anthropic API.")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,30 @@ import kotlinx.serialization.json.JsonObject
public data class AnthropicMessageRequest(
val model: String,
val messages: List<AnthropicMessage>,
val maxTokens: Int = 2048,
val maxTokens: Int = MAX_TOKENS_DEFAULT,
val temperature: Double? = null,
val system: List<SystemAnthropicMessage>? = null,
val tools: List<AnthropicTool>? = null,
val stream: Boolean = false,
val toolChoice: AnthropicToolChoice? = null,
)
) {
init {
require(maxTokens > 0) { "maxTokens must be greater than 0, but was $maxTokens" }
if (temperature != null) {
require(temperature >= 0) { "temperature must be greater than 0, but was $temperature" }
}
}

/**
* Companion object with default values for request
*/
public companion object {
/**
* Default max tokens
*/
public const val MAX_TOKENS_DEFAULT: Int = 2048
}
}

/**
* Represents a message within the Anthropic LLM system. This data class encapsulates
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package ai.koog.prompt.executor.clients.anthropic
import ai.koog.prompt.executor.clients.list
import ai.koog.prompt.llm.LLMProvider
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
import kotlin.test.assertSame

class AnthropicModelsTest {
Expand All @@ -19,4 +21,38 @@ class AnthropicModelsTest {
)
}
}

@Test
fun `AnthropicMessageRequest should use custom maxTokens when provided`() {
val customMaxTokens = 4000
val request = AnthropicMessageRequest(
model = AnthropicModels.Opus_3.id,
messages = emptyList(),
maxTokens = customMaxTokens
)

assertEquals(customMaxTokens, request.maxTokens)
}

@Test
fun `AnthropicMessageRequest should use default maxTokens when not provided`() {
val request = AnthropicMessageRequest(
model = AnthropicModels.Opus_3.id,
messages = emptyList()
)

assertEquals(AnthropicMessageRequest.MAX_TOKENS_DEFAULT, request.maxTokens)
}

@Test
fun `AnthropicMessageRequest should reject zero maxTokens`() {
val exception = assertFailsWith<IllegalArgumentException> {
AnthropicMessageRequest(
model = AnthropicModels.Opus_3.id,
messages = emptyList(),
maxTokens = 0
)
}
assertEquals("maxTokens must be greater than 0, but was 0", exception.message)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@
return JambaRequest(
model = model.id,
messages = messages,
maxTokens = 4096,
maxTokens = JambaRequest.MAX_TOKENS_DEFAULT,
temperature = if (model.capabilities.contains(
LLMCapability.Temperature
)
Expand All @@ -137,7 +137,7 @@
}

@OptIn(ExperimentalUuidApi::class)
internal fun parseJambaResponse(responseBody: String, clock: Clock = Clock.System): List<Message.Response> {

Check warning on line 140 in prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/ai21/BedrockAI21JambaSerialization.kt

View workflow job for this annotation

GitHub Actions / Qodana for JVM

Check Kotlin and Java source code coverage

Method `parseJambaResponse$prompt_executor_bedrock_client` coverage is below the threshold 50%
val response = json.decodeFromString<JambaResponse>(responseBody)

val inputTokens = response.usage?.promptTokens
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,107 +6,126 @@

@Serializable
internal data class JambaRequest(
public val model: String,
public val messages: List<JambaMessage>,
@SerialName("max_tokens")
public val maxTokens: Int? = null,
public val temperature: Double? = null,
@SerialName("top_p")
public val topP: Double? = null,
public val stop: List<String>? = null,
public val n: Int? = null,
public val stream: Boolean? = null,
public val tools: List<JambaTool>? = null,
@SerialName("response_format")
public val responseFormat: JambaResponseFormat? = null
)
val model: String,
val messages: List<JambaMessage>,
@SerialName("max_tokens") val maxTokens: Int? = MAX_TOKENS_DEFAULT,
val temperature: Double? = null,
@SerialName("top_p") val topP: Double? = null,
val stop: List<String>? = null,
val n: Int? = null,
val stream: Boolean? = null,
val tools: List<JambaTool>? = null,
@SerialName("response_format") val responseFormat: JambaResponseFormat? = null
) {
init {
if (maxTokens != null) {
require(maxTokens > 0) { "maxTokens must be greater than 0, but was $maxTokens" }
}
if (temperature != null) {
require(temperature >= 0) { "temperature must be greater than 0, but was $temperature" }
}
if (topP != null) {
require(topP in 0.0..1.0) { "topP must be between 0 and 1, but was $topP" }
}
}

/**
* Companion object with default values for request
*/
companion object {
/**
* Default max tokens
*/
const val MAX_TOKENS_DEFAULT: Int = 4096
}
}

@Serializable
internal data class JambaMessage(
public val role: String,
public val content: String? = null,
val role: String,
val content: String? = null,
@SerialName("tool_calls")
public val toolCalls: List<JambaToolCall>? = null,
val toolCalls: List<JambaToolCall>? = null,
@SerialName("tool_call_id")
public val toolCallId: String? = null
val toolCallId: String? = null
)

@Serializable
internal data class JambaTool(
public val type: String = "function",
public val function: JambaFunction
val type: String = "function",
val function: JambaFunction
)

@Serializable
internal data class JambaFunction(
public val name: String,
public val description: String,
public val parameters: JsonObject
val name: String,
val description: String,
val parameters: JsonObject
)

@Serializable
internal data class JambaToolCall(
public val id: String,
public val type: String = "function",
public val function: JambaFunctionCall
val id: String,
val type: String = "function",
val function: JambaFunctionCall
)

@Serializable
internal data class JambaFunctionCall(
public val name: String,
public val arguments: String
val name: String,
val arguments: String
)

@Serializable
internal data class JambaResponseFormat(

Check warning on line 80 in prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/ai21/JambaDataModel.kt

View workflow job for this annotation

GitHub Actions / Qodana for JVM

Check Kotlin and Java source code coverage

Constructor `JambaResponseFormat` coverage is below the threshold 50%

Check warning on line 80 in prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/ai21/JambaDataModel.kt

View workflow job for this annotation

GitHub Actions / Qodana for JVM

Check Kotlin and Java source code coverage

Class `JambaResponseFormat` coverage is below the threshold 50%
public val type: String
val type: String

Check warning on line 81 in prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/ai21/JambaDataModel.kt

View workflow job for this annotation

GitHub Actions / Qodana for JVM

Check Kotlin and Java source code coverage

Method `getType` coverage is below the threshold 50%
)

@Serializable
internal data class JambaResponse(
public val id: String,
public val model: String,
public val choices: List<JambaChoice>,
public val usage: JambaUsage? = null
val id: String,
val model: String,
val choices: List<JambaChoice>,
val usage: JambaUsage? = null
)

@Serializable
internal data class JambaChoice(
public val index: Int,
public val message: JambaMessage,
val index: Int,
val message: JambaMessage,
@SerialName("finish_reason")
public val finishReason: String? = null
val finishReason: String? = null
)

@Serializable
internal data class JambaUsage(
@SerialName("prompt_tokens")
public val promptTokens: Int,
val promptTokens: Int,
@SerialName("completion_tokens")
public val completionTokens: Int,
val completionTokens: Int,
@SerialName("total_tokens")
public val totalTokens: Int
val totalTokens: Int
)

@Serializable
internal data class JambaStreamResponse(
public val id: String,
public val choices: List<JambaStreamChoice>,
public val usage: JambaUsage? = null
val id: String,
val choices: List<JambaStreamChoice>,
val usage: JambaUsage? = null
)

@Serializable
internal data class JambaStreamChoice(
public val index: Int,
public val delta: JambaStreamDelta,
val index: Int,
val delta: JambaStreamDelta,
@SerialName("finish_reason")
public val finishReason: String? = null
val finishReason: String? = null
)

@Serializable
internal data class JambaStreamDelta(
public val role: String? = null,
public val content: String? = null,
val role: String? = null,
val content: String? = null,
@SerialName("tool_calls")
public val toolCalls: List<JambaToolCall>? = null
val toolCalls: List<JambaToolCall>? = null
)
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
}

val inferenceConfig = NovaInferenceConfig(
maxTokens = 4096,
maxTokens = prompt.params.maxTokens ?: NovaInferenceConfig.MAX_TOKENS_DEFAULT,
temperature = if (model.capabilities.contains(LLMCapability.Temperature)) {
prompt.params.temperature
} else {
Expand All @@ -59,7 +59,7 @@
)
}

internal fun parseNovaResponse(responseBody: String, clock: Clock = Clock.System): List<Message.Response> {

Check warning on line 62 in prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/amazon/BedrockAmazonNovaSerialization.kt

View workflow job for this annotation

GitHub Actions / Qodana for JVM

Check Kotlin and Java source code coverage

Method `parseNovaResponse$prompt_executor_bedrock_client` coverage is below the threshold 50%
val response = json.decodeFromString<NovaResponse>(responseBody)
val messageContent = response.output.message.content.firstOrNull()?.text ?: ""
val outputTokens = response.usage?.outputTokens
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,18 @@ internal data class NovaInferenceConfig(
@SerialName("topK")
val topK: Int? = null,
@SerialName("maxTokens")
val maxTokens: Int? = null
)
val maxTokens: Int? = MAX_TOKENS_DEFAULT
) {
/**
* Companion object with default values for request
*/
companion object {
/**
* Default max tokens
*/
const val MAX_TOKENS_DEFAULT: Int = 4096
}
}

/**
* Response data classes for Amazon Nova models
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@
return AnthropicMessageRequest(
model = model.id,
messages = messages,
maxTokens = 4096,
maxTokens = prompt.params.maxTokens ?: AnthropicMessageRequest.MAX_TOKENS_DEFAULT,
temperature = if (model.capabilities.contains(
LLMCapability.Temperature
)
Expand All @@ -186,7 +186,7 @@
}

@OptIn(ExperimentalUuidApi::class)
internal fun parseAnthropicResponse(responseBody: String, clock: Clock = Clock.System): List<Message.Response> {

Check warning on line 189 in prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/anthropic/BedrockAnthropicClaudeSerialization.kt

View workflow job for this annotation

GitHub Actions / Qodana for JVM

Check Kotlin and Java source code coverage

Method `parseAnthropicResponse$prompt_executor_bedrock_client` coverage is below the threshold 50%
val response = json.decodeFromString<AnthropicResponse>(responseBody)

val inputTokens = response.usage?.inputTokens
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import ai.koog.agents.core.tools.ToolParameterDescriptor
import ai.koog.agents.core.tools.ToolParameterType
import ai.koog.prompt.dsl.Prompt
import ai.koog.prompt.executor.clients.bedrock.BedrockModels
import ai.koog.prompt.executor.clients.bedrock.modelfamilies.ai21.JambaRequest.Companion.MAX_TOKENS_DEFAULT
import ai.koog.prompt.llm.LLMCapability
import ai.koog.prompt.llm.LLMProvider
import ai.koog.prompt.llm.LLModel
Expand Down Expand Up @@ -43,7 +44,7 @@ class BedrockAI21JambaSerializationTest {

assertNotNull(request)
assertEquals(model.id, request.model)
assertEquals(4096, request.maxTokens)
assertEquals(MAX_TOKENS_DEFAULT, request.maxTokens)
assertEquals(temperature, request.temperature)

assertEquals(2, request.messages.size)
Expand All @@ -55,6 +56,22 @@ class BedrockAI21JambaSerializationTest {
assertEquals(userMessage, request.messages[1].content)
}

@Test
fun `createJambaRequest with custom maxTokens`() {
val maxTokens = 1000

val prompt = Prompt.build("test", params = LLMParams(maxTokens = maxTokens)) {
system(systemMessage)
user(userMessage)
}

val request = BedrockAI21JambaSerialization.createJambaRequest(prompt, model, emptyList())

assertNotNull(request)
assertEquals(model.id, request.model)
assertEquals(MAX_TOKENS_DEFAULT, request.maxTokens)
}

@Test
fun `createJambaRequest with conversation history`() {
val userNewMessage = "Hello, who are you?"
Expand Down
Loading
Loading