Skip to content

Commit fa071bd

Browse files
committed
Add validation for maxTokens and update tests
1 parent 34cb5ea commit fa071bd

File tree

7 files changed

+163
-51
lines changed

7 files changed

+163
-51
lines changed

prompt/prompt-executor/prompt-executor-clients/prompt-executor-anthropic-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/anthropic/DataModel.kt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,13 @@ public data class AnthropicMessageRequest(
3434
val stream: Boolean = false,
3535
val toolChoice: AnthropicToolChoice? = null,
3636
) {
37+
init {
38+
require(maxTokens > 0) { "maxTokens must be greater than 0, but was $maxTokens" }
39+
if (temperature != null) {
40+
require(temperature >= 0) { "temperature must be greater than 0, but was $temperature" }
41+
}
42+
}
43+
3744
/**
3845
* Companion object with default values for request
3946
*/

prompt/prompt-executor/prompt-executor-clients/prompt-executor-anthropic-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/anthropic/AnthropicModelsTest.kt

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ package ai.koog.prompt.executor.clients.anthropic
33
import ai.koog.prompt.executor.clients.list
44
import ai.koog.prompt.llm.LLMProvider
55
import kotlin.test.Test
6+
import kotlin.test.assertEquals
7+
import kotlin.test.assertFailsWith
68
import kotlin.test.assertSame
79

810
class AnthropicModelsTest {
@@ -19,4 +21,38 @@ class AnthropicModelsTest {
1921
)
2022
}
2123
}
24+
25+
@Test
26+
fun `AnthropicMessageRequest should use custom maxTokens when provided`() {
27+
val customMaxTokens = 4000
28+
val request = AnthropicMessageRequest(
29+
model = AnthropicModels.Opus_3.id,
30+
messages = emptyList(),
31+
maxTokens = customMaxTokens
32+
)
33+
34+
assertEquals(customMaxTokens, request.maxTokens)
35+
}
36+
37+
@Test
38+
fun `AnthropicMessageRequest should use default maxTokens when not provided`() {
39+
val request = AnthropicMessageRequest(
40+
model = AnthropicModels.Opus_3.id,
41+
messages = emptyList()
42+
)
43+
44+
assertEquals(AnthropicMessageRequest.MAX_TOKENS_DEFAULT, request.maxTokens)
45+
}
46+
47+
@Test
48+
fun `AnthropicMessageRequest should reject zero maxTokens`() {
49+
val exception = assertFailsWith<IllegalArgumentException> {
50+
AnthropicMessageRequest(
51+
model = AnthropicModels.Opus_3.id,
52+
messages = emptyList(),
53+
maxTokens = 0
54+
)
55+
}
56+
assertEquals("maxTokens must be greater than 0, but was 0", exception.message)
57+
}
2258
}

prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/ai21/JambaDataModel.kt

Lines changed: 56 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,29 @@ import kotlinx.serialization.json.JsonObject
66

77
@Serializable
88
internal data class JambaRequest(
9-
public val model: String,
10-
public val messages: List<JambaMessage>,
11-
@SerialName("max_tokens")
12-
public val maxTokens: Int? = MAX_TOKENS_DEFAULT,
13-
public val temperature: Double? = null,
14-
@SerialName("top_p")
15-
public val topP: Double? = null,
16-
public val stop: List<String>? = null,
17-
public val n: Int? = null,
18-
public val stream: Boolean? = null,
19-
public val tools: List<JambaTool>? = null,
20-
@SerialName("response_format")
21-
public val responseFormat: JambaResponseFormat? = null
9+
val model: String,
10+
val messages: List<JambaMessage>,
11+
@SerialName("max_tokens") val maxTokens: Int? = MAX_TOKENS_DEFAULT,
12+
val temperature: Double? = null,
13+
@SerialName("top_p") val topP: Double? = null,
14+
val stop: List<String>? = null,
15+
val n: Int? = null,
16+
val stream: Boolean? = null,
17+
val tools: List<JambaTool>? = null,
18+
@SerialName("response_format") val responseFormat: JambaResponseFormat? = null
2219
) {
20+
init {
21+
if (maxTokens != null) {
22+
require(maxTokens > 0) { "maxTokens must be greater than 0, but was $maxTokens" }
23+
}
24+
if (temperature != null) {
25+
require(temperature >= 0) { "temperature must be greater than 0, but was $temperature" }
26+
}
27+
if (topP != null) {
28+
require(topP in 0.0..1.0) { "topP must be between 0 and 1, but was $topP" }
29+
}
30+
}
31+
2332
/**
2433
* Companion object with default values for request
2534
*/
@@ -33,90 +42,90 @@ internal data class JambaRequest(
3342

3443
@Serializable
3544
internal data class JambaMessage(
36-
public val role: String,
37-
public val content: String? = null,
45+
val role: String,
46+
val content: String? = null,
3847
@SerialName("tool_calls")
39-
public val toolCalls: List<JambaToolCall>? = null,
48+
val toolCalls: List<JambaToolCall>? = null,
4049
@SerialName("tool_call_id")
41-
public val toolCallId: String? = null
50+
val toolCallId: String? = null
4251
)
4352

4453
@Serializable
4554
internal data class JambaTool(
46-
public val type: String = "function",
47-
public val function: JambaFunction
55+
val type: String = "function",
56+
val function: JambaFunction
4857
)
4958

5059
@Serializable
5160
internal data class JambaFunction(
52-
public val name: String,
53-
public val description: String,
54-
public val parameters: JsonObject
61+
val name: String,
62+
val description: String,
63+
val parameters: JsonObject
5564
)
5665

5766
@Serializable
5867
internal data class JambaToolCall(
59-
public val id: String,
60-
public val type: String = "function",
61-
public val function: JambaFunctionCall
68+
val id: String,
69+
val type: String = "function",
70+
val function: JambaFunctionCall
6271
)
6372

6473
@Serializable
6574
internal data class JambaFunctionCall(
66-
public val name: String,
67-
public val arguments: String
75+
val name: String,
76+
val arguments: String
6877
)
6978

7079
@Serializable
7180
internal data class JambaResponseFormat(
72-
public val type: String
81+
val type: String
7382
)
7483

7584
@Serializable
7685
internal data class JambaResponse(
77-
public val id: String,
78-
public val model: String,
79-
public val choices: List<JambaChoice>,
80-
public val usage: JambaUsage? = null
86+
val id: String,
87+
val model: String,
88+
val choices: List<JambaChoice>,
89+
val usage: JambaUsage? = null
8190
)
8291

8392
@Serializable
8493
internal data class JambaChoice(
85-
public val index: Int,
86-
public val message: JambaMessage,
94+
val index: Int,
95+
val message: JambaMessage,
8796
@SerialName("finish_reason")
88-
public val finishReason: String? = null
97+
val finishReason: String? = null
8998
)
9099

91100
@Serializable
92101
internal data class JambaUsage(
93102
@SerialName("prompt_tokens")
94-
public val promptTokens: Int,
103+
val promptTokens: Int,
95104
@SerialName("completion_tokens")
96-
public val completionTokens: Int,
105+
val completionTokens: Int,
97106
@SerialName("total_tokens")
98-
public val totalTokens: Int
107+
val totalTokens: Int
99108
)
100109

101110
@Serializable
102111
internal data class JambaStreamResponse(
103-
public val id: String,
104-
public val choices: List<JambaStreamChoice>,
105-
public val usage: JambaUsage? = null
112+
val id: String,
113+
val choices: List<JambaStreamChoice>,
114+
val usage: JambaUsage? = null
106115
)
107116

108117
@Serializable
109118
internal data class JambaStreamChoice(
110-
public val index: Int,
111-
public val delta: JambaStreamDelta,
119+
val index: Int,
120+
val delta: JambaStreamDelta,
112121
@SerialName("finish_reason")
113-
public val finishReason: String? = null
122+
val finishReason: String? = null
114123
)
115124

116125
@Serializable
117126
internal data class JambaStreamDelta(
118-
public val role: String? = null,
119-
public val content: String? = null,
127+
val role: String? = null,
128+
val content: String? = null,
120129
@SerialName("tool_calls")
121-
public val toolCalls: List<JambaToolCall>? = null
130+
val toolCalls: List<JambaToolCall>? = null
122131
)

prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/ai21/BedrockAI21JambaSerializationTest.kt

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import ai.koog.agents.core.tools.ToolParameterDescriptor
55
import ai.koog.agents.core.tools.ToolParameterType
66
import ai.koog.prompt.dsl.Prompt
77
import ai.koog.prompt.executor.clients.bedrock.BedrockModels
8+
import ai.koog.prompt.executor.clients.bedrock.modelfamilies.ai21.JambaRequest.Companion.MAX_TOKENS_DEFAULT
89
import ai.koog.prompt.llm.LLMCapability
910
import ai.koog.prompt.llm.LLMProvider
1011
import ai.koog.prompt.llm.LLModel
@@ -43,7 +44,7 @@ class BedrockAI21JambaSerializationTest {
4344

4445
assertNotNull(request)
4546
assertEquals(model.id, request.model)
46-
assertEquals(4096, request.maxTokens)
47+
assertEquals(MAX_TOKENS_DEFAULT, request.maxTokens)
4748
assertEquals(temperature, request.temperature)
4849

4950
assertEquals(2, request.messages.size)
@@ -55,6 +56,22 @@ class BedrockAI21JambaSerializationTest {
5556
assertEquals(userMessage, request.messages[1].content)
5657
}
5758

59+
@Test
60+
fun `createJambaRequest with custom maxTokens`() {
61+
val maxTokens = 1000
62+
63+
val prompt = Prompt.build("test", params = LLMParams(maxTokens = maxTokens)) {
64+
system(systemMessage)
65+
user(userMessage)
66+
}
67+
68+
val request = BedrockAI21JambaSerialization.createJambaRequest(prompt, model, emptyList())
69+
70+
assertNotNull(request)
71+
assertEquals(model.id, request.model)
72+
assertEquals(MAX_TOKENS_DEFAULT, request.maxTokens)
73+
}
74+
5875
@Test
5976
fun `createJambaRequest with conversation history`() {
6077
val userNewMessage = "Hello, who are you?"

prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/ai21/JambaDataModelsTest.kt

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
package ai.koog.prompt.executor.clients.bedrock.modelfamilies.ai21
22

3+
import ai.koog.prompt.executor.clients.anthropic.AnthropicMessageRequest
4+
import ai.koog.prompt.executor.clients.anthropic.AnthropicModels
5+
import ai.koog.prompt.executor.clients.bedrock.modelfamilies.ai21.JambaRequest.Companion.MAX_TOKENS_DEFAULT
36
import kotlinx.serialization.json.Json
47
import kotlinx.serialization.json.buildJsonObject
58
import kotlinx.serialization.json.put
69
import kotlin.test.Test
710
import kotlin.test.assertEquals
11+
import kotlin.test.assertFailsWith
812
import kotlin.test.assertNotNull
913
import kotlin.test.assertNull
1014

@@ -51,6 +55,30 @@ class JambaDataModelsTest {
5155
assert(serialized.contains("0.7")) { "Serialized JSON should contain the temperature value: $serialized" }
5256
}
5357

58+
@Test
59+
fun `JambaRequest serialization with default maxTokens`() {
60+
val request = JambaRequest(
61+
model = "ai21.jamba-1-5-large-v1:0",
62+
messages = listOf(
63+
JambaMessage(role = "user", content = "Tell me about Paris")
64+
),
65+
temperature = 0.7
66+
)
67+
assertEquals(MAX_TOKENS_DEFAULT, request.maxTokens)
68+
}
69+
70+
@Test
71+
fun `JambaRequest serialization with maxTokens less than 1`() {
72+
val exception = assertFailsWith<IllegalArgumentException> {
73+
AnthropicMessageRequest(
74+
model = AnthropicModels.Opus_3.id,
75+
messages = emptyList(),
76+
maxTokens = 0
77+
)
78+
}
79+
assertEquals("maxTokens must be greater than 0, but was 0", exception.message)
80+
}
81+
5482
@Test
5583
fun `JambaRequest serialization with null fields`() {
5684
val request = JambaRequest(
@@ -121,7 +149,7 @@ class JambaDataModelsTest {
121149
assertEquals(1, request.messages.size)
122150
assertEquals("user", request.messages[0].role)
123151
assertEquals("Tell me about Paris", request.messages[0].content)
124-
assertNull(request.maxTokens)
152+
assertEquals(MAX_TOKENS_DEFAULT, request.maxTokens)
125153
assertNull(request.temperature)
126154
}
127155

prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/amazon/BedrockAmazonNovaSerializationTest.kt

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package ai.koog.prompt.executor.clients.bedrock.modelfamilies.amazon
22

33
import ai.koog.prompt.dsl.Prompt
44
import ai.koog.prompt.executor.clients.bedrock.BedrockModels
5+
import ai.koog.prompt.executor.clients.bedrock.modelfamilies.amazon.NovaInferenceConfig.Companion.MAX_TOKENS_DEFAULT
56
import ai.koog.prompt.llm.LLMCapability
67
import ai.koog.prompt.llm.LLMProvider
78
import ai.koog.prompt.llm.LLModel
@@ -50,10 +51,23 @@ class BedrockAmazonNovaSerializationTest {
5051
assertEquals(userMessage, request.messages[0].content[0].text)
5152

5253
assertNotNull(request.inferenceConfig)
53-
assertEquals(4096, request.inferenceConfig.maxTokens)
54+
assertEquals(MAX_TOKENS_DEFAULT, request.inferenceConfig.maxTokens)
5455
assertEquals(temperature, request.inferenceConfig.temperature)
5556
}
5657

58+
@Test
59+
fun `createNovaRequest with default maxTokens`() {
60+
val maxTokens = 1000
61+
62+
val prompt = Prompt.build("test", params = LLMParams(maxTokens = maxTokens)) {
63+
system(systemMessage)
64+
user(userMessage)
65+
}
66+
67+
val request = BedrockAmazonNovaSerialization.createNovaRequest(prompt, model)
68+
assertEquals(maxTokens, request.inferenceConfig!!.maxTokens)
69+
}
70+
5771
@Test
5872
fun `createNovaRequest with conversation history`() {
5973
val prompt = Prompt.build("test") {

prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/anthropic/BedrockAnthropicClaudeSerializationTest.kt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import ai.koog.agents.core.tools.ToolParameterDescriptor
55
import ai.koog.agents.core.tools.ToolParameterType
66
import ai.koog.prompt.dsl.Prompt
77
import ai.koog.prompt.executor.clients.anthropic.AnthropicContent
8+
import ai.koog.prompt.executor.clients.anthropic.AnthropicMessageRequest
89
import ai.koog.prompt.executor.clients.anthropic.AnthropicToolChoice
910
import ai.koog.prompt.executor.clients.bedrock.BedrockModels
1011
import ai.koog.prompt.message.Message
@@ -47,7 +48,7 @@ class BedrockAnthropicClaudeSerializationTest {
4748

4849
assertNotNull(request)
4950
assertEquals(model.id, request.model)
50-
assertEquals(4096, request.maxTokens)
51+
assertEquals(AnthropicMessageRequest.MAX_TOKENS_DEFAULT, request.maxTokens)
5152
assertEquals(temperature, request.temperature)
5253

5354
assertNotNull(request.system)

0 commit comments

Comments
 (0)