Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(rust, python): support nulls_last for multi-column sort #7242

Merged
merged 1 commit into from
Feb 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions polars/polars-core/src/chunked_array/logical/categorical/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ mod merge;
mod ops;
pub mod stringcache;

use bitflags::bitflags;
pub use builder::*;
pub(crate) use merge::*;
pub(crate) use ops::{CategoricalTakeRandomGlobal, CategoricalTakeRandomLocal};
Expand All @@ -12,13 +13,20 @@ use polars_utils::sync::SyncPtr;
use super::*;
use crate::prelude::*;

bitflags! {
#[derive(Default)]
struct BitSettings: u8 {
const ORIGINAL = 0x01;
const LEXICAL_SORT = 0x02;
}}

#[derive(Clone)]
pub struct CategoricalChunked {
logical: Logical<CategoricalType, UInt32Type>,
/// 1st bit: original local categorical
/// meaning that n_unique is the same as the cat map length
/// 2nd bit: use lexical sorting
bit_settings: u8,
bit_settings: BitSettings,
}

impl CategoricalChunked {
Expand Down Expand Up @@ -58,7 +66,9 @@ impl CategoricalChunked {
let ca = unsafe { UInt32Chunked::from_chunks(name, chunks) };
let mut logical = Logical::<UInt32Type, _>::new_logical::<CategoricalType>(ca);
logical.2 = Some(DataType::Categorical(Some(Arc::new(rev_map))));
let bit_settings = 1u8;

let mut bit_settings = BitSettings::default();
bit_settings.insert(BitSettings::ORIGINAL);
Self {
logical,
bit_settings,
Expand All @@ -67,14 +77,14 @@ impl CategoricalChunked {

pub fn set_lexical_sorted(&mut self, toggle: bool) {
if toggle {
self.bit_settings |= 1u8 << 1;
self.bit_settings.insert(BitSettings::LEXICAL_SORT);
} else {
self.bit_settings &= !(1u8 << 1);
self.bit_settings.remove(BitSettings::LEXICAL_SORT);
}
}

pub(crate) fn use_lexical_sort(&self) -> bool {
self.bit_settings & 1 << 1 != 0
self.bit_settings.contains(BitSettings::LEXICAL_SORT)
}

/// Create a [`CategoricalChunked`] from an array of `idx` and an existing [`RevMapping`]: `rev_map`.
Expand Down Expand Up @@ -103,14 +113,14 @@ impl CategoricalChunked {
}

pub(crate) fn can_fast_unique(&self) -> bool {
self.bit_settings & 1 << 0 != 0 && self.logical.chunks.len() == 1
self.bit_settings.contains(BitSettings::ORIGINAL) && self.logical.chunks.len() == 1
}

pub(crate) fn set_fast_unique(&mut self, can: bool) {
if can {
self.bit_settings |= 1u8 << 0;
self.bit_settings.insert(BitSettings::ORIGINAL);
} else {
self.bit_settings &= !(1u8 << 0);
self.bit_settings.remove(BitSettings::ORIGINAL);
}
}

Expand Down
24 changes: 20 additions & 4 deletions polars/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,22 +69,38 @@ pub(crate) fn arg_sort_multiple_impl<T: PartialOrd + Send + IsFloat + Copy>(

pub(crate) fn argsort_multiple_row_fmt(
by: &[Series],
descending: &[bool],
mut descending: Vec<bool>,
nulls_last: bool,
parallel: bool,
) -> PolarsResult<IdxCa> {
use polars_row::{convert_columns, SortField};
broadcast_descending(by.len(), &mut descending);

let mut cols = Vec::with_capacity(by.len());
let mut fields = Vec::with_capacity(by.len());

debug_assert_eq!(by.len(), descending.len());
for (by, descending) in by.iter().zip(descending) {
let by = convert_sort_column_multi_sort(by, true)?;
let data_type = by.dtype().to_arrow();
let by = by.rechunk();
cols.push(by.chunks()[0].clone());

let arr = match by.dtype() {
#[cfg(feature = "dtype-categorical")]
DataType::Categorical(_) => {
let ca = by.categorical().unwrap();
if ca.use_lexical_sort() {
by.to_arrow(0)
} else {
ca.logical().chunks[0].clone()
}
}
_ => by.to_arrow(0),
};
let data_type = arr.data_type().clone();

cols.push(arr);
fields.push(SortField {
descending: *descending,
descending,
nulls_last,
data_type,
})
Expand Down
14 changes: 9 additions & 5 deletions polars/polars-core/src/chunked_array/ops/sort/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,14 @@ pub(crate) fn convert_sort_column_multi_sort(
Ok(out)
}

pub(super) fn broadcast_descending(n_cols: usize, descending: &mut Vec<bool>) {
if n_cols > descending.len() && descending.len() == 1 {
while n_cols != descending.len() {
descending.push(descending[0]);
}
}
}

#[cfg(feature = "sort_multiple")]
pub(crate) fn prepare_arg_sort(
columns: Vec<Series>,
Expand All @@ -757,11 +765,7 @@ pub(crate) fn prepare_arg_sort(
let first = columns.remove(0);

// broadcast ordering
if n_cols > descending.len() && descending.len() == 1 {
while n_cols != descending.len() {
descending.push(descending[0]);
}
}
broadcast_descending(n_cols, &mut descending);
Ok((first, columns, descending))
}

Expand Down
6 changes: 3 additions & 3 deletions polars/polars-core/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1841,8 +1841,8 @@ impl DataFrame {
_ => {
#[cfg(feature = "sort_multiple")]
{
if std::env::var("POLARS_ROW_FMT_SORT").is_ok() {
argsort_multiple_row_fmt(&by_column, &descending, nulls_last, parallel)?
if nulls_last || std::env::var("POLARS_ROW_FMT_SORT").is_ok() {
argsort_multiple_row_fmt(&by_column, descending, nulls_last, parallel)?
} else {
let (first, by_column, descending) =
prepare_arg_sort(by_column, descending)?;
Expand Down Expand Up @@ -2451,7 +2451,7 @@ impl DataFrame {
///
/// let df2: DataFrame = df1.describe(None)?;
/// assert_eq!(df2.shape(), (9, 4));
/// dbg!(df2);
/// println!("{}", df2);
/// # Ok::<(), PolarsError>(())
/// ```
///
Expand Down
98 changes: 93 additions & 5 deletions polars/polars-row/src/encode.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use arrow::array::{Array, BinaryArray, BooleanArray, PrimitiveArray};
use arrow::array::{Array, BinaryArray, BooleanArray, DictionaryArray, PrimitiveArray, Utf8Array};
use arrow::datatypes::{DataType as ArrowDataType, DataType};
use arrow::types::NativeType;

Expand Down Expand Up @@ -51,6 +51,17 @@ unsafe fn encode_array(array: &dyn Array, field: &SortField, out: &mut RowsEncod
DataType::LargeUtf8 => {
panic!("should be cast to binary")
}
DataType::Dictionary(_, _, _) => {
let array = array
.as_any()
.downcast_ref::<DictionaryArray<u32>>()
.unwrap();
let iter = array
.iter_typed::<Utf8Array<i64>>()
.unwrap()
.map(|opt_s| opt_s.map(|s| s.as_bytes()));
crate::encodings::variable::encode_iter(iter, out, field)
}
dt => {
with_match_arrow_primitive_type!(dt, |$T| {
let array = array.as_any().downcast_ref::<PrimitiveArray<$T>>().unwrap();
Expand Down Expand Up @@ -84,10 +95,28 @@ pub fn allocate_rows_buf(columns: &[ArrayRef], fields: &[SortField]) -> RowsEnco

// for the variable length columns we must iterate to determine the length per row location
for array in columns.iter() {
if matches!(array.data_type(), ArrowDataType::LargeBinary) {
let array = array.as_any().downcast_ref::<BinaryArray<i64>>().unwrap();
for (opt_val, row_length) in array.into_iter().zip(lengths.iter_mut()) {
*row_length += crate::encodings::variable::encoded_len(opt_val)
match array.data_type() {
ArrowDataType::LargeBinary => {
let array = array.as_any().downcast_ref::<BinaryArray<i64>>().unwrap();
for (opt_val, row_length) in array.into_iter().zip(lengths.iter_mut()) {
*row_length += crate::encodings::variable::encoded_len(opt_val)
}
}
ArrowDataType::Dictionary(_, _, _) => {
let array = array
.as_any()
.downcast_ref::<DictionaryArray<u32>>()
.unwrap();
let iter = array
.iter_typed::<Utf8Array<i64>>()
.unwrap()
.map(|opt_s| opt_s.map(|s| s.as_bytes()));
for (opt_val, row_length) in iter.zip(lengths.iter_mut()) {
*row_length += crate::encodings::variable::encoded_len(opt_val)
}
}
_ => {
// the rest is fixed
}
}
}
Expand Down Expand Up @@ -139,3 +168,62 @@ pub fn allocate_rows_buf(columns: &[ArrayRef], fields: &[SortField]) -> RowsEnco
RowsEncoded::new(buf, offsets, None)
}
}

#[cfg(test)]
mod test {
use arrow::array::{Int64Array, UInt8Array, Utf8Array};

use super::*;
use crate::encodings::variable::{
BLOCK_CONTINUATION_TOKEN, BLOCK_SIZE, EMPTY_SENTINEL, NON_EMPTY_SENTINEL,
};

#[test]
fn test_str_encode() {
let sentence = "The black cat walked under a ladder but forget it's milk so it ...";
let arr =
Utf8Array::<i64>::from_iter([Some("a"), Some(""), Some("meep"), Some(sentence), None]);

let field = SortField {
descending: false,
nulls_last: false,
data_type: ArrowDataType::LargeBinary,
};
let arr = arrow::compute::cast::cast(&arr, &ArrowDataType::LargeBinary, Default::default())
.unwrap();
let rows_encoded = convert_columns(&[arr], vec![field]);
let row1 = rows_encoded.get(0);

// + 2 for the start valid byte and for the continuation token
assert_eq!(row1.len(), BLOCK_SIZE + 2);
let mut expected = [0u8; BLOCK_SIZE + 2];
expected[0] = NON_EMPTY_SENTINEL;
expected[1] = b'a';
*expected.last_mut().unwrap() = 1;
assert_eq!(row1, expected);

let row2 = rows_encoded.get(1);
let expected = &[EMPTY_SENTINEL];
assert_eq!(row2, expected);

let row3 = rows_encoded.get(2);
let mut expected = [0u8; BLOCK_SIZE + 2];
expected[0] = NON_EMPTY_SENTINEL;
*expected.last_mut().unwrap() = 4;
&mut expected[1..5].copy_from_slice(b"meep");
assert_eq!(row3, expected);

let row4 = rows_encoded.get(3);
let mut expected = [
2, 84, 104, 101, 32, 98, 108, 97, 99, 107, 32, 99, 97, 116, 32, 119, 97, 108, 107, 101,
100, 32, 117, 110, 100, 101, 114, 32, 97, 32, 108, 97, 100, 255, 100, 101, 114, 32, 98,
117, 116, 32, 102, 111, 114, 103, 101, 116, 32, 105, 116, 39, 115, 32, 109, 105, 108,
107, 32, 115, 111, 32, 105, 116, 32, 46, 255, 46, 46, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2,
];
assert_eq!(row4, expected);
let row5 = rows_encoded.get(4);
let expected = &[0u8];
assert_eq!(row5, expected);
}
}
22 changes: 17 additions & 5 deletions polars/polars-row/src/encodings/variable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@ use crate::row::RowsEncoded;
use crate::SortField;

/// The block size of the variable length encoding
const BLOCK_SIZE: usize = 32;
pub(crate) const BLOCK_SIZE: usize = 32;

/// The continuation token.
const BLOCK_CONTINUATION_TOKEN: u8 = 0xFF;
pub(crate) const BLOCK_CONTINUATION_TOKEN: u8 = 0xFF;

/// Indicates an empty string
const EMPTY_SENTINEL: u8 = 1;
pub(crate) const EMPTY_SENTINEL: u8 = 1;

/// Indicates a non-empty string
const NON_EMPTY_SENTINEL: u8 = 2;
pub(crate) const NON_EMPTY_SENTINEL: u8 = 2;

/// Returns the ceil of `value`/`divisor`
#[inline]
Expand Down Expand Up @@ -71,7 +71,9 @@ unsafe fn encode_one(out: &mut [u8], val: Option<&[u8]>, field: &SortField) -> u
1
}
Some(val) => {
let end_offset = padded_length(val.len());
let block_count = ceil(val.len(), BLOCK_SIZE);
let end_offset = 1 + block_count * (BLOCK_SIZE + 1);

let dst = out.get_unchecked_release_mut(..end_offset);

// Write `2_u8` to demarcate as non-empty, non-null string
Expand All @@ -97,6 +99,16 @@ unsafe fn encode_one(out: &mut [u8], val: Option<&[u8]>, field: &SortField) -> u
// replace the "there is another block" with
// "we are finished this, this is the length of this block"
*dst.last_mut().unwrap_unchecked() = BLOCK_SIZE as u8;
} else {
// get the last block
let start_offset = 1 + (block_count - 1) * (BLOCK_SIZE + 1);
let last_dst = dst.get_unchecked_release_mut(start_offset..);
std::ptr::copy_nonoverlapping(
src_remainder.as_ptr(),
last_dst.as_mut_ptr(),
src_remainder.len(),
);
*dst.last_mut().unwrap_unchecked() = src_remainder.len() as u8;
}

if field.descending {
Expand Down
3 changes: 2 additions & 1 deletion polars/polars-row/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@
//! [COBS]: https://en.wikipedia.org/wiki/Consistent_Overhead_Byte_Stuffing
//! [byte stuffing]: https://en.wikipedia.org/wiki/High-Level_Data_Link_Control#Asynchronous_framing

mod encode;
pub mod encode;
mod encodings;
mod row;
mod sort_field;
Expand All @@ -276,4 +276,5 @@ use arrow::array::*;
pub type ArrayRef = Box<dyn Array>;

pub use encode::convert_columns;
pub use row::RowsEncoded;
pub use sort_field::SortField;
7 changes: 7 additions & 0 deletions polars/polars-row/src/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ impl RowsEncoded {
buf: &self.buf,
}
}

#[cfg(test)]
pub fn get(&self, i: usize) -> &[u8] {
let start = self.offsets[i];
let end = self.offsets[i + 1];
&self.buf[start..end]
}
}

pub struct RowsEncodedIter<'a> {
Expand Down
1 change: 1 addition & 0 deletions polars/polars-row/src/sort_field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use encodings::fixed::FixedLengthEncoding;

use super::*;

#[derive(Clone)]
pub struct SortField {
/// Whether to sort in descending order
pub descending: bool,
Expand Down
6 changes: 0 additions & 6 deletions py-polars/polars/internals/lazyframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,12 +1082,6 @@ def sort(
if more_by:
by.extend(pli.selection_to_pyexpr_list(more_by))

# TODO: Do this check on the Rust side
if nulls_last and len(by) > 1:
raise ValueError(
"`nulls_last=True` only works when sorting by a single column"
)

if isinstance(descending, bool):
descending = [descending]
return self._from_pyldf(self._ldf.sort_by_exprs(by, descending, nulls_last))
Expand Down
Loading