Skip to content

Commit

Permalink
feat(python, rust): clearer message when stringcache-related errors o…
Browse files Browse the repository at this point in the history
…ccur (pola-rs#9715)
  • Loading branch information
MarcoGorelli authored Jul 5, 2023
1 parent e577509 commit ee9c589
Show file tree
Hide file tree
Showing 8 changed files with 44 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,7 @@ impl CategoricalChunked {
};

if is_local_different_source {
polars_bail!(
ComputeError:
"cannot concat categoricals coming from a different source; consider setting a global StringCache"
);
polars_bail!(string_cache_mismatch);
} else {
let len = self.len();
let new_rev_map = self.merge_categorical_map(other)?;
Expand Down
6 changes: 1 addition & 5 deletions polars/polars-core/src/frame/hash_join/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,7 @@ use crate::series::IsSorted;
#[cfg(feature = "dtype-categorical")]
pub fn _check_categorical_src(l: &DataType, r: &DataType) -> PolarsResult<()> {
if let (DataType::Categorical(Some(l)), DataType::Categorical(Some(r))) = (l, r) {
polars_ensure!(
l.same_src(r),
ComputeError: "joins/or comparisons on categoricals can only happen if they were \
created under the same global string cache"
);
polars_ensure!(l.same_src(r), string_cache_mismatch);
}
Ok(())
}
Expand Down
23 changes: 23 additions & 0 deletions polars/polars-error/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ pub enum PolarsError {
SchemaMismatch(ErrString),
#[error("lengths don't match: {0}")]
ShapeMismatch(ErrString),
#[error("string caches don't match: {0}")]
StringCacheMismatch(ErrString),
#[error("field not found: {0}")]
StructFieldNotFound(ErrString),
}
Expand Down Expand Up @@ -91,6 +93,7 @@ impl PolarsError {
SchemaFieldNotFound(msg) => SchemaFieldNotFound(func(msg).into()),
SchemaMismatch(msg) => SchemaMismatch(func(msg).into()),
ShapeMismatch(msg) => ShapeMismatch(func(msg).into()),
StringCacheMismatch(msg) => StringCacheMismatch(func(msg).into()),
StructFieldNotFound(msg) => StructFieldNotFound(func(msg).into()),
}
}
Expand Down Expand Up @@ -158,6 +161,26 @@ macro_rules! polars_err {
(unpack) => {
polars_err!(SchemaMismatch: "cannot unpack series, data types don't match")
};
(string_cache_mismatch) => {
polars_err!(StringCacheMismatch: r#"
cannot compare categoricals coming from different sources, consider setting a global StringCache.
Help: if you're using Python, this may look something like:
with pl.StringCache():
# Initialize Categoricals.
df1 = pl.DataFrame({'a': ['1', '2']}, schema={'a': pl.Categorical})
df2 = pl.DataFrame({'a': ['1', '3']}, schema={'a': pl.Categorical})
# Your operations go here.
pl.concat([df1, df2])
Alternatively, if the performance cost is acceptable, you could just set:
import polars as pl
pl.enable_string_cache(True)
on startup."#.trim_start())
};
(duplicate = $name:expr) => {
polars_err!(Duplicate: "column with name '{}' has more than one occurrences", $name)
};
Expand Down
5 changes: 5 additions & 0 deletions py-polars/polars/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
SchemaError,
SchemaFieldNotFoundError,
ShapeError,
StringCacheMismatchError,
StructFieldNotFoundError,
)
except ImportError:
Expand Down Expand Up @@ -43,6 +44,9 @@ class SchemaFieldNotFoundError(Exception): # type: ignore[no-redef]
class ShapeError(Exception): # type: ignore[no-redef]
"""Exception raised when trying to combine data structures with incompatible shapes.""" # noqa: W505

class StringCacheMismatchError(Exception): # type: ignore[no-redef]
"""Exception raised when string caches come from different sources."""

class StructFieldNotFoundError(Exception): # type: ignore[no-redef]
"""Exception raised when a specified schema field is not found."""

Expand Down Expand Up @@ -96,6 +100,7 @@ class ChronoFormatWarning(Warning):
"SchemaError",
"SchemaFieldNotFoundError",
"ShapeError",
"StringCacheMismatchError",
"StructFieldNotFoundError",
"TooManyRowsReturnedError",
]
4 changes: 4 additions & 0 deletions py-polars/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ impl std::convert::From<PyPolarsErr> for PyErr {
}
PolarsError::SchemaMismatch(err) => SchemaError::new_err(err.to_string()),
PolarsError::ShapeMismatch(err) => ShapeError::new_err(err.to_string()),
PolarsError::StringCacheMismatch(err) => {
StringCacheMismatchError::new_err(err.to_string())
}
PolarsError::StructFieldNotFound(name) => {
StructFieldNotFoundError::new_err(name.to_string())
}
Expand Down Expand Up @@ -75,6 +78,7 @@ create_exception!(exceptions, NoDataError, PyException);
create_exception!(exceptions, SchemaError, PyException);
create_exception!(exceptions, SchemaFieldNotFoundError, PyException);
create_exception!(exceptions, ShapeError, PyException);
create_exception!(exceptions, StringCacheMismatchError, PyException);
create_exception!(exceptions, StructFieldNotFoundError, PyException);

#[macro_export]
Expand Down
5 changes: 5 additions & 0 deletions py-polars/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,11 @@ fn polars(py: Python, m: &PyModule) -> PyResult<()> {
.unwrap();
m.add("ShapeError", py.get_type::<crate::error::ShapeError>())
.unwrap();
m.add(
"StringCacheMismatchError",
py.get_type::<crate::error::StringCacheMismatchError>(),
)
.unwrap();
m.add(
"StructFieldNotFoundError",
py.get_type::<StructFieldNotFoundError>(),
Expand Down
5 changes: 3 additions & 2 deletions py-polars/tests/unit/datatypes/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import polars as pl
from polars import StringCache
from polars.exceptions import StringCacheMismatchError
from polars.testing import assert_frame_equal


Expand Down Expand Up @@ -301,8 +302,8 @@ def test_err_on_categorical_asof_join_by_arg() -> None:
]
)
with pytest.raises(
pl.ComputeError,
match=r"joins/or comparisons on categoricals can only happen if they were created under the same global string cache",
StringCacheMismatchError,
match="cannot compare categoricals coming from different sources",
):
df1.join_asof(df2, on=pl.col("time").set_sorted(), by="cat")

Expand Down
3 changes: 2 additions & 1 deletion py-polars/tests/unit/test_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import polars as pl
from polars.config import _get_float_fmt
from polars.exceptions import StringCacheMismatchError
from polars.testing import assert_frame_equal

if TYPE_CHECKING:
Expand Down Expand Up @@ -481,7 +482,7 @@ def test_string_cache() -> None:

df1a = df1.with_columns(pl.col("a").cast(pl.Categorical))
df2a = df2.with_columns(pl.col("a").cast(pl.Categorical))
with pytest.raises(pl.ComputeError):
with pytest.raises(StringCacheMismatchError):
_ = df1a.join(df2a, on="a", how="inner")

# now turn on the cache
Expand Down

0 comments on commit ee9c589

Please sign in to comment.