Skip to content

Feature Proposal: Validation decorators for pipeline-style code (e.g., adding/removing columns) #5

@DeflateAwning

Description

@DeflateAwning

I use Polars in data pipelines, which are often composed of chains of functions which take in a df, modify it, and then return a slightly-modified version of that dataframe. My understanding is that this style is pretty common in complex data pipelines because it allows for easy testing and is easy to separate out code chunks.

It seems that so far this library enforces complete schemas, because does not have validations for this sort of transformation-style code. Any thoughts on adding this functionality?

Here is an example of the type of function I'm thinking of, with tests:

"""Decorators for validating pipeline-style transformations on Polars DataFrames."""

# pyright: strict

from collections.abc import Iterable, Callable
from typing import TypeVar
import functools

import polars as pl

from pn_data.helpers.polars.polars_validation_errors import PolarsColumnChangeCheckFailedError


F = TypeVar("F", bound=Callable[..., pl.DataFrame])


def assert_column_change(
    add: Iterable[str],
    drop: Iterable[str],
) -> Callable[[F], F]:
    expect_added = set(add)
    expect_removed = set(drop)

    if expect_added & expect_removed:
        raise ValueError("added_cols and removed_cols must not have overlapping column names")

    def decorator(func: F) -> F:
        @functools.wraps(func)
        def wrapper(df: pl.DataFrame, *args: object, **kwargs: object) -> pl.DataFrame:
            orig_columns = set(df.columns)

            # Entry checks
            if expect_added & orig_columns:
                raise PolarsColumnChangeCheckFailedError(
                    "Unexpected pre-existing columns. Columns in 'added_cols' argument should "
                    "not already exist in the entry dataframe. Unexpected columns on entry: "
                    f"{expect_added & orig_columns}"
                )

            if expect_removed - orig_columns:
                raise PolarsColumnChangeCheckFailedError(
                    f"Missing input column. All columns in the 'removed_cols' argument must be "
                    "present in the entry dataframe."
                    f"Column(s) missing: {expect_removed - orig_columns}"
                )

            # Execute function
            result_df = func(df, *args, **kwargs)

            # Exit checks
            expected_columns = (orig_columns | expect_added) - expect_removed
            actual_columns = set(result_df.columns)

            if actual_columns != expected_columns:
                raise PolarsColumnChangeCheckFailedError(
                    f"Unexpected final columns. Extra columns: {actual_columns - expected_columns}, "
                    f"Missing columns: {expected_columns - actual_columns}"
                )

            return result_df

        return wrapper  # type: ignore # FIXME

    return decorator




class PolarsColumnChangeCheckFailedError(GenericPolarsValidationError):
    """Raised when a DataFrame has columns with too long of cells."""

    pass

Example usage/test cases:

# pyright: strict

import polars as pl
import pytest

from polars.polars_validation_errors import PolarsColumnChangeCheckFailedError
from polars import validation_decorators as pl_validation_decorators


@pl_validation_decorators.assert_column_change(add=["a"], drop=["b"])
def _transform_add_a_drop_b(df: pl.DataFrame) -> pl.DataFrame:
    return df.with_columns(
        a=pl.col(df.columns[0]),  # Add 'a' col as a copy of the first column in input.
    ).drop("b")  # Drop 'b' col.


def test_assert_column_change_WITH_normal_pass() -> None:
    df_input = pl.DataFrame(
        {
            "x": [1, 2, 3],
            "y": [4, 5, 6],
            "b": [7, 8, 9],
        }
    )

    df_output: pl.DataFrame = _transform_add_a_drop_b(df_input)
    assert df_output.columns == ["x", "y", "a"]


def test_assert_column_change_WITH_invalid_construction() -> None:
    with pytest.raises(ValueError, match="overlapping column names"):

        @pl_validation_decorators.assert_column_change(add=["a"], drop=["a"])
        def transform_add_a_drop_a(df: pl.DataFrame) -> pl.DataFrame:  # type: ignore reportUnusedFunction
            return df


def test_assert_column_change_WITH_bad_input_columns() -> None:
    with pytest.raises(
        PolarsColumnChangeCheckFailedError, match="Unexpected pre-existing columns"
    ):
        _transform_add_a_drop_b(
            pl.DataFrame(
                {
                    "x": [100, 200, 300],
                    "a": [1, 2, 3],  # Shouldn't have 'a' column in input.
                }
            )
        )

    with pytest.raises(PolarsColumnChangeCheckFailedError, match="Missing input column"):
        _transform_add_a_drop_b(
            pl.DataFrame(
                {
                    "x": [100, 200, 300],
                    # "b" col is missing.
                }
            )
        )


def test_assert_column_change_WITH_misbehaving_function() -> None:
    @pl_validation_decorators.assert_column_change(add=["a"], drop=["b"])
    def _transform_misbehave(df: pl.DataFrame) -> pl.DataFrame:
        return df.drop("b")  # Drop 'b' col.

    with pytest.raises(PolarsColumnChangeCheckFailedError, match="Unexpected final columns"):
        # The function is misbehaving by not adding the 'a' column.
        _transform_misbehave(
            pl.DataFrame(
                {
                    "x": [100, 200, 300],
                    "b": [1, 2, 3],
                }
            )
        )

Note: This is copied from my issues suggesting this functionality over in the Polars repo: pola-rs/polars#21512

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions