Skip to content

Commit

Permalink
Dataframe equality (pola-rs#4076)
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego authored Jul 27, 2022
1 parent 08dd52c commit 8fc5e86
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 1 deletion.
80 changes: 80 additions & 0 deletions py-polars/polars/internals/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1553,6 +1553,86 @@ def to_numpy(self) -> np.ndarray:
else:
return out

def _comp(
self: DF, other: Any, op: Literal["eq", "neq", "gt", "lt", "gt_eq", "lt_eq"]
) -> DF:
"""Compare a DataFrame with another object."""
if isinstance(other, DataFrame):
return self._compare_to_other_df(other, op)
else:
return self._compare_to_non_df(other, op)

def _compare_to_other_df(
self: DF,
other: DataFrame,
op: Literal["eq", "neq", "gt", "lt", "gt_eq", "lt_eq"],
) -> DF:
"""Compare a DataFrame with another DataFrame."""
if self.columns != other.columns:
raise ValueError("DataFrame columns do not match")
if self.shape != other.shape:
raise ValueError("DataFrame dimensions do not match")

suffix = "__POLARS_CMP_OTHER"
other_renamed = other.select(pli.all().suffix(suffix))
combined = pli.concat([self, other_renamed], how="horizontal")

if op == "eq":
expr = [pli.col(n) == pli.col(f"{n}{suffix}") for n in self.columns]
elif op == "neq":
expr = [pli.col(n) != pli.col(f"{n}{suffix}") for n in self.columns]
elif op == "gt":
expr = [pli.col(n) > pli.col(f"{n}{suffix}") for n in self.columns]
elif op == "lt":
expr = [pli.col(n) < pli.col(f"{n}{suffix}") for n in self.columns]
elif op == "gt_eq":
expr = [pli.col(n) >= pli.col(f"{n}{suffix}") for n in self.columns]
elif op == "lt_eq":
expr = [pli.col(n) <= pli.col(f"{n}{suffix}") for n in self.columns]
else:
raise ValueError(f"got unexpected comparison operator: {op}")

return combined.select(expr) # type: ignore[return-value]

def _compare_to_non_df(
self: DF,
other: Any,
op: Literal["eq", "neq", "gt", "lt", "gt_eq", "lt_eq"],
) -> DF:
"""Compare a DataFrame with a non-DataFrame object."""
if op == "eq":
return self.select(pli.all() == other)
elif op == "neq":
return self.select(pli.all() != other)
elif op == "gt":
return self.select(pli.all() > other)
elif op == "lt":
return self.select(pli.all() < other)
elif op == "gt_eq":
return self.select(pli.all() >= other)
elif op == "lt_eq":
return self.select(pli.all() <= other)
else:
raise ValueError(f"got unexpected comparison operator: {op}")

def __eq__(self: DF, other: Any) -> DF: # type: ignore[override]
return self._comp(other, "eq")

def __ne__(self: DF, other: Any) -> DF: # type: ignore[override]
return self._comp(other, "neq")

def __gt__(self: DF, other: Any) -> DF:
return self._comp(other, "gt")

def __lt__(self: DF, other: Any) -> DF:
return self._comp(other, "lt")

def __ge__(self: DF, other: Any) -> DF:
return self._comp(other, "gt_eq")

def __le__(self: DF, other: Any) -> DF:
return self._comp(other, "lt_eq")

def __getstate__(self) -> list[pli.Series]:
return self.get_columns()

Expand Down
4 changes: 3 additions & 1 deletion py-polars/polars/internals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,9 @@ def __xor__(self, other: Series) -> Series:
def __rxor__(self, other: Series) -> Series:
return self.__xor__(other)

def _comp(self, other: Any, op: str) -> Series:
def _comp(
self, other: Any, op: Literal["eq", "neq", "gt", "lt", "gt_eq", "lt_eq"]
) -> Series:
if isinstance(other, datetime) and self.dtype == Datetime:
ts = _datetime_to_pl_timestamp(other, self.time_unit)
f = get_ffi_func(op + "_<>", Int64, self._s)
Expand Down
38 changes: 38 additions & 0 deletions py-polars/tests/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,44 @@ def test_special_char_colname_init() -> None:
assert df.is_empty()


def test_comparisons() -> None:
df = pl.DataFrame({"a": [1, 2], "b": [3, 4]})

# Constants
assert_frame_equal(df == 2, pl.DataFrame({"a": [False, True], "b": [False, False]}))
assert_frame_equal(df != 2, pl.DataFrame({"a": [True, False], "b": [True, True]}))
assert_frame_equal(df < 3.0, pl.DataFrame({"a": [True, True], "b": [False, False]}))
assert_frame_equal(df >= 2, pl.DataFrame({"a": [False, True], "b": [True, True]}))
assert_frame_equal(df <= 2, pl.DataFrame({"a": [True, True], "b": [False, False]}))

with pytest.raises(pl.ComputeError):
df > "2" # noqa: B015

# Series
s = pl.Series([3, 1])
assert_frame_equal(df >= s, pl.DataFrame({"a": [False, True], "b": [True, True]}))

# DataFrame
other = pl.DataFrame({"a": [1, 2], "b": [2, 3]})
assert_frame_equal(
df == other, pl.DataFrame({"a": [True, True], "b": [False, False]})
)

# DataFrame columns mismatch
with pytest.raises(ValueError):
df == pl.DataFrame({"a": [1, 2], "c": [3, 4]}) # noqa: B015
with pytest.raises(ValueError):
df == pl.DataFrame({"b": [3, 4], "a": [1, 2]}) # noqa: B015

# DataFrame shape mismatch
with pytest.raises(ValueError):
df == pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) # noqa: B015

# Type mismatch
with pytest.raises(pl.ComputeError):
df == pl.DataFrame({"a": [1, 2], "b": ["x", "y"]}) # noqa: B015


def test_selection() -> None:
df = pl.DataFrame({"a": [1, 2, 3], "b": [1.0, 2.0, 3.0], "c": ["a", "b", "c"]})

Expand Down

0 comments on commit 8fc5e86

Please sign in to comment.