From 936c82b0fa3a5b4bfa8c4912b97fbb4066b3016b Mon Sep 17 00:00:00 2001 From: codingl2k1 <138426806+codingl2k1@users.noreply.github.com> Date: Mon, 9 Oct 2023 11:46:22 +0800 Subject: [PATCH] BUG: Fix read_csv with index_col (#736) Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> --- python/xorbits/_mars/dataframe/datasource/read_csv.py | 9 ++++++--- .../datasource/tests/test_datasource_execution.py | 4 ++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/python/xorbits/_mars/dataframe/datasource/read_csv.py b/python/xorbits/_mars/dataframe/datasource/read_csv.py index b0ea4666e..af3913a2b 100644 --- a/python/xorbits/_mars/dataframe/datasource/read_csv.py +++ b/python/xorbits/_mars/dataframe/datasource/read_csv.py @@ -103,6 +103,7 @@ class DataFrameReadCSV( sep = StringField("sep") header = AnyField("header") index_col = Int32Field("index_col") + index_names = ListField("index_names") skiprows = Int32Field("skiprows") compression = StringField("compression") usecols = AnyField("usecols") @@ -266,6 +267,7 @@ def _pandas_read_csv(cls, f, op): nrows=op.nrows, **csv_kwargs, ) + df.index.names = op.index_names if op.keep_usecols_order: df = df[op.usecols] return df @@ -352,7 +354,7 @@ def read_csv( path: str, names: Union[List, Tuple] = None, sep: str = ",", - index_col: int = None, + index_col: Union[int, str, List[int], List[str]] = None, compression: str = None, header: Union[str, List] = "infer", dtype: Union[str, Dict] = None, @@ -709,8 +711,8 @@ def read_csv( else: index_value = parse_index(mini_df.index) columns_value = parse_index(mini_df.columns, store_data=True) - if index_col and not isinstance(index_col, int): - index_col = list(mini_df.columns).index(index_col) + # Set names and index_col may lose multiindex names, so we have to fix it. + index_names = mini_df.index.names # convert path to abs_path abs_path = convert_to_abspath(path, storage_options) @@ -721,6 +723,7 @@ def read_csv( sep=sep, header=header, index_col=index_col, + index_names=index_names, usecols=usecols, skiprows=skiprows, compression=compression, diff --git a/python/xorbits/_mars/dataframe/datasource/tests/test_datasource_execution.py b/python/xorbits/_mars/dataframe/datasource/tests/test_datasource_execution.py index 8b89220b8..28375d15b 100644 --- a/python/xorbits/_mars/dataframe/datasource/tests/test_datasource_execution.py +++ b/python/xorbits/_mars/dataframe/datasource/tests/test_datasource_execution.py @@ -606,6 +606,10 @@ def test_read_csv_execution(setup): mdf2 = md.read_csv(file_path, index_col=0, chunk_bytes=100).execute().fetch() pd.testing.assert_frame_equal(pdf, mdf2) + mdf3 = md.read_csv(file_path, index_col=[0, 1]).execute().fetch() + pdf3 = pd.read_csv(file_path, index_col=[0, 1]) + pd.testing.assert_frame_equal(pdf3, mdf3) + # test nan with tempfile.TemporaryDirectory() as tempdir: file_path = os.path.join(tempdir, "test.csv")