Skip to content

Commit

Permalink
Fugue count matching rows (#294)
Browse files Browse the repository at this point in the history
* adding in count_matching_rows

* linting / cleanup
  • Loading branch information
fdosani authored Apr 30, 2024
1 parent bcedfdc commit 889235c
Show file tree
Hide file tree
Showing 7 changed files with 266 additions and 4 deletions.
1 change: 1 addition & 0 deletions datacompy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from datacompy.fugue import (
all_columns_match,
all_rows_overlap,
count_matching_rows,
intersect_columns,
is_match,
report,
Expand Down
96 changes: 95 additions & 1 deletion datacompy/fugue.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,101 @@ def all_rows_overlap(
return all(overlap)


def count_matching_rows(
df1: AnyDataFrame,
df2: AnyDataFrame,
join_columns: Union[str, List[str]],
abs_tol: float = 0,
rel_tol: float = 0,
df1_name: str = "df1",
df2_name: str = "df2",
ignore_spaces: bool = False,
ignore_case: bool = False,
cast_column_names_lower: bool = True,
parallelism: Optional[int] = None,
strict_schema: bool = False,
) -> int:
"""Count the number of rows match (on overlapping fields)
Parameters
----------
df1 : ``AnyDataFrame``
First dataframe to check
df2 : ``AnyDataFrame``
Second dataframe to check
join_columns : list or str, optional
Column(s) to join dataframes on. If a string is passed in, that one
column will be used.
abs_tol : float, optional
Absolute tolerance between two values.
rel_tol : float, optional
Relative tolerance between two values.
df1_name : str, optional
A string name for the first dataframe. This allows the reporting to
print out an actual name instead of "df1", and allows human users to
more easily track the dataframes.
df2_name : str, optional
A string name for the second dataframe
ignore_spaces : bool, optional
Flag to strip whitespace (including newlines) from string columns (including any join
columns)
ignore_case : bool, optional
Flag to ignore the case of string columns
cast_column_names_lower: bool, optional
Boolean indicator that controls of column names will be cast into lower case
parallelism: int, optional
An integer representing the amount of parallelism. Entering a value for this
will force to use of Fugue over just vanilla Pandas
strict_schema: bool, optional
The schema must match exactly if set to ``True``. This includes the names and types. Allows for a fast fail.
Returns
-------
int
Number of matching rows
"""
if (
isinstance(df1, pd.DataFrame)
and isinstance(df2, pd.DataFrame)
and parallelism is None # user did not specify parallelism
and fa.get_current_parallelism() == 1 # currently on a local execution engine
):
comp = Compare(
df1=df1,
df2=df2,
join_columns=join_columns,
abs_tol=abs_tol,
rel_tol=rel_tol,
df1_name=df1_name,
df2_name=df2_name,
ignore_spaces=ignore_spaces,
ignore_case=ignore_case,
cast_column_names_lower=cast_column_names_lower,
)
return comp.count_matching_rows()

try:
count_matching_rows = _distributed_compare(
df1=df1,
df2=df2,
join_columns=join_columns,
return_obj_func=lambda comp: comp.count_matching_rows(),
abs_tol=abs_tol,
rel_tol=rel_tol,
df1_name=df1_name,
df2_name=df2_name,
ignore_spaces=ignore_spaces,
ignore_case=ignore_case,
cast_column_names_lower=cast_column_names_lower,
parallelism=parallelism,
strict_schema=strict_schema,
)
except _StrictSchemaError:
return False

return sum(count_matching_rows)


def report(
df1: AnyDataFrame,
df2: AnyDataFrame,
Expand Down Expand Up @@ -460,7 +555,6 @@ def _any(col: str) -> int:
any_mismatch = len(match_sample) > 0

# Column Matching
cnt_intersect = shape0("intersect_rows_shape")
rpt += render(
"column_comparison.txt",
len([col for col in column_stats if col["unequal_cnt"] > 0]),
Expand Down
18 changes: 16 additions & 2 deletions tests/test_fugue/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
import numpy as np
import pandas as pd
import pytest


@pytest.fixture
Expand All @@ -24,7 +24,8 @@ def ref_df():
c=np.random.choice(["aaa", "b_c", "csd"], 100),
)
)
return [df1, df1_copy, df2, df3, df4]
df5 = df1.sample(frac=0.1)
return [df1, df1_copy, df2, df3, df4, df5]


@pytest.fixture
Expand Down Expand Up @@ -87,3 +88,16 @@ def large_diff_df2():
np.random.seed(0)
data = np.random.randint(6, 11, size=10000)
return pd.DataFrame({"x": data, "y": np.array([9] * 10000)}).convert_dtypes()


@pytest.fixture
def count_matching_rows_df():
np.random.seed(0)
df1 = pd.DataFrame(
dict(
a=np.arange(0, 100),
b=np.arange(0, 100),
)
)
df2 = df1.sample(frac=0.1)
return [df1, df2]
38 changes: 38 additions & 0 deletions tests/test_fugue/test_duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from datacompy import (
all_columns_match,
all_rows_overlap,
count_matching_rows,
intersect_columns,
is_match,
unq_columns,
Expand Down Expand Up @@ -138,3 +139,40 @@ def test_all_rows_overlap_duckdb(
duckdb.sql("SELECT 'a' AS a, 'b' AS b"),
join_columns="a",
)


def test_count_matching_rows_duckdb(count_matching_rows_df):
with duckdb.connect():
df1 = duckdb.from_df(count_matching_rows_df[0])
df1_copy = duckdb.from_df(count_matching_rows_df[0])
df2 = duckdb.from_df(count_matching_rows_df[1])

assert (
count_matching_rows(
df1,
df1_copy,
join_columns="a",
)
== 100
)
assert count_matching_rows(df1, df2, join_columns="a") == 10
# Fugue

assert (
count_matching_rows(
df1,
df1_copy,
join_columns="a",
parallelism=2,
)
== 100
)
assert (
count_matching_rows(
df1,
df2,
join_columns="a",
parallelism=2,
)
== 10
)
40 changes: 39 additions & 1 deletion tests/test_fugue/test_fugue_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
Compare,
all_columns_match,
all_rows_overlap,
count_matching_rows,
intersect_columns,
is_match,
report,
Expand Down Expand Up @@ -144,7 +145,6 @@ def test_report_pandas(

def test_unique_columns_native(ref_df):
df1 = ref_df[0]
df1_copy = ref_df[1]
df2 = ref_df[2]
df3 = ref_df[3]

Expand Down Expand Up @@ -192,3 +192,41 @@ def test_all_rows_overlap_native(
# Fugue
assert all_rows_overlap(ref_df[0], shuffle_df, join_columns="a", parallelism=2)
assert not all_rows_overlap(ref_df[0], ref_df[4], join_columns="a", parallelism=2)


def test_count_matching_rows_native(count_matching_rows_df):
# defaults to Compare class
assert (
count_matching_rows(
count_matching_rows_df[0],
count_matching_rows_df[0].copy(),
join_columns="a",
)
== 100
)
assert (
count_matching_rows(
count_matching_rows_df[0], count_matching_rows_df[1], join_columns="a"
)
== 10
)
# Fugue

assert (
count_matching_rows(
count_matching_rows_df[0],
count_matching_rows_df[0].copy(),
join_columns="a",
parallelism=2,
)
== 100
)
assert (
count_matching_rows(
count_matching_rows_df[0],
count_matching_rows_df[1],
join_columns="a",
parallelism=2,
)
== 10
)
35 changes: 35 additions & 0 deletions tests/test_fugue/test_fugue_polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from datacompy import (
all_columns_match,
all_rows_overlap,
count_matching_rows,
intersect_columns,
is_match,
unq_columns,
Expand Down Expand Up @@ -122,3 +123,37 @@ def test_all_rows_overlap_polars(
assert all_rows_overlap(rdf, rdf_copy, join_columns="a")
assert all_rows_overlap(rdf, sdf, join_columns="a")
assert not all_rows_overlap(rdf, rdf4, join_columns="a")


def test_count_matching_rows_polars(count_matching_rows_df):
df1 = pl.from_pandas(count_matching_rows_df[0])
df2 = pl.from_pandas(count_matching_rows_df[1])
assert (
count_matching_rows(
df1,
df1.clone(),
join_columns="a",
)
== 100
)
assert count_matching_rows(df1, df2, join_columns="a") == 10
# Fugue

assert (
count_matching_rows(
df1,
df1.clone(),
join_columns="a",
parallelism=2,
)
== 100
)
assert (
count_matching_rows(
df1,
df2,
join_columns="a",
parallelism=2,
)
== 10
)
42 changes: 42 additions & 0 deletions tests/test_fugue/test_fugue_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
Compare,
all_columns_match,
all_rows_overlap,
count_matching_rows,
intersect_columns,
is_match,
report,
Expand Down Expand Up @@ -200,3 +201,44 @@ def test_all_rows_overlap_spark(
spark_session.sql("SELECT 'a' AS a, 'b' AS b"),
join_columns="a",
)


def test_count_matching_rows_spark(spark_session, count_matching_rows_df):
count_matching_rows_df[0].iteritems = count_matching_rows_df[
0
].items # pandas 2 compatibility
count_matching_rows_df[1].iteritems = count_matching_rows_df[
1
].items # pandas 2 compatibility
df1 = spark_session.createDataFrame(count_matching_rows_df[0])
df1_copy = spark_session.createDataFrame(count_matching_rows_df[0])
df2 = spark_session.createDataFrame(count_matching_rows_df[1])
assert (
count_matching_rows(
df1,
df1_copy,
join_columns="a",
)
== 100
)
assert count_matching_rows(df1, df2, join_columns="a") == 10
# Fugue

assert (
count_matching_rows(
df1,
df1_copy,
join_columns="a",
parallelism=2,
)
== 100
)
assert (
count_matching_rows(
df1,
df2,
join_columns="a",
parallelism=2,
)
== 10
)

0 comments on commit 889235c

Please sign in to comment.