Skip to content

Commit

Permalink
BUG: Fix read_csv with index_col (#736)
Browse files Browse the repository at this point in the history
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
codingl2k1 and mergify[bot] authored Oct 9, 2023
1 parent 771858d commit 936c82b
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
9 changes: 6 additions & 3 deletions python/xorbits/_mars/dataframe/datasource/read_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ class DataFrameReadCSV(
sep = StringField("sep")
header = AnyField("header")
index_col = Int32Field("index_col")
index_names = ListField("index_names")
skiprows = Int32Field("skiprows")
compression = StringField("compression")
usecols = AnyField("usecols")
Expand Down Expand Up @@ -266,6 +267,7 @@ def _pandas_read_csv(cls, f, op):
nrows=op.nrows,
**csv_kwargs,
)
df.index.names = op.index_names
if op.keep_usecols_order:
df = df[op.usecols]
return df
Expand Down Expand Up @@ -352,7 +354,7 @@ def read_csv(
path: str,
names: Union[List, Tuple] = None,
sep: str = ",",
index_col: int = None,
index_col: Union[int, str, List[int], List[str]] = None,
compression: str = None,
header: Union[str, List] = "infer",
dtype: Union[str, Dict] = None,
Expand Down Expand Up @@ -709,8 +711,8 @@ def read_csv(
else:
index_value = parse_index(mini_df.index)
columns_value = parse_index(mini_df.columns, store_data=True)
if index_col and not isinstance(index_col, int):
index_col = list(mini_df.columns).index(index_col)
# Set names and index_col may lose multiindex names, so we have to fix it.
index_names = mini_df.index.names

# convert path to abs_path
abs_path = convert_to_abspath(path, storage_options)
Expand All @@ -721,6 +723,7 @@ def read_csv(
sep=sep,
header=header,
index_col=index_col,
index_names=index_names,
usecols=usecols,
skiprows=skiprows,
compression=compression,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,10 @@ def test_read_csv_execution(setup):
mdf2 = md.read_csv(file_path, index_col=0, chunk_bytes=100).execute().fetch()
pd.testing.assert_frame_equal(pdf, mdf2)

mdf3 = md.read_csv(file_path, index_col=[0, 1]).execute().fetch()
pdf3 = pd.read_csv(file_path, index_col=[0, 1])
pd.testing.assert_frame_equal(pdf3, mdf3)

# test nan
with tempfile.TemporaryDirectory() as tempdir:
file_path = os.path.join(tempdir, "test.csv")
Expand Down

0 comments on commit 936c82b

Please sign in to comment.