diff --git a/python/pyspark/pandas/data_type_ops/num_ops.py b/python/pyspark/pandas/data_type_ops/num_ops.py index 34d313af8232e..06622ef71d880 100644 --- a/python/pyspark/pandas/data_type_ops/num_ops.py +++ b/python/pyspark/pandas/data_type_ops/num_ops.py @@ -16,7 +16,7 @@ # import numbers -from typing import Any, Union +from typing import Any, Union, Callable import numpy as np import pandas as pd @@ -271,13 +271,22 @@ def floordiv(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: _sanitize_list_like(right) if not is_valid_operand_for_numeric_arithmetic(right): raise TypeError("Floor division can not be applied to given types.") + spark_session = left._internal.spark_frame.sparkSession + use_try_divide = is_ansi_mode_enabled(spark_session) + + def fallback_div(x: PySparkColumn, y: PySparkColumn) -> PySparkColumn: + return x.__div__(y) + + safe_div: Callable[[PySparkColumn, PySparkColumn], PySparkColumn] = ( + F.try_divide if use_try_divide else fallback_div + ) def floordiv(left: PySparkColumn, right: Any) -> PySparkColumn: return F.when(F.lit(right is np.nan), np.nan).otherwise( F.when( F.lit(right != 0) | F.lit(right).isNull(), F.floor(left.__div__(right)), - ).otherwise(F.lit(np.inf).__div__(left)) + ).otherwise(safe_div(F.lit(np.inf), left)) ) right = transform_boolean_operand_to_numeric(right, spark_type=left.spark.data_type) @@ -369,6 +378,15 @@ def floordiv(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: _sanitize_list_like(right) if not is_valid_operand_for_numeric_arithmetic(right): raise TypeError("Floor division can not be applied to given types.") + spark_session = left._internal.spark_frame.sparkSession + use_try_divide = is_ansi_mode_enabled(spark_session) + + def fallback_div(x: PySparkColumn, y: PySparkColumn) -> PySparkColumn: + return x.__div__(y) + + safe_div: Callable[[PySparkColumn, PySparkColumn], PySparkColumn] = ( + F.try_divide if use_try_divide else fallback_div + ) def floordiv(left: PySparkColumn, right: Any) -> PySparkColumn: return F.when(F.lit(right is np.nan), np.nan).otherwise( @@ -377,7 +395,7 @@ def floordiv(left: PySparkColumn, right: Any) -> PySparkColumn: F.floor(left.__div__(right)), ).otherwise( F.when(F.lit(left == np.inf) | F.lit(left == -np.inf), left).otherwise( - F.lit(np.inf).__div__(left) + safe_div(F.lit(np.inf), left) ) ) ) diff --git a/python/pyspark/pandas/tests/computation/test_binary_ops.py b/python/pyspark/pandas/tests/computation/test_binary_ops.py index 3c9b7293d5d53..cda9958ad3dec 100644 --- a/python/pyspark/pandas/tests/computation/test_binary_ops.py +++ b/python/pyspark/pandas/tests/computation/test_binary_ops.py @@ -208,7 +208,11 @@ def test_binary_operator_truediv(self): self.assertRaisesRegex(TypeError, ks_err_msg, lambda: 1 / psdf["a"]) def test_binary_operator_floordiv(self): - psdf = ps.DataFrame({"a": ["x"], "b": [1]}) + pdf = pd.DataFrame({"a": ["x"], "b": [1], "c": [1.0], "d": [0]}) + psdf = ps.from_pandas(pdf) + self.assert_eq(pdf["b"] // 0, psdf["b"] // 0) + self.assert_eq(pdf["c"] // 0, psdf["c"] // 0) + self.assert_eq(pdf["d"] // 0, psdf["d"] // 0) ks_err_msg = "Floor division can not be applied to strings" self.assertRaisesRegex(TypeError, ks_err_msg, lambda: psdf["a"] // psdf["b"])