From 04d308ef2734d117f0247d6e78ab599811cc847c Mon Sep 17 00:00:00 2001 From: Vik Paruchuri Date: Tue, 15 Oct 2024 14:18:15 -0400 Subject: [PATCH] Initial table integration --- .github/workflows/tests.yml | 4 - CLA.md | 4 +- README.md | 11 +- benchmarks/table.py | 77 ----- marker/convert.py | 13 +- marker/models.py | 34 ++- marker/ocr/recognition.py | 19 +- marker/pdf/extract_text.py | 2 +- marker/postprocessors/editor.py | 123 -------- marker/postprocessors/t5.py | 141 --------- marker/settings.py | 10 +- marker/tables/cells.py | 112 ------- marker/tables/edges.py | 122 -------- marker/tables/table.py | 237 ++++++--------- poetry.lock | 523 +++++++++++++++++++------------- pyproject.toml | 16 +- 16 files changed, 452 insertions(+), 996 deletions(-) delete mode 100644 benchmarks/table.py delete mode 100644 marker/postprocessors/editor.py delete mode 100644 marker/postprocessors/t5.py delete mode 100644 marker/tables/cells.py delete mode 100644 marker/tables/edges.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index fb524f92..4aeca375 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -29,10 +29,6 @@ jobs: run: | poetry run python benchmarks/overall.py benchmark_data/pdfs benchmark_data/references report.json poetry run python scripts/verify_benchmark_scores.py report.json --type marker - - name: Run table benchmark - run: | - poetry run python benchmarks/table.py tables.json - poetry run python scripts/verify_benchmark_scores.py tables.json --type table diff --git a/CLA.md b/CLA.md index d80b275d..296e7fa1 100644 --- a/CLA.md +++ b/CLA.md @@ -1,6 +1,6 @@ Marker Contributor Agreement -This Marker Contributor Agreement ("MCA") applies to any contribution that you make to any product or project managed by us (the "project"), and sets out the intellectual property rights you grant to us in the contributed materials. The term "us" shall mean Vikas Paruchuri. The term "you" shall mean the person or entity identified below. +This Marker Contributor Agreement ("MCA") applies to any contribution that you make to any product or project managed by us (the "project"), and sets out the intellectual property rights you grant to us in the contributed materials. The term "us" shall mean Endless Labs, Inc. The term "you" shall mean the person or entity identified below. If you agree to be bound by these terms, sign by writing "I have read the CLA document and I hereby sign the CLA" in response to the CLA bot Github comment. Read this agreement carefully before signing. These terms and conditions constitute a binding legal agreement. @@ -20,5 +20,5 @@ If you or your affiliates institute patent litigation against any entity (includ - each contribution that you submit is and shall be an original work of authorship and you can legally grant the rights set out in this MCA; - to the best of your knowledge, each contribution will not violate any third party's copyrights, trademarks, patents, or other intellectual property rights; and - each contribution shall be in compliance with U.S. export control laws and other applicable export and import laws. -You agree to notify us if you become aware of any circumstance which would make any of the foregoing representations inaccurate in any respect. Vikas Paruchuri may publicly disclose your participation in the project, including the fact that you have signed the MCA. +You agree to notify us if you become aware of any circumstance which would make any of the foregoing representations inaccurate in any respect. Endless Labs, Inc. may publicly disclose your participation in the project, including the fact that you have signed the MCA. 6. This MCA is governed by the laws of the State of California and applicable U.S. Federal law. Any choice of law rules will not apply. \ No newline at end of file diff --git a/README.md b/README.md index a546b633..313651bf 100644 --- a/README.md +++ b/README.md @@ -42,7 +42,7 @@ See [below](#benchmarks) for detailed speed and accuracy benchmarks, and instruc I want marker to be as widely accessible as possible, while still funding my development/training costs. Research and personal usage is always okay, but there are some restrictions on commercial usage. -The weights for the models are licensed `cc-by-nc-sa-4.0`, but I will waive that for any organization under $5M USD in gross revenue in the most recent 12-month period AND under $5M in lifetime VC/angel funding raised. If you want to remove the GPL license requirements (dual-license) and/or use the weights commercially over the revenue limit, check out the options [here](https://www.datalab.to). +The weights for the models are licensed `cc-by-nc-sa-4.0`, but I will waive that for any organization under $5M USD in gross revenue in the most recent 12-month period AND under $5M in lifetime VC/angel funding raised. You also must not be competitive with the [Datalab API](https://www.datalab.to/). If you want to remove the GPL license requirements (dual-license) and/or use the weights commercially over the revenue limit, check out the options [here](https://www.datalab.to). # Hosted API @@ -217,14 +217,6 @@ This will benchmark marker against other text extraction methods. It sets up ba Omit `--nougat` to exclude nougat from the benchmark. I don't recommend running nougat on CPU, since it is very slow. -### Table benchmark - -There is a benchmark for table parsing, which you can run with: - -```shell -python benchmarks/table.py test_data/tables.json -``` - # Thanks This work would not have been possible without amazing open source models and datasets, including (but not limited to): @@ -233,6 +225,5 @@ This work would not have been possible without amazing open source models and da - Texify - Pypdfium2/pdfium - DocLayNet from IBM -- ByT5 from Google Thank you to the authors of these models and datasets for making them available to the community! \ No newline at end of file diff --git a/benchmarks/table.py b/benchmarks/table.py deleted file mode 100644 index 45cd3888..00000000 --- a/benchmarks/table.py +++ /dev/null @@ -1,77 +0,0 @@ -import argparse -import json - -import datasets -from surya.schema import LayoutResult, LayoutBox -from tqdm import tqdm - -from marker.benchmark.table import score_table -from marker.schema.bbox import rescale_bbox -from marker.schema.page import Page -from marker.tables.table import format_tables - - - -def main(): - parser = argparse.ArgumentParser(description="Benchmark table conversion.") - parser.add_argument("out_file", help="Output filename for results") - parser.add_argument("--dataset", type=str, help="Dataset to use", default="vikp/table_bench") - args = parser.parse_args() - - ds = datasets.load_dataset(args.dataset, split="train") - - results = [] - for i in tqdm(range(len(ds)), desc="Evaluating tables"): - row = ds[i] - marker_page = Page(**json.loads(row["marker_page"])) - table_bbox = row["table_bbox"] - gpt4_table = json.loads(row["gpt_4_table"])["markdown_table"] - - # Counterclockwise polygon from top left - table_poly = [ - [table_bbox[0], table_bbox[1]], - [table_bbox[2], table_bbox[1]], - [table_bbox[2], table_bbox[3]], - [table_bbox[0], table_bbox[3]], - ] - - # Remove all other tables from the layout results - layout_result = LayoutResult( - bboxes=[ - LayoutBox( - label="Table", - polygon=table_poly - ) - ], - segmentation_map="", - image_bbox=marker_page.text_lines.image_bbox - ) - - marker_page.layout = layout_result - format_tables([marker_page]) - - table_blocks = [block for block in marker_page.blocks if block.block_type == "Table"] - if len(table_blocks) != 1: - continue - - table_block = table_blocks[0] - table_md = table_block.lines[0].spans[0].text - - results.append({ - "score": score_table(table_md, gpt4_table), - "arxiv_id": row["arxiv_id"], - "page_idx": row["page_idx"], - "marker_table": table_md, - "gpt4_table": gpt4_table, - "table_bbox": table_bbox - }) - - avg_score = sum([r["score"] for r in results]) / len(results) - print(f"Evaluated {len(results)} tables, average score is {avg_score}.") - - with open(args.out_file, "w+") as f: - json.dump(results, f, indent=2) - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/marker/convert.py b/marker/convert.py index 3c56fc9c..a20752c0 100644 --- a/marker/convert.py +++ b/marker/convert.py @@ -20,7 +20,6 @@ from marker.cleaners.headers import filter_header_footer, filter_common_titles from marker.equations.equations import replace_equations from marker.pdf.utils import find_filetype -from marker.postprocessors.editor import edit_full_text from marker.cleaners.code import identify_code_blocks, indent_blocks from marker.cleaners.bullets import replace_bullets from marker.cleaners.headings import split_heading_blocks @@ -83,7 +82,7 @@ def convert_single_pdf( doc.del_page(0) # Unpack models from list - texify_model, layout_model, order_model, edit_model, detection_model, ocr_model = model_lst + texify_model, layout_model, order_model, detection_model, ocr_model, table_rec_model = model_lst # Identify text lines on pages surya_detection(doc, pages, detection_model, batch_multiplier=batch_multiplier) @@ -123,7 +122,7 @@ def convert_single_pdf( indent_blocks(pages) # Fix table blocks - table_count = format_tables(pages) + table_count = format_tables(pages, doc, fname, detection_model, table_rec_model, ocr_model) out_meta["block_stats"]["table"] = table_count for page in pages: @@ -160,14 +159,6 @@ def convert_single_pdf( # Replace bullet characters with a - full_text = replace_bullets(full_text) - # Postprocess text with editor model - full_text, edit_stats = edit_full_text( - full_text, - edit_model, - batch_multiplier=batch_multiplier - ) - flush_cuda_memory() - out_meta["postprocess_stats"] = {"edit": edit_stats} doc_images = images_to_dict(pages) return full_text, doc_images, out_meta \ No newline at end of file diff --git a/marker/models.py b/marker/models.py index 4764d84e..877d95e1 100644 --- a/marker/models.py +++ b/marker/models.py @@ -2,7 +2,6 @@ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # For some reason, transformers decided to use .isin for a simple op, which is not supported on MPS -from marker.postprocessors.editor import load_editing_model from surya.model.detection.model import load_model as load_detection_model, load_processor as load_detection_processor from texify.model.model import load_model as load_texify_model from texify.model.processor import load_processor as load_texify_processor @@ -11,6 +10,17 @@ from surya.model.recognition.processor import load_processor as load_recognition_processor from surya.model.ordering.model import load_model as load_order_model from surya.model.ordering.processor import load_processor as load_order_processor +from surya.model.table_rec.model import load_model as load_table_model +from surya.model.table_rec.processor import load_processor as load_table_processor + + +def setup_table_rec_model(device=None, dtype=None): + if device: + table_model = load_table_model(device=device, dtype=dtype) + else: + table_model = load_table_model() + table_model.processor = load_table_processor() + return table_model def setup_recognition_model(device=None, dtype=None): @@ -18,8 +28,7 @@ def setup_recognition_model(device=None, dtype=None): rec_model = load_recognition_model(device=device, dtype=dtype) else: rec_model = load_recognition_model() - rec_processor = load_recognition_processor() - rec_model.processor = rec_processor + rec_model.processor = load_recognition_processor() return rec_model @@ -28,9 +37,7 @@ def setup_detection_model(device=None, dtype=None): model = load_detection_model(device=device, dtype=dtype) else: model = load_detection_model() - - processor = load_detection_processor() - model.processor = processor + model.processor = load_detection_processor() return model @@ -39,8 +46,7 @@ def setup_texify_model(device=None, dtype=None): texify_model = load_texify_model(checkpoint=settings.TEXIFY_MODEL_NAME, device=device, dtype=dtype) else: texify_model = load_texify_model(checkpoint=settings.TEXIFY_MODEL_NAME, device=settings.TORCH_DEVICE_MODEL, dtype=settings.TEXIFY_DTYPE) - texify_processor = load_texify_processor() - texify_model.processor = texify_processor + texify_model.processor = load_texify_processor() return texify_model @@ -49,8 +55,7 @@ def setup_layout_model(device=None, dtype=None): model = load_detection_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT, device=device, dtype=dtype) else: model = load_detection_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT) - processor = load_detection_processor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT) - model.processor = processor + model.processor = load_detection_processor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT) return model @@ -59,12 +64,11 @@ def setup_order_model(device=None, dtype=None): model = load_order_model(device=device, dtype=dtype) else: model = load_order_model() - processor = load_order_processor() - model.processor = processor + model.processor = load_order_processor() return model -def load_all_models(device=None, dtype=None, force_load_ocr=False): +def load_all_models(device=None, dtype=None): if device is not None: assert dtype is not None, "Must provide dtype if device is provided" @@ -72,10 +76,10 @@ def load_all_models(device=None, dtype=None, force_load_ocr=False): detection = setup_detection_model(device, dtype) layout = setup_layout_model(device, dtype) order = setup_order_model(device, dtype) - edit = load_editing_model(device, dtype) # Only load recognition model if we'll need it for all pdfs ocr = setup_recognition_model(device, dtype) texify = setup_texify_model(device, dtype) - model_lst = [texify, layout, order, edit, detection, ocr] + table_model = setup_table_rec_model(device, dtype) + model_lst = [texify, layout, order, detection, ocr, table_model] return model_lst \ No newline at end of file diff --git a/marker/ocr/recognition.py b/marker/ocr/recognition.py index 44c22b12..f5ceee50 100644 --- a/marker/ocr/recognition.py +++ b/marker/ocr/recognition.py @@ -65,7 +65,10 @@ def run_ocr(doc, pages: List[Page], langs: List[str], rec_model, batch_multiplie def surya_recognition(doc, page_idxs, langs: List[str], rec_model, pages: List[Page], batch_multiplier=1) -> List[Optional[Page]]: + # Slice images in higher resolution than detection happened in images = [render_image(doc[pnum], dpi=settings.SURYA_OCR_DPI) for pnum in page_idxs] + box_scale = settings.SURYA_OCR_DPI / settings.SURYA_DETECTOR_DPI + processor = rec_model.processor selected_pages = [p for i, p in enumerate(pages) if i in page_idxs] @@ -73,6 +76,12 @@ def surya_recognition(doc, page_idxs, langs: List[str], rec_model, pages: List[P detection_results = [p.text_lines.bboxes for p in selected_pages] polygons = [[b.polygon for b in bboxes] for bboxes in detection_results] + # Scale polygons to get correct image slices + for poly in polygons: + for p in poly: + for i in range(len(p)): + p[i] = [int(p[i][0] * box_scale), int(p[i][1] * box_scale)] + results = run_recognition(images, surya_langs, rec_model, processor, polygons=polygons, batch_size=int(get_batch_size() * batch_multiplier)) new_pages = [] @@ -81,14 +90,15 @@ def surya_recognition(doc, page_idxs, langs: List[str], rec_model, pages: List[P ocr_results = result.text_lines blocks = [] for i, line in enumerate(ocr_results): + scaled_bbox = [b / box_scale for b in line.bbox] block = Block( - bbox=line.bbox, + bbox=scaled_bbox, pnum=page_idx, lines=[Line( - bbox=line.bbox, + bbox=scaled_bbox, spans=[Span( text=line.text, - bbox=line.bbox, + bbox=scaled_bbox, span_id=f"{page_idx}_{i}", font="", font_weight=0, @@ -98,10 +108,11 @@ def surya_recognition(doc, page_idxs, langs: List[str], rec_model, pages: List[P )] ) blocks.append(block) + scaled_image_bbox = [b / box_scale for b in result.image_bbox] page = Page( blocks=blocks, pnum=page_idx, - bbox=result.image_bbox, + bbox=scaled_image_bbox, rotation=0, text_lines=text_lines, ocr_method="surya" diff --git a/marker/pdf/extract_text.py b/marker/pdf/extract_text.py index 937ded31..20a510a2 100644 --- a/marker/pdf/extract_text.py +++ b/marker/pdf/extract_text.py @@ -90,7 +90,7 @@ def get_text_blocks(doc, fname, max_pages: Optional[int] = None, start_page: Opt page_range = range(start_page, start_page + max_pages) - char_blocks = dictionary_output(fname, page_range=page_range, keep_chars=True, workers=settings.PDFTEXT_CPU_WORKERS) + char_blocks = dictionary_output(fname, page_range=page_range, keep_chars=False, workers=settings.PDFTEXT_CPU_WORKERS) marker_blocks = [pdftext_format_to_blocks(page, pnum) for pnum, page in enumerate(char_blocks)] return marker_blocks, toc diff --git a/marker/postprocessors/editor.py b/marker/postprocessors/editor.py deleted file mode 100644 index 48695fbb..00000000 --- a/marker/postprocessors/editor.py +++ /dev/null @@ -1,123 +0,0 @@ -from collections import defaultdict -from itertools import chain -from typing import Optional - -from marker.settings import settings -import torch -import torch.nn.functional as F -from marker.postprocessors.t5 import T5ForTokenClassification, byt5_tokenize - - -def get_batch_size(): - if settings.EDITOR_BATCH_SIZE is not None: - return settings.EDITOR_BATCH_SIZE - elif settings.TORCH_DEVICE_MODEL == "cuda": - return 12 - return 6 - - -def load_editing_model(device=None, dtype=None): - if not settings.ENABLE_EDITOR_MODEL: - return None - - if device: - model = T5ForTokenClassification.from_pretrained( - settings.EDITOR_MODEL_NAME, - torch_dtype=dtype, - device=device, - ) - else: - model = T5ForTokenClassification.from_pretrained( - settings.EDITOR_MODEL_NAME, - torch_dtype=settings.MODEL_DTYPE, - ).to(settings.TORCH_DEVICE_MODEL) - model.eval() - - model.config.label2id = { - "equal": 0, - "delete": 1, - "newline-1": 2, - "space-1": 3, - } - model.config.id2label = {v: k for k, v in model.config.label2id.items()} - return model - - -def edit_full_text(text: str, model: Optional[T5ForTokenClassification], batch_multiplier=1) -> (str, dict): - if not model: - return text, {} - - batch_size = get_batch_size() * batch_multiplier - tokenized = byt5_tokenize(text, settings.EDITOR_MAX_LENGTH) - input_ids = tokenized["input_ids"] - char_token_lengths = tokenized["char_token_lengths"] - - # Run model - token_masks = [] - for i in range(0, len(input_ids), batch_size): - batch_input_ids = tokenized["input_ids"][i: i + batch_size] - batch_input_ids = torch.tensor(batch_input_ids, device=model.device) - batch_attention_mask = tokenized["attention_mask"][i: i + batch_size] - batch_attention_mask = torch.tensor(batch_attention_mask, device=model.device) - with torch.inference_mode(): - predictions = model(batch_input_ids, attention_mask=batch_attention_mask) - - logits = predictions.logits.cpu() - - # If the max probability is less than a threshold, we assume it's a bad prediction - # We want to be conservative to not edit the text too much - probs = F.softmax(logits, dim=-1) - max_prob = torch.max(probs, dim=-1) - cutoff_prob = max_prob.values < settings.EDITOR_CUTOFF_THRESH - labels = logits.argmax(-1) - labels[cutoff_prob] = model.config.label2id["equal"] - labels = labels.squeeze().tolist() - if len(labels) == settings.EDITOR_MAX_LENGTH: - labels = [labels] - labels = list(chain.from_iterable(labels)) - token_masks.extend(labels) - - # List of characters in the text - flat_input_ids = list(chain.from_iterable(input_ids)) - - # Strip special tokens 0,1. Keep unknown token, although it should never be used - assert len(token_masks) == len(flat_input_ids) - token_masks = [mask for mask, token in zip(token_masks, flat_input_ids) if token >= 2] - - assert len(token_masks) == len(list(text.encode("utf-8"))) - - edit_stats = defaultdict(int) - out_text = [] - start = 0 - for i, char in enumerate(text): - char_token_length = char_token_lengths[i] - masks = token_masks[start: start + char_token_length] - labels = [model.config.id2label[mask] for mask in masks] - if all(l == "delete" for l in labels): - # If we delete whitespace, roll with it, otherwise ignore - if char.strip(): - out_text.append(char) - else: - edit_stats["delete"] += 1 - elif labels[0] == "newline-1": - out_text.append("\n") - out_text.append(char) - edit_stats["newline-1"] += 1 - elif labels[0] == "space-1": - out_text.append(" ") - out_text.append(char) - edit_stats["space-1"] += 1 - else: - out_text.append(char) - edit_stats["equal"] += 1 - - start += char_token_length - - out_text = "".join(out_text) - return out_text, edit_stats - - - - - - diff --git a/marker/postprocessors/t5.py b/marker/postprocessors/t5.py deleted file mode 100644 index dac471a8..00000000 --- a/marker/postprocessors/t5.py +++ /dev/null @@ -1,141 +0,0 @@ -from transformers import T5Config, T5PreTrainedModel -import torch -from torch import nn -from copy import deepcopy -from typing import Optional, Tuple, Union -from itertools import chain - -from transformers.modeling_outputs import TokenClassifierOutput -from transformers.models.t5.modeling_t5 import T5Stack -from transformers.utils.model_parallel_utils import get_device_map, assert_device_map - - -def byt5_tokenize(text: str, max_length: int, pad_token_id: int = 0): - byte_codes = [] - for char in text: - # Add 3 to account for special tokens - byte_codes.append([byte + 3 for byte in char.encode('utf-8')]) - - tokens = list(chain.from_iterable(byte_codes)) - # Map each token to the character it represents - char_token_lengths = [len(b) for b in byte_codes] - - batched_tokens = [] - attention_mask = [] - for i in range(0, len(tokens), max_length): - batched_tokens.append(tokens[i:i + max_length]) - attention_mask.append([1] * len(batched_tokens[-1])) - - # Pad last item - if len(batched_tokens[-1]) < max_length: - batched_tokens[-1] += [pad_token_id] * (max_length - len(batched_tokens[-1])) - attention_mask[-1] += [0] * (max_length - len(attention_mask[-1])) - - return {"input_ids": batched_tokens, "attention_mask": attention_mask, "char_token_lengths": char_token_lengths} - - - - -# From https://github.com/osainz59/t5-encoder -class T5ForTokenClassification(T5PreTrainedModel): - _keys_to_ignore_on_load_missing = [r"encoder.embed_tokens.weight"] - - def __init__(self, config: T5Config): - super().__init__(config) - self.model_dim = config.d_model - - self.shared = nn.Embedding(config.vocab_size, config.d_model) - - encoder_config = deepcopy(config) - encoder_config.is_decoder = False - encoder_config.is_encoder_decoder = False - encoder_config.use_cache = False - self.encoder = T5Stack(encoder_config, self.shared) - - classifier_dropout = ( - config.classifier_dropout if hasattr(config, 'classifier_dropout') else config.dropout_rate - ) - self.dropout = nn.Dropout(classifier_dropout) - self.classifier = nn.Linear(config.d_model, config.num_labels) - - # Initialize weights and apply final processing - self.post_init() - - # Model parallel - self.model_parallel = False - self.device_map = None - - - def parallelize(self, device_map=None): - self.device_map = ( - get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) - if device_map is None - else device_map - ) - assert_device_map(self.device_map, len(self.encoder.block)) - self.encoder.parallelize(self.device_map) - self.classifier.to(self.encoder.first_device) - self.model_parallel = True - - def deparallelize(self): - self.encoder.deparallelize() - self.encoder = self.encoder.to("cpu") - self.classifier = self.classifier.to("cpu") - self.model_parallel = False - self.device_map = None - torch.cuda.empty_cache() - - def get_input_embeddings(self): - return self.shared - - def set_input_embeddings(self, new_embeddings): - self.shared = new_embeddings - self.encoder.set_input_embeddings(new_embeddings) - - def get_encoder(self): - return self.encoder - - def _prune_heads(self, heads_to_prune): - for layer, heads in heads_to_prune.items(): - self.encoder.block[layer].layer[0].SelfAttention.prune_heads(heads) - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple[torch.FloatTensor], TokenClassifierOutput]: - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - head_mask=head_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - sequence_output = outputs[0] - - sequence_output = self.dropout(sequence_output) - logits = self.classifier(sequence_output) - - loss = None - - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TokenClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions - ) \ No newline at end of file diff --git a/marker/settings.py b/marker/settings.py index 6d6bf7e7..e9df41c8 100644 --- a/marker/settings.py +++ b/marker/settings.py @@ -47,7 +47,7 @@ def TORCH_DEVICE_MODEL(self) -> str: OCR_ALL_PAGES: bool = False # Run OCR on every page even if text can be extracted ## Surya - SURYA_OCR_DPI: int = 96 + SURYA_OCR_DPI: int = 192 RECOGNITION_BATCH_SIZE: Optional[int] = None # Batch size for surya OCR defaults to 64 for cuda, 32 otherwise ## Tesseract @@ -75,12 +75,8 @@ def TORCH_DEVICE_MODEL(self) -> str: ORDER_BATCH_SIZE: Optional[int] = None # Defaults to 12 for cuda, 6 otherwise ORDER_MAX_BBOXES: int = 255 - # Final editing model - EDITOR_BATCH_SIZE: Optional[int] = None # Defaults to 6 for cuda, 12 otherwise - EDITOR_MAX_LENGTH: int = 1024 - EDITOR_MODEL_NAME: str = "vikp/pdf_postprocessor_t5" - ENABLE_EDITOR_MODEL: bool = False # The editor model can create false positives - EDITOR_CUTOFF_THRESH: float = 0.9 # Ignore predictions below this probability + # Table models + SURYA_TABLE_DPI: int = 192 # Debug DEBUG: bool = False # Enable debug logging diff --git a/marker/tables/cells.py b/marker/tables/cells.py deleted file mode 100644 index 46d83484..00000000 --- a/marker/tables/cells.py +++ /dev/null @@ -1,112 +0,0 @@ -from marker.schema.bbox import rescale_bbox, box_intersection_pct -from marker.schema.page import Page -import numpy as np -from sklearn.cluster import DBSCAN -from marker.settings import settings - - -def cluster_coords(coords, row_count): - if len(coords) == 0: - return [] - coords = np.array(sorted(set(coords))).reshape(-1, 1) - - clustering = DBSCAN(eps=.01, min_samples=max(2, row_count // 4)).fit(coords) - clusters = clustering.labels_ - - separators = [] - for label in set(clusters): - clustered_points = coords[clusters == label] - separators.append(np.mean(clustered_points)) - - separators = sorted(separators) - return separators - - -def find_column_separators(page: Page, table_box, rows, round_factor=.002, min_count=1): - left_edges = [] - right_edges = [] - centers = [] - - line_boxes = [p.bbox for p in page.text_lines.bboxes] - line_boxes = [rescale_bbox(page.text_lines.image_bbox, page.bbox, l) for l in line_boxes] - line_boxes = [l for l in line_boxes if box_intersection_pct(l, table_box) > settings.BBOX_INTERSECTION_THRESH] - - pwidth = page.bbox[2] - page.bbox[0] - pheight = page.bbox[3] - page.bbox[1] - for cell in line_boxes: - ncell = [cell[0] / pwidth, cell[1] / pheight, cell[2] / pwidth, cell[3] / pheight] - left_edges.append(ncell[0] / round_factor * round_factor) - right_edges.append(ncell[2] / round_factor * round_factor) - centers.append((ncell[0] + ncell[2]) / 2 * round_factor / round_factor) - - left_edges = [l for l in left_edges if left_edges.count(l) > min_count] - right_edges = [r for r in right_edges if right_edges.count(r) > min_count] - centers = [c for c in centers if centers.count(c) > min_count] - - sorted_left = cluster_coords(left_edges, len(rows)) - sorted_right = cluster_coords(right_edges, len(rows)) - sorted_center = cluster_coords(centers, len(rows)) - - # Find list with minimum length - separators = max([sorted_left, sorted_right, sorted_center], key=len) - separators.append(1) - separators.insert(0, 0) - return separators - - -def assign_cells_to_columns(page, table_box, rows, round_factor=.002, tolerance=.01): - separators = find_column_separators(page, table_box, rows, round_factor=round_factor) - additional_column_index = 0 - pwidth = page.bbox[2] - page.bbox[0] - row_dicts = [] - - for row in rows: - new_row = {} - last_col_index = -1 - for cell in row: - left_edge = cell[0][0] / pwidth - column_index = -1 - for i, separator in enumerate(separators): - if left_edge - tolerance < separator and last_col_index < i: - column_index = i - break - if column_index == -1: - column_index = len(separators) + additional_column_index - additional_column_index += 1 - new_row[column_index] = cell[1] - last_col_index = column_index - additional_column_index = 0 - row_dicts.append(new_row) - - max_row_idx = 0 - for row in row_dicts: - max_row_idx = max(max_row_idx, max(row.keys())) - - # Assign sorted cells to columns, account for blanks - new_rows = [] - for row in row_dicts: - flat_row = [] - for row_idx in range(1, max_row_idx + 1): - if row_idx in row: - flat_row.append(row[row_idx]) - else: - flat_row.append("") - new_rows.append(flat_row) - - # Pad rows to have the same length - max_row_len = max([len(r) for r in new_rows]) - for row in new_rows: - while len(row) < max_row_len: - row.append("") - - cols_to_remove = set() - for idx, col in enumerate(zip(*new_rows)): - col_total = sum([len(cell.strip()) > 0 for cell in col]) - if col_total == 0: - cols_to_remove.add(idx) - - rows = [] - for row in new_rows: - rows.append([col for idx, col in enumerate(row) if idx not in cols_to_remove]) - - return rows \ No newline at end of file diff --git a/marker/tables/edges.py b/marker/tables/edges.py deleted file mode 100644 index 9bf30ba1..00000000 --- a/marker/tables/edges.py +++ /dev/null @@ -1,122 +0,0 @@ -import math - -import cv2 -import numpy as np - - -def get_detected_lines_sobel(image): - sobelx = cv2.Sobel(image, cv2.CV_32F, 1, 0, ksize=3) - - scaled_sobel = np.uint8(255 * sobelx / np.max(sobelx)) - - kernel = np.ones((4, 1), np.uint8) - eroded = cv2.erode(scaled_sobel, kernel, iterations=1) - scaled_sobel = cv2.dilate(eroded, kernel, iterations=3) - - return scaled_sobel - - -def get_line_angle(x1, y1, x2, y2): - slope = (y2 - y1) / (x2 - x1) - - angle_radians = math.atan(slope) - angle_degrees = math.degrees(angle_radians) - - return angle_degrees - - -def get_detected_lines(image, slope_tol_deg=10): - new_image = image.astype(np.float32) * 255 # Convert to 0-255 range - new_image = get_detected_lines_sobel(new_image) - new_image = new_image.astype(np.uint8) - - edges = cv2.Canny(new_image, 50, 200, apertureSize=3) - - lines = cv2.HoughLinesP(edges, 1, np.pi / 180, threshold=50, minLineLength=2, maxLineGap=100) - - line_info = [] - if lines is not None: - for line in lines: - x1, y1, x2, y2 = line[0] - bbox = [x1, y1, x2, y2] - - vertical = False - if x2 == x1: - vertical = True - else: - line_angle = get_line_angle(x1, y1, x2, y2) - if 90 - slope_tol_deg < line_angle < 90 + slope_tol_deg: - vertical = True - elif -90 - slope_tol_deg < line_angle < -90 + slope_tol_deg: - vertical = True - if not vertical: - continue - - if bbox[3] < bbox[1]: - bbox[1], bbox[3] = bbox[3], bbox[1] - if bbox[2] < bbox[0]: - bbox[0], bbox[2] = bbox[2], bbox[0] - if vertical: - line_info.append(bbox) - return line_info - - -def get_vertical_lines(image, divisor=2, x_tolerance=10, y_tolerance=1): - vertical_lines = get_detected_lines(image) - - vertical_lines = sorted(vertical_lines, key=lambda x: x[0]) - for line in vertical_lines: - for i in range(0, len(line)): - line[i] = (line[i] // divisor) * divisor - - # Merge adjacent line segments together - to_remove = [] - for i, line in enumerate(vertical_lines): - for j, line2 in enumerate(vertical_lines): - if j <= i: - continue - if line[0] != line2[0]: - continue - - expanded_line1 = [line[0], line[1] - y_tolerance, line[2], - line[3] + y_tolerance] - - line1_points = set(range(int(expanded_line1[1]), int(expanded_line1[3]))) - line2_points = set(range(int(line2[1]), int(line2[3]))) - intersect_y = len(line1_points.intersection(line2_points)) > 0 - - if intersect_y: - vertical_lines[j][1] = min(line[1], line2[1]) - vertical_lines[j][3] = max(line[3], line2[3]) - to_remove.append(i) - - vertical_lines = [line for i, line in enumerate(vertical_lines) if i not in to_remove] - - # Remove redundant segments - to_remove = [] - for i, line in enumerate(vertical_lines): - if i in to_remove: - continue - for j, line2 in enumerate(vertical_lines): - if j <= i or j in to_remove: - continue - close_in_x = abs(line[0] - line2[0]) < x_tolerance - line1_points = set(range(int(line[1]), int(line[3]))) - line2_points = set(range(int(line2[1]), int(line2[3]))) - - intersect_y = len(line1_points.intersection(line2_points)) > 0 - - if close_in_x and intersect_y: - # Keep the longer line and extend it - if len(line2_points) > len(line1_points): - vertical_lines[j][1] = min(line[1], line2[1]) - vertical_lines[j][3] = max(line[3], line2[3]) - to_remove.append(i) - else: - vertical_lines[i][1] = min(line[1], line2[1]) - vertical_lines[i][3] = max(line[3], line2[3]) - to_remove.append(j) - - vertical_lines = [line for i, line in enumerate(vertical_lines) if i not in to_remove] - - return vertical_lines \ No newline at end of file diff --git a/marker/tables/table.py b/marker/tables/table.py index 280ae2a5..172ef7e7 100644 --- a/marker/tables/table.py +++ b/marker/tables/table.py @@ -1,155 +1,107 @@ -from marker.schema.bbox import merge_boxes, box_intersection_pct, rescale_bbox +from tqdm import tqdm +from pypdfium2 import PdfDocument +from tabled.assignment import assign_rows_columns +from tabled.formats import formatter +from tabled.inference.detection import merge_tables + +from surya.input.pdflines import get_page_text_lines +from tabled.inference.recognition import get_cells, recognize_tables + +from marker.pdf.images import render_image +from marker.schema.bbox import rescale_bbox from marker.schema.block import Line, Span, Block from marker.schema.page import Page -from tabulate import tabulate from typing import List from marker.settings import settings -from marker.tables.cells import assign_cells_to_columns -from marker.tables.utils import sort_table_blocks, replace_dots, replace_newlines - - -def get_table_surya(page, table_box, space_tol=.01) -> List[List[str]]: - table_rows = [] - table_row = [] - x_position = None - sorted_blocks = sort_table_blocks(page.blocks) - for block_idx, block in enumerate(sorted_blocks): - sorted_lines = sort_table_blocks(block.lines) - for line_idx, line in enumerate(sorted_lines): - line_bbox = line.bbox - intersect_pct = box_intersection_pct(line_bbox, table_box) - if intersect_pct < settings.TABLE_INTERSECTION_THRESH or len(line.spans) == 0: - continue - normed_x_start = line_bbox[0] / page.width - normed_x_end = line_bbox[2] / page.width - - cells = [[s.bbox, s.text] for s in line.spans] - if x_position is None or normed_x_start > x_position - space_tol: - # Same row - table_row.extend(cells) - else: - # New row - if len(table_row) > 0: - table_rows.append(table_row) - table_row = cells - x_position = normed_x_end - if len(table_row) > 0: - table_rows.append(table_row) - table_rows = assign_cells_to_columns(page, table_box, table_rows) - return table_rows - - -def get_table_pdftext(page: Page, table_box, space_tol=.01, round_factor=4) -> List[List[str]]: - page_width = page.width - table_rows = [] - table_cell = "" - cell_bbox = None - table_row = [] - sorted_char_blocks = sort_table_blocks(page.char_blocks) - - table_width = table_box[2] - table_box[0] - new_line_start_x = table_box[0] + table_width * .3 - table_width_pct = (table_width / page_width) * .95 - - for block_idx, block in enumerate(sorted_char_blocks): - sorted_lines = sort_table_blocks(block["lines"]) - for line_idx, line in enumerate(sorted_lines): - line_bbox = line["bbox"] - intersect_pct = box_intersection_pct(line_bbox, table_box) - if intersect_pct < settings.TABLE_INTERSECTION_THRESH: - continue - for span in line["spans"]: - for char in span["chars"]: - x_start, y_start, x_end, y_end = char["bbox"] - x_start /= page_width - x_end /= page_width - fullwidth_cell = False - - if cell_bbox is not None: - # Find boundaries of cell bbox before merging - cell_x_start, cell_y_start, cell_x_end, cell_y_end = cell_bbox - cell_x_start /= page_width - cell_x_end /= page_width - - fullwidth_cell = cell_x_end - cell_x_start >= table_width_pct - - cell_content = replace_dots(replace_newlines(table_cell)) - if cell_bbox is None: # First char - table_cell += char["char"] - cell_bbox = char["bbox"] - # Check if we are in the same cell, ensure cell is not full table width (like if stray text gets included in the table) - elif (cell_x_start - space_tol < x_start < cell_x_end + space_tol) and not fullwidth_cell: - table_cell += char["char"] - cell_bbox = merge_boxes(cell_bbox, char["bbox"]) - # New line and cell - # Use x_start < new_line_start_x to account for out-of-order cells in the pdf - elif x_start < cell_x_end - space_tol and x_start < new_line_start_x: - if len(table_cell) > 0: - table_row.append((cell_bbox, cell_content)) - table_cell = char["char"] - cell_bbox = char["bbox"] - if len(table_row) > 0: - table_row = sorted(table_row, key=lambda x: round(x[0][0] / round_factor)) - table_rows.append(table_row) - table_row = [] - else: # Same line, new cell, check against cell bbox - if len(table_cell) > 0: - table_row.append((cell_bbox, cell_content)) - table_cell = char["char"] - cell_bbox = char["bbox"] - - if len(table_cell) > 0: - table_row.append((cell_bbox, replace_dots(replace_newlines(table_cell)))) - if len(table_row) > 0: - table_row = sorted(table_row, key=lambda x: round(x[0][0] / round_factor)) - table_rows.append(table_row) - - total_cells = sum([len(row) for row in table_rows]) - if total_cells > 0: - table_rows = assign_cells_to_columns(page, table_box, table_rows) - return table_rows - else: - return [] - - -def merge_tables(page_table_boxes): - # Merge tables that are next to each other - expansion_factor = 1.02 - shrink_factor = .98 - ignore_boxes = set() - for i in range(len(page_table_boxes)): - if i in ignore_boxes: + + +def get_table_boxes(pages: List[Page], doc: PdfDocument, fname): + table_imgs = [] + table_counts = [] + table_bboxes = [] + img_sizes = [] + + for page in pages: + pnum = page.pnum + # The bbox for the entire table + bbox = [b.bbox for b in page.layout.bboxes if b.label == "Table"] + + if len(bbox) == 0: + table_counts.append(0) + img_sizes.append(None) continue - for j in range(i + 1, len(page_table_boxes)): - if j in ignore_boxes: - continue - expanded_box1 = [page_table_boxes[i][0] * shrink_factor, page_table_boxes[i][1], - page_table_boxes[i][2] * expansion_factor, page_table_boxes[i][3]] - expanded_box2 = [page_table_boxes[j][0] * shrink_factor, page_table_boxes[j][1], - page_table_boxes[j][2] * expansion_factor, page_table_boxes[j][3]] - if box_intersection_pct(expanded_box1, expanded_box2) > 0: - page_table_boxes[i] = merge_boxes(page_table_boxes[i], page_table_boxes[j]) - ignore_boxes.add(j) - return [b for i, b in enumerate(page_table_boxes) if i not in ignore_boxes] + highres_img = render_image(doc[pnum], dpi=settings.SURYA_TABLE_DPI) + + page_table_imgs = [] + lowres_bbox = [] + + # Merge tables that are next to each other + bbox = merge_tables(bbox) + + # Number of tables per page + table_counts.append(len(bbox)) + img_sizes.append(highres_img.size) + + for bb in bbox: + highres_bb = rescale_bbox(page.layout.image_bbox, [0, 0, highres_img.size[0], highres_img.size[1]], bb) + page_table_imgs.append(highres_img.crop(highres_bb)) + lowres_bbox.append(highres_bb) + table_imgs.extend(page_table_imgs) + table_bboxes.extend(lowres_bbox) + + table_idxs = [i for i, c in enumerate(table_counts) if c > 0] + sel_text_lines = get_page_text_lines( + fname, + table_idxs, + [hr for i, hr in enumerate(img_sizes) if i in table_idxs], + ) + text_lines = [] + out_img_sizes = [] + for i in range(len(table_counts)): + if i in table_idxs: + text_lines.extend([sel_text_lines.pop(0)] * table_counts[i]) + out_img_sizes.extend([img_sizes[i]] * table_counts[i]) + + assert len(table_imgs) == len(table_bboxes) == len(text_lines) == len(out_img_sizes) + assert sum(table_counts) == len(table_imgs) + + return table_imgs, table_bboxes, table_counts, text_lines, out_img_sizes + + +def format_tables(pages: List[Page], doc: PdfDocument, fname: str, detection_model, table_rec_model, ocr_model): + det_models = [detection_model, detection_model.processor] + rec_models = [table_rec_model, table_rec_model.processor, ocr_model, ocr_model.processor] + + # Don't look at table cell detection tqdm output + tqdm.disable = True + table_imgs, table_boxes, table_counts, table_text_lines, img_sizes = get_table_boxes(pages, doc, fname) + cells, needs_ocr = get_cells(table_imgs, table_boxes, img_sizes, table_text_lines, det_models, detect_boxes=settings.OCR_ALL_PAGES) + tqdm.disable = False + + table_rec = recognize_tables(table_imgs, cells, needs_ocr, rec_models) + cells = [assign_rows_columns(tr, im_size) for tr, im_size in zip(table_rec, img_sizes)] + table_md = [formatter("markdown", cell)[0] for cell in cells] -def format_tables(pages: List[Page]): - # Formats tables nicely into github flavored markdown table_count = 0 - for page in pages: + for page_idx, page in enumerate(pages): + page_table_count = table_counts[page_idx] + if page_table_count == 0: + continue + table_insert_points = {} blocks_to_remove = set() pnum = page.pnum - - page_table_boxes = [b for b in page.layout.bboxes if b.label == "Table"] - page_table_boxes = [rescale_bbox(page.layout.image_bbox, page.bbox, b.bbox) for b in page_table_boxes] - page_table_boxes = merge_tables(page_table_boxes) + highres_size = img_sizes[table_count] + page_table_boxes = table_boxes[table_count:table_count + page_table_count] for table_idx, table_box in enumerate(page_table_boxes): + lowres_table_box = rescale_bbox([0, 0, highres_size[0], highres_size[1]], page.bbox, table_box) + for block_idx, block in enumerate(page.blocks): - intersect_pct = block.intersection_pct(table_box) + intersect_pct = block.intersection_pct(lowres_table_box) if intersect_pct > settings.TABLE_INTERSECTION_THRESH and block.block_type == "Table": if table_idx not in table_insert_points: table_insert_points[table_idx] = max(0, block_idx - len(blocks_to_remove)) # Where to insert the new table @@ -163,17 +115,10 @@ def format_tables(pages: List[Page]): for table_idx, table_box in enumerate(page_table_boxes): if table_idx not in table_insert_points: + table_count += 1 continue - if page.ocr_method == "surya": - table_rows = get_table_surya(page, table_box) - else: - table_rows = get_table_pdftext(page, table_box) - # Skip empty tables - if len(table_rows) == 0: - continue - - table_text = tabulate(table_rows, headers="firstrow", tablefmt="github", disable_numparse=True) + markdown = table_md[table_count] table_block = Block( bbox=table_box, block_type="Table", @@ -187,7 +132,7 @@ def format_tables(pages: List[Page]): font_size=0, font_weight=0, block_type="Table", - text=table_text + text=markdown )] )] ) diff --git a/poetry.lock b/poetry.lock index c4be59e9..934fb970 100644 --- a/poetry.lock +++ b/poetry.lock @@ -601,6 +601,23 @@ files = [ {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] +[[package]] +name = "coloredlogs" +version = "15.0.1" +description = "Colored terminal output for Python's logging module" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +files = [ + {file = "coloredlogs-15.0.1-py2.py3-none-any.whl", hash = "sha256:612ee75c546f53e92e70049c9dbfcc18c935a2b9a53b66085ce9ef6a6e5c0934"}, + {file = "coloredlogs-15.0.1.tar.gz", hash = "sha256:7c991aa71a4577af2f82600d8f8f3a89f936baeaf9b50a9c197da014e5bf16b0"}, +] + +[package.dependencies] +humanfriendly = ">=9.1" + +[package.extras] +cron = ["capturer (>=2.4)"] + [[package]] name = "comm" version = "0.2.2" @@ -799,6 +816,17 @@ files = [ {file = "filetype-1.2.0.tar.gz", hash = "sha256:66b56cd6474bf41d8c54660347d37afcc3f7d1970648de365c102ef77548aadb"}, ] +[[package]] +name = "flatbuffers" +version = "24.3.25" +description = "The FlatBuffers serialization format for Python" +optional = false +python-versions = "*" +files = [ + {file = "flatbuffers-24.3.25-py2.py3-none-any.whl", hash = "sha256:8dbdec58f935f3765e4f7f3cf635ac3a77f83568138d6a2311f524ec96364812"}, + {file = "flatbuffers-24.3.25.tar.gz", hash = "sha256:de2ec5b203f21441716617f38443e0a8ebf3d25bf0d9c0bb0ce68fa00ad546a4"}, +] + [[package]] name = "fqdn" version = "1.5.1" @@ -1132,6 +1160,20 @@ testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gr torch = ["safetensors", "torch"] typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)"] +[[package]] +name = "humanfriendly" +version = "10.0" +description = "Human friendly output for text interfaces using Python" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +files = [ + {file = "humanfriendly-10.0-py2.py3-none-any.whl", hash = "sha256:1697e1a8a8f550fd43c2865cd84542fc175a61dcb779b6fee18cf6b6ccba1477"}, + {file = "humanfriendly-10.0.tar.gz", hash = "sha256:6b0b831ce8f15f7300721aa49829fc4e83921a9a301cc7f606be6686a2288ddc"}, +] + +[package.dependencies] +pyreadline3 = {version = "*", markers = "sys_platform == \"win32\" and python_version >= \"3.8\""} + [[package]] name = "idna" version = "3.7" @@ -1143,25 +1185,6 @@ files = [ {file = "idna-3.7.tar.gz", hash = "sha256:028ff3aadf0609c1fd278d8ea3089299412a7a8b9bd005dd08b9f8285bcb5cfc"}, ] -[[package]] -name = "importlib-metadata" -version = "8.0.0" -description = "Read metadata from Python packages" -optional = false -python-versions = ">=3.8" -files = [ - {file = "importlib_metadata-8.0.0-py3-none-any.whl", hash = "sha256:15584cf2b1bf449d98ff8a6ff1abef57bf20f3ac6454f431736cd3e660921b2f"}, - {file = "importlib_metadata-8.0.0.tar.gz", hash = "sha256:188bd24e4c346d3f0a933f275c2fec67050326a856b9a359881d7c2a697e8812"}, -] - -[package.dependencies] -zipp = ">=0.5" - -[package.extras] -doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] -perf = ["ipython"] -test = ["flufl.flake8", "importlib-resources (>=1.3)", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1)"] - [[package]] name = "intel-openmp" version = "2021.4.0" @@ -1231,7 +1254,6 @@ prompt-toolkit = ">=3.0.41,<3.1.0" pygments = ">=2.4.0" stack-data = "*" traitlets = ">=5" -typing-extensions = {version = "*", markers = "python_version < \"3.10\""} [package.extras] all = ["black", "curio", "docrepr", "exceptiongroup", "ipykernel", "ipyparallel", "ipywidgets", "matplotlib", "matplotlib (!=3.2.0)", "nbconvert", "nbformat", "notebook", "numpy (>=1.22)", "pandas", "pickleshare", "pytest (<7)", "pytest (<7.1)", "pytest-asyncio (<0.22)", "qtconsole", "setuptools (>=18.5)", "sphinx (>=1.3)", "sphinx-rtd-theme", "stack-data", "testpath", "trio", "typing-extensions"] @@ -1425,7 +1447,6 @@ files = [ ] [package.dependencies] -importlib-metadata = {version = ">=4.8.3", markers = "python_version < \"3.10\""} jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" python-dateutil = ">=2.8.2" pyzmq = ">=23.0" @@ -1517,7 +1538,6 @@ files = [ ] [package.dependencies] -importlib-metadata = {version = ">=4.8.3", markers = "python_version < \"3.10\""} jupyter-server = ">=1.1.2" [[package]] @@ -1589,7 +1609,6 @@ files = [ [package.dependencies] async-lru = ">=1.0.0" httpx = ">=0.25.0" -importlib-metadata = {version = ">=4.8.3", markers = "python_version < \"3.10\""} ipykernel = ">=6.5.0" jinja2 = ">=3.0.3" jupyter-core = "*" @@ -1634,7 +1653,6 @@ files = [ [package.dependencies] babel = ">=2.10" -importlib-metadata = {version = ">=4.8.3", markers = "python_version < \"3.10\""} jinja2 = ">=3.0.3" json5 = ">=0.9.0" jsonschema = ">=4.18.0" @@ -1999,7 +2017,6 @@ files = [ beautifulsoup4 = "*" bleach = "!=5.0.0" defusedxml = "*" -importlib-metadata = {version = ">=3.6", markers = "python_version < \"3.10\""} jinja2 = ">=3.0" jupyter-core = ">=4.7" jupyterlab-pygments = "*" @@ -2284,6 +2301,7 @@ description = "Nvidia JIT LTO Library" optional = false python-versions = ">=3" files = [ + {file = "nvidia_nvjitlink_cu12-12.5.82-py3-none-manylinux2014_aarch64.whl", hash = "sha256:98103729cc5226e13ca319a10bbf9433bbbd44ef64fe72f45f067cacc14b8d27"}, {file = "nvidia_nvjitlink_cu12-12.5.82-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f9b37bc5c8cf7509665cb6ada5aaa0ce65618f2332b7d3e78e9790511f111212"}, {file = "nvidia_nvjitlink_cu12-12.5.82-py3-none-win_amd64.whl", hash = "sha256:e782564d705ff0bf61ac3e1bf730166da66dd2fe9012f111ede5fc49b64ae697"}, ] @@ -2299,6 +2317,48 @@ files = [ {file = "nvidia_nvtx_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:65f4d98982b31b60026e0e6de73fbdfc09d08a96f4656dd3665ca616a11e1e82"}, ] +[[package]] +name = "onnxruntime" +version = "1.19.2" +description = "ONNX Runtime is a runtime accelerator for Machine Learning models" +optional = false +python-versions = "*" +files = [ + {file = "onnxruntime-1.19.2-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:84fa57369c06cadd3c2a538ae2a26d76d583e7c34bdecd5769d71ca5c0fc750e"}, + {file = "onnxruntime-1.19.2-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:bdc471a66df0c1cdef774accef69e9f2ca168c851ab5e4f2f3341512c7ef4666"}, + {file = "onnxruntime-1.19.2-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e3a4ce906105d99ebbe817f536d50a91ed8a4d1592553f49b3c23c4be2560ae6"}, + {file = "onnxruntime-1.19.2-cp310-cp310-win32.whl", hash = "sha256:4b3d723cc154c8ddeb9f6d0a8c0d6243774c6b5930847cc83170bfe4678fafb3"}, + {file = "onnxruntime-1.19.2-cp310-cp310-win_amd64.whl", hash = "sha256:17ed7382d2c58d4b7354fb2b301ff30b9bf308a1c7eac9546449cd122d21cae5"}, + {file = "onnxruntime-1.19.2-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:d863e8acdc7232d705d49e41087e10b274c42f09e259016a46f32c34e06dc4fd"}, + {file = "onnxruntime-1.19.2-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c1dfe4f660a71b31caa81fc298a25f9612815215a47b286236e61d540350d7b6"}, + {file = "onnxruntime-1.19.2-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a36511dc07c5c964b916697e42e366fa43c48cdb3d3503578d78cef30417cb84"}, + {file = "onnxruntime-1.19.2-cp311-cp311-win32.whl", hash = "sha256:50cbb8dc69d6befad4746a69760e5b00cc3ff0a59c6c3fb27f8afa20e2cab7e7"}, + {file = "onnxruntime-1.19.2-cp311-cp311-win_amd64.whl", hash = "sha256:1c3e5d415b78337fa0b1b75291e9ea9fb2a4c1f148eb5811e7212fed02cfffa8"}, + {file = "onnxruntime-1.19.2-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:68e7051bef9cfefcbb858d2d2646536829894d72a4130c24019219442b1dd2ed"}, + {file = "onnxruntime-1.19.2-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d2d366fbcc205ce68a8a3bde2185fd15c604d9645888703785b61ef174265168"}, + {file = "onnxruntime-1.19.2-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:477b93df4db467e9cbf34051662a4b27c18e131fa1836e05974eae0d6e4cf29b"}, + {file = "onnxruntime-1.19.2-cp312-cp312-win32.whl", hash = "sha256:9a174073dc5608fad05f7cf7f320b52e8035e73d80b0a23c80f840e5a97c0147"}, + {file = "onnxruntime-1.19.2-cp312-cp312-win_amd64.whl", hash = "sha256:190103273ea4507638ffc31d66a980594b237874b65379e273125150eb044857"}, + {file = "onnxruntime-1.19.2-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:636bc1d4cc051d40bc52e1f9da87fbb9c57d9d47164695dfb1c41646ea51ea66"}, + {file = "onnxruntime-1.19.2-cp38-cp38-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5bd8b875757ea941cbcfe01582970cc299893d1b65bd56731e326a8333f638a3"}, + {file = "onnxruntime-1.19.2-cp38-cp38-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b2046fc9560f97947bbc1acbe4c6d48585ef0f12742744307d3364b131ac5778"}, + {file = "onnxruntime-1.19.2-cp38-cp38-win32.whl", hash = "sha256:31c12840b1cde4ac1f7d27d540c44e13e34f2345cf3642762d2a3333621abb6a"}, + {file = "onnxruntime-1.19.2-cp38-cp38-win_amd64.whl", hash = "sha256:016229660adea180e9a32ce218b95f8f84860a200f0f13b50070d7d90e92956c"}, + {file = "onnxruntime-1.19.2-cp39-cp39-macosx_11_0_universal2.whl", hash = "sha256:006c8d326835c017a9e9f74c9c77ebb570a71174a1e89fe078b29a557d9c3848"}, + {file = "onnxruntime-1.19.2-cp39-cp39-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:df2a94179a42d530b936f154615b54748239c2908ee44f0d722cb4df10670f68"}, + {file = "onnxruntime-1.19.2-cp39-cp39-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fae4b4de45894b9ce7ae418c5484cbf0341db6813effec01bb2216091c52f7fb"}, + {file = "onnxruntime-1.19.2-cp39-cp39-win32.whl", hash = "sha256:dc5430f473e8706fff837ae01323be9dcfddd3ea471c900a91fa7c9b807ec5d3"}, + {file = "onnxruntime-1.19.2-cp39-cp39-win_amd64.whl", hash = "sha256:38475e29a95c5f6c62c2c603d69fc7d4c6ccbf4df602bd567b86ae1138881c49"}, +] + +[package.dependencies] +coloredlogs = "*" +flatbuffers = "*" +numpy = ">=1.21.6" +packaging = "*" +protobuf = "*" +sympy = "*" + [[package]] name = "opencv-python" version = "4.10.0.84" @@ -2317,12 +2377,10 @@ files = [ [package.dependencies] numpy = [ - {version = ">=1.21.0", markers = "python_version == \"3.9\" and platform_system == \"Darwin\" and platform_machine == \"arm64\""}, + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, {version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""}, {version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""}, - {version = ">=1.19.3", markers = "platform_system == \"Linux\" and platform_machine == \"aarch64\" and python_version >= \"3.8\" and python_version < \"3.10\" or python_version > \"3.9\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_system != \"Darwin\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_machine != \"arm64\" and python_version < \"3.10\""}, {version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, - {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] [[package]] @@ -2387,9 +2445,9 @@ files = [ [package.dependencies] numpy = [ + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, {version = ">=1.22.4", markers = "python_version < \"3.11\""}, {version = ">=1.23.2", markers = "python_version == \"3.11\""}, - {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -2448,20 +2506,20 @@ testing = ["docopt", "pytest"] [[package]] name = "pdftext" -version = "0.3.10" +version = "0.3.13" description = "Extract structured text from pdfs quickly" optional = false python-versions = "!=2.7.*,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,!=3.7.*,!=3.8.*,>=3.9" files = [ - {file = "pdftext-0.3.10-py3-none-any.whl", hash = "sha256:99bd900d0d0692df06719c07ce10a859750ade3eb7f10c543f637118417497f9"}, - {file = "pdftext-0.3.10.tar.gz", hash = "sha256:90de726e818fb5683a0616cabb1a75a32a7224e873c3058006c93da6e440c66c"}, + {file = "pdftext-0.3.13-py3-none-any.whl", hash = "sha256:ae8f6876cdbbc1fe611527bb362cd3d584b4c8ec9370215560f2a01be4343bbc"}, + {file = "pdftext-0.3.13.tar.gz", hash = "sha256:a37ceb759ac0da34c48f85ab5d43d0b128ad9526f949e98b96568495c7be4187"}, ] [package.dependencies] +onnxruntime = ">=1.19.2,<2.0.0" pydantic = ">=2.7.1,<3.0.0" pydantic-settings = ">=2.2.1,<3.0.0" pypdfium2 = ">=4.29.0,<5.0.0" -scikit-learn = ">=1.4.2,<2.0.0" [[package]] name = "pexpect" @@ -2756,119 +2814,123 @@ files = [ [[package]] name = "pydantic" -version = "2.8.2" +version = "2.9.2" description = "Data validation using Python type hints" optional = false python-versions = ">=3.8" files = [ - {file = "pydantic-2.8.2-py3-none-any.whl", hash = "sha256:73ee9fddd406dc318b885c7a2eab8a6472b68b8fb5ba8150949fc3db939f23c8"}, - {file = "pydantic-2.8.2.tar.gz", hash = "sha256:6f62c13d067b0755ad1c21a34bdd06c0c12625a22b0fc09c6b149816604f7c2a"}, + {file = "pydantic-2.9.2-py3-none-any.whl", hash = "sha256:f048cec7b26778210e28a0459867920654d48e5e62db0958433636cde4254f12"}, + {file = "pydantic-2.9.2.tar.gz", hash = "sha256:d155cef71265d1e9807ed1c32b4c8deec042a44a50a4188b25ac67ecd81a9c0f"}, ] [package.dependencies] -annotated-types = ">=0.4.0" -pydantic-core = "2.20.1" -typing-extensions = {version = ">=4.6.1", markers = "python_version < \"3.13\""} +annotated-types = ">=0.6.0" +pydantic-core = "2.23.4" +typing-extensions = [ + {version = ">=4.12.2", markers = "python_version >= \"3.13\""}, + {version = ">=4.6.1", markers = "python_version < \"3.13\""}, +] [package.extras] email = ["email-validator (>=2.0.0)"] +timezone = ["tzdata"] [[package]] name = "pydantic-core" -version = "2.20.1" +version = "2.23.4" description = "Core functionality for Pydantic validation and serialization" optional = false python-versions = ">=3.8" files = [ - {file = "pydantic_core-2.20.1-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:3acae97ffd19bf091c72df4d726d552c473f3576409b2a7ca36b2f535ffff4a3"}, - {file = "pydantic_core-2.20.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:41f4c96227a67a013e7de5ff8f20fb496ce573893b7f4f2707d065907bffdbd6"}, - {file = "pydantic_core-2.20.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5f239eb799a2081495ea659d8d4a43a8f42cd1fe9ff2e7e436295c38a10c286a"}, - {file = "pydantic_core-2.20.1-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:53e431da3fc53360db73eedf6f7124d1076e1b4ee4276b36fb25514544ceb4a3"}, - {file = "pydantic_core-2.20.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f1f62b2413c3a0e846c3b838b2ecd6c7a19ec6793b2a522745b0869e37ab5bc1"}, - {file = "pydantic_core-2.20.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5d41e6daee2813ecceea8eda38062d69e280b39df793f5a942fa515b8ed67953"}, - {file = "pydantic_core-2.20.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3d482efec8b7dc6bfaedc0f166b2ce349df0011f5d2f1f25537ced4cfc34fd98"}, - {file = "pydantic_core-2.20.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e93e1a4b4b33daed65d781a57a522ff153dcf748dee70b40c7258c5861e1768a"}, - {file = "pydantic_core-2.20.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:e7c4ea22b6739b162c9ecaaa41d718dfad48a244909fe7ef4b54c0b530effc5a"}, - {file = "pydantic_core-2.20.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:4f2790949cf385d985a31984907fecb3896999329103df4e4983a4a41e13e840"}, - {file = "pydantic_core-2.20.1-cp310-none-win32.whl", hash = "sha256:5e999ba8dd90e93d57410c5e67ebb67ffcaadcea0ad973240fdfd3a135506250"}, - {file = "pydantic_core-2.20.1-cp310-none-win_amd64.whl", hash = "sha256:512ecfbefef6dac7bc5eaaf46177b2de58cdf7acac8793fe033b24ece0b9566c"}, - {file = "pydantic_core-2.20.1-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:d2a8fa9d6d6f891f3deec72f5cc668e6f66b188ab14bb1ab52422fe8e644f312"}, - {file = "pydantic_core-2.20.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:175873691124f3d0da55aeea1d90660a6ea7a3cfea137c38afa0a5ffabe37b88"}, - {file = "pydantic_core-2.20.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:37eee5b638f0e0dcd18d21f59b679686bbd18917b87db0193ae36f9c23c355fc"}, - {file = "pydantic_core-2.20.1-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:25e9185e2d06c16ee438ed39bf62935ec436474a6ac4f9358524220f1b236e43"}, - {file = "pydantic_core-2.20.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:150906b40ff188a3260cbee25380e7494ee85048584998c1e66df0c7a11c17a6"}, - {file = "pydantic_core-2.20.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8ad4aeb3e9a97286573c03df758fc7627aecdd02f1da04516a86dc159bf70121"}, - {file = "pydantic_core-2.20.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d3f3ed29cd9f978c604708511a1f9c2fdcb6c38b9aae36a51905b8811ee5cbf1"}, - {file = "pydantic_core-2.20.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b0dae11d8f5ded51699c74d9548dcc5938e0804cc8298ec0aa0da95c21fff57b"}, - {file = "pydantic_core-2.20.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:faa6b09ee09433b87992fb5a2859efd1c264ddc37280d2dd5db502126d0e7f27"}, - {file = "pydantic_core-2.20.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:9dc1b507c12eb0481d071f3c1808f0529ad41dc415d0ca11f7ebfc666e66a18b"}, - {file = "pydantic_core-2.20.1-cp311-none-win32.whl", hash = "sha256:fa2fddcb7107e0d1808086ca306dcade7df60a13a6c347a7acf1ec139aa6789a"}, - {file = "pydantic_core-2.20.1-cp311-none-win_amd64.whl", hash = "sha256:40a783fb7ee353c50bd3853e626f15677ea527ae556429453685ae32280c19c2"}, - {file = "pydantic_core-2.20.1-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:595ba5be69b35777474fa07f80fc260ea71255656191adb22a8c53aba4479231"}, - {file = "pydantic_core-2.20.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a4f55095ad087474999ee28d3398bae183a66be4823f753cd7d67dd0153427c9"}, - {file = "pydantic_core-2.20.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f9aa05d09ecf4c75157197f27cdc9cfaeb7c5f15021c6373932bf3e124af029f"}, - {file = "pydantic_core-2.20.1-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e97fdf088d4b31ff4ba35db26d9cc472ac7ef4a2ff2badeabf8d727b3377fc52"}, - {file = "pydantic_core-2.20.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bc633a9fe1eb87e250b5c57d389cf28998e4292336926b0b6cdaee353f89a237"}, - {file = "pydantic_core-2.20.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d573faf8eb7e6b1cbbcb4f5b247c60ca8be39fe2c674495df0eb4318303137fe"}, - {file = "pydantic_core-2.20.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:26dc97754b57d2fd00ac2b24dfa341abffc380b823211994c4efac7f13b9e90e"}, - {file = "pydantic_core-2.20.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:33499e85e739a4b60c9dac710c20a08dc73cb3240c9a0e22325e671b27b70d24"}, - {file = "pydantic_core-2.20.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:bebb4d6715c814597f85297c332297c6ce81e29436125ca59d1159b07f423eb1"}, - {file = "pydantic_core-2.20.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:516d9227919612425c8ef1c9b869bbbee249bc91912c8aaffb66116c0b447ebd"}, - {file = "pydantic_core-2.20.1-cp312-none-win32.whl", hash = "sha256:469f29f9093c9d834432034d33f5fe45699e664f12a13bf38c04967ce233d688"}, - {file = "pydantic_core-2.20.1-cp312-none-win_amd64.whl", hash = "sha256:035ede2e16da7281041f0e626459bcae33ed998cca6a0a007a5ebb73414ac72d"}, - {file = "pydantic_core-2.20.1-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:0827505a5c87e8aa285dc31e9ec7f4a17c81a813d45f70b1d9164e03a813a686"}, - {file = "pydantic_core-2.20.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:19c0fa39fa154e7e0b7f82f88ef85faa2a4c23cc65aae2f5aea625e3c13c735a"}, - {file = "pydantic_core-2.20.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4aa223cd1e36b642092c326d694d8bf59b71ddddc94cdb752bbbb1c5c91d833b"}, - {file = "pydantic_core-2.20.1-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c336a6d235522a62fef872c6295a42ecb0c4e1d0f1a3e500fe949415761b8a19"}, - {file = "pydantic_core-2.20.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7eb6a0587eded33aeefea9f916899d42b1799b7b14b8f8ff2753c0ac1741edac"}, - {file = "pydantic_core-2.20.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:70c8daf4faca8da5a6d655f9af86faf6ec2e1768f4b8b9d0226c02f3d6209703"}, - {file = "pydantic_core-2.20.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e9fa4c9bf273ca41f940bceb86922a7667cd5bf90e95dbb157cbb8441008482c"}, - {file = "pydantic_core-2.20.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:11b71d67b4725e7e2a9f6e9c0ac1239bbc0c48cce3dc59f98635efc57d6dac83"}, - {file = "pydantic_core-2.20.1-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:270755f15174fb983890c49881e93f8f1b80f0b5e3a3cc1394a255706cabd203"}, - {file = "pydantic_core-2.20.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:c81131869240e3e568916ef4c307f8b99583efaa60a8112ef27a366eefba8ef0"}, - {file = "pydantic_core-2.20.1-cp313-none-win32.whl", hash = "sha256:b91ced227c41aa29c672814f50dbb05ec93536abf8f43cd14ec9521ea09afe4e"}, - {file = "pydantic_core-2.20.1-cp313-none-win_amd64.whl", hash = "sha256:65db0f2eefcaad1a3950f498aabb4875c8890438bc80b19362cf633b87a8ab20"}, - {file = "pydantic_core-2.20.1-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:4745f4ac52cc6686390c40eaa01d48b18997cb130833154801a442323cc78f91"}, - {file = "pydantic_core-2.20.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a8ad4c766d3f33ba8fd692f9aa297c9058970530a32c728a2c4bfd2616d3358b"}, - {file = "pydantic_core-2.20.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:41e81317dd6a0127cabce83c0c9c3fbecceae981c8391e6f1dec88a77c8a569a"}, - {file = "pydantic_core-2.20.1-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:04024d270cf63f586ad41fff13fde4311c4fc13ea74676962c876d9577bcc78f"}, - {file = "pydantic_core-2.20.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:eaad4ff2de1c3823fddf82f41121bdf453d922e9a238642b1dedb33c4e4f98ad"}, - {file = "pydantic_core-2.20.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:26ab812fa0c845df815e506be30337e2df27e88399b985d0bb4e3ecfe72df31c"}, - {file = "pydantic_core-2.20.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c5ebac750d9d5f2706654c638c041635c385596caf68f81342011ddfa1e5598"}, - {file = "pydantic_core-2.20.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2aafc5a503855ea5885559eae883978c9b6d8c8993d67766ee73d82e841300dd"}, - {file = "pydantic_core-2.20.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:4868f6bd7c9d98904b748a2653031fc9c2f85b6237009d475b1008bfaeb0a5aa"}, - {file = "pydantic_core-2.20.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:aa2f457b4af386254372dfa78a2eda2563680d982422641a85f271c859df1987"}, - {file = "pydantic_core-2.20.1-cp38-none-win32.whl", hash = "sha256:225b67a1f6d602de0ce7f6c1c3ae89a4aa25d3de9be857999e9124f15dab486a"}, - {file = "pydantic_core-2.20.1-cp38-none-win_amd64.whl", hash = "sha256:6b507132dcfc0dea440cce23ee2182c0ce7aba7054576efc65634f080dbe9434"}, - {file = "pydantic_core-2.20.1-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:b03f7941783b4c4a26051846dea594628b38f6940a2fdc0df00b221aed39314c"}, - {file = "pydantic_core-2.20.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1eedfeb6089ed3fad42e81a67755846ad4dcc14d73698c120a82e4ccf0f1f9f6"}, - {file = "pydantic_core-2.20.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:635fee4e041ab9c479e31edda27fcf966ea9614fff1317e280d99eb3e5ab6fe2"}, - {file = "pydantic_core-2.20.1-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:77bf3ac639c1ff567ae3b47f8d4cc3dc20f9966a2a6dd2311dcc055d3d04fb8a"}, - {file = "pydantic_core-2.20.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7ed1b0132f24beeec5a78b67d9388656d03e6a7c837394f99257e2d55b461611"}, - {file = "pydantic_core-2.20.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c6514f963b023aeee506678a1cf821fe31159b925c4b76fe2afa94cc70b3222b"}, - {file = "pydantic_core-2.20.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:10d4204d8ca33146e761c79f83cc861df20e7ae9f6487ca290a97702daf56006"}, - {file = "pydantic_core-2.20.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2d036c7187b9422ae5b262badb87a20a49eb6c5238b2004e96d4da1231badef1"}, - {file = "pydantic_core-2.20.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:9ebfef07dbe1d93efb94b4700f2d278494e9162565a54f124c404a5656d7ff09"}, - {file = "pydantic_core-2.20.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:6b9d9bb600328a1ce523ab4f454859e9d439150abb0906c5a1983c146580ebab"}, - {file = "pydantic_core-2.20.1-cp39-none-win32.whl", hash = "sha256:784c1214cb6dd1e3b15dd8b91b9a53852aed16671cc3fbe4786f4f1db07089e2"}, - {file = "pydantic_core-2.20.1-cp39-none-win_amd64.whl", hash = "sha256:d2fe69c5434391727efa54b47a1e7986bb0186e72a41b203df8f5b0a19a4f669"}, - {file = "pydantic_core-2.20.1-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:a45f84b09ac9c3d35dfcf6a27fd0634d30d183205230a0ebe8373a0e8cfa0906"}, - {file = "pydantic_core-2.20.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:d02a72df14dfdbaf228424573a07af10637bd490f0901cee872c4f434a735b94"}, - {file = "pydantic_core-2.20.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d2b27e6af28f07e2f195552b37d7d66b150adbaa39a6d327766ffd695799780f"}, - {file = "pydantic_core-2.20.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:084659fac3c83fd674596612aeff6041a18402f1e1bc19ca39e417d554468482"}, - {file = "pydantic_core-2.20.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:242b8feb3c493ab78be289c034a1f659e8826e2233786e36f2893a950a719bb6"}, - {file = "pydantic_core-2.20.1-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:38cf1c40a921d05c5edc61a785c0ddb4bed67827069f535d794ce6bcded919fc"}, - {file = "pydantic_core-2.20.1-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:e0bbdd76ce9aa5d4209d65f2b27fc6e5ef1312ae6c5333c26db3f5ade53a1e99"}, - {file = "pydantic_core-2.20.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:254ec27fdb5b1ee60684f91683be95e5133c994cc54e86a0b0963afa25c8f8a6"}, - {file = "pydantic_core-2.20.1-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:407653af5617f0757261ae249d3fba09504d7a71ab36ac057c938572d1bc9331"}, - {file = "pydantic_core-2.20.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:c693e916709c2465b02ca0ad7b387c4f8423d1db7b4649c551f27a529181c5ad"}, - {file = "pydantic_core-2.20.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5b5ff4911aea936a47d9376fd3ab17e970cc543d1b68921886e7f64bd28308d1"}, - {file = "pydantic_core-2.20.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:177f55a886d74f1808763976ac4efd29b7ed15c69f4d838bbd74d9d09cf6fa86"}, - {file = "pydantic_core-2.20.1-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:964faa8a861d2664f0c7ab0c181af0bea66098b1919439815ca8803ef136fc4e"}, - {file = "pydantic_core-2.20.1-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:4dd484681c15e6b9a977c785a345d3e378d72678fd5f1f3c0509608da24f2ac0"}, - {file = "pydantic_core-2.20.1-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:f6d6cff3538391e8486a431569b77921adfcdef14eb18fbf19b7c0a5294d4e6a"}, - {file = "pydantic_core-2.20.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:a6d511cc297ff0883bc3708b465ff82d7560193169a8b93260f74ecb0a5e08a7"}, - {file = "pydantic_core-2.20.1.tar.gz", hash = "sha256:26ca695eeee5f9f1aeeb211ffc12f10bcb6f71e2989988fda61dabd65db878d4"}, + {file = "pydantic_core-2.23.4-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:b10bd51f823d891193d4717448fab065733958bdb6a6b351967bd349d48d5c9b"}, + {file = "pydantic_core-2.23.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4fc714bdbfb534f94034efaa6eadd74e5b93c8fa6315565a222f7b6f42ca1166"}, + {file = "pydantic_core-2.23.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:63e46b3169866bd62849936de036f901a9356e36376079b05efa83caeaa02ceb"}, + {file = "pydantic_core-2.23.4-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ed1a53de42fbe34853ba90513cea21673481cd81ed1be739f7f2efb931b24916"}, + {file = "pydantic_core-2.23.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cfdd16ab5e59fc31b5e906d1a3f666571abc367598e3e02c83403acabc092e07"}, + {file = "pydantic_core-2.23.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:255a8ef062cbf6674450e668482456abac99a5583bbafb73f9ad469540a3a232"}, + {file = "pydantic_core-2.23.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4a7cd62e831afe623fbb7aabbb4fe583212115b3ef38a9f6b71869ba644624a2"}, + {file = "pydantic_core-2.23.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f09e2ff1f17c2b51f2bc76d1cc33da96298f0a036a137f5440ab3ec5360b624f"}, + {file = "pydantic_core-2.23.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:e38e63e6f3d1cec5a27e0afe90a085af8b6806ee208b33030e65b6516353f1a3"}, + {file = "pydantic_core-2.23.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:0dbd8dbed2085ed23b5c04afa29d8fd2771674223135dc9bc937f3c09284d071"}, + {file = "pydantic_core-2.23.4-cp310-none-win32.whl", hash = "sha256:6531b7ca5f951d663c339002e91aaebda765ec7d61b7d1e3991051906ddde119"}, + {file = "pydantic_core-2.23.4-cp310-none-win_amd64.whl", hash = "sha256:7c9129eb40958b3d4500fa2467e6a83356b3b61bfff1b414c7361d9220f9ae8f"}, + {file = "pydantic_core-2.23.4-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:77733e3892bb0a7fa797826361ce8a9184d25c8dffaec60b7ffe928153680ba8"}, + {file = "pydantic_core-2.23.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1b84d168f6c48fabd1f2027a3d1bdfe62f92cade1fb273a5d68e621da0e44e6d"}, + {file = "pydantic_core-2.23.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:df49e7a0861a8c36d089c1ed57d308623d60416dab2647a4a17fe050ba85de0e"}, + {file = "pydantic_core-2.23.4-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ff02b6d461a6de369f07ec15e465a88895f3223eb75073ffea56b84d9331f607"}, + {file = "pydantic_core-2.23.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:996a38a83508c54c78a5f41456b0103c30508fed9abcad0a59b876d7398f25fd"}, + {file = "pydantic_core-2.23.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d97683ddee4723ae8c95d1eddac7c192e8c552da0c73a925a89fa8649bf13eea"}, + {file = "pydantic_core-2.23.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:216f9b2d7713eb98cb83c80b9c794de1f6b7e3145eef40400c62e86cee5f4e1e"}, + {file = "pydantic_core-2.23.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:6f783e0ec4803c787bcea93e13e9932edab72068f68ecffdf86a99fd5918878b"}, + {file = "pydantic_core-2.23.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:d0776dea117cf5272382634bd2a5c1b6eb16767c223c6a5317cd3e2a757c61a0"}, + {file = "pydantic_core-2.23.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d5f7a395a8cf1621939692dba2a6b6a830efa6b3cee787d82c7de1ad2930de64"}, + {file = "pydantic_core-2.23.4-cp311-none-win32.whl", hash = "sha256:74b9127ffea03643e998e0c5ad9bd3811d3dac8c676e47db17b0ee7c3c3bf35f"}, + {file = "pydantic_core-2.23.4-cp311-none-win_amd64.whl", hash = "sha256:98d134c954828488b153d88ba1f34e14259284f256180ce659e8d83e9c05eaa3"}, + {file = "pydantic_core-2.23.4-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:f3e0da4ebaef65158d4dfd7d3678aad692f7666877df0002b8a522cdf088f231"}, + {file = "pydantic_core-2.23.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f69a8e0b033b747bb3e36a44e7732f0c99f7edd5cea723d45bc0d6e95377ffee"}, + {file = "pydantic_core-2.23.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:723314c1d51722ab28bfcd5240d858512ffd3116449c557a1336cbe3919beb87"}, + {file = "pydantic_core-2.23.4-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:bb2802e667b7051a1bebbfe93684841cc9351004e2badbd6411bf357ab8d5ac8"}, + {file = "pydantic_core-2.23.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d18ca8148bebe1b0a382a27a8ee60350091a6ddaf475fa05ef50dc35b5df6327"}, + {file = "pydantic_core-2.23.4-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:33e3d65a85a2a4a0dc3b092b938a4062b1a05f3a9abde65ea93b233bca0e03f2"}, + {file = "pydantic_core-2.23.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:128585782e5bfa515c590ccee4b727fb76925dd04a98864182b22e89a4e6ed36"}, + {file = "pydantic_core-2.23.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:68665f4c17edcceecc112dfed5dbe6f92261fb9d6054b47d01bf6371a6196126"}, + {file = "pydantic_core-2.23.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:20152074317d9bed6b7a95ade3b7d6054845d70584216160860425f4fbd5ee9e"}, + {file = "pydantic_core-2.23.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:9261d3ce84fa1d38ed649c3638feefeae23d32ba9182963e465d58d62203bd24"}, + {file = "pydantic_core-2.23.4-cp312-none-win32.whl", hash = "sha256:4ba762ed58e8d68657fc1281e9bb72e1c3e79cc5d464be146e260c541ec12d84"}, + {file = "pydantic_core-2.23.4-cp312-none-win_amd64.whl", hash = "sha256:97df63000f4fea395b2824da80e169731088656d1818a11b95f3b173747b6cd9"}, + {file = "pydantic_core-2.23.4-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:7530e201d10d7d14abce4fb54cfe5b94a0aefc87da539d0346a484ead376c3cc"}, + {file = "pydantic_core-2.23.4-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:df933278128ea1cd77772673c73954e53a1c95a4fdf41eef97c2b779271bd0bd"}, + {file = "pydantic_core-2.23.4-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0cb3da3fd1b6a5d0279a01877713dbda118a2a4fc6f0d821a57da2e464793f05"}, + {file = "pydantic_core-2.23.4-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:42c6dcb030aefb668a2b7009c85b27f90e51e6a3b4d5c9bc4c57631292015b0d"}, + {file = "pydantic_core-2.23.4-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:696dd8d674d6ce621ab9d45b205df149399e4bb9aa34102c970b721554828510"}, + {file = "pydantic_core-2.23.4-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2971bb5ffe72cc0f555c13e19b23c85b654dd2a8f7ab493c262071377bfce9f6"}, + {file = "pydantic_core-2.23.4-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8394d940e5d400d04cad4f75c0598665cbb81aecefaca82ca85bd28264af7f9b"}, + {file = "pydantic_core-2.23.4-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:0dff76e0602ca7d4cdaacc1ac4c005e0ce0dcfe095d5b5259163a80d3a10d327"}, + {file = "pydantic_core-2.23.4-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:7d32706badfe136888bdea71c0def994644e09fff0bfe47441deaed8e96fdbc6"}, + {file = "pydantic_core-2.23.4-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:ed541d70698978a20eb63d8c5d72f2cc6d7079d9d90f6b50bad07826f1320f5f"}, + {file = "pydantic_core-2.23.4-cp313-none-win32.whl", hash = "sha256:3d5639516376dce1940ea36edf408c554475369f5da2abd45d44621cb616f769"}, + {file = "pydantic_core-2.23.4-cp313-none-win_amd64.whl", hash = "sha256:5a1504ad17ba4210df3a045132a7baeeba5a200e930f57512ee02909fc5c4cb5"}, + {file = "pydantic_core-2.23.4-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:d4488a93b071c04dc20f5cecc3631fc78b9789dd72483ba15d423b5b3689b555"}, + {file = "pydantic_core-2.23.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:81965a16b675b35e1d09dd14df53f190f9129c0202356ed44ab2728b1c905658"}, + {file = "pydantic_core-2.23.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4ffa2ebd4c8530079140dd2d7f794a9d9a73cbb8e9d59ffe24c63436efa8f271"}, + {file = "pydantic_core-2.23.4-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:61817945f2fe7d166e75fbfb28004034b48e44878177fc54d81688e7b85a3665"}, + {file = "pydantic_core-2.23.4-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:29d2c342c4bc01b88402d60189f3df065fb0dda3654744d5a165a5288a657368"}, + {file = "pydantic_core-2.23.4-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5e11661ce0fd30a6790e8bcdf263b9ec5988e95e63cf901972107efc49218b13"}, + {file = "pydantic_core-2.23.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9d18368b137c6295db49ce7218b1a9ba15c5bc254c96d7c9f9e924a9bc7825ad"}, + {file = "pydantic_core-2.23.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ec4e55f79b1c4ffb2eecd8a0cfba9955a2588497d96851f4c8f99aa4a1d39b12"}, + {file = "pydantic_core-2.23.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:374a5e5049eda9e0a44c696c7ade3ff355f06b1fe0bb945ea3cac2bc336478a2"}, + {file = "pydantic_core-2.23.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:5c364564d17da23db1106787675fc7af45f2f7b58b4173bfdd105564e132e6fb"}, + {file = "pydantic_core-2.23.4-cp38-none-win32.whl", hash = "sha256:d7a80d21d613eec45e3d41eb22f8f94ddc758a6c4720842dc74c0581f54993d6"}, + {file = "pydantic_core-2.23.4-cp38-none-win_amd64.whl", hash = "sha256:5f5ff8d839f4566a474a969508fe1c5e59c31c80d9e140566f9a37bba7b8d556"}, + {file = "pydantic_core-2.23.4-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:a4fa4fc04dff799089689f4fd502ce7d59de529fc2f40a2c8836886c03e0175a"}, + {file = "pydantic_core-2.23.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0a7df63886be5e270da67e0966cf4afbae86069501d35c8c1b3b6c168f42cb36"}, + {file = "pydantic_core-2.23.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dcedcd19a557e182628afa1d553c3895a9f825b936415d0dbd3cd0bbcfd29b4b"}, + {file = "pydantic_core-2.23.4-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5f54b118ce5de9ac21c363d9b3caa6c800341e8c47a508787e5868c6b79c9323"}, + {file = "pydantic_core-2.23.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:86d2f57d3e1379a9525c5ab067b27dbb8a0642fb5d454e17a9ac434f9ce523e3"}, + {file = "pydantic_core-2.23.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:de6d1d1b9e5101508cb37ab0d972357cac5235f5c6533d1071964c47139257df"}, + {file = "pydantic_core-2.23.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1278e0d324f6908e872730c9102b0112477a7f7cf88b308e4fc36ce1bdb6d58c"}, + {file = "pydantic_core-2.23.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:9a6b5099eeec78827553827f4c6b8615978bb4b6a88e5d9b93eddf8bb6790f55"}, + {file = "pydantic_core-2.23.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:e55541f756f9b3ee346b840103f32779c695a19826a4c442b7954550a0972040"}, + {file = "pydantic_core-2.23.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a5c7ba8ffb6d6f8f2ab08743be203654bb1aaa8c9dcb09f82ddd34eadb695605"}, + {file = "pydantic_core-2.23.4-cp39-none-win32.whl", hash = "sha256:37b0fe330e4a58d3c58b24d91d1eb102aeec675a3db4c292ec3928ecd892a9a6"}, + {file = "pydantic_core-2.23.4-cp39-none-win_amd64.whl", hash = "sha256:1498bec4c05c9c787bde9125cfdcc63a41004ff167f495063191b863399b1a29"}, + {file = "pydantic_core-2.23.4-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:f455ee30a9d61d3e1a15abd5068827773d6e4dc513e795f380cdd59932c782d5"}, + {file = "pydantic_core-2.23.4-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:1e90d2e3bd2c3863d48525d297cd143fe541be8bbf6f579504b9712cb6b643ec"}, + {file = "pydantic_core-2.23.4-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2e203fdf807ac7e12ab59ca2bfcabb38c7cf0b33c41efeb00f8e5da1d86af480"}, + {file = "pydantic_core-2.23.4-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e08277a400de01bc72436a0ccd02bdf596631411f592ad985dcee21445bd0068"}, + {file = "pydantic_core-2.23.4-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f220b0eea5965dec25480b6333c788fb72ce5f9129e8759ef876a1d805d00801"}, + {file = "pydantic_core-2.23.4-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:d06b0c8da4f16d1d1e352134427cb194a0a6e19ad5db9161bf32b2113409e728"}, + {file = "pydantic_core-2.23.4-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:ba1a0996f6c2773bd83e63f18914c1de3c9dd26d55f4ac302a7efe93fb8e7433"}, + {file = "pydantic_core-2.23.4-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:9a5bce9d23aac8f0cf0836ecfc033896aa8443b501c58d0602dbfd5bd5b37753"}, + {file = "pydantic_core-2.23.4-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:78ddaaa81421a29574a682b3179d4cf9e6d405a09b99d93ddcf7e5239c742e21"}, + {file = "pydantic_core-2.23.4-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:883a91b5dd7d26492ff2f04f40fbb652de40fcc0afe07e8129e8ae779c2110eb"}, + {file = "pydantic_core-2.23.4-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:88ad334a15b32a791ea935af224b9de1bf99bcd62fabf745d5f3442199d86d59"}, + {file = "pydantic_core-2.23.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:233710f069d251feb12a56da21e14cca67994eab08362207785cf8c598e74577"}, + {file = "pydantic_core-2.23.4-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:19442362866a753485ba5e4be408964644dd6a09123d9416c54cd49171f50744"}, + {file = "pydantic_core-2.23.4-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:624e278a7d29b6445e4e813af92af37820fafb6dcc55c012c834f9e26f9aaaef"}, + {file = "pydantic_core-2.23.4-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:f5ef8f42bec47f21d07668a043f077d507e5bf4e668d5c6dfe6aaba89de1a5b8"}, + {file = "pydantic_core-2.23.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:aea443fffa9fbe3af1a9ba721a87f926fe548d32cab71d188a6ede77d0ff244e"}, + {file = "pydantic_core-2.23.4.tar.gz", hash = "sha256:2584f7cf844ac4d970fba483a717dbe10c1c1c96a969bf65d61ffe94df1b2863"}, ] [package.dependencies] @@ -2876,13 +2938,13 @@ typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" [[package]] name = "pydantic-settings" -version = "2.4.0" +version = "2.5.2" description = "Settings management using Pydantic" optional = false python-versions = ">=3.8" files = [ - {file = "pydantic_settings-2.4.0-py3-none-any.whl", hash = "sha256:bb6849dc067f1687574c12a639e231f3a6feeed0a12d710c1382045c5db1c315"}, - {file = "pydantic_settings-2.4.0.tar.gz", hash = "sha256:ed81c3a0f46392b4d7c0a565c05884e6e54b3456e6f0fe4d8814981172dc9a88"}, + {file = "pydantic_settings-2.5.2-py3-none-any.whl", hash = "sha256:2c912e55fd5794a59bf8c832b9de832dcfdf4778d79ff79b708744eed499a907"}, + {file = "pydantic_settings-2.5.2.tar.gz", hash = "sha256:f90b139682bee4d2065273d5185d71d37ea46cfe57e1b5ae184fc6a0b2484ca0"}, ] [package.dependencies] @@ -2949,6 +3011,20 @@ files = [ {file = "pypdfium2-4.30.0.tar.gz", hash = "sha256:48b5b7e5566665bc1015b9d69c1ebabe21f6aee468b509531c3c8318eeee2e16"}, ] +[[package]] +name = "pyreadline3" +version = "3.5.4" +description = "A python implementation of GNU readline." +optional = false +python-versions = ">=3.8" +files = [ + {file = "pyreadline3-3.5.4-py3-none-any.whl", hash = "sha256:eaf8e6cc3c49bcccf145fc6067ba8643d1df34d604a1ec0eccbf7a18e6d3fae6"}, + {file = "pyreadline3-3.5.4.tar.gz", hash = "sha256:8d57d53039a1c75adba8e50dd3d992b28143480816187ea5efbd5c78e6c885b7"}, +] + +[package.extras] +dev = ["build", "flake8", "mypy", "pytest", "twine"] + [[package]] name = "python-dateutil" version = "2.9.0.post0" @@ -3743,87 +3819,103 @@ torch = ["safetensors[numpy]", "torch (>=1.10)"] [[package]] name = "scikit-learn" -version = "1.4.2" +version = "1.5.2" description = "A set of python modules for machine learning and data mining" optional = false python-versions = ">=3.9" files = [ - {file = "scikit-learn-1.4.2.tar.gz", hash = "sha256:daa1c471d95bad080c6e44b4946c9390a4842adc3082572c20e4f8884e39e959"}, - {file = "scikit_learn-1.4.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8539a41b3d6d1af82eb629f9c57f37428ff1481c1e34dddb3b9d7af8ede67ac5"}, - {file = "scikit_learn-1.4.2-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:68b8404841f944a4a1459b07198fa2edd41a82f189b44f3e1d55c104dbc2e40c"}, - {file = "scikit_learn-1.4.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:81bf5d8bbe87643103334032dd82f7419bc8c8d02a763643a6b9a5c7288c5054"}, - {file = "scikit_learn-1.4.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:36f0ea5d0f693cb247a073d21a4123bdf4172e470e6d163c12b74cbb1536cf38"}, - {file = "scikit_learn-1.4.2-cp310-cp310-win_amd64.whl", hash = "sha256:87440e2e188c87db80ea4023440923dccbd56fbc2d557b18ced00fef79da0727"}, - {file = "scikit_learn-1.4.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:45dee87ac5309bb82e3ea633955030df9bbcb8d2cdb30383c6cd483691c546cc"}, - {file = "scikit_learn-1.4.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:1d0b25d9c651fd050555aadd57431b53d4cf664e749069da77f3d52c5ad14b3b"}, - {file = "scikit_learn-1.4.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b0203c368058ab92efc6168a1507d388d41469c873e96ec220ca8e74079bf62e"}, - {file = "scikit_learn-1.4.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:44c62f2b124848a28fd695db5bc4da019287abf390bfce602ddc8aa1ec186aae"}, - {file = "scikit_learn-1.4.2-cp311-cp311-win_amd64.whl", hash = "sha256:5cd7b524115499b18b63f0c96f4224eb885564937a0b3477531b2b63ce331904"}, - {file = "scikit_learn-1.4.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:90378e1747949f90c8f385898fff35d73193dfcaec3dd75d6b542f90c4e89755"}, - {file = "scikit_learn-1.4.2-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:ff4effe5a1d4e8fed260a83a163f7dbf4f6087b54528d8880bab1d1377bd78be"}, - {file = "scikit_learn-1.4.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:671e2f0c3f2c15409dae4f282a3a619601fa824d2c820e5b608d9d775f91780c"}, - {file = "scikit_learn-1.4.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d36d0bc983336bbc1be22f9b686b50c964f593c8a9a913a792442af9bf4f5e68"}, - {file = "scikit_learn-1.4.2-cp312-cp312-win_amd64.whl", hash = "sha256:d762070980c17ba3e9a4a1e043ba0518ce4c55152032f1af0ca6f39b376b5928"}, - {file = "scikit_learn-1.4.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:d9993d5e78a8148b1d0fdf5b15ed92452af5581734129998c26f481c46586d68"}, - {file = "scikit_learn-1.4.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:426d258fddac674fdf33f3cb2d54d26f49406e2599dbf9a32b4d1696091d4256"}, - {file = "scikit_learn-1.4.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5460a1a5b043ae5ae4596b3126a4ec33ccba1b51e7ca2c5d36dac2169f62ab1d"}, - {file = "scikit_learn-1.4.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:49d64ef6cb8c093d883e5a36c4766548d974898d378e395ba41a806d0e824db8"}, - {file = "scikit_learn-1.4.2-cp39-cp39-win_amd64.whl", hash = "sha256:c97a50b05c194be9146d61fe87dbf8eac62b203d9e87a3ccc6ae9aed2dfaf361"}, + {file = "scikit_learn-1.5.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:299406827fb9a4f862626d0fe6c122f5f87f8910b86fe5daa4c32dcd742139b6"}, + {file = "scikit_learn-1.5.2-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:2d4cad1119c77930b235579ad0dc25e65c917e756fe80cab96aa3b9428bd3fb0"}, + {file = "scikit_learn-1.5.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8c412ccc2ad9bf3755915e3908e677b367ebc8d010acbb3f182814524f2e5540"}, + {file = "scikit_learn-1.5.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a686885a4b3818d9e62904d91b57fa757fc2bed3e465c8b177be652f4dd37c8"}, + {file = "scikit_learn-1.5.2-cp310-cp310-win_amd64.whl", hash = "sha256:c15b1ca23d7c5f33cc2cb0a0d6aaacf893792271cddff0edbd6a40e8319bc113"}, + {file = "scikit_learn-1.5.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:03b6158efa3faaf1feea3faa884c840ebd61b6484167c711548fce208ea09445"}, + {file = "scikit_learn-1.5.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:1ff45e26928d3b4eb767a8f14a9a6efbf1cbff7c05d1fb0f95f211a89fd4f5de"}, + {file = "scikit_learn-1.5.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f763897fe92d0e903aa4847b0aec0e68cadfff77e8a0687cabd946c89d17e675"}, + {file = "scikit_learn-1.5.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f8b0ccd4a902836493e026c03256e8b206656f91fbcc4fde28c57a5b752561f1"}, + {file = "scikit_learn-1.5.2-cp311-cp311-win_amd64.whl", hash = "sha256:6c16d84a0d45e4894832b3c4d0bf73050939e21b99b01b6fd59cbb0cf39163b6"}, + {file = "scikit_learn-1.5.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f932a02c3f4956dfb981391ab24bda1dbd90fe3d628e4b42caef3e041c67707a"}, + {file = "scikit_learn-1.5.2-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:3b923d119d65b7bd555c73be5423bf06c0105678ce7e1f558cb4b40b0a5502b1"}, + {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f60021ec1574e56632be2a36b946f8143bf4e5e6af4a06d85281adc22938e0dd"}, + {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:394397841449853c2290a32050382edaec3da89e35b3e03d6cc966aebc6a8ae6"}, + {file = "scikit_learn-1.5.2-cp312-cp312-win_amd64.whl", hash = "sha256:57cc1786cfd6bd118220a92ede80270132aa353647684efa385a74244a41e3b1"}, + {file = "scikit_learn-1.5.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e9a702e2de732bbb20d3bad29ebd77fc05a6b427dc49964300340e4c9328b3f5"}, + {file = "scikit_learn-1.5.2-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:b0768ad641981f5d3a198430a1d31c3e044ed2e8a6f22166b4d546a5116d7908"}, + {file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:178ddd0a5cb0044464fc1bfc4cca5b1833bfc7bb022d70b05db8530da4bb3dd3"}, + {file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7284ade780084d94505632241bf78c44ab3b6f1e8ccab3d2af58e0e950f9c12"}, + {file = "scikit_learn-1.5.2-cp313-cp313-win_amd64.whl", hash = "sha256:b7b0f9a0b1040830d38c39b91b3a44e1b643f4b36e36567b80b7c6bd2202a27f"}, + {file = "scikit_learn-1.5.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:757c7d514ddb00ae249832fe87100d9c73c6ea91423802872d9e74970a0e40b9"}, + {file = "scikit_learn-1.5.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:52788f48b5d8bca5c0736c175fa6bdaab2ef00a8f536cda698db61bd89c551c1"}, + {file = "scikit_learn-1.5.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:643964678f4b5fbdc95cbf8aec638acc7aa70f5f79ee2cdad1eec3df4ba6ead8"}, + {file = "scikit_learn-1.5.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ca64b3089a6d9b9363cd3546f8978229dcbb737aceb2c12144ee3f70f95684b7"}, + {file = "scikit_learn-1.5.2-cp39-cp39-win_amd64.whl", hash = "sha256:3bed4909ba187aca80580fe2ef370d9180dcf18e621a27c4cf2ef10d279a7efe"}, + {file = "scikit_learn-1.5.2.tar.gz", hash = "sha256:b4237ed7b3fdd0a4882792e68ef2545d5baa50aca3bb45aa7df468138ad8f94d"}, ] [package.dependencies] joblib = ">=1.2.0" numpy = ">=1.19.5" scipy = ">=1.6.0" -threadpoolctl = ">=2.0.0" +threadpoolctl = ">=3.1.0" [package.extras] -benchmark = ["matplotlib (>=3.3.4)", "memory-profiler (>=0.57.0)", "pandas (>=1.1.5)"] -docs = ["Pillow (>=7.1.2)", "matplotlib (>=3.3.4)", "memory-profiler (>=0.57.0)", "numpydoc (>=1.2.0)", "pandas (>=1.1.5)", "plotly (>=5.14.0)", "pooch (>=1.6.0)", "scikit-image (>=0.17.2)", "seaborn (>=0.9.0)", "sphinx (>=6.0.0)", "sphinx-copybutton (>=0.5.2)", "sphinx-gallery (>=0.15.0)", "sphinx-prompt (>=1.3.0)", "sphinxext-opengraph (>=0.4.2)"] +benchmark = ["matplotlib (>=3.3.4)", "memory_profiler (>=0.57.0)", "pandas (>=1.1.5)"] +build = ["cython (>=3.0.10)", "meson-python (>=0.16.0)", "numpy (>=1.19.5)", "scipy (>=1.6.0)"] +docs = ["Pillow (>=7.1.2)", "matplotlib (>=3.3.4)", "memory_profiler (>=0.57.0)", "numpydoc (>=1.2.0)", "pandas (>=1.1.5)", "plotly (>=5.14.0)", "polars (>=0.20.30)", "pooch (>=1.6.0)", "pydata-sphinx-theme (>=0.15.3)", "scikit-image (>=0.17.2)", "seaborn (>=0.9.0)", "sphinx (>=7.3.7)", "sphinx-copybutton (>=0.5.2)", "sphinx-design (>=0.5.0)", "sphinx-design (>=0.6.0)", "sphinx-gallery (>=0.16.0)", "sphinx-prompt (>=1.4.0)", "sphinx-remove-toctrees (>=1.0.0.post1)", "sphinxcontrib-sass (>=0.3.4)", "sphinxext-opengraph (>=0.9.1)"] examples = ["matplotlib (>=3.3.4)", "pandas (>=1.1.5)", "plotly (>=5.14.0)", "pooch (>=1.6.0)", "scikit-image (>=0.17.2)", "seaborn (>=0.9.0)"] -tests = ["black (>=23.3.0)", "matplotlib (>=3.3.4)", "mypy (>=1.3)", "numpydoc (>=1.2.0)", "pandas (>=1.1.5)", "polars (>=0.19.12)", "pooch (>=1.6.0)", "pyamg (>=4.0.0)", "pyarrow (>=12.0.0)", "pytest (>=7.1.2)", "pytest-cov (>=2.9.0)", "ruff (>=0.0.272)", "scikit-image (>=0.17.2)"] +install = ["joblib (>=1.2.0)", "numpy (>=1.19.5)", "scipy (>=1.6.0)", "threadpoolctl (>=3.1.0)"] +maintenance = ["conda-lock (==2.5.6)"] +tests = ["black (>=24.3.0)", "matplotlib (>=3.3.4)", "mypy (>=1.9)", "numpydoc (>=1.2.0)", "pandas (>=1.1.5)", "polars (>=0.20.30)", "pooch (>=1.6.0)", "pyamg (>=4.0.0)", "pyarrow (>=12.0.0)", "pytest (>=7.1.2)", "pytest-cov (>=2.9.0)", "ruff (>=0.2.1)", "scikit-image (>=0.17.2)"] [[package]] name = "scipy" -version = "1.13.1" +version = "1.14.1" description = "Fundamental algorithms for scientific computing in Python" optional = false -python-versions = ">=3.9" -files = [ - {file = "scipy-1.13.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:20335853b85e9a49ff7572ab453794298bcf0354d8068c5f6775a0eabf350aca"}, - {file = "scipy-1.13.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:d605e9c23906d1994f55ace80e0125c587f96c020037ea6aa98d01b4bd2e222f"}, - {file = "scipy-1.13.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cfa31f1def5c819b19ecc3a8b52d28ffdcc7ed52bb20c9a7589669dd3c250989"}, - {file = "scipy-1.13.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f26264b282b9da0952a024ae34710c2aff7d27480ee91a2e82b7b7073c24722f"}, - {file = "scipy-1.13.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:eccfa1906eacc02de42d70ef4aecea45415f5be17e72b61bafcfd329bdc52e94"}, - {file = "scipy-1.13.1-cp310-cp310-win_amd64.whl", hash = "sha256:2831f0dc9c5ea9edd6e51e6e769b655f08ec6db6e2e10f86ef39bd32eb11da54"}, - {file = "scipy-1.13.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:27e52b09c0d3a1d5b63e1105f24177e544a222b43611aaf5bc44d4a0979e32f9"}, - {file = "scipy-1.13.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:54f430b00f0133e2224c3ba42b805bfd0086fe488835effa33fa291561932326"}, - {file = "scipy-1.13.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e89369d27f9e7b0884ae559a3a956e77c02114cc60a6058b4e5011572eea9299"}, - {file = "scipy-1.13.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a78b4b3345f1b6f68a763c6e25c0c9a23a9fd0f39f5f3d200efe8feda560a5fa"}, - {file = "scipy-1.13.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:45484bee6d65633752c490404513b9ef02475b4284c4cfab0ef946def50b3f59"}, - {file = "scipy-1.13.1-cp311-cp311-win_amd64.whl", hash = "sha256:5713f62f781eebd8d597eb3f88b8bf9274e79eeabf63afb4a737abc6c84ad37b"}, - {file = "scipy-1.13.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:5d72782f39716b2b3509cd7c33cdc08c96f2f4d2b06d51e52fb45a19ca0c86a1"}, - {file = "scipy-1.13.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:017367484ce5498445aade74b1d5ab377acdc65e27095155e448c88497755a5d"}, - {file = "scipy-1.13.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:949ae67db5fa78a86e8fa644b9a6b07252f449dcf74247108c50e1d20d2b4627"}, - {file = "scipy-1.13.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:de3ade0e53bc1f21358aa74ff4830235d716211d7d077e340c7349bc3542e884"}, - {file = "scipy-1.13.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:2ac65fb503dad64218c228e2dc2d0a0193f7904747db43014645ae139c8fad16"}, - {file = "scipy-1.13.1-cp312-cp312-win_amd64.whl", hash = "sha256:cdd7dacfb95fea358916410ec61bbc20440f7860333aee6d882bb8046264e949"}, - {file = "scipy-1.13.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:436bbb42a94a8aeef855d755ce5a465479c721e9d684de76bf61a62e7c2b81d5"}, - {file = "scipy-1.13.1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:8335549ebbca860c52bf3d02f80784e91a004b71b059e3eea9678ba994796a24"}, - {file = "scipy-1.13.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d533654b7d221a6a97304ab63c41c96473ff04459e404b83275b60aa8f4b7004"}, - {file = "scipy-1.13.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:637e98dcf185ba7f8e663e122ebf908c4702420477ae52a04f9908707456ba4d"}, - {file = "scipy-1.13.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a014c2b3697bde71724244f63de2476925596c24285c7a637364761f8710891c"}, - {file = "scipy-1.13.1-cp39-cp39-win_amd64.whl", hash = "sha256:392e4ec766654852c25ebad4f64e4e584cf19820b980bc04960bca0b0cd6eaa2"}, - {file = "scipy-1.13.1.tar.gz", hash = "sha256:095a87a0312b08dfd6a6155cbbd310a8c51800fc931b8c0b84003014b874ed3c"}, +python-versions = ">=3.10" +files = [ + {file = "scipy-1.14.1-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:b28d2ca4add7ac16ae8bb6632a3c86e4b9e4d52d3e34267f6e1b0c1f8d87e389"}, + {file = "scipy-1.14.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:d0d2821003174de06b69e58cef2316a6622b60ee613121199cb2852a873f8cf3"}, + {file = "scipy-1.14.1-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:8bddf15838ba768bb5f5083c1ea012d64c9a444e16192762bd858f1e126196d0"}, + {file = "scipy-1.14.1-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:97c5dddd5932bd2a1a31c927ba5e1463a53b87ca96b5c9bdf5dfd6096e27efc3"}, + {file = "scipy-1.14.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2ff0a7e01e422c15739ecd64432743cf7aae2b03f3084288f399affcefe5222d"}, + {file = "scipy-1.14.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8e32dced201274bf96899e6491d9ba3e9a5f6b336708656466ad0522d8528f69"}, + {file = "scipy-1.14.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:8426251ad1e4ad903a4514712d2fa8fdd5382c978010d1c6f5f37ef286a713ad"}, + {file = "scipy-1.14.1-cp310-cp310-win_amd64.whl", hash = "sha256:a49f6ed96f83966f576b33a44257d869756df6cf1ef4934f59dd58b25e0327e5"}, + {file = "scipy-1.14.1-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:2da0469a4ef0ecd3693761acbdc20f2fdeafb69e6819cc081308cc978153c675"}, + {file = "scipy-1.14.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:c0ee987efa6737242745f347835da2cc5bb9f1b42996a4d97d5c7ff7928cb6f2"}, + {file = "scipy-1.14.1-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:3a1b111fac6baec1c1d92f27e76511c9e7218f1695d61b59e05e0fe04dc59617"}, + {file = "scipy-1.14.1-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:8475230e55549ab3f207bff11ebfc91c805dc3463ef62eda3ccf593254524ce8"}, + {file = "scipy-1.14.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:278266012eb69f4a720827bdd2dc54b2271c97d84255b2faaa8f161a158c3b37"}, + {file = "scipy-1.14.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fef8c87f8abfb884dac04e97824b61299880c43f4ce675dd2cbeadd3c9b466d2"}, + {file = "scipy-1.14.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:b05d43735bb2f07d689f56f7b474788a13ed8adc484a85aa65c0fd931cf9ccd2"}, + {file = "scipy-1.14.1-cp311-cp311-win_amd64.whl", hash = "sha256:716e389b694c4bb564b4fc0c51bc84d381735e0d39d3f26ec1af2556ec6aad94"}, + {file = "scipy-1.14.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:631f07b3734d34aced009aaf6fedfd0eb3498a97e581c3b1e5f14a04164a456d"}, + {file = "scipy-1.14.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:af29a935803cc707ab2ed7791c44288a682f9c8107bc00f0eccc4f92c08d6e07"}, + {file = "scipy-1.14.1-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:2843f2d527d9eebec9a43e6b406fb7266f3af25a751aa91d62ff416f54170bc5"}, + {file = "scipy-1.14.1-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:eb58ca0abd96911932f688528977858681a59d61a7ce908ffd355957f7025cfc"}, + {file = "scipy-1.14.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:30ac8812c1d2aab7131a79ba62933a2a76f582d5dbbc695192453dae67ad6310"}, + {file = "scipy-1.14.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f9ea80f2e65bdaa0b7627fb00cbeb2daf163caa015e59b7516395fe3bd1e066"}, + {file = "scipy-1.14.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:edaf02b82cd7639db00dbff629995ef185c8df4c3ffa71a5562a595765a06ce1"}, + {file = "scipy-1.14.1-cp312-cp312-win_amd64.whl", hash = "sha256:2ff38e22128e6c03ff73b6bb0f85f897d2362f8c052e3b8ad00532198fbdae3f"}, + {file = "scipy-1.14.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:1729560c906963fc8389f6aac023739ff3983e727b1a4d87696b7bf108316a79"}, + {file = "scipy-1.14.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:4079b90df244709e675cdc8b93bfd8a395d59af40b72e339c2287c91860deb8e"}, + {file = "scipy-1.14.1-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:e0cf28db0f24a38b2a0ca33a85a54852586e43cf6fd876365c86e0657cfe7d73"}, + {file = "scipy-1.14.1-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:0c2f95de3b04e26f5f3ad5bb05e74ba7f68b837133a4492414b3afd79dfe540e"}, + {file = "scipy-1.14.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b99722ea48b7ea25e8e015e8341ae74624f72e5f21fc2abd45f3a93266de4c5d"}, + {file = "scipy-1.14.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5149e3fd2d686e42144a093b206aef01932a0059c2a33ddfa67f5f035bdfe13e"}, + {file = "scipy-1.14.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e4f5a7c49323533f9103d4dacf4e4f07078f360743dec7f7596949149efeec06"}, + {file = "scipy-1.14.1-cp313-cp313-win_amd64.whl", hash = "sha256:baff393942b550823bfce952bb62270ee17504d02a1801d7fd0719534dfb9c84"}, + {file = "scipy-1.14.1.tar.gz", hash = "sha256:5a275584e726026a5699459aa72f828a610821006228e841b94275c4a7c08417"}, ] [package.dependencies] -numpy = ">=1.22.4,<2.3" +numpy = ">=1.23.5,<2.3" [package.extras] -dev = ["cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy", "pycodestyle", "pydevtool", "rich-click", "ruff", "types-psutil", "typing_extensions"] -doc = ["jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.12.0)", "jupytext", "matplotlib (>=3.5)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0)", "sphinx-design (>=0.4.0)"] -test = ["array-api-strict", "asv", "gmpy2", "hypothesis (>=6.30)", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] +dev = ["cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy (==1.10.0)", "pycodestyle", "pydevtool", "rich-click", "ruff (>=0.0.292)", "types-psutil", "typing_extensions"] +doc = ["jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.13.1)", "jupytext", "matplotlib (>=3.5)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0,<=7.3.7)", "sphinx-design (>=0.4.0)"] +test = ["Cython", "array-api-strict (>=2.0)", "asv", "gmpy2", "hypothesis (>=6.30)", "meson", "mpmath", "ninja", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] [[package]] name = "send2trash" @@ -3956,19 +4048,20 @@ snowflake = ["snowflake-connector-python (>=2.8.0)", "snowflake-snowpark-python [[package]] name = "surya-ocr" -version = "0.5.0" -description = "OCR, layout, reading order, and line detection in 90+ languages" +version = "0.6.3" +description = "OCR, layout, reading order, and table recognition in 90+ languages" optional = false -python-versions = "!=2.7.*,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,!=3.7.*,!=3.8.*,>=3.9" +python-versions = ">=3.10" files = [ - {file = "surya_ocr-0.5.0-py3-none-any.whl", hash = "sha256:e70516d74f3816c5b2a61bdf8f7eeb5fbd5670514bc5ae2eb0947d33c60c22d3"}, - {file = "surya_ocr-0.5.0.tar.gz", hash = "sha256:a80740c2b000d9630cf3d5525043c95096efaeb6b0892254ff32339a171e789a"}, + {file = "surya_ocr-0.6.3-py3-none-any.whl", hash = "sha256:f4d98e643ed6003a1fed2a758bed391ffc7be908c849d3ab741b05c4d6a714a2"}, + {file = "surya_ocr-0.6.3.tar.gz", hash = "sha256:cf0e9382352eaf96ff74fe0ca5daff30f96f0897bb481ff418a8ae1a7ce31534"}, ] [package.dependencies] filetype = ">=1.2.0,<2.0.0" ftfy = ">=6.1.3,<7.0.0" opencv-python = ">=4.9.0.80,<5.0.0.0" +pdftext = ">=0.3.12,<0.4.0" pillow = ">=10.2.0,<11.0.0" pydantic = ">=2.5.3,<3.0.0" pydantic-settings = ">=2.1.0,<3.0.0" @@ -3995,6 +4088,27 @@ mpmath = ">=1.1.0,<1.4" [package.extras] dev = ["hypothesis (>=6.70.0)", "pytest (>=7.1.0)"] +[[package]] +name = "tabled-pdf" +version = "0.1.0" +description = "Detect and recognize tables in PDFs and images." +optional = false +python-versions = "<4.0,>=3.10" +files = [ + {file = "tabled_pdf-0.1.0-py3-none-any.whl", hash = "sha256:95e3e5863cfbe829c9f233e3e9dc31be8c5f24ffd2367f57e983e710aeee659e"}, + {file = "tabled_pdf-0.1.0.tar.gz", hash = "sha256:63a2c7d3ae55b3e7e467c2fbad9d78c7c57e31810324fc584cbf322e8e026890"}, +] + +[package.dependencies] +click = ">=8.1.7,<9.0.0" +pydantic = ">=2.9.2,<3.0.0" +pydantic-settings = ">=2.5.2,<3.0.0" +pypdfium2 = ">=4.30.0,<5.0.0" +python-dotenv = ">=1.0.1,<2.0.0" +scikit-learn = ">=1.5.2,<2.0.0" +surya-ocr = ">=0.6.3,<0.7.0" +tabulate = ">=0.9.0,<0.10.0" + [[package]] name = "tabulate" version = "0.9.0" @@ -4858,22 +4972,7 @@ files = [ idna = ">=2.0" multidict = ">=4.0" -[[package]] -name = "zipp" -version = "3.19.2" -description = "Backport of pathlib-compatible object wrapper for zip files" -optional = false -python-versions = ">=3.8" -files = [ - {file = "zipp-3.19.2-py3-none-any.whl", hash = "sha256:f091755f667055f2d02b32c53771a7a6c8b47e1fdbc4b72a8b9072b3eef8015c"}, - {file = "zipp-3.19.2.tar.gz", hash = "sha256:bf1dcf6450f873a13e952a29504887c89e6de7506209e5b1bcc3460135d4de19"}, -] - -[package.extras] -doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] -test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-itertools", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"] - [metadata] lock-version = "2.0" -python-versions = ">=3.9,<3.13,!=3.9.7" -content-hash = "3f4bb2a0bfc8c717d377368f6e3fafcf7ef7d68030c6c16e0b3719dbdd9fca1f" +python-versions = "^3.10" +content-hash = "887985e53de36c13b8f82a96b1a93fea4ca6762db31bdcf9aa8147572c8a4771" diff --git a/pyproject.toml b/pyproject.toml index 92968ab8..8e2189b7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,25 +20,23 @@ include = [ ] [tool.poetry.dependencies] -python = ">=3.9,<3.13,!=3.9.7" -scikit-learn = "^1.3.2,<=1.4.2" +python = "^3.10" Pillow = "^10.1.0" pydantic = "^2.4.2" pydantic-settings = "^2.0.3" -transformers = "^4.36.2" -numpy = "^1.26.1" +transformers = "^4.45.2" python-dotenv = "^1.0.0" -torch = "^2.2.2" # Issue with torch 2.3.0 and vision models - https://github.com/pytorch/pytorch/issues/121834 +torch = "^2.4.1" tqdm = "^4.66.1" tabulate = "^0.9.0" ftfy = "^6.1.1" -texify = "^0.1.10" +texify = "^0.2.0" rapidfuzz = "^3.8.1" -surya-ocr = "^0.5.0" +surya-ocr = "^0.6.3" filetype = "^1.2.0" regex = "^2024.4.28" -pdftext = "^0.3.10" -grpcio = "^1.63.0" +pdftext = "^0.3.13" +tabled-pdf = "^0.1.0" [tool.poetry.group.dev.dependencies] jupyter = "^1.0.0"