Skip to content

Commit

Permalink
feat(python): Expose plan and expression nodes through `NodeTraverser…
Browse files Browse the repository at this point in the history
…` to Python (pola-rs#15776)

Co-authored-by: ritchie <[email protected]>
  • Loading branch information
wence- and ritchie46 authored Apr 30, 2024
1 parent b285a7f commit 81f4ac2
Show file tree
Hide file tree
Showing 12 changed files with 1,730 additions and 9 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
Copyright (c) 2020 Ritchie Vink
Some portions Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
12 changes: 6 additions & 6 deletions crates/polars-plan/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ mod bounds;
#[cfg(feature = "business")]
mod business;
#[cfg(feature = "dtype-categorical")]
mod cat;
pub mod cat;
#[cfg(feature = "round_series")]
mod clip;
#[cfg(feature = "dtype-struct")]
Expand Down Expand Up @@ -38,13 +38,13 @@ mod nan;
mod peaks;
#[cfg(feature = "ffi_plugin")]
mod plugin;
mod pow;
pub mod pow;
#[cfg(feature = "random")]
mod random;
#[cfg(feature = "range")]
mod range;
#[cfg(feature = "rolling_window")]
mod rolling;
pub mod rolling;
#[cfg(feature = "round_series")]
mod round;
#[cfg(feature = "row_hash")]
Expand All @@ -63,7 +63,7 @@ mod struct_;
#[cfg(any(feature = "temporal", feature = "date_offset"))]
mod temporal;
#[cfg(feature = "trigonometry")]
mod trigonometry;
pub mod trigonometry;
mod unique;

use std::fmt::{Display, Formatter};
Expand All @@ -88,10 +88,10 @@ pub use self::boolean::BooleanFunction;
#[cfg(feature = "business")]
pub(super) use self::business::BusinessFunction;
#[cfg(feature = "dtype-categorical")]
pub(crate) use self::cat::CategoricalFunction;
pub use self::cat::CategoricalFunction;
#[cfg(feature = "temporal")]
pub(super) use self::datetime::TemporalFunction;
pub(super) use self::pow::PowFunction;
pub use self::pow::PowFunction;
#[cfg(feature = "range")]
pub(super) use self::range::RangeFunction;
#[cfg(feature = "rolling_window")]
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-plan/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pub mod dt;
mod expr;
mod expr_dyn_fn;
mod from;
pub(crate) mod function_expr;
pub mod function_expr;
pub mod functions;
mod list;
#[cfg(feature = "meta")]
Expand Down
1 change: 1 addition & 0 deletions py-polars/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ polars-lazy = { workspace = true, features = ["python"] }
polars-ops = { workspace = true }
polars-parquet = { workspace = true, optional = true }
polars-plan = { workspace = true }
polars-time = { workspace = true }
polars-utils = { workspace = true }

ahash = { workspace = true }
Expand Down
10 changes: 10 additions & 0 deletions py-polars/src/conversion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,16 @@ impl FromPyObject<'_> for Wrap<Schema> {
}
}

impl IntoPy<PyObject> for Wrap<&Schema> {
fn into_py(self, py: Python<'_>) -> PyObject {
let dict = PyDict::new(py);
for (k, v) in self.0.iter() {
dict.set_item(k.as_str(), Wrap(v.clone())).unwrap();
}
dict.into_py(py)
}
}

#[derive(Clone, Debug)]
#[repr(transparent)]
pub struct ObjectValue {
Expand Down
4 changes: 3 additions & 1 deletion py-polars/src/lazyframe/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
mod exitable;

mod visit;
pub(crate) mod visitor;
use std::collections::HashMap;
use std::io::BufWriter;
use std::num::NonZeroUsize;
Expand All @@ -13,6 +14,7 @@ use polars_core::prelude::*;
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyDict, PyList};
pub(crate) use visit::PyExprIR;

use crate::arrow_interop::to_rust::pyarrow_schema_to_rust;
use crate::error::PyPolarsErr;
Expand Down
207 changes: 207 additions & 0 deletions py-polars/src/lazyframe/visit.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
use std::sync::Mutex;

use polars_plan::logical_plan::{to_aexpr, Context, IR};
use polars_plan::prelude::expr_ir::ExprIR;
use polars_plan::prelude::{AExpr, PythonOptions};
use polars_utils::arena::{Arena, Node};
use pyo3::prelude::*;
use visitor::{expr_nodes, nodes};

use super::*;
use crate::raise_err;

#[derive(Clone)]
#[pyclass]
pub(crate) struct PyExprIR {
#[pyo3(get)]
node: usize,
#[pyo3(get)]
output_name: String,
}

impl From<ExprIR> for PyExprIR {
fn from(value: ExprIR) -> Self {
Self {
node: value.node().0,
output_name: value.output_name().into(),
}
}
}

impl From<&ExprIR> for PyExprIR {
fn from(value: &ExprIR) -> Self {
Self {
node: value.node().0,
output_name: value.output_name().into(),
}
}
}

#[pyclass]
struct NodeTraverser {
root: Node,
lp_arena: Arc<Mutex<Arena<IR>>>,
expr_arena: Arc<Mutex<Arena<AExpr>>>,
scratch: Vec<Node>,
expr_scratch: Vec<ExprIR>,
expr_mapping: Option<Vec<Node>>,
}

impl NodeTraverser {
fn fill_inputs(&mut self) {
let lp_arena = self.lp_arena.lock().unwrap();
let this_node = lp_arena.get(self.root);
self.scratch.clear();
this_node.copy_inputs(&mut self.scratch);
}

fn fill_expressions(&mut self) {
let lp_arena = self.lp_arena.lock().unwrap();
let this_node = lp_arena.get(self.root);
self.expr_scratch.clear();
this_node.copy_exprs(&mut self.expr_scratch);
}

fn scratch_to_list(&mut self) -> PyObject {
Python::with_gil(|py| {
PyList::new(py, self.scratch.drain(..).map(|node| node.0)).to_object(py)
})
}

fn expr_to_list(&mut self) -> PyObject {
Python::with_gil(|py| {
PyList::new(
py,
self.expr_scratch
.drain(..)
.map(|e| PyExprIR::from(e).into_py(py)),
)
.to_object(py)
})
}
}

#[pymethods]
impl NodeTraverser {
/// Get expression nodes
fn get_exprs(&mut self) -> PyObject {
self.fill_expressions();
self.expr_to_list()
}

/// Get input nodes
fn get_inputs(&mut self) -> PyObject {
self.fill_inputs();
self.scratch_to_list()
}

/// Get Schema of current node as python dict<str, pl.DataType>
fn get_schema(&self, py: Python<'_>) -> PyObject {
let lp_arena = self.lp_arena.lock().unwrap();
let schema = lp_arena.get(self.root).schema(&lp_arena);
Wrap(&**schema).into_py(py)
}

/// Get expression dtype.
fn get_dtype(&self, expr_node: usize, py: Python<'_>) -> PyResult<PyObject> {
let expr_node = Node(expr_node);
let lp_arena = self.lp_arena.lock().unwrap();
let schema = lp_arena.get(self.root).schema(&lp_arena);
let expr_arena = self.expr_arena.lock().unwrap();
let field = expr_arena
.get(expr_node)
.to_field(&schema, Context::Default, &expr_arena)
.map_err(PyPolarsErr::from)?;
Ok(Wrap(field.dtype).to_object(py))
}

/// Set the current node in the plan.
fn set_node(&mut self, node: usize) {
self.root = Node(node);
}

/// Set a python UDF that will replace the subtree location with this function src.
fn set_udf(&mut self, function: PyObject, schema: Wrap<Schema>) {
let ir = IR::PythonScan {
options: PythonOptions {
scan_fn: Some(function.into()),
schema: Arc::new(schema.0),
output_schema: None,
with_columns: None,
pyarrow: false,
predicate: None,
n_rows: None,
},
predicate: None,
};
let mut lp_arena = self.lp_arena.lock().unwrap();
lp_arena.replace(self.root, ir);
}

fn view_current_node(&self, py: Python<'_>) -> PyResult<PyObject> {
let lp_arena = self.lp_arena.lock().unwrap();
let lp_node = lp_arena.get(self.root);
nodes::into_py(py, lp_node)
}

fn view_expression(&self, py: Python<'_>, node: usize) -> PyResult<PyObject> {
let expr_arena = self.expr_arena.lock().unwrap();
let n = match &self.expr_mapping {
Some(mapping) => *mapping.get(node).unwrap(),
None => Node(node),
};
let expr = expr_arena.get(n);
expr_nodes::into_py(py, expr)
}

/// Add some expressions to the arena and return their new node ids as well
/// as the total number of nodes in the arena.
fn add_expressions(&mut self, expressions: Vec<PyExpr>) -> PyResult<(Vec<usize>, usize)> {
let mut expr_arena = self.expr_arena.lock().unwrap();
Ok((
expressions
.into_iter()
.map(|e| to_aexpr(e.inner, &mut expr_arena).0)
.collect(),
expr_arena.len(),
))
}

/// Set up a mapping of expression nodes used in `view_expression_node``.
/// With a mapping set, `view_expression_node(i)` produces the node for
/// `mapping[i]`.
fn set_expr_mapping(&mut self, mapping: Vec<usize>) -> PyResult<()> {
if mapping.len() != self.expr_arena.lock().unwrap().len() {
raise_err!("Invalid mapping length", ComputeError);
}
self.expr_mapping = Some(mapping.into_iter().map(Node).collect());
Ok(())
}

/// Unset the expression mapping (reinstates the identity map)
fn unset_expr_mapping(&mut self) {
self.expr_mapping = None;
}
}

#[pymethods]
#[allow(clippy::should_implement_trait)]
impl PyLazyFrame {
fn visit(&self) -> PyResult<NodeTraverser> {
let mut lp_arena = Arena::with_capacity(16);
let mut expr_arena = Arena::with_capacity(16);
let root = self
.ldf
.clone()
.optimize(&mut lp_arena, &mut expr_arena)
.map_err(PyPolarsErr::from)?;
Ok(NodeTraverser {
root,
lp_arena: Arc::new(Mutex::new(lp_arena)),
expr_arena: Arc::new(Mutex::new(expr_arena)),
scratch: vec![],
expr_scratch: vec![],
expr_mapping: None,
})
}
}
Loading

0 comments on commit 81f4ac2

Please sign in to comment.