Skip to content

Commit

Permalink
feat(python): Reduce scan_csv() (and friends') memory usage when usin…
Browse files Browse the repository at this point in the history
…g BytesIO (#20649)

Co-authored-by: Itamar Turner-Trauring <[email protected]>
  • Loading branch information
itamarst and pythonspeed authored Jan 13, 2025
1 parent e346f82 commit ec03299
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 10 deletions.
2 changes: 1 addition & 1 deletion crates/polars-ops/src/chunked_array/gather/chunked.rs
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,7 @@ where
buffers.extend(data_buffers.iter().cloned());
v.insert(offset);
offset
}
},
};
buffer_offsets.push(offset);
}
Expand Down
24 changes: 16 additions & 8 deletions crates/polars-python/src/file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -322,14 +322,20 @@ pub fn get_python_scan_source_input(
write: bool,
) -> PyResult<PythonScanSourceInput> {
Python::with_gil(|py| {
let py_f_0 = py_f;
let py_f = py_f_0.clone_ref(py).into_bound(py);
let py_f = py_f.into_bound(py);

// CPython has some internal tricks that means much of the time
// BytesIO.getvalue() involves no memory copying, unlike
// BytesIO.read(). So we want to handle BytesIO specially in order
// to save memory.
let py_f = read_if_bytesio(py_f);

// If the pyobject is a `bytes` class
if let Ok(b) = py_f.downcast::<PyBytes>() {
return Ok(PythonScanSourceInput::Buffer(MemSlice::from_arc(
b.as_bytes(),
Arc::new(py_f_0),
// We want to specifically keep alive the PyBytes object.
Arc::new(b.clone().unbind()),
)));
}

Expand Down Expand Up @@ -373,15 +379,17 @@ pub fn get_file_like(f: PyObject, truncate: bool) -> PyResult<Box<dyn FileLike>>
Ok(get_either_file(f, truncate)?.into_dyn())
}

/// If the give file-like is a BytesIO, read its contents.
/// If the give file-like is a BytesIO, read its contents in a memory-efficient
/// way.
fn read_if_bytesio(py_f: Bound<PyAny>) -> Bound<PyAny> {
if py_f.getattr("read").is_ok() {
let bytes_io = py_f.py().import("io").unwrap().getattr("BytesIO").unwrap();
if py_f.is_instance(&bytes_io).unwrap() {
// Note that BytesIO has some memory optimizations ensuring that much of
// the time getvalue() doesn't need to copy the underlying data:
let Ok(bytes) = py_f.call_method0("getvalue") else {
return py_f;
};
if bytes.downcast::<PyBytes>().is_ok() || bytes.downcast::<PyString>().is_ok() {
return bytes.clone();
}
return bytes;
}
py_f
}
Expand Down
1 change: 1 addition & 0 deletions py-polars/src/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ unsafe impl<A: GlobalAlloc> GlobalAlloc for TracemallocAllocator<A> {
}

unsafe fn realloc(&self, ptr: *mut u8, layout: std::alloc::Layout, new_size: usize) -> *mut u8 {
PyTraceMalloc_Untrack(TRACEMALLOC_DOMAIN, ptr as uintptr_t);
let result = self.wrapped_alloc.realloc(ptr, layout, new_size);
PyTraceMalloc_Track(TRACEMALLOC_DOMAIN, result as uintptr_t, new_size);
result
Expand Down
11 changes: 10 additions & 1 deletion py-polars/tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import random
import string
import sys
import time
import tracemalloc
from typing import TYPE_CHECKING, Any, cast

Expand Down Expand Up @@ -205,7 +206,11 @@ def get_peak(self) -> int:
return tracemalloc.get_traced_memory()[1]


@pytest.fixture
# The bizarre syntax is from
# https://github.com/pytest-dev/pytest/issues/1368#issuecomment-2344450259 - we
# need to mark any test using this fixture as slow because we have a sleep
# added to work around a CPython bug, see the end of the function.
@pytest.fixture(params=[pytest.param(0, marks=pytest.mark.slow)])
def memory_usage_without_pyarrow() -> Generator[MemoryUsage, Any, Any]:
"""
Provide an API for measuring peak memory usage.
Expand All @@ -231,6 +236,10 @@ def memory_usage_without_pyarrow() -> Generator[MemoryUsage, Any, Any]:
try:
yield MemoryUsage()
finally:
# Workaround for https://github.com/python/cpython/issues/128679
time.sleep(1)
gc.collect()

tracemalloc.stop()


Expand Down
28 changes: 28 additions & 0 deletions py-polars/tests/unit/io/test_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

if TYPE_CHECKING:
from polars._typing import SchemaDict
from tests.unit.conftest import MemoryUsage


@dataclass
Expand Down Expand Up @@ -929,3 +930,30 @@ def test_predicate_stats_eval_nested_binary() -> None:
),
pl.DataFrame({"x": [2]}),
)


@pytest.mark.slow
@pytest.mark.parametrize("streaming", [True, False])
def test_scan_csv_bytesio_memory_usage(
streaming: bool,
memory_usage_without_pyarrow: MemoryUsage,
) -> None:
memory_usage = memory_usage_without_pyarrow

# Create CSV that is ~6-7 MB in size:
f = io.BytesIO()
df = pl.DataFrame({"mydata": pl.int_range(0, 1_000_000, eager=True)})
df.write_csv(f)
assert 6_000_000 < f.tell() < 7_000_000
f.seek(0, 0)

# A lazy scan shouldn't make a full copy of the data:
starting_memory = memory_usage.get_current()
assert (
pl.scan_csv(f)
.filter(pl.col("mydata") == 999_999)
.collect(new_streaming=streaming) # type: ignore[call-overload]
.item()
== 999_999
)
assert memory_usage.get_peak() - starting_memory < 1_000_000

0 comments on commit ec03299

Please sign in to comment.