Skip to content

Commit

Permalink
fix(python): Raise a proper python typed exception when IO writers tr…
Browse files Browse the repository at this point in the history
…y to write to an non existent folder (pola-rs#12936)
  • Loading branch information
Yerachmiel-Feltzman authored Dec 8, 2023
1 parent 2451a6a commit 05fdb79
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 18 deletions.
11 changes: 5 additions & 6 deletions py-polars/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ impl PyDataFrame {
use polars::io::avro::AvroWriter;

if let Ok(s) = py_f.extract::<&str>(py) {
let f = std::fs::File::create(s).unwrap();
let f = std::fs::File::create(s)?;
AvroWriter::new(f)
.with_compression(compression.0)
.with_name(name)
Expand Down Expand Up @@ -616,9 +616,8 @@ impl PyDataFrame {
let null = null_value.unwrap_or_default();

if let Ok(s) = py_f.extract::<&str>(py) {
let f = std::fs::File::create(s)?;
py.allow_threads(|| {
let f = std::fs::File::create(s).map_err(PolarsError::Io)?;

// No need for a buffered writer, because the csv writer does internal buffering.
CsvWriter::new(f)
.include_bom(include_bom)
Expand Down Expand Up @@ -666,8 +665,8 @@ impl PyDataFrame {
compression: Wrap<Option<IpcCompression>>,
) -> PyResult<()> {
if let Ok(s) = py_f.extract::<&str>(py) {
let f = std::fs::File::create(s)?;
py.allow_threads(|| {
let f = std::fs::File::create(s).unwrap();
IpcWriter::new(f)
.with_compression(compression.0)
.finish(&mut self.df)
Expand All @@ -692,8 +691,8 @@ impl PyDataFrame {
compression: Wrap<Option<IpcCompression>>,
) -> PyResult<()> {
if let Ok(s) = py_f.extract::<&str>(py) {
let f = std::fs::File::create(s)?;
py.allow_threads(|| {
let f = std::fs::File::create(s).unwrap();
IpcStreamWriter::new(f)
.with_compression(compression.0)
.finish(&mut self.df)
Expand Down Expand Up @@ -803,7 +802,7 @@ impl PyDataFrame {
let compression = parse_parquet_compression(compression, compression_level)?;

if let Ok(s) = py_f.extract::<&str>(py) {
let f = std::fs::File::create(s).unwrap();
let f = std::fs::File::create(s)?;
py.allow_threads(|| {
ParquetWriter::new(f)
.with_compression(compression)
Expand Down
14 changes: 2 additions & 12 deletions py-polars/tests/unit/io/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import textwrap
import zlib
from datetime import date, datetime, time, timedelta, timezone
from pathlib import Path
from typing import TYPE_CHECKING, TypedDict

import numpy as np
Expand All @@ -20,6 +19,8 @@
from polars.utils.various import normalize_filepath

if TYPE_CHECKING:
from pathlib import Path

from polars.type_aliases import TimeUnit


Expand Down Expand Up @@ -1709,14 +1710,3 @@ def test_csv_no_new_line_last() -> None:
"a": [1, 2, 3],
"b": [1.0, 2.0, 2.1],
}


def test_write_csv_no_location_raise_io_exception() -> None:
df = pl.DataFrame({"a": [1]})
non_existing_path = Path("non", "existing", "path", "file.csv")
if non_existing_path.exists():
pytest.fail(
"Testing on a non existing path failed because the path does exist."
)
with pytest.raises(FileNotFoundError):
df.write_csv(non_existing_path)
27 changes: 27 additions & 0 deletions py-polars/tests/unit/io/test_other.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,33 @@ def test_read_missing_file(read_function: Callable[[Any], pl.DataFrame]) -> None
read_function("fake_file_path")


@pytest.mark.parametrize(
"write_method_name",
[
# "write_excel" not included
# because it already raises a FileCreateError
# from the underlying library dependency
"write_csv",
"write_ipc",
"write_ipc_stream",
"write_json",
"write_ndjson",
"write_parquet",
"write_avro",
],
)
def test_write_missing_directory(write_method_name: str) -> None:
df = pl.DataFrame({"a": [1]})
non_existing_path = Path("non", "existing", "path")
if non_existing_path.exists():
pytest.fail(
"Testing on a non existing path failed because the path does exist."
)
write_method = getattr(df, write_method_name)
with pytest.raises(FileNotFoundError):
write_method(non_existing_path)


def test_read_missing_file_path_truncated() -> None:
content = "lskdfj".join(str(i) for i in range(25))
with pytest.raises(
Expand Down

0 comments on commit 05fdb79

Please sign in to comment.