Skip to content

Commit f8e4fbb

Browse files
committed
Add unit tests for ContextWindowStrategy in OllamaClient
1 parent 21cf358 commit f8e4fbb

File tree

1 file changed

+231
-0
lines changed
  • prompt/prompt-executor/prompt-executor-clients/prompt-executor-ollama-client/src/commonTest/kotlin/ai/koog/prompt/executor/ollama/client

1 file changed

+231
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
package ai.koog.prompt.executor.ollama.client
2+
3+
import ai.koog.prompt.dsl.Prompt
4+
import ai.koog.prompt.dsl.prompt
5+
import ai.koog.prompt.executor.ollama.client.dto.OllamaChatMessageDTO
6+
import ai.koog.prompt.executor.ollama.client.dto.OllamaChatRequestDTO
7+
import ai.koog.prompt.executor.ollama.client.dto.OllamaChatResponseDTO
8+
import ai.koog.prompt.llm.OllamaModels
9+
import ai.koog.prompt.message.Message
10+
import ai.koog.prompt.message.ResponseMetaInfo
11+
import ai.koog.prompt.tokenizer.PromptTokenizer
12+
import io.ktor.client.HttpClient
13+
import io.ktor.client.engine.mock.MockEngine
14+
import io.ktor.client.engine.mock.respond
15+
import io.ktor.client.request.HttpRequestData
16+
import io.ktor.http.HttpHeaders
17+
import io.ktor.http.HttpStatusCode
18+
import io.ktor.http.content.TextContent
19+
import io.ktor.http.headersOf
20+
import kotlinx.coroutines.test.runTest
21+
import kotlinx.datetime.Clock
22+
import kotlinx.serialization.json.Json
23+
import kotlin.test.Test
24+
import kotlin.test.assertEquals
25+
import kotlin.test.assertNotNull
26+
import kotlin.test.assertNull
27+
28+
class ContextWindowStrategyTest {
29+
@Test
30+
fun `test None strategy`() = runTest {
31+
val mockServer = MockOllamaChatServer { request -> makeDummyResponse(request) }
32+
33+
val ollamaClient = OllamaClient(
34+
baseClient = HttpClient(mockServer.mockEngine),
35+
contextWindowStrategy = ContextWindowStrategy.Companion.None,
36+
)
37+
38+
ollamaClient.execute(
39+
prompt = prompt("test-prompt") { },
40+
model = OllamaModels.Meta.LLAMA_3_2,
41+
)
42+
43+
val requestHistory = mockServer.requestHistory
44+
assertEquals(requestHistory.size, 1)
45+
46+
val response = requestHistory.first()
47+
assertNotNull(response.options)
48+
assertNull(response.options.numCtx)
49+
}
50+
51+
@Test
52+
fun `test Fixed strategy`() = runTest {
53+
val mockServer = MockOllamaChatServer { request -> makeDummyResponse(request) }
54+
55+
val ollamaClient = OllamaClient(
56+
baseClient = HttpClient(mockServer.mockEngine),
57+
contextWindowStrategy = ContextWindowStrategy.Companion.Fixed(42),
58+
)
59+
60+
ollamaClient.execute(
61+
prompt = prompt("test-prompt") { },
62+
model = OllamaModels.Meta.LLAMA_3_2,
63+
)
64+
65+
val requestHistory = mockServer.requestHistory
66+
assertEquals(requestHistory.size, 1)
67+
68+
val response = requestHistory.first()
69+
assertNotNull(response.options)
70+
assertEquals(42, response.options.numCtx)
71+
}
72+
73+
@Test
74+
fun `test FitPrompt strategy with tokenizer`() = runTest {
75+
val mockServer = MockOllamaChatServer { request -> makeDummyResponse(request) }
76+
77+
val ollamaClient = OllamaClient(
78+
baseClient = HttpClient(mockServer.mockEngine),
79+
contextWindowStrategy = ContextWindowStrategy.Companion.FitPrompt(
80+
promptTokenizer = object : PromptTokenizer {
81+
override fun tokenCountFor(message: Message): Int = error("Not needed")
82+
override fun tokenCountFor(prompt: Prompt): Int = 3000
83+
},
84+
granularity = 1024,
85+
minimumContextLength = 2048,
86+
),
87+
)
88+
89+
ollamaClient.execute(
90+
prompt = prompt("test-prompt") { },
91+
model = OllamaModels.Meta.LLAMA_3_2,
92+
)
93+
94+
val requestHistory = mockServer.requestHistory
95+
assertEquals(requestHistory.size, 1)
96+
97+
val response = requestHistory.first()
98+
assertNotNull(response.options)
99+
assertEquals(3072, response.options.numCtx)
100+
}
101+
102+
@Test
103+
fun `test FitPrompt strategy without tokenizer and no previous token usage`() = runTest {
104+
val mockServer = MockOllamaChatServer { request -> makeDummyResponse(request) }
105+
106+
val ollamaClient = OllamaClient(
107+
baseClient = HttpClient(mockServer.mockEngine),
108+
contextWindowStrategy = ContextWindowStrategy.Companion.FitPrompt(
109+
promptTokenizer = null,
110+
granularity = 1024,
111+
minimumContextLength = 2048,
112+
),
113+
)
114+
115+
ollamaClient.execute(
116+
prompt = prompt("test-prompt") { },
117+
model = OllamaModels.Meta.LLAMA_3_2,
118+
)
119+
120+
val requestHistory = mockServer.requestHistory
121+
assertEquals(requestHistory.size, 1)
122+
123+
val response = requestHistory.first()
124+
assertNotNull(response.options)
125+
assertEquals(2048, response.options.numCtx)
126+
}
127+
128+
@Test
129+
fun `test FitPrompt strategy without tokenizer and existing token usage`() = runTest {
130+
val mockServer = MockOllamaChatServer { request -> makeDummyResponse(request) }
131+
132+
val ollamaClient = OllamaClient(
133+
baseClient = HttpClient(mockServer.mockEngine),
134+
contextWindowStrategy = ContextWindowStrategy.Companion.FitPrompt(
135+
promptTokenizer = null,
136+
granularity = 1024,
137+
minimumContextLength = 2048,
138+
),
139+
)
140+
141+
ollamaClient.execute(
142+
prompt = prompt("test-prompt") {
143+
message(
144+
Message.Assistant(
145+
"Dummy message",
146+
metaInfo = ResponseMetaInfo(
147+
timestamp = Clock.System.now(),
148+
totalTokensCount = 5000,
149+
)
150+
)
151+
)
152+
},
153+
model = OllamaModels.Meta.LLAMA_3_2,
154+
)
155+
156+
val requestHistory = mockServer.requestHistory
157+
assertEquals(requestHistory.size, 1)
158+
159+
val response = requestHistory.first()
160+
assertNotNull(response.options)
161+
assertEquals(5120, response.options.numCtx)
162+
}
163+
164+
@Test
165+
fun `test FitPrompt strategy with tokenizer and too long prompt`() = runTest {
166+
val mockServer = MockOllamaChatServer { request -> makeDummyResponse(request) }
167+
168+
val ollamaClient = OllamaClient(
169+
baseClient = HttpClient(mockServer.mockEngine),
170+
contextWindowStrategy = ContextWindowStrategy.Companion.FitPrompt(
171+
promptTokenizer = object : PromptTokenizer {
172+
override fun tokenCountFor(message: Message): Int = error("Not needed")
173+
override fun tokenCountFor(prompt: Prompt): Int = 9000
174+
},
175+
granularity = 1024,
176+
minimumContextLength = 2048,
177+
),
178+
)
179+
180+
ollamaClient.execute(
181+
prompt = prompt("test-prompt") { },
182+
model = OllamaModels.Meta.LLAMA_3_2.copy(
183+
contextLength = 8192
184+
),
185+
)
186+
187+
val requestHistory = mockServer.requestHistory
188+
assertEquals(requestHistory.size, 1)
189+
190+
val response = requestHistory.first()
191+
assertNotNull(response.options)
192+
assertEquals(8192, response.options.numCtx)
193+
}
194+
}
195+
196+
private fun makeDummyResponse(
197+
request: OllamaChatRequestDTO,
198+
content: String = "OK",
199+
promptEvalCount: Int = 10,
200+
evalCount: Int = 100,
201+
): OllamaChatResponseDTO = OllamaChatResponseDTO(
202+
model = request.model,
203+
message = OllamaChatMessageDTO(role = "assistant", content = content),
204+
done = true,
205+
promptEvalCount = promptEvalCount,
206+
evalCount = evalCount,
207+
)
208+
209+
private class MockOllamaChatServer(
210+
private val handler: (OllamaChatRequestDTO) -> OllamaChatResponseDTO,
211+
) {
212+
val mockEngine = MockEngine { requestData ->
213+
val request = requestData.extractChatRequest()
214+
val response = handler(request)
215+
respond(
216+
content = Json.encodeToString<OllamaChatResponseDTO>(response),
217+
status = HttpStatusCode.OK,
218+
headers = headersOf(HttpHeaders.ContentType to listOf("application/json")),
219+
)
220+
}
221+
222+
val requestHistory: List<OllamaChatRequestDTO>
223+
get() = mockEngine.requestHistory.map { it.extractChatRequest() }
224+
225+
private fun HttpRequestData.extractChatRequest(): OllamaChatRequestDTO {
226+
val requestContent = body as TextContent
227+
val requestBody = requestContent.text
228+
val request = Json.decodeFromString<OllamaChatRequestDTO>(requestBody)
229+
return request
230+
}
231+
}

0 commit comments

Comments
 (0)