Skip to content

Commit 69e2881

Browse files
committed
Add ContextWindowStrategy to control context length in OllamaClient
1 parent e5cddfe commit 69e2881

File tree

5 files changed

+139
-15
lines changed

5 files changed

+139
-15
lines changed

prompt/prompt-executor/prompt-executor-clients/prompt-executor-ollama-client/build.gradle.kts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@ kotlin {
1212
sourceSets {
1313
commonMain {
1414
dependencies {
15+
api(project(":agents:agents-core"))
1516
api(project(":agents:agents-tools"))
1617
api(project(":prompt:prompt-llm"))
1718
api(project(":prompt:prompt-model"))
19+
api(project(":prompt:prompt-tokenizer"))
1820
api(project(":agents:agents-tools"))
1921
api(project(":prompt:prompt-executor:prompt-executor-model"))
2022
api(project(":prompt:prompt-executor:prompt-executor-clients"))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
package ai.koog.prompt.executor.ollama.client
2+
3+
import ai.koog.agents.core.annotation.ExperimentalAgentsApi
4+
import ai.koog.prompt.dsl.Prompt
5+
import ai.koog.prompt.llm.LLModel
6+
import ai.koog.prompt.tokenizer.PromptTokenizer
7+
import io.github.oshai.kotlinlogging.KotlinLogging
8+
9+
private val logger = KotlinLogging.logger { }
10+
11+
/**
12+
* Represents a strategy for computing the context window length for `OllamaClient`.
13+
* Different implementations define specific approaches to computing the context window length.
14+
* Based on the context window length computed by this strategy, Ollama will truncate the context window accordingly.
15+
*
16+
* To decide the context window length, Ollama proceeds as follows:
17+
* - If a `num_ctx` parameter is specified in the chat request, the context window length is set to that value.
18+
* - If the model definition contains a `num_ctx` parameter, the context window length is set to that value.
19+
* - If an `OLLAMA_CONTEXT_LENGTH` environment variable is set, the context window length is set to that value.
20+
* - Otherwise, the context window length is set to the default value of 2048.
21+
*
22+
* Effectively, this strategy allows you to specify what `num_ctx` value will be set in chat requests sent to Ollama,
23+
* for a given prompt and model.
24+
*
25+
* Important: You will want to have a context window length that does not change often for a specific model.
26+
* Indeed, Ollama will reload the model every time the context window length changes.
27+
*
28+
* Example implementations:
29+
* - [ContextWindowStrategy.None]
30+
* - [ContextWindowStrategy.Fixed]
31+
* - [ContextWindowStrategy.FitPrompt]
32+
*/
33+
@ExperimentalAgentsApi
34+
public interface ContextWindowStrategy {
35+
36+
public fun computeContextLength(prompt: Prompt, model: LLModel): Long?
37+
38+
public companion object {
39+
/**
40+
* A strategy for letting the Ollama server decide the context window length.
41+
* To decide the context window length, Ollama proceeds as follows:
42+
* - If the model definition contains a `num_ctx` parameter, the context window length is set to that value.
43+
* - If an `OLLAMA_CONTEXT_LENGTH` environment variable is set, the context window length is set to that value.
44+
* - Otherwise, the context window length is set to the default value of 2048.
45+
*/
46+
public data object None : ContextWindowStrategy {
47+
override fun computeContextLength(prompt: Prompt, model: LLModel): Long? = null
48+
}
49+
50+
/**
51+
* A strategy for specifying a fixed context window length.
52+
* If the given [contextLength] is more than the maximum context window length supported by the model,
53+
* the context window length will be set to the maximum context window length supported by the model.
54+
*
55+
* @param contextLength The context window length to use.
56+
*/
57+
public data class Fixed(val contextLength: Long) : ContextWindowStrategy {
58+
override fun computeContextLength(prompt: Prompt, model: LLModel): Long {
59+
if (contextLength > model.contextLength) {
60+
logger.warn {
61+
"Context length $contextLength was more than what is supported by model '${model.id}'," +
62+
" falling back to the model's maximum context length ${model.contextLength}"
63+
}
64+
return model.contextLength
65+
}
66+
return contextLength
67+
}
68+
}
69+
70+
/**
71+
* A strategy for computing the context window length based on the prompt length.
72+
*
73+
* @param promptTokenizer The [PromptTokenizer] to use for computing the prompt length,
74+
* or null to use the last reported token usage.
75+
* @param granularity The granularity to use for computing the context window length. Defaults to 2048.
76+
* @param minimumContextLength The minimum context window length,
77+
* if the prompt length is less than it or cannot be computed yet.
78+
* If not null, [minimumContextLength] must be a multiple of the [granularity].
79+
* If null, we let Ollama decide the context window length.
80+
*/
81+
public data class FitPrompt(
82+
val promptTokenizer: PromptTokenizer? = null,
83+
val granularity: Long = 2048,
84+
val minimumContextLength: Long? = null,
85+
) : ContextWindowStrategy {
86+
87+
init {
88+
require(granularity > 0) { "Granularity must be greater than 0" }
89+
require(minimumContextLength == null || minimumContextLength % granularity == 0L) {
90+
"Minimum context length must be a multiple of granularity"
91+
}
92+
}
93+
94+
override fun computeContextLength(prompt: Prompt, model: LLModel): Long? {
95+
val promptLength = when {
96+
promptTokenizer != null -> promptTokenizer.tokenCountFor(prompt)
97+
prompt.latestTokenUsage != 0 -> prompt.latestTokenUsage
98+
else -> null
99+
}
100+
101+
if (promptLength == null) return minimumContextLength
102+
if (promptLength > model.contextLength) {
103+
logger.warn {
104+
"Prompt length $promptLength was more than the maximum context length of model '${model.id}'," +
105+
" falling back to the model's maximum context length ${model.contextLength}"
106+
}
107+
return model.contextLength
108+
}
109+
110+
return (promptLength / granularity + 1) * granularity
111+
}
112+
}
113+
}
114+
}

prompt/prompt-executor/prompt-executor-clients/prompt-executor-ollama-client/src/commonMain/kotlin/ai/koog/prompt/executor/ollama/client/OllamaClient.kt

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package ai.koog.prompt.executor.ollama.client
22

3+
import ai.koog.agents.core.annotation.ExperimentalAgentsApi
34
import ai.koog.agents.core.tools.ToolDescriptor
45
import ai.koog.prompt.dsl.ModerationCategory
56
import ai.koog.prompt.dsl.ModerationCategoryResult
@@ -19,7 +20,6 @@ import ai.koog.prompt.executor.ollama.client.dto.OllamaPullModelResponseDTO
1920
import ai.koog.prompt.executor.ollama.client.dto.OllamaShowModelRequestDTO
2021
import ai.koog.prompt.executor.ollama.client.dto.OllamaShowModelResponseDTO
2122
import ai.koog.prompt.executor.ollama.client.dto.extractOllamaJsonFormat
22-
import ai.koog.prompt.executor.ollama.client.dto.extractOllamaOptions
2323
import ai.koog.prompt.executor.ollama.client.dto.getToolCalls
2424
import ai.koog.prompt.executor.ollama.client.dto.toOllamaChatMessages
2525
import ai.koog.prompt.executor.ollama.client.dto.toOllamaModelCard
@@ -53,19 +53,24 @@ import kotlinx.serialization.json.Json
5353
/**
5454
* Client for interacting with the Ollama API with comprehensive model support.
5555
*
56+
* Implements:
57+
* - [LLMClient] for executing prompts and streaming responses.
58+
* - [LLMEmbeddingProvider] for generating embeddings from input text.
59+
*
5660
* @param baseUrl The base URL of the Ollama server. Defaults to "http://localhost:11434".
5761
* @param baseClient The underlying HTTP client used for making requests.
5862
* @param timeoutConfig Configuration for connection, request, and socket timeouts.
5963
* @param clock Clock instance used for tracking response metadata timestamps.
60-
* Implements:
61-
* - LLMClient for executing prompts and streaming responses.
62-
* - LLMEmbeddingProvider for generating embeddings from input text.
64+
* @param contextWindowStrategy The [ContextWindowStrategy] to use for computing context window lengths.
65+
* Defaults to [ContextWindowStrategy.None].
6366
*/
67+
@OptIn(ExperimentalAgentsApi::class)
6468
public class OllamaClient(
6569
public val baseUrl: String = "http://localhost:11434",
6670
baseClient: HttpClient = HttpClient(engineFactoryProvider()),
6771
timeoutConfig: ConnectionTimeoutConfig = ConnectionTimeoutConfig(),
68-
private val clock: Clock = Clock.System
72+
private val clock: Clock = Clock.System,
73+
private val contextWindowStrategy: ContextWindowStrategy = ContextWindowStrategy.Companion.None,
6974
) : LLMClient, LLMEmbeddingProvider {
7075

7176
private companion object {
@@ -155,7 +160,7 @@ public class OllamaClient(
155160
messages = prompt.toOllamaChatMessages(model),
156161
tools = if (tools.isNotEmpty()) tools.map { it.toOllamaTool() } else null,
157162
format = prompt.extractOllamaJsonFormat(),
158-
options = prompt.extractOllamaOptions(),
163+
options = extractOllamaOptions(prompt, model),
159164
stream = false,
160165
)
161166
)
@@ -230,7 +235,7 @@ public class OllamaClient(
230235
OllamaChatRequestDTO(
231236
model = model.id,
232237
messages = prompt.toOllamaChatMessages(model),
233-
options = prompt.extractOllamaOptions(),
238+
options = extractOllamaOptions(prompt, model),
234239
stream = true,
235240
)
236241
)
@@ -256,6 +261,16 @@ public class OllamaClient(
256261
}
257262
}
258263

264+
/**
265+
* Prepare Ollama chat request options from the given prompt and model.
266+
*/
267+
internal fun extractOllamaOptions(prompt: Prompt, model: LLModel): OllamaChatRequestDTO.Options {
268+
return OllamaChatRequestDTO.Options(
269+
temperature = prompt.params.temperature,
270+
numCtx = contextWindowStrategy.computeContextLength(prompt, model),
271+
)
272+
}
273+
259274
/**
260275
* Embeds the given text using the Ollama model.
261276
*

prompt/prompt-executor/prompt-executor-clients/prompt-executor-ollama-client/src/commonMain/kotlin/ai/koog/prompt/executor/ollama/client/dto/OllamaConverters.kt

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -111,14 +111,6 @@ internal fun Prompt.extractOllamaJsonFormat(): JsonObject? {
111111
return if (schema is LLMParams.Schema.JSON) schema.schema else null
112112
}
113113

114-
/**
115-
* Extracts options from the prompt, if temperature is defined.
116-
*/
117-
internal fun Prompt.extractOllamaOptions(): OllamaChatRequestDTO.Options? {
118-
val temperature = params.temperature
119-
return temperature?.let { OllamaChatRequestDTO.Options(temperature = temperature) }
120-
}
121-
122114
/**
123115
* Extracts tool calls from a ChatMessage.
124116
* Returns the first tool call for compatibility, but logs if multiple calls exist.

prompt/prompt-executor/prompt-executor-clients/prompt-executor-ollama-client/src/commonMain/kotlin/ai/koog/prompt/executor/ollama/client/dto/OllamaModels.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ internal data class OllamaChatRequestDTO(
7070
@Serializable
7171
internal data class Options(
7272
val temperature: Double? = null,
73+
@SerialName("num_ctx") val numCtx: Long? = null,
7374
)
7475
}
7576

0 commit comments

Comments
 (0)