From c2b717dcf3be6ee57394ee42fc00d2f526dccdd4 Mon Sep 17 00:00:00 2001 From: Robbe Sneyders Date: Sat, 24 Feb 2024 11:45:32 +0100 Subject: [PATCH] Don't write metadata file (#875) We reintroduced writing the metadata file in #864 to preserve the divisions of the data when writing and reading again. We turned this behavior off in the past, but without proper documentation of the reason. I'm now running into issues with Dask workers dying when writing large datasets though, presumably because of the metadata file, as documented in these Dask issues: - https://github.com/dask/dask/issues/6600 - https://github.com/dask/dask/issues/3873 - https://github.com/dask/dask/issues/8901 Also, while I ran into issues with the preservation of divisions before, I can't reproduce this locally with a small example. Let's turn writing metadata off again and validate if we are still having issues with this. --- src/fondant/component/data_io.py | 5 +---- tests/component/test_data_io.py | 28 +++++++--------------------- 2 files changed, 8 insertions(+), 25 deletions(-) diff --git a/src/fondant/component/data_io.py b/src/fondant/component/data_io.py index 76d5e7a2..5eddefc3 100644 --- a/src/fondant/component/data_io.py +++ b/src/fondant/component/data_io.py @@ -5,7 +5,6 @@ import dask.dataframe as dd from dask.diagnostics import ProgressBar -from dask.distributed import Client from fondant.core.component_spec import OperationSpec from fondant.core.manifest import Manifest @@ -157,7 +156,6 @@ def __init__( def write_dataframe( self, dataframe: dd.DataFrame, - dask_client: t.Optional[Client] = None, ) -> None: dataframe.index = dataframe.index.rename(DEFAULT_INDEX_NAME) @@ -176,7 +174,7 @@ def write_dataframe( with ProgressBar(): logging.info("Writing data...") - dd.compute(write_task, scheduler=dask_client) + dd.compute(write_task) @staticmethod def validate_dataframe_columns(dataframe: dd.DataFrame, columns: t.List[str]): @@ -234,7 +232,6 @@ def _create_write_task( schema=schema, overwrite=False, compute=False, - write_metadata_file=True, ) logging.info(f"Creating write task for: {location}") return write_task diff --git a/tests/component/test_data_io.py b/tests/component/test_data_io.py index ec733d3c..905269a6 100644 --- a/tests/component/test_data_io.py +++ b/tests/component/test_data_io.py @@ -4,7 +4,6 @@ import dask.dataframe as dd import pyarrow as pa import pytest -from dask.distributed import Client from fondant.component.data_io import DaskDataLoader, DaskDataWriter from fondant.core.component_spec import ComponentSpec, OperationSpec from fondant.core.manifest import Manifest @@ -21,13 +20,6 @@ NUMBER_OF_TEST_ROWS = 151 -@pytest.fixture() -def dask_client(): # noqa: PT004 - client = Client() - yield - client.close() - - @pytest.fixture() def manifest(): return Manifest.from_file(manifest_path) @@ -121,7 +113,6 @@ def test_write_dataset( dataframe, manifest, component_spec, - dask_client, ): """Test writing out subsets.""" # Dictionary specifying the expected subsets to write and their column names @@ -134,7 +125,7 @@ def test_write_dataset( operation_spec=OperationSpec(component_spec), ) # write dataframe to temp dir - data_writer.write_dataframe(dataframe, dask_client) + data_writer.write_dataframe(dataframe) # read written data and assert dataframe = dd.read_parquet( temp_dir @@ -152,7 +143,6 @@ def test_write_dataset_custom_produces( dataframe, manifest, component_spec_produces, - dask_client, ): """Test writing out subsets.""" produces = { @@ -175,7 +165,7 @@ def test_write_dataset_custom_produces( ) # write dataframe to temp dir - data_writer.write_dataframe(dataframe, dask_client) + data_writer.write_dataframe(dataframe) # # read written data and assert dataframe = dd.read_parquet( temp_dir @@ -194,7 +184,6 @@ def test_write_reset_index( dataframe, manifest, component_spec, - dask_client, ): """Test writing out the index and fields that have no dask index and checking if the id index was created. @@ -207,19 +196,18 @@ def test_write_reset_index( manifest=manifest, operation_spec=OperationSpec(component_spec), ) - data_writer.write_dataframe(dataframe, dask_client) + data_writer.write_dataframe(dataframe) dataframe = dd.read_parquet(fn) assert dataframe.index.name == "id" @pytest.mark.parametrize("partitions", list(range(1, 5))) -def test_write_divisions( # noqa: PLR0913 +def test_write_divisions( tmp_path_factory, dataframe, manifest, component_spec, partitions, - dask_client, ): """Test writing out index and subsets and asserting they have the divisions of the dataframe.""" # repartition the dataframe (default is 3 partitions) @@ -233,7 +221,7 @@ def test_write_divisions( # noqa: PLR0913 operation_spec=OperationSpec(component_spec), ) - data_writer.write_dataframe(dataframe, dask_client) + data_writer.write_dataframe(dataframe) dataframe = dd.read_parquet(fn) assert dataframe.index.name == "id" @@ -245,7 +233,6 @@ def test_write_fields_invalid( dataframe, manifest, component_spec, - dask_client, ): """Test writing out fields but the dataframe columns are incomplete.""" with tmp_path_factory.mktemp("temp") as fn: @@ -262,7 +249,7 @@ def test_write_fields_invalid( r"but not found in dataframe" ) with pytest.raises(ValueError, match=expected_error_msg): - data_writer.write_dataframe(dataframe, dask_client) + data_writer.write_dataframe(dataframe) def test_write_fields_invalid_several_fields_missing( @@ -270,7 +257,6 @@ def test_write_fields_invalid_several_fields_missing( dataframe, manifest, component_spec, - dask_client, ): """Test writing out fields but the dataframe columns are incomplete.""" with tmp_path_factory.mktemp("temp") as fn: @@ -288,4 +274,4 @@ def test_write_fields_invalid_several_fields_missing( r"but not found in dataframe" ) with pytest.raises(ValueError, match=expected_error_msg): - data_writer.write_dataframe(dataframe, dask_client) + data_writer.write_dataframe(dataframe)