Skip to content

Fix: ensure get_df uses SQLAlchemy engine to avoid pandas warning #52224

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jul 16, 2025
Merged
9 changes: 9 additions & 0 deletions providers/postgres/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,13 @@ dependencies = [
"openlineage" = [
"apache-airflow-providers-openlineage"
]
"pandas" = [
'pandas>=2.1.2; python_version <"3.13"',
'pandas>=2.2.3; python_version >="3.13"',
]
"polars" = [
"polars>=1.26.0"
]

[dependency-groups]
dev = [
Expand All @@ -81,6 +88,8 @@ dev = [
"apache-airflow-providers-common-sql",
"apache-airflow-providers-openlineage",
# Additional devel dependencies (do not remove this line and add extra development dependencies)
"apache-airflow-providers-common-sql[pandas]",
"apache-airflow-providers-common-sql[polars]",
]

# To build docs:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,27 @@
from __future__ import annotations

import os
from collections.abc import Mapping
from contextlib import closing
from copy import deepcopy
from typing import TYPE_CHECKING, Any, TypeAlias, cast
from typing import TYPE_CHECKING, Any, Literal, TypeAlias, cast, overload

import psycopg2
import psycopg2.extensions
import psycopg2.extras
from psycopg2.extras import DictCursor, Json, NamedTupleCursor, RealDictCursor
from sqlalchemy.engine import URL

from airflow.exceptions import AirflowException
from airflow.exceptions import (
AirflowException,
AirflowOptionalProviderFeatureException,
)
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.providers.postgres.dialects.postgres import PostgresDialect

if TYPE_CHECKING:
from pandas import DataFrame as PandasDataFrame
from polars import DataFrame as PolarsDataFrame
from psycopg2.extensions import connection

from airflow.providers.common.sql.dialects.dialect import Dialect
Expand Down Expand Up @@ -177,6 +183,62 @@ def get_conn(self) -> connection:
self.conn = psycopg2.connect(**conn_args)
return self.conn

@overload
def get_df(
self,
sql: str | list[str],
parameters: list | tuple | Mapping[str, Any] | None = None,
*,
df_type: Literal["pandas"] = "pandas",
**kwargs: Any,
) -> PandasDataFrame: ...

@overload
def get_df(
self,
sql: str | list[str],
parameters: list | tuple | Mapping[str, Any] | None = None,
*,
df_type: Literal["polars"] = ...,
**kwargs: Any,
) -> PolarsDataFrame: ...

def get_df(
self,
sql: str | list[str],
parameters: list | tuple | Mapping[str, Any] | None = None,
*,
df_type: Literal["pandas", "polars"] = "pandas",
**kwargs: Any,
) -> PandasDataFrame | PolarsDataFrame:
"""
Execute the sql and returns a dataframe.

:param sql: the sql statement to be executed (str) or a list of sql statements to execute
:param parameters: The parameters to render the SQL query with.
:param df_type: Type of dataframe to return, either "pandas" or "polars"
:param kwargs: (optional) passed into `pandas.io.sql.read_sql` or `polars.read_database` method
:return: A pandas or polars DataFrame containing the query results.
"""
if df_type == "pandas":
try:
from pandas.io import sql as psql
except ImportError:
raise AirflowOptionalProviderFeatureException(
"pandas library not installed, run: pip install "
"'apache-airflow-providers-common-sql[pandas]'."
)

engine = self.get_sqlalchemy_engine()
with engine.connect() as conn:
return psql.read_sql(sql, con=conn, params=parameters, **kwargs)

elif df_type == "polars":
return self._get_polars_df(sql, parameters, **kwargs)

else:
raise ValueError(f"Unsupported df_type: {df_type}")

def copy_expert(self, sql: str, filename: str) -> None:
"""
Execute SQL using psycopg2's ``copy_expert`` method.
Expand Down
34 changes: 34 additions & 0 deletions providers/postgres/tests/unit/postgres/hooks/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import os
from unittest import mock

import pandas as pd
import polars as pl
import psycopg2.extras
import pytest
import sqlalchemy
Expand All @@ -33,6 +35,8 @@
from airflow.providers.postgres.hooks.postgres import PostgresHook
from airflow.utils.types import NOTSET

from tests_common.test_utils.common_sql import mock_db_hook

INSERT_SQL_STATEMENT = "INSERT INTO connection (id, conn_id, conn_type, description, host, {}, login, password, port, is_encrypted, is_extra_encrypted, extra) VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)"


Expand Down Expand Up @@ -517,6 +521,36 @@ def test_serialize_cell(self, raw_cell, expected_serialized):
else:
assert expected_serialized == raw_cell

@pytest.mark.parametrize(
"df_type, expected_type",
[
("pandas", pd.DataFrame),
("polars", pl.DataFrame),
],
)
@mock.patch("airflow.providers.postgres.hooks.postgres.PostgresHook._get_polars_df")
@mock.patch("pandas.io.sql.read_sql")
@mock.patch("airflow.providers.postgres.hooks.postgres.PostgresHook.get_sqlalchemy_engine")
def test_get_df_with_df_type(
self, mock_get_engine, mock_read_sql, mock_polars_df, df_type, expected_type
):
hook = mock_db_hook(PostgresHook)
mock_read_sql.return_value = pd.DataFrame()
mock_polars_df.return_value = pl.DataFrame()
sql = "SELECT * FROM table"
if df_type == "pandas":
mock_conn = mock.MagicMock()
mock_engine = mock.MagicMock()
mock_engine.connect.return_value.__enter__.return_value = mock_conn
mock_get_engine.return_value = mock_engine
df = hook.get_df(sql, df_type="pandas")
mock_read_sql.assert_called_once_with(sql, con=mock_conn, params=None)
assert isinstance(df, expected_type)
elif df_type == "polars":
df = hook.get_df(sql, df_type="polars")
mock_polars_df.assert_called_once_with(sql, None)
assert isinstance(df, expected_type)

def test_insert_rows(self):
table = "table"
rows = [("hello",), ("world",)]
Expand Down
Loading