Bases: Metric
Collects error messages from model predictions and provides statistics on their occurrences.
Parameters:
| Name |
Type |
Description |
Default |
show_errors
|
bool
|
If True, logs each collected error message.
|
False
|
type_separator
|
str
|
Separator used to split error messages into type and details.
|
': '
|
Source code in src/kibad_llm/metrics/statistics.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84 | class ErrorCollector(Metric):
"""Collects error messages from model predictions and provides statistics on their occurrences.
Args:
show_errors (bool): If True, logs each collected error message.
type_separator (str): Separator used to split error messages into type and details.
"""
def __init__(self, show_errors: bool = False, type_separator: str = ": ") -> None:
self.show_errors = show_errors
self.type_separator = type_separator
self.reset()
def reset(self) -> None:
self.state: list[list[str]] = []
def _update(self, prediction: Any, reference: Any, record_id: Hashable | None = None) -> None:
errors: list[list[str]] = []
found_error_keys = []
if "error" in prediction:
if prediction["error"] is not None:
errors.append([prediction["error"]])
else:
errors.append([])
found_error_keys.append("error")
if "error_list" in prediction and isinstance(prediction["error_list"], list):
for e in prediction["error_list"]:
if e is not None:
errors.append([e])
else:
errors.append([])
found_error_keys.append("error_list")
if "errors" in prediction:
errors.append(prediction["errors"])
found_error_keys.append("errors")
if "errors_list" in prediction and isinstance(prediction["errors_list"], list):
errors.extend(prediction["errors_list"])
found_error_keys.append("errors_list")
if len(found_error_keys) == 0:
raise ValueError(
f"No error keys found in prediction for record id={record_id}. "
f"Expected one of: 'error', 'error_list', 'errors', 'errors_list'."
)
if len(found_error_keys) > 1:
raise ValueError(
f"Multiple error keys found in prediction for record id={record_id}: "
f"{found_error_keys}. Please provide only one of these keys."
)
if self.show_errors:
for errors_per_request in errors:
for error in errors_per_request:
logger.info(f"Collected error (id={record_id}): {error}")
self.state.extend(errors)
def _compute(self) -> dict[str, Any]:
# group errors by message type and count occurrences
errors_grouped = defaultdict(list)
for errors in self.state:
# overall no error / with error
if len(errors) == 0:
errors_grouped["no_error"].append("no_error")
else:
errors_grouped["with_error"].append("with_error")
# detailed error types
for error in errors:
err_parts = error.split(self.type_separator, 1)
errors_grouped[err_parts[0]].append(error)
# just return counts for now
counts = {k: len(v) for k, v in errors_grouped.items()}
return counts
|