Skip to content

Commit

Permalink
feat(rust, python): support nulls_last for multi-column sort (pola-rs…
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Feb 28, 2023
1 parent 41262b5 commit 6fe4f77
Show file tree
Hide file tree
Showing 13 changed files with 206 additions and 48 deletions.
26 changes: 18 additions & 8 deletions polars/polars-core/src/chunked_array/logical/categorical/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ mod merge;
mod ops;
pub mod stringcache;

use bitflags::bitflags;
pub use builder::*;
pub(crate) use merge::*;
pub(crate) use ops::{CategoricalTakeRandomGlobal, CategoricalTakeRandomLocal};
Expand All @@ -12,13 +13,20 @@ use polars_utils::sync::SyncPtr;
use super::*;
use crate::prelude::*;

bitflags! {
#[derive(Default)]
struct BitSettings: u8 {
const ORIGINAL = 0x01;
const LEXICAL_SORT = 0x02;
}}

#[derive(Clone)]
pub struct CategoricalChunked {
logical: Logical<CategoricalType, UInt32Type>,
/// 1st bit: original local categorical
/// meaning that n_unique is the same as the cat map length
/// 2nd bit: use lexical sorting
bit_settings: u8,
bit_settings: BitSettings,
}

impl CategoricalChunked {
Expand Down Expand Up @@ -58,7 +66,9 @@ impl CategoricalChunked {
let ca = unsafe { UInt32Chunked::from_chunks(name, chunks) };
let mut logical = Logical::<UInt32Type, _>::new_logical::<CategoricalType>(ca);
logical.2 = Some(DataType::Categorical(Some(Arc::new(rev_map))));
let bit_settings = 1u8;

let mut bit_settings = BitSettings::default();
bit_settings.insert(BitSettings::ORIGINAL);
Self {
logical,
bit_settings,
Expand All @@ -67,14 +77,14 @@ impl CategoricalChunked {

pub fn set_lexical_sorted(&mut self, toggle: bool) {
if toggle {
self.bit_settings |= 1u8 << 1;
self.bit_settings.insert(BitSettings::LEXICAL_SORT);
} else {
self.bit_settings &= !(1u8 << 1);
self.bit_settings.remove(BitSettings::LEXICAL_SORT);
}
}

pub(crate) fn use_lexical_sort(&self) -> bool {
self.bit_settings & 1 << 1 != 0
self.bit_settings.contains(BitSettings::LEXICAL_SORT)
}

/// Create a [`CategoricalChunked`] from an array of `idx` and an existing [`RevMapping`]: `rev_map`.
Expand Down Expand Up @@ -103,14 +113,14 @@ impl CategoricalChunked {
}

pub(crate) fn can_fast_unique(&self) -> bool {
self.bit_settings & 1 << 0 != 0 && self.logical.chunks.len() == 1
self.bit_settings.contains(BitSettings::ORIGINAL) && self.logical.chunks.len() == 1
}

pub(crate) fn set_fast_unique(&mut self, can: bool) {
if can {
self.bit_settings |= 1u8 << 0;
self.bit_settings.insert(BitSettings::ORIGINAL);
} else {
self.bit_settings &= !(1u8 << 0);
self.bit_settings.remove(BitSettings::ORIGINAL);
}
}

Expand Down
24 changes: 20 additions & 4 deletions polars/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,22 +69,38 @@ pub(crate) fn arg_sort_multiple_impl<T: PartialOrd + Send + IsFloat + Copy>(

pub(crate) fn argsort_multiple_row_fmt(
by: &[Series],
descending: &[bool],
mut descending: Vec<bool>,
nulls_last: bool,
parallel: bool,
) -> PolarsResult<IdxCa> {
use polars_row::{convert_columns, SortField};
broadcast_descending(by.len(), &mut descending);

let mut cols = Vec::with_capacity(by.len());
let mut fields = Vec::with_capacity(by.len());

debug_assert_eq!(by.len(), descending.len());
for (by, descending) in by.iter().zip(descending) {
let by = convert_sort_column_multi_sort(by, true)?;
let data_type = by.dtype().to_arrow();
let by = by.rechunk();
cols.push(by.chunks()[0].clone());

let arr = match by.dtype() {
#[cfg(feature = "dtype-categorical")]
DataType::Categorical(_) => {
let ca = by.categorical().unwrap();
if ca.use_lexical_sort() {
by.to_arrow(0)
} else {
ca.logical().chunks[0].clone()
}
}
_ => by.to_arrow(0),
};
let data_type = arr.data_type().clone();

cols.push(arr);
fields.push(SortField {
descending: *descending,
descending,
nulls_last,
data_type,
})
Expand Down
14 changes: 9 additions & 5 deletions polars/polars-core/src/chunked_array/ops/sort/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,14 @@ pub(crate) fn convert_sort_column_multi_sort(
Ok(out)
}

pub(super) fn broadcast_descending(n_cols: usize, descending: &mut Vec<bool>) {
if n_cols > descending.len() && descending.len() == 1 {
while n_cols != descending.len() {
descending.push(descending[0]);
}
}
}

#[cfg(feature = "sort_multiple")]
pub(crate) fn prepare_arg_sort(
columns: Vec<Series>,
Expand All @@ -757,11 +765,7 @@ pub(crate) fn prepare_arg_sort(
let first = columns.remove(0);

// broadcast ordering
if n_cols > descending.len() && descending.len() == 1 {
while n_cols != descending.len() {
descending.push(descending[0]);
}
}
broadcast_descending(n_cols, &mut descending);
Ok((first, columns, descending))
}

Expand Down
6 changes: 3 additions & 3 deletions polars/polars-core/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1841,8 +1841,8 @@ impl DataFrame {
_ => {
#[cfg(feature = "sort_multiple")]
{
if std::env::var("POLARS_ROW_FMT_SORT").is_ok() {
argsort_multiple_row_fmt(&by_column, &descending, nulls_last, parallel)?
if nulls_last || std::env::var("POLARS_ROW_FMT_SORT").is_ok() {
argsort_multiple_row_fmt(&by_column, descending, nulls_last, parallel)?
} else {
let (first, by_column, descending) =
prepare_arg_sort(by_column, descending)?;
Expand Down Expand Up @@ -2451,7 +2451,7 @@ impl DataFrame {
///
/// let df2: DataFrame = df1.describe(None)?;
/// assert_eq!(df2.shape(), (9, 4));
/// dbg!(df2);
/// println!("{}", df2);
/// # Ok::<(), PolarsError>(())
/// ```
///
Expand Down
98 changes: 93 additions & 5 deletions polars/polars-row/src/encode.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use arrow::array::{Array, BinaryArray, BooleanArray, PrimitiveArray};
use arrow::array::{Array, BinaryArray, BooleanArray, DictionaryArray, PrimitiveArray, Utf8Array};
use arrow::datatypes::{DataType as ArrowDataType, DataType};
use arrow::types::NativeType;

Expand Down Expand Up @@ -51,6 +51,17 @@ unsafe fn encode_array(array: &dyn Array, field: &SortField, out: &mut RowsEncod
DataType::LargeUtf8 => {
panic!("should be cast to binary")
}
DataType::Dictionary(_, _, _) => {
let array = array
.as_any()
.downcast_ref::<DictionaryArray<u32>>()
.unwrap();
let iter = array
.iter_typed::<Utf8Array<i64>>()
.unwrap()
.map(|opt_s| opt_s.map(|s| s.as_bytes()));
crate::encodings::variable::encode_iter(iter, out, field)
}
dt => {
with_match_arrow_primitive_type!(dt, |$T| {
let array = array.as_any().downcast_ref::<PrimitiveArray<$T>>().unwrap();
Expand Down Expand Up @@ -84,10 +95,28 @@ pub fn allocate_rows_buf(columns: &[ArrayRef], fields: &[SortField]) -> RowsEnco

// for the variable length columns we must iterate to determine the length per row location
for array in columns.iter() {
if matches!(array.data_type(), ArrowDataType::LargeBinary) {
let array = array.as_any().downcast_ref::<BinaryArray<i64>>().unwrap();
for (opt_val, row_length) in array.into_iter().zip(lengths.iter_mut()) {
*row_length += crate::encodings::variable::encoded_len(opt_val)
match array.data_type() {
ArrowDataType::LargeBinary => {
let array = array.as_any().downcast_ref::<BinaryArray<i64>>().unwrap();
for (opt_val, row_length) in array.into_iter().zip(lengths.iter_mut()) {
*row_length += crate::encodings::variable::encoded_len(opt_val)
}
}
ArrowDataType::Dictionary(_, _, _) => {
let array = array
.as_any()
.downcast_ref::<DictionaryArray<u32>>()
.unwrap();
let iter = array
.iter_typed::<Utf8Array<i64>>()
.unwrap()
.map(|opt_s| opt_s.map(|s| s.as_bytes()));
for (opt_val, row_length) in iter.zip(lengths.iter_mut()) {
*row_length += crate::encodings::variable::encoded_len(opt_val)
}
}
_ => {
// the rest is fixed
}
}
}
Expand Down Expand Up @@ -139,3 +168,62 @@ pub fn allocate_rows_buf(columns: &[ArrayRef], fields: &[SortField]) -> RowsEnco
RowsEncoded::new(buf, offsets, None)
}
}

#[cfg(test)]
mod test {
use arrow::array::{Int64Array, UInt8Array, Utf8Array};

use super::*;
use crate::encodings::variable::{
BLOCK_CONTINUATION_TOKEN, BLOCK_SIZE, EMPTY_SENTINEL, NON_EMPTY_SENTINEL,
};

#[test]
fn test_str_encode() {
let sentence = "The black cat walked under a ladder but forget it's milk so it ...";
let arr =
Utf8Array::<i64>::from_iter([Some("a"), Some(""), Some("meep"), Some(sentence), None]);

let field = SortField {
descending: false,
nulls_last: false,
data_type: ArrowDataType::LargeBinary,
};
let arr = arrow::compute::cast::cast(&arr, &ArrowDataType::LargeBinary, Default::default())
.unwrap();
let rows_encoded = convert_columns(&[arr], vec![field]);
let row1 = rows_encoded.get(0);

// + 2 for the start valid byte and for the continuation token
assert_eq!(row1.len(), BLOCK_SIZE + 2);
let mut expected = [0u8; BLOCK_SIZE + 2];
expected[0] = NON_EMPTY_SENTINEL;
expected[1] = b'a';
*expected.last_mut().unwrap() = 1;
assert_eq!(row1, expected);

let row2 = rows_encoded.get(1);
let expected = &[EMPTY_SENTINEL];
assert_eq!(row2, expected);

let row3 = rows_encoded.get(2);
let mut expected = [0u8; BLOCK_SIZE + 2];
expected[0] = NON_EMPTY_SENTINEL;
*expected.last_mut().unwrap() = 4;
&mut expected[1..5].copy_from_slice(b"meep");
assert_eq!(row3, expected);

let row4 = rows_encoded.get(3);
let mut expected = [
2, 84, 104, 101, 32, 98, 108, 97, 99, 107, 32, 99, 97, 116, 32, 119, 97, 108, 107, 101,
100, 32, 117, 110, 100, 101, 114, 32, 97, 32, 108, 97, 100, 255, 100, 101, 114, 32, 98,
117, 116, 32, 102, 111, 114, 103, 101, 116, 32, 105, 116, 39, 115, 32, 109, 105, 108,
107, 32, 115, 111, 32, 105, 116, 32, 46, 255, 46, 46, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2,
];
assert_eq!(row4, expected);
let row5 = rows_encoded.get(4);
let expected = &[0u8];
assert_eq!(row5, expected);
}
}
22 changes: 17 additions & 5 deletions polars/polars-row/src/encodings/variable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@ use crate::row::RowsEncoded;
use crate::SortField;

/// The block size of the variable length encoding
const BLOCK_SIZE: usize = 32;
pub(crate) const BLOCK_SIZE: usize = 32;

/// The continuation token.
const BLOCK_CONTINUATION_TOKEN: u8 = 0xFF;
pub(crate) const BLOCK_CONTINUATION_TOKEN: u8 = 0xFF;

/// Indicates an empty string
const EMPTY_SENTINEL: u8 = 1;
pub(crate) const EMPTY_SENTINEL: u8 = 1;

/// Indicates a non-empty string
const NON_EMPTY_SENTINEL: u8 = 2;
pub(crate) const NON_EMPTY_SENTINEL: u8 = 2;

/// Returns the ceil of `value`/`divisor`
#[inline]
Expand Down Expand Up @@ -71,7 +71,9 @@ unsafe fn encode_one(out: &mut [u8], val: Option<&[u8]>, field: &SortField) -> u
1
}
Some(val) => {
let end_offset = padded_length(val.len());
let block_count = ceil(val.len(), BLOCK_SIZE);
let end_offset = 1 + block_count * (BLOCK_SIZE + 1);

let dst = out.get_unchecked_release_mut(..end_offset);

// Write `2_u8` to demarcate as non-empty, non-null string
Expand All @@ -97,6 +99,16 @@ unsafe fn encode_one(out: &mut [u8], val: Option<&[u8]>, field: &SortField) -> u
// replace the "there is another block" with
// "we are finished this, this is the length of this block"
*dst.last_mut().unwrap_unchecked() = BLOCK_SIZE as u8;
} else {
// get the last block
let start_offset = 1 + (block_count - 1) * (BLOCK_SIZE + 1);
let last_dst = dst.get_unchecked_release_mut(start_offset..);
std::ptr::copy_nonoverlapping(
src_remainder.as_ptr(),
last_dst.as_mut_ptr(),
src_remainder.len(),
);
*dst.last_mut().unwrap_unchecked() = src_remainder.len() as u8;
}

if field.descending {
Expand Down
3 changes: 2 additions & 1 deletion polars/polars-row/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@
//! [COBS]: https://en.wikipedia.org/wiki/Consistent_Overhead_Byte_Stuffing
//! [byte stuffing]: https://en.wikipedia.org/wiki/High-Level_Data_Link_Control#Asynchronous_framing
mod encode;
pub mod encode;
mod encodings;
mod row;
mod sort_field;
Expand All @@ -276,4 +276,5 @@ use arrow::array::*;
pub type ArrayRef = Box<dyn Array>;

pub use encode::convert_columns;
pub use row::RowsEncoded;
pub use sort_field::SortField;
7 changes: 7 additions & 0 deletions polars/polars-row/src/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ impl RowsEncoded {
buf: &self.buf,
}
}

#[cfg(test)]
pub fn get(&self, i: usize) -> &[u8] {
let start = self.offsets[i];
let end = self.offsets[i + 1];
&self.buf[start..end]
}
}

pub struct RowsEncodedIter<'a> {
Expand Down
1 change: 1 addition & 0 deletions polars/polars-row/src/sort_field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use encodings::fixed::FixedLengthEncoding;

use super::*;

#[derive(Clone)]
pub struct SortField {
/// Whether to sort in descending order
pub descending: bool,
Expand Down
6 changes: 0 additions & 6 deletions py-polars/polars/internals/lazyframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1117,12 +1117,6 @@ def sort(
if more_by:
by.extend(pli.selection_to_pyexpr_list(more_by))

# TODO: Do this check on the Rust side
if nulls_last and len(by) > 1:
raise ValueError(
"`nulls_last=True` only works when sorting by a single column"
)

if isinstance(descending, bool):
descending = [descending]
return self._from_pyldf(self._ldf.sort_by_exprs(by, descending, nulls_last))
Expand Down
Loading

0 comments on commit 6fe4f77

Please sign in to comment.