Skip to content

Commit 38a8424

Browse files
Add maxTokens as prompt parameters (#579)
Co-authored-by: Anastasiia.Zarechneva <[email protected]>
1 parent 0c92e85 commit 38a8424

File tree

13 files changed

+204
-60
lines changed

13 files changed

+204
-60
lines changed

prompt/prompt-executor/prompt-executor-clients/prompt-executor-anthropic-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/anthropic/AnthropicLLMClient.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ public open class AnthropicLLMClient(
282282
model = settings.modelVersionsMap[model]
283283
?: throw IllegalArgumentException("Unsupported model: $model"),
284284
messages = messages,
285-
maxTokens = 2048, // This is required by the API
285+
maxTokens = prompt.params.maxTokens ?: AnthropicMessageRequest.MAX_TOKENS_DEFAULT,
286286
temperature = prompt.params.temperature,
287287
system = systemMessage,
288288
tools = if (tools.isNotEmpty()) anthropicTools else emptyList(), // Always provide a list for tools

prompt/prompt-executor/prompt-executor-clients/prompt-executor-anthropic-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/anthropic/DataModel.kt

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,30 @@ import kotlinx.serialization.json.JsonObject
2727
public data class AnthropicMessageRequest(
2828
val model: String,
2929
val messages: List<AnthropicMessage>,
30-
val maxTokens: Int = 2048,
30+
val maxTokens: Int = MAX_TOKENS_DEFAULT,
3131
val temperature: Double? = null,
3232
val system: List<SystemAnthropicMessage>? = null,
3333
val tools: List<AnthropicTool>? = null,
3434
val stream: Boolean = false,
3535
val toolChoice: AnthropicToolChoice? = null,
36-
)
36+
) {
37+
init {
38+
require(maxTokens > 0) { "maxTokens must be greater than 0, but was $maxTokens" }
39+
if (temperature != null) {
40+
require(temperature >= 0) { "temperature must be greater than 0, but was $temperature" }
41+
}
42+
}
43+
44+
/**
45+
* Companion object with default values for request
46+
*/
47+
public companion object {
48+
/**
49+
* Default max tokens
50+
*/
51+
public const val MAX_TOKENS_DEFAULT: Int = 2048
52+
}
53+
}
3754

3855
/**
3956
* Represents a message within the Anthropic LLM system. This data class encapsulates

prompt/prompt-executor/prompt-executor-clients/prompt-executor-anthropic-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/anthropic/AnthropicModelsTest.kt

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ package ai.koog.prompt.executor.clients.anthropic
33
import ai.koog.prompt.executor.clients.list
44
import ai.koog.prompt.llm.LLMProvider
55
import kotlin.test.Test
6+
import kotlin.test.assertEquals
7+
import kotlin.test.assertFailsWith
68
import kotlin.test.assertSame
79

810
class AnthropicModelsTest {
@@ -19,4 +21,38 @@ class AnthropicModelsTest {
1921
)
2022
}
2123
}
24+
25+
@Test
26+
fun `AnthropicMessageRequest should use custom maxTokens when provided`() {
27+
val customMaxTokens = 4000
28+
val request = AnthropicMessageRequest(
29+
model = AnthropicModels.Opus_3.id,
30+
messages = emptyList(),
31+
maxTokens = customMaxTokens
32+
)
33+
34+
assertEquals(customMaxTokens, request.maxTokens)
35+
}
36+
37+
@Test
38+
fun `AnthropicMessageRequest should use default maxTokens when not provided`() {
39+
val request = AnthropicMessageRequest(
40+
model = AnthropicModels.Opus_3.id,
41+
messages = emptyList()
42+
)
43+
44+
assertEquals(AnthropicMessageRequest.MAX_TOKENS_DEFAULT, request.maxTokens)
45+
}
46+
47+
@Test
48+
fun `AnthropicMessageRequest should reject zero maxTokens`() {
49+
val exception = assertFailsWith<IllegalArgumentException> {
50+
AnthropicMessageRequest(
51+
model = AnthropicModels.Opus_3.id,
52+
messages = emptyList(),
53+
maxTokens = 0
54+
)
55+
}
56+
assertEquals("maxTokens must be greater than 0, but was 0", exception.message)
57+
}
2258
}

prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/ai21/BedrockAI21JambaSerialization.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ internal object BedrockAI21JambaSerialization {
123123
return JambaRequest(
124124
model = model.id,
125125
messages = messages,
126-
maxTokens = 4096,
126+
maxTokens = JambaRequest.MAX_TOKENS_DEFAULT,
127127
temperature = if (model.capabilities.contains(
128128
LLMCapability.Temperature
129129
)

prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/ai21/JambaDataModel.kt

Lines changed: 67 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -6,107 +6,126 @@ import kotlinx.serialization.json.JsonObject
66

77
@Serializable
88
internal data class JambaRequest(
9-
public val model: String,
10-
public val messages: List<JambaMessage>,
11-
@SerialName("max_tokens")
12-
public val maxTokens: Int? = null,
13-
public val temperature: Double? = null,
14-
@SerialName("top_p")
15-
public val topP: Double? = null,
16-
public val stop: List<String>? = null,
17-
public val n: Int? = null,
18-
public val stream: Boolean? = null,
19-
public val tools: List<JambaTool>? = null,
20-
@SerialName("response_format")
21-
public val responseFormat: JambaResponseFormat? = null
22-
)
9+
val model: String,
10+
val messages: List<JambaMessage>,
11+
@SerialName("max_tokens") val maxTokens: Int? = MAX_TOKENS_DEFAULT,
12+
val temperature: Double? = null,
13+
@SerialName("top_p") val topP: Double? = null,
14+
val stop: List<String>? = null,
15+
val n: Int? = null,
16+
val stream: Boolean? = null,
17+
val tools: List<JambaTool>? = null,
18+
@SerialName("response_format") val responseFormat: JambaResponseFormat? = null
19+
) {
20+
init {
21+
if (maxTokens != null) {
22+
require(maxTokens > 0) { "maxTokens must be greater than 0, but was $maxTokens" }
23+
}
24+
if (temperature != null) {
25+
require(temperature >= 0) { "temperature must be greater than 0, but was $temperature" }
26+
}
27+
if (topP != null) {
28+
require(topP in 0.0..1.0) { "topP must be between 0 and 1, but was $topP" }
29+
}
30+
}
31+
32+
/**
33+
* Companion object with default values for request
34+
*/
35+
companion object {
36+
/**
37+
* Default max tokens
38+
*/
39+
const val MAX_TOKENS_DEFAULT: Int = 4096
40+
}
41+
}
2342

2443
@Serializable
2544
internal data class JambaMessage(
26-
public val role: String,
27-
public val content: String? = null,
45+
val role: String,
46+
val content: String? = null,
2847
@SerialName("tool_calls")
29-
public val toolCalls: List<JambaToolCall>? = null,
48+
val toolCalls: List<JambaToolCall>? = null,
3049
@SerialName("tool_call_id")
31-
public val toolCallId: String? = null
50+
val toolCallId: String? = null
3251
)
3352

3453
@Serializable
3554
internal data class JambaTool(
36-
public val type: String = "function",
37-
public val function: JambaFunction
55+
val type: String = "function",
56+
val function: JambaFunction
3857
)
3958

4059
@Serializable
4160
internal data class JambaFunction(
42-
public val name: String,
43-
public val description: String,
44-
public val parameters: JsonObject
61+
val name: String,
62+
val description: String,
63+
val parameters: JsonObject
4564
)
4665

4766
@Serializable
4867
internal data class JambaToolCall(
49-
public val id: String,
50-
public val type: String = "function",
51-
public val function: JambaFunctionCall
68+
val id: String,
69+
val type: String = "function",
70+
val function: JambaFunctionCall
5271
)
5372

5473
@Serializable
5574
internal data class JambaFunctionCall(
56-
public val name: String,
57-
public val arguments: String
75+
val name: String,
76+
val arguments: String
5877
)
5978

6079
@Serializable
6180
internal data class JambaResponseFormat(
62-
public val type: String
81+
val type: String
6382
)
6483

6584
@Serializable
6685
internal data class JambaResponse(
67-
public val id: String,
68-
public val model: String,
69-
public val choices: List<JambaChoice>,
70-
public val usage: JambaUsage? = null
86+
val id: String,
87+
val model: String,
88+
val choices: List<JambaChoice>,
89+
val usage: JambaUsage? = null
7190
)
7291

7392
@Serializable
7493
internal data class JambaChoice(
75-
public val index: Int,
76-
public val message: JambaMessage,
94+
val index: Int,
95+
val message: JambaMessage,
7796
@SerialName("finish_reason")
78-
public val finishReason: String? = null
97+
val finishReason: String? = null
7998
)
8099

81100
@Serializable
82101
internal data class JambaUsage(
83102
@SerialName("prompt_tokens")
84-
public val promptTokens: Int,
103+
val promptTokens: Int,
85104
@SerialName("completion_tokens")
86-
public val completionTokens: Int,
105+
val completionTokens: Int,
87106
@SerialName("total_tokens")
88-
public val totalTokens: Int
107+
val totalTokens: Int
89108
)
90109

91110
@Serializable
92111
internal data class JambaStreamResponse(
93-
public val id: String,
94-
public val choices: List<JambaStreamChoice>,
95-
public val usage: JambaUsage? = null
112+
val id: String,
113+
val choices: List<JambaStreamChoice>,
114+
val usage: JambaUsage? = null
96115
)
97116

98117
@Serializable
99118
internal data class JambaStreamChoice(
100-
public val index: Int,
101-
public val delta: JambaStreamDelta,
119+
val index: Int,
120+
val delta: JambaStreamDelta,
102121
@SerialName("finish_reason")
103-
public val finishReason: String? = null
122+
val finishReason: String? = null
104123
)
105124

106125
@Serializable
107126
internal data class JambaStreamDelta(
108-
public val role: String? = null,
109-
public val content: String? = null,
127+
val role: String? = null,
128+
val content: String? = null,
110129
@SerialName("tool_calls")
111-
public val toolCalls: List<JambaToolCall>? = null
130+
val toolCalls: List<JambaToolCall>? = null
112131
)

prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/amazon/BedrockAmazonNovaSerialization.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ internal object BedrockAmazonNovaSerialization {
4444
}
4545

4646
val inferenceConfig = NovaInferenceConfig(
47-
maxTokens = 4096,
47+
maxTokens = prompt.params.maxTokens ?: NovaInferenceConfig.MAX_TOKENS_DEFAULT,
4848
temperature = if (model.capabilities.contains(LLMCapability.Temperature)) {
4949
prompt.params.temperature
5050
} else {

prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/amazon/NovaDataModels.kt

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,18 @@ internal data class NovaInferenceConfig(
4545
@SerialName("topK")
4646
val topK: Int? = null,
4747
@SerialName("maxTokens")
48-
val maxTokens: Int? = null
49-
)
48+
val maxTokens: Int? = MAX_TOKENS_DEFAULT
49+
) {
50+
/**
51+
* Companion object with default values for request
52+
*/
53+
companion object {
54+
/**
55+
* Default max tokens
56+
*/
57+
const val MAX_TOKENS_DEFAULT: Int = 4096
58+
}
59+
}
5060

5161
/**
5262
* Response data classes for Amazon Nova models

prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/anthropic/BedrockAnthropicClaudeSerialization.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ internal object BedrockAnthropicClaudeSerialization {
170170
return AnthropicMessageRequest(
171171
model = model.id,
172172
messages = messages,
173-
maxTokens = 4096,
173+
maxTokens = prompt.params.maxTokens ?: AnthropicMessageRequest.MAX_TOKENS_DEFAULT,
174174
temperature = if (model.capabilities.contains(
175175
LLMCapability.Temperature
176176
)

prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/ai21/BedrockAI21JambaSerializationTest.kt

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import ai.koog.agents.core.tools.ToolParameterDescriptor
55
import ai.koog.agents.core.tools.ToolParameterType
66
import ai.koog.prompt.dsl.Prompt
77
import ai.koog.prompt.executor.clients.bedrock.BedrockModels
8+
import ai.koog.prompt.executor.clients.bedrock.modelfamilies.ai21.JambaRequest.Companion.MAX_TOKENS_DEFAULT
89
import ai.koog.prompt.llm.LLMCapability
910
import ai.koog.prompt.llm.LLMProvider
1011
import ai.koog.prompt.llm.LLModel
@@ -43,7 +44,7 @@ class BedrockAI21JambaSerializationTest {
4344

4445
assertNotNull(request)
4546
assertEquals(model.id, request.model)
46-
assertEquals(4096, request.maxTokens)
47+
assertEquals(MAX_TOKENS_DEFAULT, request.maxTokens)
4748
assertEquals(temperature, request.temperature)
4849

4950
assertEquals(2, request.messages.size)
@@ -55,6 +56,22 @@ class BedrockAI21JambaSerializationTest {
5556
assertEquals(userMessage, request.messages[1].content)
5657
}
5758

59+
@Test
60+
fun `createJambaRequest with custom maxTokens`() {
61+
val maxTokens = 1000
62+
63+
val prompt = Prompt.build("test", params = LLMParams(maxTokens = maxTokens)) {
64+
system(systemMessage)
65+
user(userMessage)
66+
}
67+
68+
val request = BedrockAI21JambaSerialization.createJambaRequest(prompt, model, emptyList())
69+
70+
assertNotNull(request)
71+
assertEquals(model.id, request.model)
72+
assertEquals(MAX_TOKENS_DEFAULT, request.maxTokens)
73+
}
74+
5875
@Test
5976
fun `createJambaRequest with conversation history`() {
6077
val userNewMessage = "Hello, who are you?"

0 commit comments

Comments
 (0)