From 0835bf159569eec692365d03f5fd77dfbcf994d3 Mon Sep 17 00:00:00 2001 From: Ivan Longin Date: Mon, 23 Sep 2024 15:59:48 +0200 Subject: [PATCH] Removed `Catalog.open_object()` and refactor method to return file object from row (#467) * removed catalog.open_object and refactor method to return file objects from row * removed not used method --- src/datachain/catalog/catalog.py | 68 +++++++++----------------------- tests/func/test_catalog.py | 54 ++++++++++--------------- 2 files changed, 39 insertions(+), 83 deletions(-) diff --git a/src/datachain/catalog/catalog.py b/src/datachain/catalog/catalog.py index f9a1bf326..33764b538 100644 --- a/src/datachain/catalog/catalog.py +++ b/src/datachain/catalog/catalog.py @@ -79,6 +79,7 @@ ) from datachain.dataset import DatasetVersion from datachain.job import Job + from datachain.lib.file import File logger = logging.getLogger("datachain") @@ -1399,65 +1400,34 @@ def edit_dataset( dataset = self.get_dataset(name) return self.update_dataset(dataset, **update_data) - def get_file_signals( - self, dataset_name: str, dataset_version: int, row: RowDict - ) -> Optional[RowDict]: + def get_file_from_row( + self, dataset_name: str, dataset_version: int, row: RowDict, signal_name: str + ) -> "File": """ - Function that returns file signals from dataset row. - Note that signal names are without prefix, so if there was 'laion__file__source' - in original row, result will have just 'source' - Example output: - { - "source": "s3://ldb-public", - "path": "animals/dogs/dog.jpg", - ... - } + Function that returns specific file signal from dataset row by name. """ from datachain.lib.file import File from datachain.lib.signal_schema import DEFAULT_DELIMITER, SignalSchema version = self.get_dataset(dataset_name).get_version(dataset_version) - - file_signals_values = RowDict() - schema = SignalSchema.deserialize(version.feature_schema) - for file_signals in schema.get_signals(File): - prefix = file_signals.replace(".", DEFAULT_DELIMITER) + DEFAULT_DELIMITER - file_signals_values[file_signals] = { - c_name.removeprefix(prefix): c_value - for c_name, c_value in row.items() - if c_name.startswith(prefix) - and DEFAULT_DELIMITER not in c_name.removeprefix(prefix) - } - if not file_signals_values: - return None - - # there can be multiple file signals in a schema, but taking the first - # one for now. In future we might add ability to choose from which one - # to open object - return next(iter(file_signals_values.values())) - - def open_object( - self, - dataset_name: str, - dataset_version: int, - row: RowDict, - use_cache: bool = True, - **config: Any, - ): - from datachain.lib.file import File + if signal_name not in schema.get_signals(File): + raise RuntimeError( + f"File signal with path {signal_name} not found in ", + f"dataset {dataset_name}@v{dataset_version} signals schema", + ) - file_signals = self.get_file_signals(dataset_name, dataset_version, row) - if not file_signals: - raise RuntimeError("Cannot open object without file signals") + prefix = signal_name.replace(".", DEFAULT_DELIMITER) + DEFAULT_DELIMITER + file_signals = { + c_name.removeprefix(prefix): c_value + for c_name, c_value in row.items() + if c_name.startswith(prefix) + and DEFAULT_DELIMITER not in c_name.removeprefix(prefix) + and c_name.removeprefix(prefix) in File.model_fields + } - config = config or self.client_config - client = self.get_client(file_signals["source"], **config) - return client.open_object( - File._from_row(file_signals), - use_cache=use_cache, - ) + return File(**file_signals) def ls( self, diff --git a/tests/func/test_catalog.py b/tests/func/test_catalog.py index 29eebc0d2..25eacf581 100644 --- a/tests/func/test_catalog.py +++ b/tests/func/test_catalog.py @@ -847,7 +847,7 @@ def test_garbage_collect(cloud_test_catalog, from_cli, capsys): assert catalog.get_temp_table_names() == [] -def test_get_file_signals(cloud_test_catalog, dogs_dataset): +def test_get_file_from_row(cloud_test_catalog, dogs_dataset): catalog = cloud_test_catalog.catalog catalog.metastore.update_dataset_version( dogs_dataset, @@ -863,18 +863,22 @@ def test_get_file_signals(cloud_test_catalog, dogs_dataset): "name": "Jon", "age": 25, "f1__source": "s3://first_bucket", - "f1__name": "image1.jpg", + "f1__path": "image1.jpg", "f2__source": "s3://second_bucket", - "f2__name": "image2.jpg", + "f2__path": "image2.jpg", } - assert catalog.get_file_signals(dogs_dataset.name, 1, row) == { - "source": "s3://first_bucket", - "name": "image1.jpg", - } + assert catalog.get_file_from_row(dogs_dataset.name, 1, row, "f1") == File( + source="s3://first_bucket", + path="image1.jpg", + ) + assert catalog.get_file_from_row(dogs_dataset.name, 1, row, "f2") == File( + source="s3://second_bucket", + path="image2.jpg", + ) -def test_get_file_signals_with_custom_types(cloud_test_catalog, dogs_dataset): +def test_get_file_from_row_with_custom_types(cloud_test_catalog, dogs_dataset): catalog = cloud_test_catalog.catalog catalog.metastore.update_dataset_version( dogs_dataset, @@ -885,7 +889,7 @@ def test_get_file_signals_with_custom_types(cloud_test_catalog, dogs_dataset): "f1": "File@v1", "f2": "File@v1", "_custom_types": { - "File@v1": {"source": "str", "name": "str"}, + "File@v1": {"source": "str", "path": "str"}, }, }, ) @@ -893,36 +897,18 @@ def test_get_file_signals_with_custom_types(cloud_test_catalog, dogs_dataset): "name": "Jon", "age": 25, "f1__source": "s3://first_bucket", - "f1__name": "image1.jpg", + "f1__path": "image1.jpg", "f2__source": "s3://second_bucket", - "f2__name": "image2.jpg", - } - - assert catalog.get_file_signals(dogs_dataset.name, 1, row) == { - "source": "s3://first_bucket", - "name": "image1.jpg", + "f2__path": "image2.jpg", } - -def test_get_file_signals_no_signals(cloud_test_catalog, dogs_dataset): - catalog = cloud_test_catalog.catalog - catalog.metastore.update_dataset_version( - dogs_dataset, - 1, - feature_schema={ - "name": "str", - "age": "str", - }, + assert catalog.get_file_from_row(dogs_dataset.name, 1, row, "f1") == File( + source="s3://first_bucket", + path="image1.jpg", ) - row = { - "name": "Jon", - "age": 25, - } - - assert catalog.get_file_signals(dogs_dataset.name, 1, row) is None -def test_open_object_no_file_signals(cloud_test_catalog, dogs_dataset): +def test_get_file_from_row_no_signals(cloud_test_catalog, dogs_dataset): catalog = cloud_test_catalog.catalog catalog.metastore.update_dataset_version( dogs_dataset, @@ -938,4 +924,4 @@ def test_open_object_no_file_signals(cloud_test_catalog, dogs_dataset): } with pytest.raises(RuntimeError): - assert catalog.open_object(dogs_dataset.name, 1, row) + assert catalog.get_file_from_row(dogs_dataset.name, 1, row, "missing")