diff --git a/polars/polars-core/src/chunked_array/logical/categorical/mod.rs b/polars/polars-core/src/chunked_array/logical/categorical/mod.rs index 898d55c2f00c..9813271adf23 100644 --- a/polars/polars-core/src/chunked_array/logical/categorical/mod.rs +++ b/polars/polars-core/src/chunked_array/logical/categorical/mod.rs @@ -4,6 +4,7 @@ mod merge; mod ops; pub mod stringcache; +use bitflags::bitflags; pub use builder::*; pub(crate) use merge::*; pub(crate) use ops::{CategoricalTakeRandomGlobal, CategoricalTakeRandomLocal}; @@ -12,13 +13,20 @@ use polars_utils::sync::SyncPtr; use super::*; use crate::prelude::*; +bitflags! { + #[derive(Default)] + struct BitSettings: u8 { + const ORIGINAL = 0x01; + const LEXICAL_SORT = 0x02; +}} + #[derive(Clone)] pub struct CategoricalChunked { logical: Logical, /// 1st bit: original local categorical /// meaning that n_unique is the same as the cat map length /// 2nd bit: use lexical sorting - bit_settings: u8, + bit_settings: BitSettings, } impl CategoricalChunked { @@ -58,7 +66,9 @@ impl CategoricalChunked { let ca = unsafe { UInt32Chunked::from_chunks(name, chunks) }; let mut logical = Logical::::new_logical::(ca); logical.2 = Some(DataType::Categorical(Some(Arc::new(rev_map)))); - let bit_settings = 1u8; + + let mut bit_settings = BitSettings::default(); + bit_settings.insert(BitSettings::ORIGINAL); Self { logical, bit_settings, @@ -67,14 +77,14 @@ impl CategoricalChunked { pub fn set_lexical_sorted(&mut self, toggle: bool) { if toggle { - self.bit_settings |= 1u8 << 1; + self.bit_settings.insert(BitSettings::LEXICAL_SORT); } else { - self.bit_settings &= !(1u8 << 1); + self.bit_settings.remove(BitSettings::LEXICAL_SORT); } } pub(crate) fn use_lexical_sort(&self) -> bool { - self.bit_settings & 1 << 1 != 0 + self.bit_settings.contains(BitSettings::LEXICAL_SORT) } /// Create a [`CategoricalChunked`] from an array of `idx` and an existing [`RevMapping`]: `rev_map`. @@ -103,14 +113,14 @@ impl CategoricalChunked { } pub(crate) fn can_fast_unique(&self) -> bool { - self.bit_settings & 1 << 0 != 0 && self.logical.chunks.len() == 1 + self.bit_settings.contains(BitSettings::ORIGINAL) && self.logical.chunks.len() == 1 } pub(crate) fn set_fast_unique(&mut self, can: bool) { if can { - self.bit_settings |= 1u8 << 0; + self.bit_settings.insert(BitSettings::ORIGINAL); } else { - self.bit_settings &= !(1u8 << 0); + self.bit_settings.remove(BitSettings::ORIGINAL); } } diff --git a/polars/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs b/polars/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs index f08506f85170..06ff3736299f 100644 --- a/polars/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs +++ b/polars/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs @@ -69,22 +69,38 @@ pub(crate) fn arg_sort_multiple_impl( pub(crate) fn argsort_multiple_row_fmt( by: &[Series], - descending: &[bool], + mut descending: Vec, nulls_last: bool, parallel: bool, ) -> PolarsResult { use polars_row::{convert_columns, SortField}; + broadcast_descending(by.len(), &mut descending); let mut cols = Vec::with_capacity(by.len()); let mut fields = Vec::with_capacity(by.len()); + debug_assert_eq!(by.len(), descending.len()); for (by, descending) in by.iter().zip(descending) { let by = convert_sort_column_multi_sort(by, true)?; - let data_type = by.dtype().to_arrow(); let by = by.rechunk(); - cols.push(by.chunks()[0].clone()); + + let arr = match by.dtype() { + #[cfg(feature = "dtype-categorical")] + DataType::Categorical(_) => { + let ca = by.categorical().unwrap(); + if ca.use_lexical_sort() { + by.to_arrow(0) + } else { + ca.logical().chunks[0].clone() + } + } + _ => by.to_arrow(0), + }; + let data_type = arr.data_type().clone(); + + cols.push(arr); fields.push(SortField { - descending: *descending, + descending, nulls_last, data_type, }) diff --git a/polars/polars-core/src/chunked_array/ops/sort/mod.rs b/polars/polars-core/src/chunked_array/ops/sort/mod.rs index 982d73001f4c..b6ffe7ff1219 100644 --- a/polars/polars-core/src/chunked_array/ops/sort/mod.rs +++ b/polars/polars-core/src/chunked_array/ops/sort/mod.rs @@ -742,6 +742,14 @@ pub(crate) fn convert_sort_column_multi_sort( Ok(out) } +pub(super) fn broadcast_descending(n_cols: usize, descending: &mut Vec) { + if n_cols > descending.len() && descending.len() == 1 { + while n_cols != descending.len() { + descending.push(descending[0]); + } + } +} + #[cfg(feature = "sort_multiple")] pub(crate) fn prepare_arg_sort( columns: Vec, @@ -757,11 +765,7 @@ pub(crate) fn prepare_arg_sort( let first = columns.remove(0); // broadcast ordering - if n_cols > descending.len() && descending.len() == 1 { - while n_cols != descending.len() { - descending.push(descending[0]); - } - } + broadcast_descending(n_cols, &mut descending); Ok((first, columns, descending)) } diff --git a/polars/polars-core/src/frame/mod.rs b/polars/polars-core/src/frame/mod.rs index e8fdd1ab9670..d700e6957483 100644 --- a/polars/polars-core/src/frame/mod.rs +++ b/polars/polars-core/src/frame/mod.rs @@ -1841,8 +1841,8 @@ impl DataFrame { _ => { #[cfg(feature = "sort_multiple")] { - if std::env::var("POLARS_ROW_FMT_SORT").is_ok() { - argsort_multiple_row_fmt(&by_column, &descending, nulls_last, parallel)? + if nulls_last || std::env::var("POLARS_ROW_FMT_SORT").is_ok() { + argsort_multiple_row_fmt(&by_column, descending, nulls_last, parallel)? } else { let (first, by_column, descending) = prepare_arg_sort(by_column, descending)?; @@ -2451,7 +2451,7 @@ impl DataFrame { /// /// let df2: DataFrame = df1.describe(None)?; /// assert_eq!(df2.shape(), (9, 4)); - /// dbg!(df2); + /// println!("{}", df2); /// # Ok::<(), PolarsError>(()) /// ``` /// diff --git a/polars/polars-row/src/encode.rs b/polars/polars-row/src/encode.rs index fa64f859aca8..23e9348f54af 100644 --- a/polars/polars-row/src/encode.rs +++ b/polars/polars-row/src/encode.rs @@ -1,4 +1,4 @@ -use arrow::array::{Array, BinaryArray, BooleanArray, PrimitiveArray}; +use arrow::array::{Array, BinaryArray, BooleanArray, DictionaryArray, PrimitiveArray, Utf8Array}; use arrow::datatypes::{DataType as ArrowDataType, DataType}; use arrow::types::NativeType; @@ -51,6 +51,17 @@ unsafe fn encode_array(array: &dyn Array, field: &SortField, out: &mut RowsEncod DataType::LargeUtf8 => { panic!("should be cast to binary") } + DataType::Dictionary(_, _, _) => { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + let iter = array + .iter_typed::>() + .unwrap() + .map(|opt_s| opt_s.map(|s| s.as_bytes())); + crate::encodings::variable::encode_iter(iter, out, field) + } dt => { with_match_arrow_primitive_type!(dt, |$T| { let array = array.as_any().downcast_ref::>().unwrap(); @@ -84,10 +95,28 @@ pub fn allocate_rows_buf(columns: &[ArrayRef], fields: &[SortField]) -> RowsEnco // for the variable length columns we must iterate to determine the length per row location for array in columns.iter() { - if matches!(array.data_type(), ArrowDataType::LargeBinary) { - let array = array.as_any().downcast_ref::>().unwrap(); - for (opt_val, row_length) in array.into_iter().zip(lengths.iter_mut()) { - *row_length += crate::encodings::variable::encoded_len(opt_val) + match array.data_type() { + ArrowDataType::LargeBinary => { + let array = array.as_any().downcast_ref::>().unwrap(); + for (opt_val, row_length) in array.into_iter().zip(lengths.iter_mut()) { + *row_length += crate::encodings::variable::encoded_len(opt_val) + } + } + ArrowDataType::Dictionary(_, _, _) => { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + let iter = array + .iter_typed::>() + .unwrap() + .map(|opt_s| opt_s.map(|s| s.as_bytes())); + for (opt_val, row_length) in iter.zip(lengths.iter_mut()) { + *row_length += crate::encodings::variable::encoded_len(opt_val) + } + } + _ => { + // the rest is fixed } } } @@ -139,3 +168,62 @@ pub fn allocate_rows_buf(columns: &[ArrayRef], fields: &[SortField]) -> RowsEnco RowsEncoded::new(buf, offsets, None) } } + +#[cfg(test)] +mod test { + use arrow::array::{Int64Array, UInt8Array, Utf8Array}; + + use super::*; + use crate::encodings::variable::{ + BLOCK_CONTINUATION_TOKEN, BLOCK_SIZE, EMPTY_SENTINEL, NON_EMPTY_SENTINEL, + }; + + #[test] + fn test_str_encode() { + let sentence = "The black cat walked under a ladder but forget it's milk so it ..."; + let arr = + Utf8Array::::from_iter([Some("a"), Some(""), Some("meep"), Some(sentence), None]); + + let field = SortField { + descending: false, + nulls_last: false, + data_type: ArrowDataType::LargeBinary, + }; + let arr = arrow::compute::cast::cast(&arr, &ArrowDataType::LargeBinary, Default::default()) + .unwrap(); + let rows_encoded = convert_columns(&[arr], vec![field]); + let row1 = rows_encoded.get(0); + + // + 2 for the start valid byte and for the continuation token + assert_eq!(row1.len(), BLOCK_SIZE + 2); + let mut expected = [0u8; BLOCK_SIZE + 2]; + expected[0] = NON_EMPTY_SENTINEL; + expected[1] = b'a'; + *expected.last_mut().unwrap() = 1; + assert_eq!(row1, expected); + + let row2 = rows_encoded.get(1); + let expected = &[EMPTY_SENTINEL]; + assert_eq!(row2, expected); + + let row3 = rows_encoded.get(2); + let mut expected = [0u8; BLOCK_SIZE + 2]; + expected[0] = NON_EMPTY_SENTINEL; + *expected.last_mut().unwrap() = 4; + &mut expected[1..5].copy_from_slice(b"meep"); + assert_eq!(row3, expected); + + let row4 = rows_encoded.get(3); + let mut expected = [ + 2, 84, 104, 101, 32, 98, 108, 97, 99, 107, 32, 99, 97, 116, 32, 119, 97, 108, 107, 101, + 100, 32, 117, 110, 100, 101, 114, 32, 97, 32, 108, 97, 100, 255, 100, 101, 114, 32, 98, + 117, 116, 32, 102, 111, 114, 103, 101, 116, 32, 105, 116, 39, 115, 32, 109, 105, 108, + 107, 32, 115, 111, 32, 105, 116, 32, 46, 255, 46, 46, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, + ]; + assert_eq!(row4, expected); + let row5 = rows_encoded.get(4); + let expected = &[0u8]; + assert_eq!(row5, expected); + } +} diff --git a/polars/polars-row/src/encodings/variable.rs b/polars/polars-row/src/encodings/variable.rs index 654d3949a318..941fe2226d8a 100644 --- a/polars/polars-row/src/encodings/variable.rs +++ b/polars/polars-row/src/encodings/variable.rs @@ -17,16 +17,16 @@ use crate::row::RowsEncoded; use crate::SortField; /// The block size of the variable length encoding -const BLOCK_SIZE: usize = 32; +pub(crate) const BLOCK_SIZE: usize = 32; /// The continuation token. -const BLOCK_CONTINUATION_TOKEN: u8 = 0xFF; +pub(crate) const BLOCK_CONTINUATION_TOKEN: u8 = 0xFF; /// Indicates an empty string -const EMPTY_SENTINEL: u8 = 1; +pub(crate) const EMPTY_SENTINEL: u8 = 1; /// Indicates a non-empty string -const NON_EMPTY_SENTINEL: u8 = 2; +pub(crate) const NON_EMPTY_SENTINEL: u8 = 2; /// Returns the ceil of `value`/`divisor` #[inline] @@ -71,7 +71,9 @@ unsafe fn encode_one(out: &mut [u8], val: Option<&[u8]>, field: &SortField) -> u 1 } Some(val) => { - let end_offset = padded_length(val.len()); + let block_count = ceil(val.len(), BLOCK_SIZE); + let end_offset = 1 + block_count * (BLOCK_SIZE + 1); + let dst = out.get_unchecked_release_mut(..end_offset); // Write `2_u8` to demarcate as non-empty, non-null string @@ -97,6 +99,16 @@ unsafe fn encode_one(out: &mut [u8], val: Option<&[u8]>, field: &SortField) -> u // replace the "there is another block" with // "we are finished this, this is the length of this block" *dst.last_mut().unwrap_unchecked() = BLOCK_SIZE as u8; + } else { + // get the last block + let start_offset = 1 + (block_count - 1) * (BLOCK_SIZE + 1); + let last_dst = dst.get_unchecked_release_mut(start_offset..); + std::ptr::copy_nonoverlapping( + src_remainder.as_ptr(), + last_dst.as_mut_ptr(), + src_remainder.len(), + ); + *dst.last_mut().unwrap_unchecked() = src_remainder.len() as u8; } if field.descending { diff --git a/polars/polars-row/src/lib.rs b/polars/polars-row/src/lib.rs index 0976d8aaac1e..2987eaa430a3 100644 --- a/polars/polars-row/src/lib.rs +++ b/polars/polars-row/src/lib.rs @@ -266,7 +266,7 @@ //! [COBS]: https://en.wikipedia.org/wiki/Consistent_Overhead_Byte_Stuffing //! [byte stuffing]: https://en.wikipedia.org/wiki/High-Level_Data_Link_Control#Asynchronous_framing -mod encode; +pub mod encode; mod encodings; mod row; mod sort_field; @@ -276,4 +276,5 @@ use arrow::array::*; pub type ArrayRef = Box; pub use encode::convert_columns; +pub use row::RowsEncoded; pub use sort_field::SortField; diff --git a/polars/polars-row/src/row.rs b/polars/polars-row/src/row.rs index 73d2b3990751..985943165529 100644 --- a/polars/polars-row/src/row.rs +++ b/polars/polars-row/src/row.rs @@ -24,6 +24,13 @@ impl RowsEncoded { buf: &self.buf, } } + + #[cfg(test)] + pub fn get(&self, i: usize) -> &[u8] { + let start = self.offsets[i]; + let end = self.offsets[i + 1]; + &self.buf[start..end] + } } pub struct RowsEncodedIter<'a> { diff --git a/polars/polars-row/src/sort_field.rs b/polars/polars-row/src/sort_field.rs index c566997044e1..e20af9ee83ec 100644 --- a/polars/polars-row/src/sort_field.rs +++ b/polars/polars-row/src/sort_field.rs @@ -3,6 +3,7 @@ use encodings::fixed::FixedLengthEncoding; use super::*; +#[derive(Clone)] pub struct SortField { /// Whether to sort in descending order pub descending: bool, diff --git a/py-polars/polars/internals/lazyframe/frame.py b/py-polars/polars/internals/lazyframe/frame.py index 5b6120933b0f..20a3d4c4bad6 100644 --- a/py-polars/polars/internals/lazyframe/frame.py +++ b/py-polars/polars/internals/lazyframe/frame.py @@ -1117,12 +1117,6 @@ def sort( if more_by: by.extend(pli.selection_to_pyexpr_list(more_by)) - # TODO: Do this check on the Rust side - if nulls_last and len(by) > 1: - raise ValueError( - "`nulls_last=True` only works when sorting by a single column" - ) - if isinstance(descending, bool): descending = [descending] return self._from_pyldf(self._ldf.sort_by_exprs(by, descending, nulls_last)) diff --git a/py-polars/polars/internals/whenthen.py b/py-polars/polars/internals/whenthen.py index c8d736ffbd3a..cca4ab8f733f 100644 --- a/py-polars/polars/internals/whenthen.py +++ b/py-polars/polars/internals/whenthen.py @@ -68,7 +68,7 @@ class WhenThen: def __init__(self, pywhenthen: Any): self._pywhenthen = pywhenthen - def when(self, predicate: pli.Expr | bool) -> WhenThenThen: + def when(self, predicate: pli.Expr | bool | pli.Series) -> WhenThenThen: """Start another "when, then, otherwise" layer.""" predicate = pli.expr_to_lit_or_expr(predicate) return WhenThenThen(self._pywhenthen.when(predicate._pyexpr)) @@ -131,7 +131,7 @@ def then( return WhenThen(pywhenthen) -def when(expr: pli.Expr | bool) -> When: +def when(expr: pli.Expr | bool | pli.Series) -> When: """ Start a "when, then, otherwise" expression. diff --git a/py-polars/tests/unit/operations/test_sort.py b/py-polars/tests/unit/operations/test_sort.py index c3b037f60046..4d41ef2c8d89 100644 --- a/py-polars/tests/unit/operations/test_sort.py +++ b/py-polars/tests/unit/operations/test_sort.py @@ -1,7 +1,10 @@ from __future__ import annotations +import random +import string from datetime import datetime +import numpy as np import pytest import polars as pl @@ -446,9 +449,6 @@ def test_sort_args() -> None: result = df.sort("a", nulls_last=True) assert_frame_equal(result, df) - with pytest.raises(ValueError): - df.sort("a", "b", nulls_last=True) - def test_sort_type_coersion_6892() -> None: df = pl.DataFrame({"a": [2, 1], "b": [2, 3]}) @@ -456,3 +456,33 @@ def test_sort_type_coersion_6892() -> None: "a": [1, 2], "b": [3, 2], } + + +@pytest.mark.slow() +def test_sort_row_fmt() -> None: + # we sort nulls_last as this will always dispatch + # to row_fmt and is the default in pandas + + n = 1000 + strs = pl.Series("strs", random.choices(string.ascii_lowercase, k=n)) + strs = pl.select( + pl.when(strs == "a") + .then("") + .when(strs == "b") + .then(None) + .otherwise(strs) + .alias("strs") + ).to_series() + + vals = pl.Series("vals", np.random.rand(n)) + + df = pl.DataFrame([vals, strs]) + df_pd = df.to_pandas() + + for descending in [True, False]: + pl.testing.assert_frame_equal( + df.sort(["strs", "vals"], nulls_last=True, descending=descending), + pl.from_pandas( + df_pd.sort_values(["strs", "vals"], ascending=not descending) + ), + ) diff --git a/py-polars/tests/unit/test_series.py b/py-polars/tests/unit/test_series.py index 23e5e8ec6fee..8d713c0ef7cc 100644 --- a/py-polars/tests/unit/test_series.py +++ b/py-polars/tests/unit/test_series.py @@ -1935,12 +1935,7 @@ def test_trigonometric(f: str) -> None: expected = ( pl.Series("a", getattr(np, f)(s.to_numpy())) .to_frame() - .with_columns( - pl.when(s.is_null()) # type: ignore[arg-type] - .then(None) - .otherwise(pl.col("a")) - .alias("a") - ) + .with_columns(pl.when(s.is_null()).then(None).otherwise(pl.col("a")).alias("a")) .to_series() ) result = getattr(s, f)()