Bases: UnionExtractor
Extractor that repeats extraction multiple times with history and aggregates results per key.
This extractor calls the base extraction function multiple times (for each entry in overrides)
on the same input text, passing the history of previous messages to each subsequent call.
See UnionExtractor for accepted parameters and details about the aggregation logic.
Source code in src/kibad_llm/extractors/conditional.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61 | class ConditionalUnionExtractor(UnionExtractor):
"""Extractor that repeats extraction multiple times with history and aggregates results per key.
This extractor calls the base extraction function multiple times (for each entry in overrides)
on the same input text, passing the history of previous messages to each subsequent call.
See UnionExtractor for accepted parameters and details about the aggregation logic.
"""
def __call__(self, *args, **kwargs) -> dict[str, Any]:
combined_kwargs = {**self.default_kwargs, **kwargs}
results = []
history: list[SimpleChatMessage] = []
for override_name, override_params in self.overrides.items():
# adjust kwargs:
# 1) to return formatted messages for history
current_kwargs = {
**combined_kwargs,
**override_params,
"return_messages_formatted": True,
"truncate_user_message_formatted": None,
}
# 2) if history exists, pass it and disable system message
if len(history) > 0:
current_kwargs["prompt_template"]["system_message"] = None
current_kwargs["history"] = history
current_result = extract_from_text_lenient(*args, **current_kwargs)
# collect messages for history
for role_str, content in current_result["messages_formatted"].items():
role = MessageRole(role_str)
history.append(SimpleChatMessage(role=role, content=content))
# append assistant response or error to history
history.append(
SimpleChatMessage(
role=MessageRole.ASSISTANT,
content=current_result["response_content"] or current_result["error"],
)
)
results.append(current_result)
structured_outputs = [v.get("structured", None) for v in results]
aggregated_structured = self.aggregator(structured_outputs)
result: dict[str, Any] = {
"structured": aggregated_structured,
}
for field in self.return_as_list:
result[f"{field}_list"] = [v.get(field, None) for v in results]
return result
|