Skip to content

Commit

Permalink
feat(python)!: Return instantiated DataType objects in schema/`dtyp…
Browse files Browse the repository at this point in the history
…e` methods (pola-rs#12470)
  • Loading branch information
stinodego authored Dec 3, 2023
1 parent 433b87d commit 6491844
Show file tree
Hide file tree
Showing 20 changed files with 217 additions and 124 deletions.
6 changes: 3 additions & 3 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@
import deltalake
from xlsxwriter import Workbook

from polars import Expr, LazyFrame, Series
from polars import DataType, Expr, LazyFrame, Series
from polars.interchange.dataframe import PolarsDataFrame
from polars.type_aliases import (
AsofJoinStrategy,
Expand Down Expand Up @@ -1206,7 +1206,7 @@ def columns(self, names: Sequence[str]) -> None:
self._df.set_column_names(names)

@property
def dtypes(self) -> list[PolarsDataType]:
def dtypes(self) -> list[DataType]:
"""
Get the datatypes of the columns of this DataFrame.
Expand Down Expand Up @@ -1255,7 +1255,7 @@ def flags(self) -> dict[str, dict[str, bool]]:
return {name: self[name].flags for name in self.columns}

@property
def schema(self) -> SchemaDict:
def schema(self) -> OrderedDict[str, DataType]:
"""
Get a dict[column name, DataType].
Expand Down
19 changes: 12 additions & 7 deletions py-polars/polars/datatypes/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,19 +77,24 @@ def is_nested(self) -> bool: # noqa: D102
class DataType(metaclass=DataTypeClass):
"""Base class for all Polars data types."""

def __new__(cls, *args: Any, **kwargs: Any) -> PolarsDataType: # type: ignore[misc] # noqa: D102
# this formulation allows for equivalent use of "pl.Type" and "pl.Type()", while
# still respecting types that take initialisation params (eg: Duration/Datetime)
if args or kwargs:
return super().__new__(cls)
return cls

def __reduce__(self) -> Any:
return (_custom_reconstruct, (type(self), object, None), self.__dict__)

def _string_repr(self) -> str:
return _dtype_str_repr(self)

def __eq__(self, other: PolarsDataType) -> bool: # type: ignore[override]
if type(other) is DataTypeClass:
return issubclass(other, type(self))
else:
return isinstance(other, type(self))

def __hash__(self) -> int:
return hash(self.__class__)

def __repr__(self) -> str:
return self.__class__.__name__

@classmethod
def base_type(cls) -> DataTypeClass:
"""
Expand Down
5 changes: 2 additions & 3 deletions py-polars/polars/io/delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
from polars.io.pyarrow_dataset import scan_pyarrow_dataset

if TYPE_CHECKING:
from polars import DataFrame, LazyFrame
from polars.type_aliases import PolarsDataType
from polars import DataFrame, DataType, LazyFrame


def read_delta(
Expand Down Expand Up @@ -320,7 +319,7 @@ def _check_if_delta_available() -> None:
)


def _check_for_unsupported_types(dtypes: list[PolarsDataType]) -> None:
def _check_for_unsupported_types(dtypes: list[DataType]) -> None:
schema_dtypes = unpack_dtypes(*dtypes)
unsupported_types = {Time, Categorical, Null}
overlap = schema_dtypes & unsupported_types
Expand Down
5 changes: 2 additions & 3 deletions py-polars/polars/io/ipc/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
if TYPE_CHECKING:
from io import BytesIO

from polars import DataFrame, LazyFrame
from polars.type_aliases import PolarsDataType
from polars import DataFrame, DataType, LazyFrame


def read_ipc(
Expand Down Expand Up @@ -185,7 +184,7 @@ def read_ipc_stream(
)


def read_ipc_schema(source: str | BinaryIO | Path | bytes) -> dict[str, PolarsDataType]:
def read_ipc_schema(source: str | BinaryIO | Path | bytes) -> dict[str, DataType]:
"""
Get the schema of an IPC file without reading data.
Expand Down
6 changes: 3 additions & 3 deletions py-polars/polars/io/parquet/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
if TYPE_CHECKING:
from io import BytesIO

from polars import DataFrame, LazyFrame
from polars.type_aliases import ParallelStrategy, PolarsDataType
from polars import DataFrame, DataType, LazyFrame
from polars.type_aliases import ParallelStrategy


def read_parquet(
Expand Down Expand Up @@ -143,7 +143,7 @@ def read_parquet(

def read_parquet_schema(
source: str | BinaryIO | Path | bytes,
) -> dict[str, PolarsDataType]:
) -> dict[str, DataType]:
"""
Get the schema of a Parquet file without reading data.
Expand Down
6 changes: 3 additions & 3 deletions py-polars/polars/lazyframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@

import pyarrow as pa

from polars import DataFrame, Expr
from polars import DataFrame, DataType, Expr
from polars.dependencies import numpy as np
from polars.type_aliases import (
AsofJoinStrategy,
Expand Down Expand Up @@ -693,7 +693,7 @@ def columns(self) -> list[str]:
return self._ldf.columns()

@property
def dtypes(self) -> list[PolarsDataType]:
def dtypes(self) -> list[DataType]:
"""
Get dtypes of columns in LazyFrame.
Expand All @@ -717,7 +717,7 @@ def dtypes(self) -> list[PolarsDataType]:
return self._ldf.dtypes()

@property
def schema(self) -> SchemaDict:
def schema(self) -> OrderedDict[str, DataType]:
"""
Get a dict[column name, DataType].
Expand Down
4 changes: 2 additions & 2 deletions py-polars/polars/series/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def median(self) -> dt.date | dt.datetime | dt.timedelta | None:
if s.dtype == Date:
return _to_python_date(int(out))
else:
return _to_python_datetime(int(out), s.dtype.time_unit) # type: ignore[union-attr]
return _to_python_datetime(int(out), s.dtype.time_unit) # type: ignore[attr-defined]
return None

def mean(self) -> dt.date | dt.datetime | None:
Expand All @@ -108,7 +108,7 @@ def mean(self) -> dt.date | dt.datetime | None:
if s.dtype == Date:
return _to_python_date(int(out))
else:
return _to_python_datetime(int(out), s.dtype.time_unit) # type: ignore[union-attr]
return _to_python_datetime(int(out), s.dtype.time_unit) # type: ignore[attr-defined]
return None

def to_string(self, format: str) -> Series:
Expand Down
27 changes: 15 additions & 12 deletions py-polars/polars/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@
if TYPE_CHECKING:
import sys

from polars import DataFrame, Expr
from polars import DataFrame, DataType, Expr
from polars.series._numpy import SeriesView
from polars.type_aliases import (
ClosedInterval,
Expand Down Expand Up @@ -365,7 +365,7 @@ def _get_ptr(self) -> tuple[int, int, int]:
return self._s.get_ptr()

@property
def dtype(self) -> PolarsDataType:
def dtype(self) -> DataType:
"""
Get the data type of this Series.
Expand Down Expand Up @@ -398,10 +398,13 @@ def flags(self) -> dict[str, bool]:
return out

@property
def inner_dtype(self) -> PolarsDataType | None:
def inner_dtype(self) -> DataType | None:
"""
Get the inner dtype in of a List typed Series.
.. deprecated:: 0.19.14
Use `Series.dtype.inner` instead.
Returns
-------
DataType
Expand All @@ -412,7 +415,7 @@ def inner_dtype(self) -> PolarsDataType | None:
version="0.19.14",
)
try:
return self.dtype.inner # type: ignore[union-attr]
return self.dtype.inner # type: ignore[attr-defined]
except AttributeError:
return None

Expand Down Expand Up @@ -502,12 +505,12 @@ def _comp(self, other: Any, op: ComparisonOperator) -> Series:
time_unit = "us"
elif self.dtype == Datetime:
# Use local time zone info
time_zone = self.dtype.time_zone # type: ignore[union-attr]
time_zone = self.dtype.time_zone # type: ignore[attr-defined]
if str(other.tzinfo) != str(time_zone):
raise TypeError(
f"Datetime time zone {other.tzinfo!r} does not match Series timezone {time_zone!r}"
)
time_unit = self.dtype.time_unit # type: ignore[union-attr]
time_unit = self.dtype.time_unit # type: ignore[attr-defined]
else:
raise ValueError(
f"cannot compare datetime.datetime to Series of type {self.dtype}"
Expand All @@ -524,7 +527,7 @@ def _comp(self, other: Any, op: ComparisonOperator) -> Series:
return self._from_pyseries(f(d))

elif isinstance(other, timedelta) and self.dtype == Duration:
time_unit = self.dtype.time_unit # type: ignore[union-attr]
time_unit = self.dtype.time_unit # type: ignore[attr-defined]
td = _timedelta_to_pl_timedelta(other, time_unit) # type: ignore[arg-type]
f = get_ffi_func(op + "_<>", Int64, self._s)
assert f is not None
Expand Down Expand Up @@ -4051,9 +4054,9 @@ def convert_to_date(arr: np.ndarray[Any, Any]) -> np.ndarray[Any, Any]:
if self.dtype == Date:
tp = "datetime64[D]"
elif self.dtype == Duration:
tp = f"timedelta64[{self.dtype.time_unit}]" # type: ignore[union-attr]
tp = f"timedelta64[{self.dtype.time_unit}]" # type: ignore[attr-defined]
else:
tp = f"datetime64[{self.dtype.time_unit}]" # type: ignore[union-attr]
tp = f"datetime64[{self.dtype.time_unit}]" # type: ignore[attr-defined]
return arr.astype(tp)

def raise_no_zero_copy() -> None:
Expand All @@ -4066,7 +4069,7 @@ def raise_no_zero_copy() -> None:
writable=writable,
use_pyarrow=use_pyarrow,
)
np_array.shape = (self.len(), self.dtype.width) # type: ignore[union-attr]
np_array.shape = (self.len(), self.dtype.width) # type: ignore[attr-defined]
return np_array

if (
Expand Down Expand Up @@ -6972,7 +6975,7 @@ def is_boolean(self) -> bool:
True
"""
return self.dtype is Boolean
return self.dtype == Boolean

@deprecate_function("Use `Series.dtype == pl.Utf8` instead.", version="0.19.14")
def is_utf8(self) -> bool:
Expand All @@ -6989,7 +6992,7 @@ def is_utf8(self) -> bool:
True
"""
return self.dtype is Utf8
return self.dtype == Utf8

@deprecate_renamed_function("gather_every", version="0.19.14")
def take_every(self, n: int) -> Series:
Expand Down
7 changes: 3 additions & 4 deletions py-polars/polars/series/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
from polars.utils.various import sphinx_accessor

if TYPE_CHECKING:
from polars import DataFrame, Series
from polars import DataFrame, DataType, Series
from polars.polars import PySeries
from polars.type_aliases import SchemaDict
elif os.getenv("BUILDING_SPHINX_DOCS"):
property = sphinx_accessor

Expand Down Expand Up @@ -66,10 +65,10 @@ def rename_fields(self, names: Sequence[str]) -> Series:
"""

@property
def schema(self) -> SchemaDict:
def schema(self) -> OrderedDict[str, DataType]:
"""Get the struct definition as a name/dtype schema dict."""
if getattr(self, "_s", None) is None:
return {}
return OrderedDict()
return OrderedDict(self._s.dtype().to_schema())

def unnest(self) -> DataFrame:
Expand Down
10 changes: 5 additions & 5 deletions py-polars/polars/testing/asserts/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from polars.testing.asserts.utils import raise_assertion_error

if TYPE_CHECKING:
from polars.type_aliases import PolarsDataType
from polars import DataType


def assert_series_equal(
Expand Down Expand Up @@ -252,19 +252,19 @@ def _assert_series_nan_values_match(left: Series, right: Series) -> None:
)


def _comparing_floats(left: PolarsDataType, right: PolarsDataType) -> bool:
def _comparing_floats(left: DataType, right: DataType) -> bool:
return left.is_float() and right.is_float()


def _comparing_lists(left: PolarsDataType, right: PolarsDataType) -> bool:
def _comparing_lists(left: DataType, right: DataType) -> bool:
return left in (List, Array) and right in (List, Array)


def _comparing_structs(left: PolarsDataType, right: PolarsDataType) -> bool:
def _comparing_structs(left: DataType, right: DataType) -> bool:
return left == Struct and right == Struct


def _comparing_nested_floats(left: PolarsDataType, right: PolarsDataType) -> bool:
def _comparing_nested_floats(left: DataType, right: DataType) -> bool:
if not (_comparing_lists(left, right) or _comparing_structs(left, right)):
return False

Expand Down
Loading

0 comments on commit 6491844

Please sign in to comment.