Skip to content

Prediction

load_with_metadata(log=None, file=None, skip_by_id=None, **load_kwargs)

Load a dataset from a prediction log directory, extracting metadata such as Hydra overrides and attach it to the dataset. If log is not provided, load directly from file (backward compatibility).

Parameters:

Name Type Description Default
log str | None

Path to the prediction log directory.

None
file str | None

Path to the dataset file.

None
skip_by_id list[str] | None

Optional list of IDs to drop respective entries from the dataset after loading.

None
**load_kwargs

Additional keyword arguments to pass to read_and_preprocess.

{}

Returns: The loaded dataset, possibly wrapped in DictWithMetadata if loaded from a log directory.

Source code in src/kibad_llm/dataset/prediction.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def load_with_metadata(
    log: str | None = None,
    file: str | None = None,
    skip_by_id: list[str] | None = None,
    **load_kwargs,
) -> dict:
    """Load a dataset from a prediction log directory, extracting metadata such as Hydra overrides
    and attach it to the dataset. If `log` is not provided, load directly from `file` (backward compatibility).

    Args:
        log: Path to the prediction log directory.
        file: Path to the dataset file.
        skip_by_id: Optional list of IDs to drop respective entries from the dataset after loading.
        **load_kwargs: Additional keyword arguments to pass to `read_and_preprocess`.
    Returns:
        The loaded dataset, possibly wrapped in `DictWithMetadata` if loaded from a log directory.
    """

    metadata = None

    if log is not None:
        if file is not None:
            raise ValueError("Specify either 'log' or 'file', not both.")

        # get job return value from json file
        job_return_value_file = os.path.join(log, "job_return_value.json")
        with open(job_return_value_file) as f:
            job_return_value = json.load(f)

        # get overrides from hydra overrides yaml file
        overrides_file = os.path.join(log, ".hydra", "overrides.yaml")
        with open(overrides_file) as f:
            overrides = yaml.safe_load(f)
        metadata = {"overrides": overrides, "job_return_value": job_return_value}

        # use output file from job return value
        file = job_return_value["output_file"]
    else:
        if file is None:
            raise ValueError("Either 'log' or 'file' must be specified.")

    dataset = read_and_preprocess(file=file, **load_kwargs)
    if skip_by_id is not None:
        skipped = set(skip_by_id) & set(dataset.keys())
        dataset = {k: v for k, v in dataset.items() if k not in skip_by_id}
        logger.warning(f"Skipped {len(skipped)} entries from loaded data: {sorted(skipped)}")
    if metadata is not None:
        dataset = DictWithMetadata(dataset, metadata=metadata)
    return dataset