Skip to content

Commit

Permalink
perf: use zeroable_vec in ewm_mean_by (pola-rs#16166)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli authored May 12, 2024
1 parent 9ea2504 commit 2f81742
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 31 deletions.
72 changes: 42 additions & 30 deletions crates/polars-ops/src/series/ops/ewm_by.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,29 @@
use arrow::compute::concatenate::concatenate_validities;
use arrow::compute::utils::combine_validities_and;
use bytemuck::allocation::zeroed_vec;
use num_traits::{Float, FromPrimitive, One, Zero};
use polars_core::prelude::*;
use polars_core::utils::align_chunks_binary;

pub fn ewm_mean_by(
s: &Series,
times: &Series,
half_life: i64,
assume_sorted: bool,
) -> PolarsResult<Series> {
let func = match assume_sorted {
true => ewm_mean_by_impl_sorted,
false => ewm_mean_by_impl,
};
match (s.dtype(), times.dtype()) {
(DataType::Float64, DataType::Int64) => {
Ok(func(s.f64().unwrap(), times.i64().unwrap(), half_life).into_series())
},
(DataType::Float32, DataType::Int64) => {
Ok(ewm_mean_by_impl(s.f32().unwrap(), times.i64().unwrap(), half_life).into_series())
},
(DataType::Float64, DataType::Int64) => Ok((if assume_sorted {
ewm_mean_by_impl_sorted(s.f64().unwrap(), times.i64().unwrap(), half_life)
} else {
ewm_mean_by_impl(s.f64().unwrap(), times.i64().unwrap(), half_life)
})
.into_series()),
(DataType::Float32, DataType::Int64) => Ok((if assume_sorted {
ewm_mean_by_impl_sorted(s.f32().unwrap(), times.i64().unwrap(), half_life)
} else {
ewm_mean_by_impl(s.f32().unwrap(), times.i64().unwrap(), half_life)
})
.into_series()),
#[cfg(feature = "dtype-datetime")]
(_, DataType::Datetime(time_unit, _)) => {
let half_life = adjust_half_life_to_time_unit(half_life, time_unit);
Expand Down Expand Up @@ -61,50 +67,56 @@ where
ChunkedArray<T>: ChunkTakeUnchecked<IdxCa>,
{
let sorting_indices = times.arg_sort(Default::default());
let values = unsafe { values.take_unchecked(&sorting_indices) };
let times = unsafe { times.take_unchecked(&sorting_indices) };
let sorted_values = unsafe { values.take_unchecked(&sorting_indices) };
let sorted_times = unsafe { times.take_unchecked(&sorting_indices) };
let sorting_indices = sorting_indices
.cont_slice()
.expect("`arg_sort` should have returned a single chunk");

let mut out = vec![None; times.len()];
let mut out: Vec<_> = zeroed_vec(sorted_times.len());

let mut skip_rows: usize = 0;
let mut prev_time: i64 = 0;
let mut prev_result = T::Native::zero();
for (idx, (value, time)) in values.iter().zip(times.iter()).enumerate() {
for (idx, (value, time)) in sorted_values.iter().zip(sorted_times.iter()).enumerate() {
if let (Some(time), Some(value)) = (time, value) {
prev_time = time;
prev_result = value;
unsafe {
let out_idx = sorting_indices.get_unchecked(idx);
*out.get_unchecked_mut(*out_idx as usize) = Some(prev_result);
*out.get_unchecked_mut(*out_idx as usize) = prev_result;
}
skip_rows = idx + 1;
break;
};
}
values
sorted_values
.iter()
.zip(times.iter())
.zip(sorted_times.iter())
.enumerate()
.skip(skip_rows)
.for_each(|(idx, (value, time))| {
let result_opt = match (time, value) {
(Some(time), Some(value)) => {
let result = update(value, prev_result, time, prev_time, half_life);
prev_time = time;
prev_result = result;
Some(result)
},
_ => None,
if let (Some(time), Some(value)) = (time, value) {
let result = update(value, prev_result, time, prev_time, half_life);
prev_time = time;
prev_result = result;
unsafe {
let out_idx = sorting_indices.get_unchecked(idx);
*out.get_unchecked_mut(*out_idx as usize) = result;
}
};
unsafe {
let out_idx = sorting_indices.get_unchecked(idx);
*out.get_unchecked_mut(*out_idx as usize) = result_opt;
}
});
ChunkedArray::<T>::from_iter_options(values.name(), out.into_iter())
let mut arr = T::Array::from_zeroable_vec(out, values.dtype().to_arrow(true));
if (times.null_count() > 0) || (values.null_count() > 0) {
let (times, values) = align_chunks_binary(times, values);
let times_chunk_refs: Vec<_> = times.chunks().iter().map(|c| &**c).collect();
let times_validity = concatenate_validities(&times_chunk_refs);
let values_chunk_refs: Vec<_> = values.chunks().iter().map(|c| &**c).collect();
let values_validity = concatenate_validities(&values_chunk_refs);
let validity = combine_validities_and(times_validity.as_ref(), values_validity.as_ref());
arr = arr.with_validity_typed(validity);
}
ChunkedArray::with_chunk(values.name(), arr)
}

/// Fastpath if `times` is known to already be sorted.
Expand Down
17 changes: 16 additions & 1 deletion py-polars/tests/unit/operations/test_ewm_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest

import polars as pl
from polars.testing import assert_frame_equal
from polars.testing import assert_frame_equal, assert_series_equal

if TYPE_CHECKING:
from polars.type_aliases import PolarsIntegerType, TimeUnit
Expand Down Expand Up @@ -223,3 +223,18 @@ def test_ewma_by_warn_two_chunks() -> None:
pl.col("values").ewm_mean_by("by", half_life="2i"),
)
assert_frame_equal(result, expected.sort("by"))


def test_ewma_by_multiple_chunks() -> None:
# times contains null
times = pl.Series([1, 2]).append(pl.Series([None], dtype=pl.Int64))
values = pl.Series([1, 2]).append(pl.Series([3]))
result = values.ewm_mean_by(times, half_life="2i")
expected = pl.Series([1.0, 1.292893, None])
assert_series_equal(result, expected)

# values contains null
times = pl.Series([1, 2]).append(pl.Series([3]))
values = pl.Series([1, 2]).append(pl.Series([None], dtype=pl.Int64))
result = values.ewm_mean_by(times, half_life="2i")
assert_series_equal(result, expected)

0 comments on commit 2f81742

Please sign in to comment.