Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 8ba041f

Browse files
authoredNov 5, 2023
Merge pull request #270 from confident-ai/feature/tracing
added tracing
2 parents 72122a9 + 2171df0 commit 8ba041f

File tree

6 files changed

+432
-71
lines changed

6 files changed

+432
-71
lines changed
 

‎deepeval/api.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from deepeval.key_handler import KEY_FILE_HANDLER
2020
from deepeval.metrics.base_metric import BaseMetric
2121
from deepeval.test_case import LLMTestCase
22+
from deepeval.tracing import TraceData, get_trace_stack
2223

2324
API_BASE_URL = "https://app.confident-ai.com/api"
2425
# API_BASE_URL = "http://localhost:3000/api"
@@ -45,9 +46,9 @@ class APITestCase(BaseModel):
4546
metrics_metadata: List[MetricsMetadata] = Field(
4647
..., alias="metricsMetadata"
4748
)
48-
threshold: float
4949
run_duration: float = Field(..., alias="runDuration")
5050
context: Optional[list] = Field(None)
51+
traceStack: Optional[dict] = Field(None)
5152

5253

5354
class MetricScore(BaseModel):
@@ -144,14 +145,12 @@ def add_llm_test_case(
144145
metrics_metadata_dict.add_metric(metric)
145146
metrics_metadata = metrics_metadata_dict.get_metrics_metadata()
146147
success = all([metric.is_successful() for metric in metrics])
147-
threshold = metrics[0].minimum_score
148148

149149
if existing_test_case:
150150
# If it exists, append the metrics to the existing test case
151151
existing_test_case.metrics_metadata.extend(metrics_metadata)
152-
# Update the success status and threshold
153-
existing_test_case.success = success
154-
existing_test_case.threshold = threshold
152+
# Update the success status
153+
existing_test_case.success = success and existing_test_case.success
155154
else:
156155
# If it doesn't exist, create a new test case
157156
# Adding backwards compatibility to ensure context still works.
@@ -167,9 +166,9 @@ def add_llm_test_case(
167166
expectedOutput=test_case.expected_output,
168167
success=success,
169168
metricsMetadata=metrics_metadata,
170-
threshold=threshold,
171169
runDuration=run_duration,
172170
context=context,
171+
traceStack=get_trace_stack(),
173172
)
174173
)
175174

‎deepeval/test_case.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"""
33
import hashlib
44
from dataclasses import dataclass
5-
from typing import Any, List, Optional, Union
5+
from typing import List, Optional, Union
66

77

88
@dataclass

‎deepeval/tracing.py

Lines changed: 283 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,283 @@
1+
from dataclasses import dataclass
2+
from enum import Enum
3+
from functools import wraps
4+
from typing import Any, Callable, List, Union, Optional
5+
from time import perf_counter
6+
import traceback
7+
from inspect import signature, isfunction, ismethod
8+
import threading
9+
from deepeval.utils import dataclass_to_dict
10+
11+
12+
class TraceType(Enum):
13+
LLM = "LLM"
14+
RETRIEVER = "Retriever"
15+
EMBEDDING = "Embedding"
16+
TOOL = "Tool"
17+
AGENT = "Agent"
18+
CHAIN = "Chain"
19+
20+
21+
class TraceStatus(Enum):
22+
SUCCESS = "Success"
23+
ERROR = "Error"
24+
25+
26+
@dataclass
27+
class LlmMetadata:
28+
model: str
29+
inputTokenUsage: int
30+
outputTokenUsage: int
31+
cost: float
32+
33+
34+
@dataclass
35+
class EmbeddingMetadata:
36+
model: str
37+
38+
39+
@dataclass
40+
class BaseTrace:
41+
type: Union[TraceType, str]
42+
executionTime: float
43+
name: str
44+
input: dict
45+
output: dict
46+
status: TraceStatus
47+
traces: List["TraceData"]
48+
49+
50+
@dataclass
51+
class LlmTrace(BaseTrace):
52+
input: str
53+
llmMetadata: LlmMetadata = None
54+
55+
56+
@dataclass
57+
class EmbeddingTrace(BaseTrace):
58+
embeddingMetadata: EmbeddingMetadata
59+
60+
61+
@dataclass
62+
class GenericTrace(BaseTrace):
63+
type: str
64+
65+
66+
TraceData = Union[LlmTrace, EmbeddingTrace, GenericTrace]
67+
68+
69+
class TraceManager:
70+
def __init__(self):
71+
self._local = threading.local()
72+
73+
def get_trace_stack(self):
74+
if not hasattr(self._local, "trace_stack"):
75+
self._local.trace_stack = []
76+
self._local.dict_trace_stack = None
77+
return self._local.trace_stack
78+
79+
def clear_trace_stack(self):
80+
self.get_trace_stack().clear()
81+
82+
def pop_trace_stack(self):
83+
if self.get_trace_stack():
84+
self.get_trace_stack().pop()
85+
86+
def append_to_trace_stack(self, trace_instance):
87+
self.get_trace_stack().append(trace_instance)
88+
89+
def set_dict_trace_stack(self, dict_trace_stack):
90+
self._local.dict_trace_stack = dict_trace_stack
91+
92+
def get_and_reset_dict_trace_stack(self):
93+
dict_trace_stack = getattr(self._local, "dict_trace_stack", None)
94+
self._local.dict_trace_stack = None
95+
return dict_trace_stack
96+
97+
98+
trace_manager = TraceManager()
99+
100+
101+
def trace(
102+
type: str,
103+
name: Optional[str] = None,
104+
model: Optional[str] = None,
105+
characters_per_token: Optional[Union[float, int]] = None,
106+
cost_per_token: Optional[float] = None,
107+
):
108+
assert isinstance(
109+
type, Union[TraceType, str]
110+
), "'type' must be a 'TraceType' or str"
111+
112+
if type in [TraceType.LLM, TraceType.EMBEDDING] and model is None:
113+
raise ValueError(f"{type} trace type requires a model.")
114+
assert model is None or isinstance(
115+
model, str
116+
), "'model' must be a str or None"
117+
118+
if type not in [TraceType.LLM, TraceType.EMBEDDING] and model is not None:
119+
raise ValueError(
120+
f"Parameter 'model' should not be provided for {type} trace types."
121+
)
122+
123+
if type == TraceType.LLM and characters_per_token is None:
124+
raise ValueError(
125+
"LLM trace type requires 'characters_per_token' as a parameters."
126+
)
127+
assert characters_per_token is None or isinstance(
128+
characters_per_token, Union[float, int]
129+
), "'characters_per_token' must be an int, float or None"
130+
131+
if type == TraceType.LLM and cost_per_token is None:
132+
raise ValueError(
133+
"LLM trace type requires 'cost_per_token' as a parameters."
134+
)
135+
assert cost_per_token is None or isinstance(
136+
cost_per_token, Union[int, float]
137+
), "'cost_per_token' must be an int, float or None"
138+
139+
if type != TraceType.LLM and (
140+
characters_per_token is not None or cost_per_token is not None
141+
):
142+
raise ValueError(
143+
"Parameters 'characters_per_token' and 'cost_per_token' should not be provided for non-LLM trace types."
144+
)
145+
146+
def decorator_trace(func: Callable):
147+
if type == TraceType.LLM:
148+
sig = signature(func)
149+
params = sig.parameters.values()
150+
151+
# Check if it's an instance method, adjust parameter list if 'self' or 'cls' is present
152+
if any(p.name in ["self", "cls"] for p in params):
153+
params = [p for p in params if p.name not in ["self", "cls"]]
154+
155+
# There should be exactly one parameter left of type str
156+
if len(params) != 1:
157+
raise ValueError(
158+
"Function of type `TraceType.LLM` must have exactly one parameter of type str"
159+
)
160+
161+
@wraps(func)
162+
def wrapper(*args, **kwargs):
163+
sig = signature(func)
164+
if type == TraceType.LLM:
165+
input_str = (
166+
args[1]
167+
if "self" in sig.parameters or "cls" in sig.parameters
168+
else args[0]
169+
)
170+
if not isinstance(input_str, str):
171+
raise ValueError(
172+
"Argument type for `TraceType.LLM` must be a string"
173+
)
174+
175+
bound_method = False
176+
# Check if it is called with 'self' or 'cls' parameter
177+
params = sig.parameters
178+
if args:
179+
first_param = next(iter(params))
180+
if first_param == "self" or first_param == "cls":
181+
bound_method = True
182+
183+
# Remove 'self' or 'cls' parameter if function is a method
184+
if bound_method:
185+
trace_args = args[1:]
186+
else:
187+
trace_args = args
188+
189+
# Proceed to create your trace, using trace_args instead of args
190+
trace_instance_input = {"args": trace_args, "kwargs": kwargs}
191+
192+
trace_instance = None
193+
effective_name = name if name is not None else func.__name__
194+
if type == TraceType.LLM:
195+
trace_instance = LlmTrace(
196+
type=type,
197+
executionTime=0,
198+
name=effective_name,
199+
input=input_str,
200+
output=None,
201+
status=TraceStatus.SUCCESS,
202+
traces=[],
203+
llmMetadata=None,
204+
)
205+
elif type == TraceType.EMBEDDING:
206+
trace_instance = EmbeddingTrace(
207+
type=type,
208+
executionTime=0,
209+
name=effective_name,
210+
input=trace_instance_input,
211+
output=None,
212+
status=TraceStatus.SUCCESS,
213+
traces=[],
214+
embeddingMetadata=EmbeddingMetadata(model=model),
215+
)
216+
else:
217+
trace_instance = GenericTrace(
218+
type=type,
219+
executionTime=0,
220+
name=effective_name,
221+
input=trace_instance_input,
222+
output=None,
223+
status=TraceStatus.SUCCESS,
224+
traces=[],
225+
)
226+
227+
trace_manager.append_to_trace_stack(trace_instance)
228+
start_time = perf_counter()
229+
try:
230+
result = func(*args, **kwargs)
231+
trace_instance.output = result
232+
233+
if type == TraceType.LLM:
234+
if not isinstance(trace_instance.output, str):
235+
raise ValueError(
236+
"Methods/functions of type 'TraceType.LLM' must return only a string"
237+
)
238+
239+
input_token_usage = len(input_str) * characters_per_token
240+
output_token_usage = len(result) * characters_per_token
241+
trace_instance.llmMetadata = LlmMetadata(
242+
model=model,
243+
inputTokenUsage=input_token_usage,
244+
outputTokenUsage=output_token_usage,
245+
cost=(input_token_usage + output_token_usage)
246+
* cost_per_token,
247+
)
248+
249+
except Exception as e:
250+
trace_instance.status = TraceStatus.ERROR
251+
trace_instance.output = {
252+
"type": __builtins__["type"](e).__name__,
253+
"message": str(e),
254+
"traceback": traceback.format_exc(),
255+
}
256+
raise e
257+
258+
finally:
259+
trace_instance.executionTime = perf_counter() - start_time
260+
261+
current_trace_stack = trace_manager.get_trace_stack()
262+
if len(current_trace_stack) > 1:
263+
parent_trace = current_trace_stack[-2]
264+
parent_trace.traces.append(trace_instance)
265+
266+
if len(current_trace_stack) == 1:
267+
dict_representation = dataclass_to_dict(
268+
current_trace_stack[0]
269+
)
270+
trace_manager.set_dict_trace_stack(dict_representation)
271+
trace_manager.clear_trace_stack()
272+
else:
273+
trace_manager.pop_trace_stack()
274+
275+
return result
276+
277+
return wrapper
278+
279+
return decorator_trace
280+
281+
282+
def get_trace_stack():
283+
return trace_manager.get_and_reset_dict_trace_stack()

‎deepeval/utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,27 @@
1+
from enum import Enum
12
import os
23
import time
4+
from typing import Any
35
import tqdm
46
import re
57
import string
68
import numpy as np
9+
from dataclasses import asdict, is_dataclass
10+
11+
12+
def dataclass_to_dict(instance: Any) -> Any:
13+
if is_dataclass(instance):
14+
return {k: dataclass_to_dict(v) for k, v in asdict(instance).items()}
15+
elif isinstance(instance, Enum):
16+
return instance.value
17+
elif isinstance(instance, list):
18+
return [dataclass_to_dict(item) for item in instance]
19+
elif isinstance(instance, tuple):
20+
return tuple(dataclass_to_dict(item) for item in instance)
21+
elif isinstance(instance, dict):
22+
return {k: dataclass_to_dict(v) for k, v in instance.items()}
23+
else:
24+
return instance
725

826

927
def softmax(x):

‎poetry.lock

Lines changed: 124 additions & 64 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ ragas = "^0.0.19"
3232
pytest-rerunfailures = "^12.0"
3333
pytest-asyncio = "^0.21.1"
3434
coverage = "*"
35+
black = "*"
3536

3637
[tool.black]
3738
line-length = 80

0 commit comments

Comments
 (0)
Please sign in to comment.