Skip to content

Commit 0efa287

Browse files
authored
AIAgent: complete a deferred result on error (#10)
Fixed the issue with AIAgent::runAndGetResult handing if an error occurred during agent execution. Also, updated a test to avoid a deadlock.
1 parent 70b4265 commit 0efa287

File tree

2 files changed

+116
-65
lines changed

2 files changed

+116
-65
lines changed

agents/agents-core/src/commonMain/kotlin/ai/grazie/code/agents/core/AIAgent.kt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,11 @@ abstract class AIAgent<TStrategy : AIAgentStrategy<TConfig>, TConfig : AIAgentCo
8383
}
8484

8585
null -> {
86+
// If execution is stopped by an error (and we didn't throw an exception up to this point),
87+
// let's complete the deferred with null to unblock any awaiting coroutines.
88+
if (!agentResultDeferred.isCompleted) {
89+
agentResultDeferred.complete(null)
90+
}
8691
logger.debug { "Agent execution completed. Stopping..." }
8792
runningMutex.withLock {
8893
isRunning = false

prompt/prompt-executor/prompt-executor-llms-all/src/jvmTest/kotlin/ai/jetbrains/code/prompt/executor/llms/all/KotlinAIAgentWithMultipleLLMTest.kt

Lines changed: 111 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@ import kotlinx.serialization.Serializable
3131
import org.junit.jupiter.api.Disabled
3232
import kotlin.coroutines.coroutineContext
3333
import kotlin.test.Test
34+
import kotlin.test.assertEquals
3435
import kotlin.test.assertNotNull
36+
import kotlin.test.assertNull
3537
import kotlin.test.assertTrue
3638
import kotlin.time.Duration.Companion.seconds
3739

@@ -284,16 +286,116 @@ class KotlinAIAgentWithMultipleLLMTest {
284286
}
285287
}
286288

289+
// TODO: pass the `OPEN_AI_API_TEST_KEY` and `ANTHROPIC_API_TEST_KEY`
287290
@Disabled("This test requires valid API keys")
288-
@OptIn(DelicateCoroutinesApi::class)
289291
@Test
290292
fun testKotlinAIAgentWithOpenAIAndAnthropic() = runTest(timeout = 600.seconds) {
291-
// TODO: pass the `OPEN_AI_API_TEST_KEY` and `ANTHROPIC_API_TEST_KEY`
292-
return@runTest
293-
294293
// Create the clients
295-
val eventsChannel = Channel<Event>()
294+
val eventsChannel = Channel<Event>(Channel.UNLIMITED)
295+
val fs = MockFileSystem()
296+
val eventHandler = EventHandler {
297+
onToolCall { stage, tool, arguments ->
298+
println(
299+
"[$stage] Calling tool ${tool.name} with arguments ${
300+
arguments.toString().lines().first().take(100)
301+
}"
302+
)
303+
}
304+
305+
handleResult {
306+
eventsChannel.send(Event.Termination)
307+
}
308+
}
309+
val agent = createTestAgent(eventsChannel, fs, eventHandler, maxAgentIterations = 42)
310+
311+
val result = agent.runAndGetResult(
312+
"Generate me a project in Ktor that has a GET endpoint that returns the capital of France. Write a test"
313+
)
314+
315+
assertNotNull(result)
316+
317+
assertTrue(
318+
fs.fileCount() > 0,
319+
"Agent must have created at least one file"
320+
)
321+
322+
val messages = mutableListOf<Event.Message>()
323+
for (msg in eventsChannel) {
324+
if (msg is Event.Message) messages.add(msg)
325+
else break
326+
}
327+
328+
assertTrue(
329+
messages.any { it.llmClient == "AnthropicDirectLLMClient" },
330+
"At least one message must be delegated to Anthropic client"
331+
)
332+
333+
assertTrue(
334+
messages.any { it.llmClient == "OpenAIDirectLLMClient" },
335+
"At least one message must be delegated to OpenAI client"
336+
)
337+
338+
assertTrue(
339+
messages
340+
.filter { it.llmClient == "AnthropicDirectLLMClient" }
341+
.all { it.prompt.model.provider == LLMProvider.Anthropic },
342+
"All prompts with Anthropic model must be delegated to Anthropic client"
343+
)
344+
345+
assertTrue(
346+
messages
347+
.filter { it.llmClient == "OpenAIDirectLLMClient" }
348+
.all { it.prompt.model.provider == LLMProvider.OpenAI },
349+
"All prompts with OpenAI model must be delegated to OpenAI client"
350+
)
351+
}
352+
353+
// TODO: pass the `OPEN_AI_API_TEST_KEY` and `ANTHROPIC_API_TEST_KEY`
354+
@Disabled("This test requires valid API keys")
355+
@Test
356+
fun testTerminationOnIterationsLimitExhaustion() = runTest(timeout = 600.seconds) {
357+
val eventsChannel = Channel<Event>(Channel.UNLIMITED)
358+
val fs = MockFileSystem()
359+
var errorMessage: String? = null
360+
val eventHandler = EventHandler {
361+
onToolCall { stage, tool, arguments ->
362+
println(
363+
"[$stage] Calling tool ${tool.name} with arguments ${
364+
arguments.toString().lines().first().take(100)
365+
}"
366+
)
367+
}
368+
369+
handleResult {
370+
eventsChannel.send(Event.Termination)
371+
}
296372

373+
handleError {
374+
errorMessage = it.message
375+
true
376+
}
377+
}
378+
val steps = 10
379+
val agent = createTestAgent(eventsChannel, fs, eventHandler, maxAgentIterations = steps)
380+
381+
val result = agent.runAndGetResult(
382+
"Generate me a project in Ktor that has a GET endpoint that returns the capital of France. Write a test"
383+
)
384+
assertNull(result)
385+
assertEquals(
386+
"Local AI Agent has run into a problem: agent couldn't finish in given number of steps ($steps). " +
387+
"Please, consider increasing `maxAgentIterations` value in agent's configuration",
388+
errorMessage
389+
)
390+
}
391+
392+
@OptIn(DelicateCoroutinesApi::class)
393+
private fun createTestAgent(
394+
eventsChannel: Channel<Event>,
395+
fs: MockFileSystem,
396+
eventHandler: EventHandler,
397+
maxAgentIterations: Int
398+
): KotlinAIAgent {
297399
val openAIClient = OpenAIDirectLLMClient(openAIApiKey).reportingTo(eventsChannel)
298400
val anthropicClient = AnthropicDirectLLMClient(anthropicApiKey).reportingTo(eventsChannel)
299401

@@ -361,8 +463,6 @@ class KotlinAIAgentWithMultipleLLMTest {
361463
}
362464
}
363465

364-
val fs = MockFileSystem()
365-
366466
val tools = ToolRegistry {
367467
stage("anthropic") {
368468
tool(CreateFile(fs))
@@ -375,72 +475,18 @@ class KotlinAIAgentWithMultipleLLMTest {
375475
}
376476
}
377477

378-
379478
// Create the agent
380-
val agent = KotlinAIAgent(
479+
return KotlinAIAgent(
381480
toolRegistry = tools,
382481
strategy = strategy,
383-
eventHandler = EventHandler {
384-
onToolCall { stage, tool, arguments ->
385-
println(
386-
"[$stage] Calling tool ${tool.name} with arguments ${
387-
arguments.toString().lines().first().take(100)
388-
}"
389-
)
390-
}
391-
392-
handleResult {
393-
eventsChannel.send(Event.Termination)
394-
}
395-
},
396-
agentConfig = LocalAgentConfig(prompt(OpenAIModels.GPT4o, "test") {}, 15),
482+
eventHandler = eventHandler,
483+
agentConfig = LocalAgentConfig(prompt(OpenAIModels.GPT4o, "test") {}, maxAgentIterations),
397484
promptExecutor = executor,
398485
cs = CoroutineScope(newFixedThreadPoolContext(2, "TestAgent"))
399486
) {
400487
install(TraceFeature) {
401488
addMessageProcessor(TestLogPrinter())
402489
}
403490
}
404-
405-
val result = agent.runAndGetResult(
406-
"Generate me a project in Ktor that has a GET endpoint that returns the capital of France. Write a test"
407-
)
408-
409-
assertNotNull(result)
410-
411-
assertTrue(
412-
fs.fileCount() > 0,
413-
"Agent must have created at least one file"
414-
)
415-
416-
val messages = mutableListOf<Event.Message>()
417-
for (msg in eventsChannel) {
418-
if (msg is Event.Message) messages.add(msg)
419-
else break
420-
}
421-
422-
assertTrue(
423-
messages.any { it.llmClient == "AnthropicSuspendableDirectClient" },
424-
"At least one message must be delegated to Anthropic client"
425-
)
426-
427-
assertTrue(
428-
messages.any { it.llmClient == "OpenAISuspendableDirectClient" },
429-
"At least one message must be delegated to OpenAI client"
430-
)
431-
432-
assertTrue(
433-
messages
434-
.filter { it.llmClient == "AnthropicSuspendableDirectClient" }
435-
.all { it.prompt.model.provider == LLMProvider.Anthropic },
436-
"All prompts with Anthropic model must be delegated to Anthropic client"
437-
)
438-
439-
assertTrue(
440-
messages
441-
.filter { it.llmClient == "OpenAISuspendableDirectClient" }
442-
.all { it.prompt.model.provider == LLMProvider.OpenAI },
443-
"All prompts with OpenAI model must be delegated to OpenAI client"
444-
)
445491
}
446-
}
492+
}

0 commit comments

Comments
 (0)