From 0686926048de1238f99062affe235690c02668b5 Mon Sep 17 00:00:00 2001 From: Joeri de Ruiter Date: Mon, 21 Aug 2023 16:58:16 +0200 Subject: [PATCH] Type annotations for bookwyrm.importers --- bookwyrm/importers/calibre_import.py | 6 ++- bookwyrm/importers/importer.py | 51 ++++++++++++++++------- bookwyrm/importers/librarything_import.py | 20 ++++++--- bookwyrm/importers/openlibrary_import.py | 4 +- bookwyrm/models/import_job.py | 6 +-- mypy.ini | 3 ++ 6 files changed, 62 insertions(+), 28 deletions(-) diff --git a/bookwyrm/importers/calibre_import.py b/bookwyrm/importers/calibre_import.py index 5426e9333c..5c22a539df 100644 --- a/bookwyrm/importers/calibre_import.py +++ b/bookwyrm/importers/calibre_import.py @@ -1,4 +1,6 @@ """ handle reading a csv from calibre """ +from typing import Any, Optional + from bookwyrm.models import Shelf from . import Importer @@ -9,7 +11,7 @@ class CalibreImporter(Importer): service = "Calibre" - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): # Add timestamp to row_mappings_guesses for date_added to avoid # integrity error row_mappings_guesses = [] @@ -23,6 +25,6 @@ def __init__(self, *args, **kwargs): self.row_mappings_guesses = row_mappings_guesses super().__init__(*args, **kwargs) - def get_shelf(self, normalized_row): + def get_shelf(self, normalized_row: dict[str, Optional[str]]) -> Optional[str]: # Calibre export does not indicate which shelf to use. Use a default one for now return Shelf.TO_READ diff --git a/bookwyrm/importers/importer.py b/bookwyrm/importers/importer.py index 4c2abb5218..5b3192fa5b 100644 --- a/bookwyrm/importers/importer.py +++ b/bookwyrm/importers/importer.py @@ -1,8 +1,10 @@ """ handle reading a csv from an external service, defaults are from Goodreads """ import csv from datetime import timedelta +from typing import Iterable, Optional + from django.utils import timezone -from bookwyrm.models import ImportJob, ImportItem, SiteSettings +from bookwyrm.models import ImportJob, ImportItem, SiteSettings, User class Importer: @@ -35,19 +37,26 @@ class Importer: } # pylint: disable=too-many-locals - def create_job(self, user, csv_file, include_reviews, privacy): + def create_job( + self, user: User, csv_file: Iterable[str], include_reviews: bool, privacy: str + ) -> ImportJob: """check over a csv and creates a database entry for the job""" csv_reader = csv.DictReader(csv_file, delimiter=self.delimiter) rows = list(csv_reader) if len(rows) < 1: raise ValueError("CSV file is empty") - rows = enumerate(rows) + + mappings = ( + self.create_row_mappings(list(fieldnames)) + if (fieldnames := csv_reader.fieldnames) + else {} + ) job = ImportJob.objects.create( user=user, include_reviews=include_reviews, privacy=privacy, - mappings=self.create_row_mappings(csv_reader.fieldnames), + mappings=mappings, source=self.service, ) @@ -55,16 +64,20 @@ def create_job(self, user, csv_file, include_reviews, privacy): if enforce_limit and allowed_imports <= 0: job.complete_job() return job - for index, entry in rows: + for index, entry in enumerate(rows): if enforce_limit and index >= allowed_imports: break self.create_item(job, index, entry) return job - def update_legacy_job(self, job): + def update_legacy_job(self, job: ImportJob) -> None: """patch up a job that was in the old format""" items = job.items - headers = list(items.first().data.keys()) + first_item = items.first() + if first_item is None: + return + + headers = list(first_item.data.keys()) job.mappings = self.create_row_mappings(headers) job.updated_date = timezone.now() job.save() @@ -75,24 +88,24 @@ def update_legacy_job(self, job): item.normalized_data = normalized item.save() - def create_row_mappings(self, headers): + def create_row_mappings(self, headers: list[str]) -> dict[str, Optional[str]]: """guess what the headers mean""" mappings = {} for (key, guesses) in self.row_mappings_guesses: - value = [h for h in headers if h.lower() in guesses] - value = value[0] if len(value) else None + values = [h for h in headers if h.lower() in guesses] + value = values[0] if len(values) else None if value: headers.remove(value) mappings[key] = value return mappings - def create_item(self, job, index, data): + def create_item(self, job: ImportJob, index: int, data: dict[str, str]) -> None: """creates and saves an import item""" normalized = self.normalize_row(data, job.mappings) normalized["shelf"] = self.get_shelf(normalized) ImportItem(job=job, index=index, data=data, normalized_data=normalized).save() - def get_shelf(self, normalized_row): + def get_shelf(self, normalized_row: dict[str, Optional[str]]) -> Optional[str]: """determine which shelf to use""" shelf_name = normalized_row.get("shelf") if not shelf_name: @@ -103,11 +116,15 @@ def get_shelf(self, normalized_row): ] return shelf[0] if shelf else None - def normalize_row(self, entry, mappings): # pylint: disable=no-self-use + # pylint: disable=no-self-use + def normalize_row( + self, entry: dict[str, str], mappings: dict[str, Optional[str]] + ) -> dict[str, Optional[str]]: """use the dataclass to create the formatted row of data""" - return {k: entry.get(v) for k, v in mappings.items()} + return {k: entry.get(v) if v else None for k, v in mappings.items()} - def get_import_limit(self, user): # pylint: disable=no-self-use + # pylint: disable=no-self-use + def get_import_limit(self, user: User) -> tuple[int, int]: """check if import limit is set and return how many imports are left""" site_settings = SiteSettings.objects.get() import_size_limit = site_settings.import_size_limit @@ -125,7 +142,9 @@ def get_import_limit(self, user): # pylint: disable=no-self-use allowed_imports = import_size_limit - imported_books return enforce_limit, allowed_imports - def create_retry_job(self, user, original_job, items): + def create_retry_job( + self, user: User, original_job: ImportJob, items: list[ImportItem] + ) -> ImportJob: """retry items that didn't import""" job = ImportJob.objects.create( user=user, diff --git a/bookwyrm/importers/librarything_import.py b/bookwyrm/importers/librarything_import.py index ea31b46eb6..145657ba08 100644 --- a/bookwyrm/importers/librarything_import.py +++ b/bookwyrm/importers/librarything_import.py @@ -1,11 +1,16 @@ """ handle reading a tsv from librarything """ import re +from typing import Optional from bookwyrm.models import Shelf from . import Importer +def _remove_brackets(value: Optional[str]) -> Optional[str]: + return re.sub(r"\[|\]", "", value) if value else None + + class LibrarythingImporter(Importer): """csv downloads from librarything""" @@ -13,16 +18,19 @@ class LibrarythingImporter(Importer): delimiter = "\t" encoding = "ISO-8859-1" - def normalize_row(self, entry, mappings): # pylint: disable=no-self-use + def normalize_row( + self, entry: dict[str, str], mappings: dict[str, Optional[str]] + ) -> dict[str, Optional[str]]: # pylint: disable=no-self-use """use the dataclass to create the formatted row of data""" - remove_brackets = lambda v: re.sub(r"\[|\]", "", v) if v else None - normalized = {k: remove_brackets(entry.get(v)) for k, v in mappings.items()} - isbn_13 = normalized.get("isbn_13") - isbn_13 = isbn_13.split(", ") if isbn_13 else [] + normalized = { + k: _remove_brackets(entry.get(v) if v else None) + for k, v in mappings.items() + } + isbn_13 = value.split(", ") if (value := normalized.get("isbn_13")) else [] normalized["isbn_13"] = isbn_13[1] if len(isbn_13) > 1 else None return normalized - def get_shelf(self, normalized_row): + def get_shelf(self, normalized_row: dict[str, Optional[str]]) -> Optional[str]: if normalized_row["date_finished"]: return Shelf.READ_FINISHED if normalized_row["date_started"]: diff --git a/bookwyrm/importers/openlibrary_import.py b/bookwyrm/importers/openlibrary_import.py index ef10306091..6a954ed3c7 100644 --- a/bookwyrm/importers/openlibrary_import.py +++ b/bookwyrm/importers/openlibrary_import.py @@ -1,4 +1,6 @@ """ handle reading a csv from openlibrary""" +from typing import Any + from . import Importer @@ -7,7 +9,7 @@ class OpenLibraryImporter(Importer): service = "OpenLibrary" - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): self.row_mappings_guesses.append(("openlibrary_key", ["edition id"])) self.row_mappings_guesses.append(("openlibrary_work_key", ["work id"])) super().__init__(*args, **kwargs) diff --git a/bookwyrm/models/import_job.py b/bookwyrm/models/import_job.py index bb5144297c..8929e90376 100644 --- a/bookwyrm/models/import_job.py +++ b/bookwyrm/models/import_job.py @@ -54,10 +54,10 @@ def construct_search_term(title, author): class ImportJob(models.Model): """entry for a specific request for book data import""" - user = models.ForeignKey(User, on_delete=models.CASCADE) + user: User = models.ForeignKey(User, on_delete=models.CASCADE) created_date = models.DateTimeField(default=timezone.now) updated_date = models.DateTimeField(default=timezone.now) - include_reviews = models.BooleanField(default=True) + include_reviews: bool = models.BooleanField(default=True) mappings = models.JSONField() source = models.CharField(max_length=100) privacy = models.CharField(max_length=255, default="public", choices=PrivacyLevels) @@ -76,7 +76,7 @@ def start_job(self): self.save(update_fields=["task_id"]) - def complete_job(self): + def complete_job(self) -> None: """Report that the job has completed""" self.status = "complete" self.complete = True diff --git a/mypy.ini b/mypy.ini index 2a29e314f0..fe181e365c 100644 --- a/mypy.ini +++ b/mypy.ini @@ -13,6 +13,9 @@ implicit_reexport = True [mypy-bookwyrm.connectors.*] ignore_errors = False +[mypy-bookwyrm.importers.*] +ignore_errors = False + [mypy-celerywyrm.*] ignore_errors = False