67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164 | def predict(cfg: DictConfig) -> dict[str, Any]:
"""Run classification based information extraction on PDF files.
Reads all PDF files from cfg.pdf_directory, converts them to markdown,
extracts structured information using an LLM model and a prompt template,
and writes the results to cfg.output_file in JSON lines format. Caches
intermediate results by using map from the datasets library.
Args:
cfg: OmegaConf configuration. See configs/predict.yaml for details.
"""
# use start time as part of output folder to avoid overwriting previous results
formatted_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S_%f")
time_start = time.time()
# show hydra run log directory
run_log_dir = get_run_log_dir()
if run_log_dir is not None:
logger.info(f"Run log directory: {run_log_dir}")
# log git commit info
git_info = get_git_info()
logger.info(f"Git commit: {git_info['commit_hash']}, branch: {git_info['branch']}")
if git_info["is_dirty"]:
logger.warning("Git repository has uncommitted changes!")
data_base_path = Path(cfg.pdf_directory)
# Create the dataset based on the sorted file names. This will define the cache key.
# IMPORTANT: This means the whole dataset will be re-processed if any file is added/removed!
file_names = sorted(p.name for p in data_base_path.glob("*.pdf"))
dataset = Dataset.from_generator(
generator=_file_name_generator, gen_kwargs={"file_names": file_names}
)
logger.info(f"Processing {len(dataset)} PDF files from {cfg.pdf_directory} ...")
if cfg.get("fast_dev_run", False) and len(dataset) > 0:
logger.warning(
f"fast_dev_run is set: only processing the first PDF file ({dataset[0]['file_name']}) but show extended logging ..."
)
dataset = dataset.take(1)
set_global_handler("simple")
logger.info("Converting PDF to markdown ...")
logger.info(f"PDF reader config: {OmegaConf.to_container(cfg.pdf_reader, resolve=True)}")
pdf_reader = instantiate(cfg.pdf_reader, _convert_="all")
pdf_reader_wrapped = wrap_map_func(func=pdf_reader, result_key="text")
t_start_pdf_conversion = time.perf_counter()
dataset = dataset.map(
function=pdf_reader_wrapped,
input_columns=["file_name"],
fn_kwargs={"base_path": data_base_path},
num_proc=cfg.pdf_reader_num_proc,
)
t_delta_pdf_conversion = time.perf_counter() - t_start_pdf_conversion
logger.info("Instantiating Extractor ...")
logger.info(f"Extractor config: {OmegaConf.to_container(cfg.extractor, resolve=True)}")
# The extractor gets the text and file_name as input
# and should return a json serializable dictionary with the extracted information.
extractor: Callable[[str, str], dict[str, Any]] = instantiate(cfg.extractor, _convert_="all")
logger.info("Extract information from markdown ...")
t_start_extraction = time.perf_counter()
dataset = dataset.map(
function=extractor,
input_columns=["text", "file_name"],
load_from_cache_file=cfg.get("extraction_caching", False),
# prevent hashing based on the extractor code/content if caching is disabled
new_fingerprint="dummy" if not cfg.get("extraction_caching", False) else None,
num_proc=cfg.extractor_num_proc,
)
t_delta_extraction = time.perf_counter() - t_start_extraction
# delete the extractor (may free up GPU memory if the extractor contains a vllm in-process LLM)
del extractor
if not cfg.get("store_text_in_predictions", True):
dataset = dataset.remove_columns("text")
# use output dir with timestamp to avoid overwriting previous results
output_file = os.path.join(cfg.output_dir, formatted_time, cfg.output_file_name)
logger.info(f"Writing results to {output_file} ...")
dataset.to_json(output_file, force_ascii=False)
result = {
"output_file": os.path.relpath(output_file, start=os.getcwd()),
"output_file_absolute": os.path.abspath(output_file),
"time_pdf_conversion": t_delta_pdf_conversion,
"time_extraction": t_delta_extraction,
"time_start": int(time_start),
"time_end": int(time.time()),
**git_info,
}
# if running on a SLURM cluster, also log the job ID for easier tracking
job_id = os.getenv("SLURM_JOB_ID")
if job_id is not None:
result["slurm_job_id"] = job_id
return result
|