Skip to content

Conditional

ConditionalUnionExtractor

Bases: UnionExtractor

Extractor that repeats extraction multiple times with history and aggregates results per key. This extractor calls the base extraction function multiple times (for each entry in overrides) on the same input text, passing the history of previous messages to each subsequent call.

See UnionExtractor for accepted parameters and details about the aggregation logic.

Source code in src/kibad_llm/extractors/conditional.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
class ConditionalUnionExtractor(UnionExtractor):
    """Extractor that repeats extraction multiple times with history and aggregates results per key.
    This extractor calls the base extraction function multiple times (for each entry in overrides)
    on the same input text, passing the history of previous messages to each subsequent call.

    See UnionExtractor for accepted parameters and details about the aggregation logic.
    """

    def __call__(self, *args, **kwargs) -> dict[str, Any]:
        combined_kwargs = {**self.default_kwargs, **kwargs}
        results = []
        history: list[SimpleChatMessage] = []
        for override_name, override_params in self.overrides.items():
            # adjust kwargs:
            # 1) to return formatted messages for history
            current_kwargs = {
                **combined_kwargs,
                **override_params,
                "return_messages_formatted": True,
                "truncate_user_message_formatted": None,
            }
            # 2) if history exists, pass it and disable system message
            if len(history) > 0:
                current_kwargs["prompt_template"]["system_message"] = None
                current_kwargs["history"] = history

            current_result = extract_from_text_lenient(*args, **current_kwargs)

            # collect messages for history
            for role_str, content in current_result["messages_formatted"].items():
                role = MessageRole(role_str)
                history.append(SimpleChatMessage(role=role, content=content))
            # append assistant response or error to history
            history.append(
                SimpleChatMessage(
                    role=MessageRole.ASSISTANT,
                    content=current_result["response_content"] or current_result["error"],
                )
            )

            results.append(current_result)

        structured_outputs = [v.get("structured", None) for v in results]
        aggregated_structured = self.aggregator(structured_outputs)

        result: dict[str, Any] = {
            "structured": aggregated_structured,
        }
        for field in self.return_as_list:
            result[f"{field}_list"] = [v.get(field, None) for v in results]
        return result