Skip to content

Commit 3e4f286

Browse files
authored
Refactor and add mocked tests for SimpleAgent functionality (#19)
1 parent 8fa323e commit 3e4f286

File tree

2 files changed

+239
-145
lines changed

2 files changed

+239
-145
lines changed

agents/agents-test/src/jvmTest/kotlin/ai/grazie/code/agents/test/SimpleAgentIntegrationTest.kt

Lines changed: 81 additions & 145 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ import ai.jetbrains.code.prompt.executor.llms.all.simpleOpenAIExecutor
1010
import kotlinx.coroutines.CoroutineScope
1111
import kotlinx.coroutines.runBlocking
1212
import org.junit.jupiter.api.Test
13+
import kotlin.test.AfterTest
1314
import kotlin.test.assertTrue
14-
import kotlin.test.junit5.JUnit5Asserter
1515

1616
class SimpleAgentIntegrationTest {
1717
val systemPrompt = """
@@ -27,179 +27,115 @@ class SimpleAgentIntegrationTest {
2727
}
2828
block(apiToken)
2929
}
30-
3130

32-
@Test
33-
fun `simpleChatAgent should call default tools`() = runBlockingWithToken {
34-
val actualToolCalls = mutableListOf<String>()
35-
val eventHandler = EventHandler {
36-
onToolCall { stage, tool, args ->
37-
println("Tool called: stage ${stage.name}, tool ${tool.name}, args $args")
38-
actualToolCalls.add(tool.name)
39-
}
40-
41-
handleError {
42-
JUnit5Asserter.fail("An error occurred: ${it.message}\n${it.stackTraceToString()}")
43-
}
31+
val eventHandler = EventHandler {
32+
onToolCall { stage, tool, args ->
33+
println("Tool called: stage ${stage.name}, tool ${tool.name}, args $args")
34+
actualToolCalls.add(tool.name)
35+
}
4436

45-
handleResult {
46-
println("Agent result: $it")
47-
}
37+
handleError {
38+
errors.add(it)
4839
}
4940

50-
try {
51-
actualToolCalls.clear()
52-
53-
val agent = simpleChatAgent(
54-
executor = simpleOpenAIExecutor(apiToken),
55-
cs = this,
56-
systemPrompt = systemPrompt,
57-
llmModel = OpenAIModels.GPT4o,
58-
temperature = 1.0,
59-
eventHandler = eventHandler,
60-
maxIterations = 10,
61-
)
62-
63-
agent.run("Please exit.")
64-
assertTrue(actualToolCalls.isNotEmpty(), "No tools were called")
65-
} catch (e: Exception) {
66-
println("An error occurred: ${e.message}\n${e.stackTraceToString()}")
41+
handleResult {
42+
results.add(it)
6743
}
6844
}
6945

70-
@Test
71-
fun `simpleChatAgent should call a custom tool`() = runBlockingWithToken {
72-
val actualToolCalls = mutableListOf<String>()
73-
val eventHandler = EventHandler {
74-
onToolCall { stage, tool, args ->
75-
println("Tool called: stage ${stage.name}, tool ${tool.name}, args $args")
76-
actualToolCalls.add(tool.name)
77-
}
46+
val actualToolCalls = mutableListOf<String>()
47+
val errors = mutableListOf<Throwable>()
48+
val results = mutableListOf<String?>()
7849

79-
handleError {
80-
JUnit5Asserter.fail("An error occurred: ${it.message}\n${it.stackTraceToString()}")
81-
}
50+
@AfterTest
51+
fun teardown() {
52+
actualToolCalls.clear()
53+
errors.clear()
54+
results.clear()
55+
}
8256

83-
handleResult {
84-
println("Agent result: $it")
85-
}
86-
}
8757

58+
@Test
59+
fun `simpleChatAgent should call default tools`() = runBlockingWithToken {
60+
val agent = simpleChatAgent(
61+
executor = simpleOpenAIExecutor(apiToken),
62+
cs = this,
63+
systemPrompt = systemPrompt,
64+
llmModel = OpenAIModels.GPT4o,
65+
temperature = 1.0,
66+
eventHandler = eventHandler,
67+
maxIterations = 10,
68+
)
69+
70+
agent.run("Please exit.")
71+
assertTrue(actualToolCalls.isNotEmpty(), "No tools were called")
72+
}
73+
74+
@Test
75+
fun `simpleChatAgent should call a custom tool`() = runBlockingWithToken {
8876
val toolRegistry = ToolRegistry {
8977
stage {
9078
tool(SayToUser)
9179
}
9280
}
9381

94-
try {
95-
actualToolCalls.clear()
96-
97-
val agent = simpleChatAgent(
98-
executor = simpleOpenAIExecutor(apiToken),
99-
cs = this,
100-
systemPrompt = systemPrompt,
101-
llmModel = OpenAIModels.GPT4oMini,
102-
temperature = 1.0,
103-
eventHandler = eventHandler,
104-
maxIterations = 10,
105-
toolRegistry = toolRegistry,
106-
)
107-
108-
agent.run("Hello, how are you?")
109-
110-
assertTrue(actualToolCalls.isNotEmpty(), "No tools were called")
111-
assertTrue(actualToolCalls.contains("__say_to_user__"), "The __say_to_user__ tool was not called")
112-
} catch (e: Exception) {
113-
JUnit5Asserter.fail("An error occurred: ${e.message}\n${e.stackTraceToString()}")
114-
}
82+
val agent = simpleChatAgent(
83+
executor = simpleOpenAIExecutor(apiToken),
84+
cs = this,
85+
systemPrompt = systemPrompt,
86+
llmModel = OpenAIModels.GPT4oMini,
87+
temperature = 1.0,
88+
eventHandler = eventHandler,
89+
maxIterations = 10,
90+
toolRegistry = toolRegistry,
91+
)
92+
93+
agent.run("Hello, how are you?")
94+
95+
assertTrue(actualToolCalls.isNotEmpty(), "No tools were called")
96+
assertTrue(actualToolCalls.contains("__say_to_user__"), "The __say_to_user__ tool was not called")
11597
}
11698

11799
@Test
118100
fun `simpleSingleRunAgent should not call tools by default`() = runBlockingWithToken {
119-
val actualToolCalls = mutableListOf<String>()
120-
121-
val eventHandler = EventHandler {
122-
onToolCall { stage, tool, args ->
123-
println("Tool called: stage ${stage.name}, tool ${tool.name}, args $args")
124-
actualToolCalls.add(tool.name)
125-
}
126-
127-
handleError {
128-
JUnit5Asserter.fail("An error occurred: ${it.message}\n${it.stackTraceToString()}")
129-
}
130-
131-
handleResult {
132-
println("Agent result: $it")
133-
}
134-
}
135-
136-
try {
137-
actualToolCalls.clear()
138-
139-
val agent = simpleSingleRunAgent(
140-
executor = simpleOpenAIExecutor(apiToken),
141-
cs = this,
142-
systemPrompt = systemPrompt,
143-
llmModel = OpenAIModels.GPT4oMini,
144-
temperature = 1.0,
145-
eventHandler = eventHandler,
146-
maxIterations = 10,
147-
)
148-
149-
agent.run("Repeat what I say: hello, I'm good.")
150-
151-
// by default, simpleSingleRunAgent has no tools underneath
152-
assertTrue(actualToolCalls.isEmpty(), "No tools should be called")
153-
} catch (e: Exception) {
154-
JUnit5Asserter.fail("An error occurred: ${e.message}\n${e.stackTraceToString()}")
155-
}
101+
val agent = simpleSingleRunAgent(
102+
executor = simpleOpenAIExecutor(apiToken),
103+
cs = this,
104+
systemPrompt = systemPrompt,
105+
llmModel = OpenAIModels.GPT4oMini,
106+
temperature = 1.0,
107+
eventHandler = eventHandler,
108+
maxIterations = 10,
109+
)
110+
111+
agent.run("Repeat what I say: hello, I'm good.")
112+
113+
// by default, simpleSingleRunAgent has no tools underneath
114+
assertTrue(actualToolCalls.isEmpty(), "No tools should be called")
156115
}
157116

158117
@Test
159118
fun `simpleSingleRunAgent should call a custom tool`() = runBlockingWithToken {
160-
val actualToolCalls = mutableListOf<String>()
161-
162-
val eventHandler = EventHandler {
163-
onToolCall { stage, tool, args ->
164-
println("Tool called: stage ${stage.name}, tool ${tool.name}, args $args")
165-
actualToolCalls.add(tool.name)
166-
}
167-
168-
handleError {
169-
JUnit5Asserter.fail("An error occurred: ${it.message}\n${it.stackTraceToString()}")
170-
}
171-
172-
handleResult {
173-
println("Agent result: $it")
174-
}
175-
}
176-
177119
val toolRegistry = ToolRegistry {
178120
stage {
179121
tool(SayToUser)
180122
}
181123
}
182124

183-
try {
184-
actualToolCalls.clear()
185-
186-
val agent = simpleSingleRunAgent(
187-
executor = simpleOpenAIExecutor(apiToken),
188-
cs = this,
189-
systemPrompt = systemPrompt,
190-
llmModel = OpenAIModels.GPT4oMini,
191-
temperature = 1.0,
192-
eventHandler = eventHandler,
193-
toolRegistry = toolRegistry,
194-
maxIterations = 10,
195-
)
196-
197-
agent.run("Write a Kotlin function to calculate factorial.")
198-
199-
assertTrue(actualToolCalls.isNotEmpty(), "No tools were called")
200-
assertTrue(actualToolCalls.contains("__say_to_user__"), "The __say_to_user__ tool was not called")
201-
} catch (e: Exception) {
202-
JUnit5Asserter.fail("An error occurred: ${e.message}\n${e.stackTraceToString()}")
203-
}
125+
val agent = simpleSingleRunAgent(
126+
executor = simpleOpenAIExecutor(apiToken),
127+
cs = this,
128+
systemPrompt = systemPrompt,
129+
llmModel = OpenAIModels.GPT4oMini,
130+
temperature = 1.0,
131+
eventHandler = eventHandler,
132+
toolRegistry = toolRegistry,
133+
maxIterations = 10,
134+
)
135+
136+
agent.run("Write a Kotlin function to calculate factorial.")
137+
138+
assertTrue(actualToolCalls.isNotEmpty(), "No tools were called")
139+
assertTrue(actualToolCalls.contains("__say_to_user__"), "The __say_to_user__ tool was not called")
204140
}
205141
}

0 commit comments

Comments
 (0)