Skip to content

Commit e75cb39

Browse files
committed
Add validation for maxTokens and update tests
1 parent 34cb5ea commit e75cb39

File tree

7 files changed

+161
-51
lines changed

7 files changed

+161
-51
lines changed

prompt/prompt-executor/prompt-executor-clients/prompt-executor-anthropic-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/anthropic/DataModel.kt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,13 @@ public data class AnthropicMessageRequest(
3434
val stream: Boolean = false,
3535
val toolChoice: AnthropicToolChoice? = null,
3636
) {
37+
init {
38+
require(maxTokens > 0) { "maxTokens must be greater than 0, but was $maxTokens" }
39+
if (temperature != null) {
40+
require(temperature >= 0) { "temperature must be greater than 0, but was $temperature" }
41+
}
42+
}
43+
3744
/**
3845
* Companion object with default values for request
3946
*/

prompt/prompt-executor/prompt-executor-clients/prompt-executor-anthropic-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/anthropic/AnthropicModelsTest.kt

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ package ai.koog.prompt.executor.clients.anthropic
33
import ai.koog.prompt.executor.clients.list
44
import ai.koog.prompt.llm.LLMProvider
55
import kotlin.test.Test
6+
import kotlin.test.assertEquals
7+
import kotlin.test.assertFailsWith
68
import kotlin.test.assertSame
79

810
class AnthropicModelsTest {
@@ -19,4 +21,38 @@ class AnthropicModelsTest {
1921
)
2022
}
2123
}
24+
25+
@Test
26+
fun `AnthropicMessageRequest should use custom maxTokens when provided`() {
27+
val customMaxTokens = 4000
28+
val request = AnthropicMessageRequest(
29+
model = AnthropicModels.Opus_3.id,
30+
messages = emptyList(),
31+
maxTokens = customMaxTokens
32+
)
33+
34+
assertEquals(customMaxTokens, request.maxTokens)
35+
}
36+
37+
@Test
38+
fun `AnthropicMessageRequest should use default maxTokens when not provided`() {
39+
val request = AnthropicMessageRequest(
40+
model = AnthropicModels.Opus_3.id,
41+
messages = emptyList()
42+
)
43+
44+
assertEquals(AnthropicMessageRequest.MAX_TOKENS_DEFAULT, request.maxTokens)
45+
}
46+
47+
@Test
48+
fun `AnthropicMessageRequest should reject zero maxTokens`() {
49+
val exception = assertFailsWith<IllegalArgumentException> {
50+
AnthropicMessageRequest(
51+
model = AnthropicModels.Opus_3.id,
52+
messages = emptyList(),
53+
maxTokens = 0
54+
)
55+
}
56+
assertEquals("maxTokens must be greater than 0, but was 0", exception.message)
57+
}
2258
}

prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/ai21/JambaDataModel.kt

Lines changed: 54 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,27 @@ import kotlinx.serialization.json.JsonObject
66

77
@Serializable
88
internal data class JambaRequest(
9-
public val model: String,
10-
public val messages: List<JambaMessage>,
11-
@SerialName("max_tokens")
12-
public val maxTokens: Int? = MAX_TOKENS_DEFAULT,
13-
public val temperature: Double? = null,
14-
@SerialName("top_p")
15-
public val topP: Double? = null,
16-
public val stop: List<String>? = null,
17-
public val n: Int? = null,
18-
public val stream: Boolean? = null,
19-
public val tools: List<JambaTool>? = null,
20-
@SerialName("response_format")
21-
public val responseFormat: JambaResponseFormat? = null
9+
val model: String,
10+
val messages: List<JambaMessage>,
11+
@SerialName("max_tokens") val maxTokens: Int? = MAX_TOKENS_DEFAULT,
12+
val temperature: Double? = null,
13+
@SerialName("top_p") val topP: Double? = null,
14+
val stop: List<String>? = null,
15+
val n: Int? = null,
16+
val stream: Boolean? = null,
17+
val tools: List<JambaTool>? = null,
18+
@SerialName("response_format") val responseFormat: JambaResponseFormat? = null
2219
) {
20+
init {
21+
require(maxTokens != null && maxTokens > 0) { "maxTokens must be greater than 0, but was $maxTokens" }
22+
if (temperature != null) {
23+
require(temperature >= 0) { "temperature must be greater than 0, but was $temperature" }
24+
}
25+
if (topP != null) {
26+
require(topP in 0.0..1.0) { "topP must be between 0 and 1, but was $topP" }
27+
}
28+
}
29+
2330
/**
2431
* Companion object with default values for request
2532
*/
@@ -33,90 +40,90 @@ internal data class JambaRequest(
3340

3441
@Serializable
3542
internal data class JambaMessage(
36-
public val role: String,
37-
public val content: String? = null,
43+
val role: String,
44+
val content: String? = null,
3845
@SerialName("tool_calls")
39-
public val toolCalls: List<JambaToolCall>? = null,
46+
val toolCalls: List<JambaToolCall>? = null,
4047
@SerialName("tool_call_id")
41-
public val toolCallId: String? = null
48+
val toolCallId: String? = null
4249
)
4350

4451
@Serializable
4552
internal data class JambaTool(
46-
public val type: String = "function",
47-
public val function: JambaFunction
53+
val type: String = "function",
54+
val function: JambaFunction
4855
)
4956

5057
@Serializable
5158
internal data class JambaFunction(
52-
public val name: String,
53-
public val description: String,
54-
public val parameters: JsonObject
59+
val name: String,
60+
val description: String,
61+
val parameters: JsonObject
5562
)
5663

5764
@Serializable
5865
internal data class JambaToolCall(
59-
public val id: String,
60-
public val type: String = "function",
61-
public val function: JambaFunctionCall
66+
val id: String,
67+
val type: String = "function",
68+
val function: JambaFunctionCall
6269
)
6370

6471
@Serializable
6572
internal data class JambaFunctionCall(
66-
public val name: String,
67-
public val arguments: String
73+
val name: String,
74+
val arguments: String
6875
)
6976

7077
@Serializable
7178
internal data class JambaResponseFormat(
72-
public val type: String
79+
val type: String
7380
)
7481

7582
@Serializable
7683
internal data class JambaResponse(
77-
public val id: String,
78-
public val model: String,
79-
public val choices: List<JambaChoice>,
80-
public val usage: JambaUsage? = null
84+
val id: String,
85+
val model: String,
86+
val choices: List<JambaChoice>,
87+
val usage: JambaUsage? = null
8188
)
8289

8390
@Serializable
8491
internal data class JambaChoice(
85-
public val index: Int,
86-
public val message: JambaMessage,
92+
val index: Int,
93+
val message: JambaMessage,
8794
@SerialName("finish_reason")
88-
public val finishReason: String? = null
95+
val finishReason: String? = null
8996
)
9097

9198
@Serializable
9299
internal data class JambaUsage(
93100
@SerialName("prompt_tokens")
94-
public val promptTokens: Int,
101+
val promptTokens: Int,
95102
@SerialName("completion_tokens")
96-
public val completionTokens: Int,
103+
val completionTokens: Int,
97104
@SerialName("total_tokens")
98-
public val totalTokens: Int
105+
val totalTokens: Int
99106
)
100107

101108
@Serializable
102109
internal data class JambaStreamResponse(
103-
public val id: String,
104-
public val choices: List<JambaStreamChoice>,
105-
public val usage: JambaUsage? = null
110+
val id: String,
111+
val choices: List<JambaStreamChoice>,
112+
val usage: JambaUsage? = null
106113
)
107114

108115
@Serializable
109116
internal data class JambaStreamChoice(
110-
public val index: Int,
111-
public val delta: JambaStreamDelta,
117+
val index: Int,
118+
val delta: JambaStreamDelta,
112119
@SerialName("finish_reason")
113-
public val finishReason: String? = null
120+
val finishReason: String? = null
114121
)
115122

116123
@Serializable
117124
internal data class JambaStreamDelta(
118-
public val role: String? = null,
119-
public val content: String? = null,
125+
val role: String? = null,
126+
val content: String? = null,
120127
@SerialName("tool_calls")
121-
public val toolCalls: List<JambaToolCall>? = null
128+
val toolCalls: List<JambaToolCall>? = null
122129
)

prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/ai21/BedrockAI21JambaSerializationTest.kt

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import ai.koog.agents.core.tools.ToolParameterDescriptor
55
import ai.koog.agents.core.tools.ToolParameterType
66
import ai.koog.prompt.dsl.Prompt
77
import ai.koog.prompt.executor.clients.bedrock.BedrockModels
8+
import ai.koog.prompt.executor.clients.bedrock.modelfamilies.ai21.JambaRequest.Companion.MAX_TOKENS_DEFAULT
89
import ai.koog.prompt.llm.LLMCapability
910
import ai.koog.prompt.llm.LLMProvider
1011
import ai.koog.prompt.llm.LLModel
@@ -43,7 +44,7 @@ class BedrockAI21JambaSerializationTest {
4344

4445
assertNotNull(request)
4546
assertEquals(model.id, request.model)
46-
assertEquals(4096, request.maxTokens)
47+
assertEquals(MAX_TOKENS_DEFAULT, request.maxTokens)
4748
assertEquals(temperature, request.temperature)
4849

4950
assertEquals(2, request.messages.size)
@@ -55,6 +56,22 @@ class BedrockAI21JambaSerializationTest {
5556
assertEquals(userMessage, request.messages[1].content)
5657
}
5758

59+
@Test
60+
fun `createJambaRequest with custom maxTokens`() {
61+
val maxTokens = 1000
62+
63+
val prompt = Prompt.build("test", params = LLMParams(maxTokens = maxTokens)) {
64+
system(systemMessage)
65+
user(userMessage)
66+
}
67+
68+
val request = BedrockAI21JambaSerialization.createJambaRequest(prompt, model, emptyList())
69+
70+
assertNotNull(request)
71+
assertEquals(model.id, request.model)
72+
assertEquals(MAX_TOKENS_DEFAULT, request.maxTokens)
73+
}
74+
5875
@Test
5976
fun `createJambaRequest with conversation history`() {
6077
val userNewMessage = "Hello, who are you?"

prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/ai21/JambaDataModelsTest.kt

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
package ai.koog.prompt.executor.clients.bedrock.modelfamilies.ai21
22

3+
import ai.koog.prompt.executor.clients.anthropic.AnthropicMessageRequest
4+
import ai.koog.prompt.executor.clients.anthropic.AnthropicModels
5+
import ai.koog.prompt.executor.clients.bedrock.modelfamilies.ai21.JambaRequest.Companion.MAX_TOKENS_DEFAULT
36
import kotlinx.serialization.json.Json
47
import kotlinx.serialization.json.buildJsonObject
58
import kotlinx.serialization.json.put
69
import kotlin.test.Test
710
import kotlin.test.assertEquals
11+
import kotlin.test.assertFailsWith
812
import kotlin.test.assertNotNull
913
import kotlin.test.assertNull
1014

@@ -51,6 +55,30 @@ class JambaDataModelsTest {
5155
assert(serialized.contains("0.7")) { "Serialized JSON should contain the temperature value: $serialized" }
5256
}
5357

58+
@Test
59+
fun `JambaRequest serialization with default maxTokens`() {
60+
val request = JambaRequest(
61+
model = "ai21.jamba-1-5-large-v1:0",
62+
messages = listOf(
63+
JambaMessage(role = "user", content = "Tell me about Paris")
64+
),
65+
temperature = 0.7
66+
)
67+
assertEquals(MAX_TOKENS_DEFAULT, request.maxTokens)
68+
}
69+
70+
@Test
71+
fun `JambaRequest serialization with maxTokens less than 1`() {
72+
val exception = assertFailsWith<IllegalArgumentException> {
73+
AnthropicMessageRequest(
74+
model = AnthropicModels.Opus_3.id,
75+
messages = emptyList(),
76+
maxTokens = 0
77+
)
78+
}
79+
assertEquals("maxTokens must be greater than 0, but was 0", exception.message)
80+
}
81+
5482
@Test
5583
fun `JambaRequest serialization with null fields`() {
5684
val request = JambaRequest(
@@ -121,7 +149,7 @@ class JambaDataModelsTest {
121149
assertEquals(1, request.messages.size)
122150
assertEquals("user", request.messages[0].role)
123151
assertEquals("Tell me about Paris", request.messages[0].content)
124-
assertNull(request.maxTokens)
152+
assertEquals(MAX_TOKENS_DEFAULT, request.maxTokens)
125153
assertNull(request.temperature)
126154
}
127155

prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/amazon/BedrockAmazonNovaSerializationTest.kt

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package ai.koog.prompt.executor.clients.bedrock.modelfamilies.amazon
22

33
import ai.koog.prompt.dsl.Prompt
44
import ai.koog.prompt.executor.clients.bedrock.BedrockModels
5+
import ai.koog.prompt.executor.clients.bedrock.modelfamilies.amazon.NovaInferenceConfig.Companion.MAX_TOKENS_DEFAULT
56
import ai.koog.prompt.llm.LLMCapability
67
import ai.koog.prompt.llm.LLMProvider
78
import ai.koog.prompt.llm.LLModel
@@ -50,10 +51,23 @@ class BedrockAmazonNovaSerializationTest {
5051
assertEquals(userMessage, request.messages[0].content[0].text)
5152

5253
assertNotNull(request.inferenceConfig)
53-
assertEquals(4096, request.inferenceConfig.maxTokens)
54+
assertEquals(MAX_TOKENS_DEFAULT, request.inferenceConfig.maxTokens)
5455
assertEquals(temperature, request.inferenceConfig.temperature)
5556
}
5657

58+
@Test
59+
fun `createNovaRequest with default maxTokens`() {
60+
val maxTokens = 1000
61+
62+
val prompt = Prompt.build("test", params = LLMParams(maxTokens = maxTokens)) {
63+
system(systemMessage)
64+
user(userMessage)
65+
}
66+
67+
val request = BedrockAmazonNovaSerialization.createNovaRequest(prompt, model)
68+
assertEquals(maxTokens, request.inferenceConfig!!.maxTokens)
69+
}
70+
5771
@Test
5872
fun `createNovaRequest with conversation history`() {
5973
val prompt = Prompt.build("test") {

prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/bedrock/modelfamilies/anthropic/BedrockAnthropicClaudeSerializationTest.kt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import ai.koog.agents.core.tools.ToolParameterDescriptor
55
import ai.koog.agents.core.tools.ToolParameterType
66
import ai.koog.prompt.dsl.Prompt
77
import ai.koog.prompt.executor.clients.anthropic.AnthropicContent
8+
import ai.koog.prompt.executor.clients.anthropic.AnthropicMessageRequest
89
import ai.koog.prompt.executor.clients.anthropic.AnthropicToolChoice
910
import ai.koog.prompt.executor.clients.bedrock.BedrockModels
1011
import ai.koog.prompt.message.Message
@@ -47,7 +48,7 @@ class BedrockAnthropicClaudeSerializationTest {
4748

4849
assertNotNull(request)
4950
assertEquals(model.id, request.model)
50-
assertEquals(4096, request.maxTokens)
51+
assertEquals(AnthropicMessageRequest.MAX_TOKENS_DEFAULT, request.maxTokens)
5152
assertEquals(temperature, request.temperature)
5253

5354
assertNotNull(request.system)

0 commit comments

Comments
 (0)