Skip to content

Commit bec4690

Browse files
committed
Add search agent test
1 parent b51b4cb commit bec4690

File tree

1 file changed

+59
-0
lines changed

1 file changed

+59
-0
lines changed

chat/test/agent/test_search_agent.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33

44
from agent.search_agent import SearchAgent
55
from langchain_core.language_models.fake_chat_models import FakeListChatModel
6+
from langchain_core.language_models.fake_chat_models import FakeMessagesListChatModel
67
from langgraph.checkpoint.memory import MemorySaver
8+
from langchain_core.messages import AIMessage
9+
710

811

912
class TestSearchAgent(TestCase):
@@ -26,6 +29,62 @@ def test_search_agent_invoke_simple(self, mock_create_saver):
2629
self.assertGreater(len(result["messages"]), 0)
2730
self.assertEqual(result["messages"][-1].content, expected_response)
2831

32+
# @patch("agent.tools.search")
33+
# @patch("agent.search_agent.checkpoint_saver", return_value=MemorySaver())
34+
# def test_search_agent_invoke_with_facets(self, mock_create_saver, mock_search):
35+
# expected_response = "This is a mocked LLM response."
36+
# chat_model = FakeListChatModel(responses=[expected_response])
37+
# search_agent = SearchAgent(model=chat_model, streaming=True)
38+
# facets = [{"country": "France"}]
39+
# result = search_agent.invoke(
40+
# question="What is the capital of France?",
41+
# ref="test_ref",
42+
# facets=facets,
43+
# )
44+
# self.assertIn("messages", result)
45+
# self.assertGreater(len(result["messages"]), 0)
46+
# self.assertEqual(result["messages"][-1].content, expected_response)
47+
# # Check that the search tool was called with the correct facets
48+
# called = False
49+
# for call in mock_search.call_args_list:
50+
# args, kwargs = call
51+
# if "facets" in kwargs and kwargs["facets"] == facets:
52+
# called = True
53+
# break
54+
# print("mock_search.call_args_list:", mock_search.call_args_list)
55+
# self.assertTrue(called, "search tool was not called with the expected facets")
56+
57+
@patch("agent.search_agent.checkpoint_saver", return_value=MemorySaver())
58+
@patch("agent.tools.opensearch_vector_store")
59+
def test_search_agent_invoke_with_facets(self, mock_opensearch_store, mock_create_saver):
60+
mock_store_instance = mock_opensearch_store.return_value
61+
mock_store_instance.similarity_search.return_value = []
62+
63+
tool_call_response = AIMessage(
64+
content="I'll search for information.",
65+
tool_calls=[{
66+
"name": "search",
67+
"args": {"query": "capital of France"},
68+
"id": "test_call_id"
69+
}]
70+
)
71+
final_response = AIMessage(content="This is the final response.")
72+
73+
chat_model = FakeMessagesListChatModel(responses=[tool_call_response, final_response])
74+
search_agent = SearchAgent(model=chat_model)
75+
76+
facets = [{"country": "France"}]
77+
search_agent.invoke(
78+
question="What is the capital of France?",
79+
ref="test_ref",
80+
facets=facets,
81+
)
82+
83+
mock_store_instance.similarity_search.assert_called()
84+
call_kwargs = mock_store_instance.similarity_search.call_args[1]
85+
self.assertIn("facets", call_kwargs)
86+
self.assertEqual(call_kwargs["facets"], facets)
87+
2988
@patch("agent.search_agent.checkpoint_saver")
3089
def test_search_agent_invocation(self, mock_create_saver):
3190
# Create a memory saver instance with a Mock for delete_checkpoints

0 commit comments

Comments
 (0)