3
3
4
4
from agent .search_agent import SearchAgent
5
5
from langchain_core .language_models .fake_chat_models import FakeListChatModel
6
+ from langchain_core .language_models .fake_chat_models import FakeMessagesListChatModel
6
7
from langgraph .checkpoint .memory import MemorySaver
8
+ from langchain_core .messages import AIMessage
9
+
7
10
8
11
9
12
class TestSearchAgent (TestCase ):
@@ -26,6 +29,62 @@ def test_search_agent_invoke_simple(self, mock_create_saver):
26
29
self .assertGreater (len (result ["messages" ]), 0 )
27
30
self .assertEqual (result ["messages" ][- 1 ].content , expected_response )
28
31
32
+ # @patch("agent.tools.search")
33
+ # @patch("agent.search_agent.checkpoint_saver", return_value=MemorySaver())
34
+ # def test_search_agent_invoke_with_facets(self, mock_create_saver, mock_search):
35
+ # expected_response = "This is a mocked LLM response."
36
+ # chat_model = FakeListChatModel(responses=[expected_response])
37
+ # search_agent = SearchAgent(model=chat_model, streaming=True)
38
+ # facets = [{"country": "France"}]
39
+ # result = search_agent.invoke(
40
+ # question="What is the capital of France?",
41
+ # ref="test_ref",
42
+ # facets=facets,
43
+ # )
44
+ # self.assertIn("messages", result)
45
+ # self.assertGreater(len(result["messages"]), 0)
46
+ # self.assertEqual(result["messages"][-1].content, expected_response)
47
+ # # Check that the search tool was called with the correct facets
48
+ # called = False
49
+ # for call in mock_search.call_args_list:
50
+ # args, kwargs = call
51
+ # if "facets" in kwargs and kwargs["facets"] == facets:
52
+ # called = True
53
+ # break
54
+ # print("mock_search.call_args_list:", mock_search.call_args_list)
55
+ # self.assertTrue(called, "search tool was not called with the expected facets")
56
+
57
+ @patch ("agent.search_agent.checkpoint_saver" , return_value = MemorySaver ())
58
+ @patch ("agent.tools.opensearch_vector_store" )
59
+ def test_search_agent_invoke_with_facets (self , mock_opensearch_store , mock_create_saver ):
60
+ mock_store_instance = mock_opensearch_store .return_value
61
+ mock_store_instance .similarity_search .return_value = []
62
+
63
+ tool_call_response = AIMessage (
64
+ content = "I'll search for information." ,
65
+ tool_calls = [{
66
+ "name" : "search" ,
67
+ "args" : {"query" : "capital of France" },
68
+ "id" : "test_call_id"
69
+ }]
70
+ )
71
+ final_response = AIMessage (content = "This is the final response." )
72
+
73
+ chat_model = FakeMessagesListChatModel (responses = [tool_call_response , final_response ])
74
+ search_agent = SearchAgent (model = chat_model )
75
+
76
+ facets = [{"country" : "France" }]
77
+ search_agent .invoke (
78
+ question = "What is the capital of France?" ,
79
+ ref = "test_ref" ,
80
+ facets = facets ,
81
+ )
82
+
83
+ mock_store_instance .similarity_search .assert_called ()
84
+ call_kwargs = mock_store_instance .similarity_search .call_args [1 ]
85
+ self .assertIn ("facets" , call_kwargs )
86
+ self .assertEqual (call_kwargs ["facets" ], facets )
87
+
29
88
@patch ("agent.search_agent.checkpoint_saver" )
30
89
def test_search_agent_invocation (self , mock_create_saver ):
31
90
# Create a memory saver instance with a Mock for delete_checkpoints
0 commit comments