From 742961de19e6884824b09ed7bdf3eaf19b8b11c5 Mon Sep 17 00:00:00 2001 From: Mark Jacobson <52427991+marksparkza@users.noreply.github.com> Date: Tue, 24 Oct 2023 13:46:57 +0200 Subject: [PATCH] Get catalog record by DOI case-insensitively Fixes SAEON/mims-catalog#2 --- odp/api/routers/catalog.py | 7 +++---- test/api/test_catalog.py | 6 +++--- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/odp/api/routers/catalog.py b/odp/api/routers/catalog.py index dedb7b1..9377afc 100644 --- a/odp/api/routers/catalog.py +++ b/odp/api/routers/catalog.py @@ -22,7 +22,7 @@ from odp.api.models import CatalogModel, PublishedDataCiteRecordModel, PublishedSAEONRecordModel, RetractedRecordModel, SearchResult from odp.const import DOI_REGEX, ODPCatalog, ODPScope from odp.db import Session -from odp.db.models import Catalog, CatalogRecord, CatalogRecordFacet, PublishedRecord +from odp.db.models import Catalog, CatalogRecord, CatalogRecordFacet, PublishedRecord, Record from odp.lib.datacite import DataciteClient, DataciteError router = APIRouter() @@ -282,9 +282,8 @@ async def get_catalog_record_by_id_or_doi( except ValueError: if re.match(DOI_REGEX, record_id_or_doi): - stmt = stmt.where(CatalogRecord.published_record.comparator.contains({ - 'doi': record_id_or_doi - })) + stmt = stmt.join(Record) + stmt = stmt.where(func.lower(Record.doi) == record_id_or_doi.lower()) else: raise HTTPException(HTTP_422_UNPROCESSABLE_ENTITY, 'Invalid record identifier: expecting a UUID or DOI') diff --git a/test/api/test_catalog.py b/test/api/test_catalog.py index 692ba04..e11e905 100644 --- a/test/api/test_catalog.py +++ b/test/api/test_catalog.py @@ -242,7 +242,7 @@ def check_metadata_record(schema_id, deep=True): route = f'/catalog/{catalog_id}/records' resp_code = 200 if endpoint == 'get': - route += f'/{example_record.doi}' if example_record.doi else f'/{example_record.id}' + route += f'/{example_record.doi.swapcase()}' if example_record.doi else f'/{example_record.id}' resp_code = 200 if published else 404 r = api(scopes, create_scopes=False).get(route) @@ -319,7 +319,7 @@ def test_get_published_metadata_value( ) route = f'/catalog/{catalog_id}/getvalue/' - route += example_record.doi if example_record.doi else example_record.id + route += example_record.doi.swapcase() if example_record.doi else example_record.id r = api(scopes, create_scopes=False).get(route, params=dict( schema_id=schema_id, @@ -351,7 +351,7 @@ def test_get_published_metadata_document( ) route = f'/catalog/{catalog_id}/getvalue/' - route += example_record.doi if example_record.doi else example_record.id + route += example_record.doi.swapcase() if example_record.doi else example_record.id r = api(scopes, create_scopes=False).get(route, params=dict( schema_id=schema_id,