Skip to content

Predict

get_git_info()

Get current git commit hash, branch name, and dirty status.

Returns:

Type Description
dict[str, str | bool]

Dictionary with 'commit_hash', 'branch', and 'is_dirty' keys.

Source code in src/kibad_llm/predict.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def get_git_info() -> dict[str, str | bool]:
    """Get current git commit hash, branch name, and dirty status.

    Returns:
        Dictionary with 'commit_hash', 'branch', and 'is_dirty' keys.
    """
    try:
        repo = git.Repo(search_parent_directories=True)
        return {
            "commit_hash": repo.head.object.hexsha,
            "branch": repo.active_branch.name,
            "is_dirty": repo.is_dirty(),
        }
    except (git.InvalidGitRepositoryError, git.GitCommandError) as e:
        logger.warning(f"Could not get git info: {e}")
        return {
            "commit_hash": "unknown",
            "branch": "unknown",
            "is_dirty": False,
        }

get_run_log_dir()

Get the Hydra run log directory from the HydraConfig. If not possible (e.g. during integration tests), returns None.

Returns:

Type Description
str | None

The Hydra run log directory as a string.

Source code in src/kibad_llm/predict.py
53
54
55
56
57
58
59
60
61
62
63
64
def get_run_log_dir() -> str | None:
    """Get the Hydra run log directory from the HydraConfig.
    If not possible (e.g. during integration tests), returns None.

    Returns:
        The Hydra run log directory as a string.
    """
    try:
        return HydraConfig.get().runtime.output_dir
    except (ValueError, OmegaConfBaseException) as e:
        logger.warning(f"Could not get Hydra run log directory: {e}")
        return None

predict(cfg)

Run classification based information extraction on PDF files.

Reads all PDF files from cfg.pdf_directory, converts them to markdown, extracts structured information using an LLM model and a prompt template, and writes the results to cfg.output_file in JSON lines format. Caches intermediate results by using map from the datasets library.

Parameters:

Name Type Description Default
cfg DictConfig

OmegaConf configuration. See configs/predict.yaml for details.

required
Source code in src/kibad_llm/predict.py
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
def predict(cfg: DictConfig) -> dict[str, Any]:
    """Run classification based information extraction on PDF files.

    Reads all PDF files from cfg.pdf_directory, converts them to markdown,
    extracts structured information using an LLM model and a prompt template,
    and writes the results to cfg.output_file in JSON lines format. Caches
    intermediate results by using map from the datasets library.

    Args:
        cfg: OmegaConf configuration. See configs/predict.yaml for details.
    """
    # use start time as part of output folder to avoid overwriting previous results
    formatted_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S_%f")

    time_start = time.time()

    # show hydra run log directory
    run_log_dir = get_run_log_dir()
    if run_log_dir is not None:
        logger.info(f"Run log directory: {run_log_dir}")

    # log git commit info
    git_info = get_git_info()
    logger.info(f"Git commit: {git_info['commit_hash']}, branch: {git_info['branch']}")
    if git_info["is_dirty"]:
        logger.warning("Git repository has uncommitted changes!")

    data_base_path = Path(cfg.pdf_directory)

    # Create the dataset based on the sorted file names. This will define the cache key.
    # IMPORTANT: This means the whole dataset will be re-processed if any file is added/removed!
    file_names = sorted(p.name for p in data_base_path.glob("*.pdf"))
    dataset = Dataset.from_generator(
        generator=_file_name_generator, gen_kwargs={"file_names": file_names}
    )
    logger.info(f"Processing {len(dataset)} PDF files from {cfg.pdf_directory} ...")
    if cfg.get("fast_dev_run", False) and len(dataset) > 0:
        logger.warning(
            f"fast_dev_run is set: only processing the first PDF file ({dataset[0]['file_name']}) but show extended logging ..."
        )
        dataset = dataset.take(1)
        set_global_handler("simple")

    logger.info("Converting PDF to markdown ...")
    logger.info(f"PDF reader config: {OmegaConf.to_container(cfg.pdf_reader, resolve=True)}")
    pdf_reader = instantiate(cfg.pdf_reader, _convert_="all")
    pdf_reader_wrapped = wrap_map_func(func=pdf_reader, result_key="text")
    t_start_pdf_conversion = time.perf_counter()
    dataset = dataset.map(
        function=pdf_reader_wrapped,
        input_columns=["file_name"],
        fn_kwargs={"base_path": data_base_path},
        num_proc=cfg.pdf_reader_num_proc,
    )
    t_delta_pdf_conversion = time.perf_counter() - t_start_pdf_conversion

    logger.info("Instantiating Extractor ...")
    logger.info(f"Extractor config: {OmegaConf.to_container(cfg.extractor, resolve=True)}")
    # The extractor gets the text and file_name as input
    # and should return a json serializable dictionary with the extracted information.
    extractor: Callable[[str, str], dict[str, Any]] = instantiate(cfg.extractor, _convert_="all")

    logger.info("Extract information from markdown ...")
    t_start_extraction = time.perf_counter()
    dataset = dataset.map(
        function=extractor,
        input_columns=["text", "file_name"],
        load_from_cache_file=cfg.get("extraction_caching", False),
        # prevent hashing based on the extractor code/content if caching is disabled
        new_fingerprint="dummy" if not cfg.get("extraction_caching", False) else None,
        num_proc=cfg.extractor_num_proc,
    )
    t_delta_extraction = time.perf_counter() - t_start_extraction

    # delete the extractor (may free up GPU memory if the extractor contains a vllm in-process LLM)
    del extractor

    if not cfg.get("store_text_in_predictions", True):
        dataset = dataset.remove_columns("text")

    # use output dir with timestamp to avoid overwriting previous results
    output_file = os.path.join(cfg.output_dir, formatted_time, cfg.output_file_name)
    logger.info(f"Writing results to {output_file} ...")
    dataset.to_json(output_file, force_ascii=False)
    result = {
        "output_file": os.path.relpath(output_file, start=os.getcwd()),
        "output_file_absolute": os.path.abspath(output_file),
        "time_pdf_conversion": t_delta_pdf_conversion,
        "time_extraction": t_delta_extraction,
        "time_start": int(time_start),
        "time_end": int(time.time()),
        **git_info,
    }
    # if running on a SLURM cluster, also log the job ID for easier tracking
    job_id = os.getenv("SLURM_JOB_ID")
    if job_id is not None:
        result["slurm_job_id"] = job_id
    return result