Skip to content

Allow adjusting context window sizes for Ollama dynamically #335

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@ import ai.koog.agents.core.agent.entity.AIAgentStorageKey
import ai.koog.agents.core.feature.AIAgentFeature
import ai.koog.agents.core.feature.AIAgentPipeline
import ai.koog.agents.core.feature.config.FeatureConfig
import ai.koog.prompt.dsl.Prompt
import ai.koog.prompt.message.Message
import ai.koog.prompt.tokenizer.CachingTokenizer
import ai.koog.prompt.tokenizer.NoTokenizer
import ai.koog.prompt.tokenizer.OnDemandTokenizer
import ai.koog.prompt.tokenizer.PromptTokenizer
import ai.koog.prompt.tokenizer.Tokenizer

/**
Expand Down Expand Up @@ -46,111 +47,6 @@ public class MessageTokenizerConfig : FeatureConfig() {
public var enableCaching: Boolean = true
}

/**
* An interface that provides utilities for tokenizing and calculating token usage in messages and prompts.
*/
public interface PromptTokenizer {
/**
* Calculates the number of tokens required for a given message.
*
* @param message The message for which the token count should be determined.
* @return The number of tokens required to encode the message.
*/
public fun tokenCountFor(message: Message): Int

/**
* Calculates the total number of tokens spent in a given prompt.
*
* @param prompt The prompt for which the total tokens spent need to be calculated.
* @return The total number of tokens spent as an integer.
*/
public fun tokenCountFor(prompt: Prompt): Int
}

/**
* An implementation of the [PromptTokenizer] interface that delegates token counting
* to an instance of the [Tokenizer] interface. The class provides methods to estimate
* the token count for individual messages and for the entirety of a prompt.
*
* This is useful in contexts where token-based costs or limitations are significant,
* such as when interacting with large language models (LLMs).
*
* @property tokenizer The [Tokenizer] instance used for token counting.
*/
public class OnDemandTokenizer(private val tokenizer: Tokenizer) : PromptTokenizer {

/**
* Computes the number of tokens in a given message.
*
* @param message The message for which the token count needs to be calculated.
* The content of the message is analyzed to estimate the token count.
* @return The estimated number of tokens in the message content.
*/
public override fun tokenCountFor(message: Message): Int = tokenizer.countTokens(message.content)

/**
* Calculates the total number of tokens spent for the given prompt based on its messages.
*
* @param prompt The `Prompt` instance containing the list of messages for which the total token count will be calculated.
* @return The total number of tokens across all messages in the prompt.
*/
public override fun tokenCountFor(prompt: Prompt): Int = prompt.messages.sumOf(::tokenCountFor)
}

/**
* A caching implementation of the `PromptTokenizer` interface that optimizes token counting
* by storing previously computed token counts for messages. This reduces redundant computations
* when the same message is processed multiple times.
*
* @constructor Creates an instance of `CachingTokenizer` with a provided `Tokenizer` instance
* that performs the actual token counting.
* @property tokenizer The underlying `Tokenizer` used for counting tokens in the message content.
*/
public class CachingTokenizer(private val tokenizer: Tokenizer) : PromptTokenizer {
/**
* A cache that maps a `Message` to its corresponding token count.
*
* This is used to store the results of token computations for reuse, optimizing performance
* by avoiding repeated invocations of the token counting process on the same message content.
*
* Token counts are computed lazily and stored in the cache when requested via the `tokensFor`
* method. This cache can be cleared using the `clearCache` method.
*/
internal val cache = mutableMapOf<Message, Int>()

/**
* Retrieves the number of tokens contained in the content of the given message.
* This method utilizes caching to improve performance, storing previously
* computed token counts and reusing them for identical messages.
*
* @param message The message whose content's token count is to be retrieved
* @return The number of tokens in the content of the message
*/
public override fun tokenCountFor(message: Message): Int = cache.getOrPut(message) {
tokenizer.countTokens(message.content)
}

/**
* Calculates the total number of tokens spent on the given prompt by summing the token usage
* of all messages associated with the prompt.
*
* @param prompt The prompt containing the list of messages whose token usage will be calculated.
* @return The total number of tokens spent across all messages in the provided prompt.
*/
public override fun tokenCountFor(prompt: Prompt): Int = prompt.messages.sumOf(::tokenCountFor)

/**
* Clears all cached token counts from the internal cache.
*
* This method is useful when the state of the cached data becomes invalid
* or needs resetting. After calling this, any subsequent token count
* calculations will be recomputed rather than retrieved from the cache.
*/
public fun clearCache() {
cache.clear()
}
}

/**
* The [MessageTokenizer] feature is responsible for handling tokenization of messages using a provided [Tokenizer]
* implementation. It serves as a feature that can be installed into an `AIAgentPipeline`. The tokenizer behavior can be configured
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,10 @@ import ai.koog.agents.testing.tools.getMockExecutor
import ai.koog.agents.testing.tools.mockLLMAnswer
import ai.koog.prompt.dsl.prompt
import ai.koog.prompt.llm.OllamaModels
import ai.koog.prompt.message.Message
import ai.koog.prompt.message.RequestMetaInfo
import ai.koog.prompt.message.ResponseMetaInfo
import ai.koog.prompt.tokenizer.Tokenizer
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.test.runTest
import kotlinx.datetime.Clock
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertTrue

/**
* Test for the MessageTokenizer feature.
Expand Down Expand Up @@ -72,84 +66,6 @@ class MessageTokenizerTest {
}
}

@Test
fun testPromptTokenizer() = runTest {
// Create a mock tokenizer to track token usage
val mockTokenizer = MockTokenizer()

// Create a prompt tokenizer with our mock tokenizer
val promptTokenizer = OnDemandTokenizer(mockTokenizer)

// Create a prompt with some messages
val testPrompt = prompt("test-prompt") {
system("You are a helpful assistant.")
user("What is the capital of France?")
assistant("Paris is the capital of France.")
}

// Count tokens in the prompt
val totalTokens = promptTokenizer.tokenCountFor(testPrompt)

// Verify that tokens were counted
assertTrue(totalTokens > 0, "Total tokens should be greater than 0")

// Verify that the tokenizer was used and counted tokens
assertTrue(mockTokenizer.totalTokens > 0, "Tokenizer should have counted tokens")

// Verify that the total tokens match what we expect
assertEquals(totalTokens, mockTokenizer.totalTokens, "Total tokens should match the tokenizer's count")

// Print the total tokens spent
println("[DEBUG_LOG] Total tokens spent: ${mockTokenizer.totalTokens}")

val requestMetainfo = RequestMetaInfo.create(Clock.System)
val responseMetainfo = ResponseMetaInfo.create(Clock.System)
// Count tokens for individual messages
val systemTokens = promptTokenizer.tokenCountFor(
Message.System("You are a helpful assistant.", requestMetainfo)
)
val userTokens = promptTokenizer.tokenCountFor(Message.User("What is the capital of France?", requestMetainfo))
val assistantTokens = promptTokenizer.tokenCountFor(
Message.Assistant("Paris is the capital of France.", responseMetainfo)
)

// Print token counts for each message
println("[DEBUG_LOG] System message tokens: $systemTokens")
println("[DEBUG_LOG] User message tokens: $userTokens")
println("[DEBUG_LOG] Assistant message tokens: $assistantTokens")

// Verify that the sum of individual message tokens equals the total
val sumOfMessageTokens = systemTokens + userTokens + assistantTokens
assertEquals(sumOfMessageTokens, totalTokens, "Sum of message tokens should equal total tokens")
}

@Test
fun testCachingPromptTokenizer() = runTest {
// Create a mock tokenizer to track token usage
val mockTokenizer = MockTokenizer()

// Create a prompt tokenizer with our mock tokenizer
val promptTokenizer = CachingTokenizer(mockTokenizer)

// Create a prompt with some messages
val testPrompt = prompt("test-prompt") {
system("You are a helpful assistant.")
user("What is the capital of France?")
assistant("Paris is the capital of France.")
}

assertEquals(0, promptTokenizer.cache.size)
promptTokenizer.tokenCountFor(testPrompt)
assertEquals(3, promptTokenizer.cache.size)
promptTokenizer.clearCache()
assertEquals(0, promptTokenizer.cache.size)
promptTokenizer.tokenCountFor(testPrompt.messages[1])
promptTokenizer.tokenCountFor(testPrompt.messages[2])
assertEquals(2, promptTokenizer.cache.size)
promptTokenizer.tokenCountFor(testPrompt)
assertEquals(3, promptTokenizer.cache.size)
}

@Test
fun testTokenizerInAgents() {
val testToolRegistry = ToolRegistry {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ kotlin {
api(project(":agents:agents-tools"))
api(project(":prompt:prompt-llm"))
api(project(":prompt:prompt-model"))
api(project(":prompt:prompt-tokenizer"))
api(project(":agents:agents-tools"))
api(project(":prompt:prompt-executor:prompt-executor-model"))
api(project(":prompt:prompt-executor:prompt-executor-clients"))
Expand Down Expand Up @@ -43,6 +44,7 @@ kotlin {
implementation(project(":agents:agents-features:agents-features-event-handler"))
implementation(libs.kotlinx.coroutines.core)
implementation(libs.kotlinx.coroutines.test)
implementation(libs.ktor.client.mock)
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
package ai.koog.prompt.executor.ollama.client

import ai.koog.prompt.dsl.Prompt
import ai.koog.prompt.llm.LLModel
import ai.koog.prompt.tokenizer.PromptTokenizer
import io.github.oshai.kotlinlogging.KotlinLogging

private val logger = KotlinLogging.logger { }

/**
* Represents a strategy for computing the context window length for `OllamaClient`.
* Different implementations define specific approaches to computing the context window length.
* Based on the context window length computed by this strategy, Ollama will truncate the context window accordingly.
*
* To decide the context window length, Ollama proceeds as follows:
* - If a `num_ctx` parameter is specified in the chat request, the context window length is set to that value.
* - If the model definition contains a `num_ctx` parameter, the context window length is set to that value.
* - If an `OLLAMA_CONTEXT_LENGTH` environment variable is set, the context window length is set to that value.
* - Otherwise, the context window length is set to the default value of 2048.
*
* Effectively, this strategy allows you to specify what `num_ctx` value will be set in chat requests sent to Ollama,
* for a given prompt and model.
*
* Important: You will want to have a context window length that does not change often for a specific model.
* Indeed, Ollama will reload the model every time the context window length changes.
*
* Example implementations:
* - [ContextWindowStrategy.None]
* - [ContextWindowStrategy.Fixed]
* - [ContextWindowStrategy.FitPrompt]
*/
public interface ContextWindowStrategy {

public fun computeContextLength(prompt: Prompt, model: LLModel): Long?

public companion object {
/**
* A strategy for letting the Ollama server decide the context window length.
* To decide the context window length, Ollama proceeds as follows:
* - If the model definition contains a `num_ctx` parameter, the context window length is set to that value.
* - If an `OLLAMA_CONTEXT_LENGTH` environment variable is set, the context window length is set to that value.
* - Otherwise, the context window length is set to the default value of 2048.
*/
public data object None : ContextWindowStrategy {
override fun computeContextLength(prompt: Prompt, model: LLModel): Long? = null
}

/**
* A strategy for specifying a fixed context window length.
* If the given [contextLength] is more than the maximum context window length supported by the model,
* the context window length will be set to the maximum context window length supported by the model.
*
* @param contextLength The context window length to use.
*/
public data class Fixed(val contextLength: Long) : ContextWindowStrategy {
override fun computeContextLength(prompt: Prompt, model: LLModel): Long {
if (contextLength > model.contextLength) {
logger.warn {
"Context length $contextLength was more than what is supported by model '${model.id}'," +
" falling back to the model's maximum context length ${model.contextLength}"
}
return model.contextLength
}
return contextLength
}
}

/**
* A strategy for computing the context window length based on the prompt length.
*
* @param promptTokenizer The [PromptTokenizer] to use for computing the prompt length,
* or null to use the last reported token usage.
* @param granularity The granularity to use for computing the context window length. Defaults to 2048.
* @param minimumContextLength The minimum context window length,
* if the prompt length is less than it or cannot be computed yet.
* If not null, [minimumContextLength] must be a multiple of the [granularity].
* If null, we let Ollama decide the context window length.
*/
public data class FitPrompt(
val promptTokenizer: PromptTokenizer? = null,
val granularity: Long = 2048,
val minimumContextLength: Long? = null,
) : ContextWindowStrategy {

init {
require(granularity > 0) { "Granularity must be greater than 0" }
require(minimumContextLength == null || minimumContextLength % granularity == 0L) {
"Minimum context length must be a multiple of granularity"
}
}

override fun computeContextLength(prompt: Prompt, model: LLModel): Long? {
val promptLength = when {
promptTokenizer != null -> promptTokenizer.tokenCountFor(prompt)
prompt.latestTokenUsage != 0 -> prompt.latestTokenUsage
else -> null
}

if (promptLength == null) return minimumContextLength
if (promptLength > model.contextLength) {
logger.warn {
"Prompt length $promptLength was more than the maximum context length of model '${model.id}'," +
" falling back to the model's maximum context length ${model.contextLength}"
}
return model.contextLength
}

return (promptLength / granularity + 1) * granularity
}
}
}
}
Loading
Loading