Skip to content

Commit

Permalink
proper parse strings that contain delimiters
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Feb 12, 2021
1 parent ab7a6f0 commit 067dd73
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 3 deletions.
48 changes: 48 additions & 0 deletions polars/polars-io/src/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -547,4 +547,52 @@ mod test {
.unwrap()
.series_equal(&Series::new("column_3", &[3, 3])));
}

#[test]
fn test_escape_comma() {
let csv = r#"column_1,column_2,column_3
-86.64408227,"Autauga, Alabama, US",11
-86.64408227,"Autauga, Alabama, US",12"#;
let file = Cursor::new(csv);
let df = CsvReader::new(file).finish().unwrap();
assert_eq!(df.shape(), (2, 3));
assert!(df
.column("column_3")
.unwrap()
.series_equal(&Series::new("column_3", &[11, 12])));
}

#[test]
fn test_escape_2() {
// this is is harder than it looks.
// Fields:
// * hello
// * ","
// * " "
// * world
// * "!"
let csv = r#"hello,","," ",world,"!"
hello,","," ",world,"!"
hello,","," ",world,"!"
hello,","," ",world,"!""#;
let file = Cursor::new(csv);
let df = CsvReader::new(file)
.has_header(false)
.with_n_threads(Some(1))
.finish()
.unwrap();

for (col, val) in &[
("column_1", "hello"),
("column_2", ","),
("column_3", " "),
("column_4", "world"),
("column_5", "!"),
] {
assert!(df
.column(col)
.unwrap()
.series_equal(&Series::new("", &[&**val; 4])));
}
}
}
70 changes: 69 additions & 1 deletion polars/polars-io/src/csv_core/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,72 @@ pub(crate) fn get_line_stats(mut bytes: &[u8], n_lines: usize) -> Option<(f32, f
Some((mean, std))
}

/// An adapted version of std::iter::Split.
/// This exists solely because we cannot split the lines naively as
///
/// ```text
/// lines.split(b',').for_each(do_stuff)
/// ```
///
/// This will fail when strings have contained delimiters.
/// For instance: "Street, City, Country" is a valid string field, that contains multiple delimiters.
struct SplitFields<'a> {
v: &'a [u8],
delimiter: u8,
// escaped string field ",
str_delimiter: [u8; 2],
finished: bool,
}

impl<'a> SplitFields<'a> {
fn new(slice: &'a [u8], delimiter: u8) -> Self {
Self {
v: slice,
delimiter,
str_delimiter: [b'"', delimiter],
finished: false,
}
}

fn finish(&mut self) -> Option<&'a [u8]> {
if self.finished {
None
} else {
self.finished = true;
Some(self.v)
}
}
}

impl<'a> Iterator for SplitFields<'a> {
type Item = &'a [u8];

#[inline]
fn next(&mut self) -> Option<&'a [u8]> {
if self.finished {
return None;
}
// There can be strings with delimiters:
// "Street, City",
let pos = if !self.v.is_empty() && self.v[0] == b'"' {
// we offset 1 because "," is a valid field and we don't want to match position 0.
match self.v[1..].windows(2).position(|x| x == self.str_delimiter) {
None => return self.finish(),
Some(idx) => idx + 2,
}
} else {
match self.v.iter().position(|x| *x == self.delimiter) {
None => return self.finish(),
Some(idx) => idx,
}
};

let ret = Some(&self.v[..pos]);
self.v = &self.v[pos + 1..];
ret
}
}

/// Parse CSV.
///
/// # Arguments
Expand Down Expand Up @@ -175,7 +241,9 @@ pub(crate) fn parse_lines(
.expect("at least one column should be projected");
let mut processed_fields = 0;

for (idx, field) in line.split(|b| *b == delimiter).enumerate() {
let iter = SplitFields::new(line, delimiter);

for (idx, field) in iter.enumerate() {
if idx == next_projected {
debug_assert!(processed_fields < buffers.len());
let buf = unsafe {
Expand Down
2 changes: 1 addition & 1 deletion py-polars/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "py-polars"
version = "0.6.0-alha.1"
version = "0.6.0-beta.1"
authors = ["ritchie46 <[email protected]>"]
edition = "2018"
readme = "README.md"
Expand Down
2 changes: 1 addition & 1 deletion py-polars/pypolars/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,4 +270,4 @@ def concat(dfs: "List[DataFrame]", rechunk=True) -> "DataFrame":


def range(lower: int, upper: int, step: Optional[int] = None) -> Series:
return Series("range", np.arange(lower, upper, step), nullable=False)
return Series("range", np.arange(lower, upper, step), nullable=False)

0 comments on commit 067dd73

Please sign in to comment.