Skip to content

Commit 3b75442

Browse files
committed
[SPARK-52402][PS] Fix divide-by-zero errors in Kendall and Pearson correlation under ANSI mode
### What changes were proposed in this pull request? Fix divide-by-zero error in groupby().corr('kendall') with ANSI mode enabled ### Why are the changes needed? Ensure pandas on Spark works well with ANSI mode on. Part of https://issues.apache.org/jira/browse/SPARK-52169. ### Does this PR introduce _any_ user-facing change? Yes ```py >>> ps.set_option("compute.fail_on_ansi_mode", False) >>> ps.set_option("compute.ansi_mode_support", True) >>> df = ps.DataFrame( ... {"A": [0, 0, 0, 1, 1, 2], "B": [-1, 2, 3, 5, 6, 0], "C": [4, 6, 5, 1, 3, 0]}, ... columns=["A", "B", "C"] ... ) ``` FROM ```py >>> df.groupby("A").corr('kendall') 25/06/04 14:40:03 ERROR Executor: Exception in task 0.0 in stage 13.0 (TID 51) org.apache.spark.SparkArithmeticException: [DIVIDE_BY_ZERO] Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error. SQLSTATE: 22012 == DataFrame == "__truediv__" was called from ... ``` TO ```py >>> df.groupby("A").corr('kendall') B C A 0 B 1.000000 0.333333 C 0.333333 1.000000 1 B 1.000000 1.000000 C 1.000000 1.000000 2 B 1.000000 NaN C NaN 1.000000 ``` ### How was this patch tested? Unit tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #51090 from xinrong-meng/ansi_corr. Authored-by: Xinrong Meng <[email protected]> Signed-off-by: Xinrong Meng <[email protected]>
1 parent 2695636 commit 3b75442

File tree

4 files changed

+51
-33
lines changed

4 files changed

+51
-33
lines changed

python/pyspark/pandas/correlation.py

Lines changed: 51 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from pyspark.sql import DataFrame as SparkDataFrame, functions as F
2121
from pyspark.sql.window import Window
22-
from pyspark.pandas.utils import verify_temp_column_name
22+
from pyspark.pandas.utils import verify_temp_column_name, is_ansi_mode_enabled
2323

2424

2525
CORRELATION_VALUE_1_COLUMN = "__correlation_value_1_input__"
@@ -60,6 +60,7 @@ def compute(sdf: SparkDataFrame, groupKeys: List[str], method: str) -> SparkData
6060
.alias(CORRELATION_VALUE_2_COLUMN),
6161
],
6262
)
63+
spark_session = sdf.sparkSession
6364

6465
if method in ["pearson", "spearman"]:
6566
# convert values to avg ranks for spearman correlation
@@ -125,16 +126,20 @@ def compute(sdf: SparkDataFrame, groupKeys: List[str], method: str) -> SparkData
125126
)
126127
)
127128

129+
if is_ansi_mode_enabled(spark_session):
130+
corr_expr = F.try_divide(
131+
F.covar_samp(CORRELATION_VALUE_1_COLUMN, CORRELATION_VALUE_2_COLUMN),
132+
F.stddev_samp(CORRELATION_VALUE_1_COLUMN)
133+
* F.stddev_samp(CORRELATION_VALUE_2_COLUMN),
134+
)
135+
else:
136+
corr_expr = F.corr(CORRELATION_VALUE_1_COLUMN, CORRELATION_VALUE_2_COLUMN)
137+
128138
sdf = sdf.groupby(groupKeys).agg(
129-
F.corr(CORRELATION_VALUE_1_COLUMN, CORRELATION_VALUE_2_COLUMN).alias(
130-
CORRELATION_CORR_OUTPUT_COLUMN
139+
corr_expr.alias(CORRELATION_CORR_OUTPUT_COLUMN),
140+
F.count(F.when(~F.isnull(CORRELATION_VALUE_1_COLUMN), 1)).alias(
141+
CORRELATION_COUNT_OUTPUT_COLUMN
131142
),
132-
F.count(
133-
F.when(
134-
~F.isnull(CORRELATION_VALUE_1_COLUMN),
135-
1,
136-
)
137-
).alias(CORRELATION_COUNT_OUTPUT_COLUMN),
138143
)
139144

140145
return sdf
@@ -219,6 +224,42 @@ def compute(sdf: SparkDataFrame, groupKeys: List[str], method: str) -> SparkData
219224
F.col(CORRELATION_VALUE_2_COLUMN) == F.col(CORRELATION_VALUE_Y_COLUMN)
220225
)
221226

227+
if is_ansi_mode_enabled(spark_session):
228+
corr_expr = F.try_divide(
229+
F.col(CORRELATION_KENDALL_P_COLUMN) - F.col(CORRELATION_KENDALL_Q_COLUMN),
230+
F.sqrt(
231+
(
232+
F.col(CORRELATION_KENDALL_P_COLUMN)
233+
+ F.col(CORRELATION_KENDALL_Q_COLUMN)
234+
+ F.col(CORRELATION_KENDALL_T_COLUMN)
235+
)
236+
* (
237+
F.col(CORRELATION_KENDALL_P_COLUMN)
238+
+ F.col(CORRELATION_KENDALL_Q_COLUMN)
239+
+ F.col(CORRELATION_KENDALL_U_COLUMN)
240+
)
241+
),
242+
)
243+
else:
244+
corr_expr = (
245+
F.col(CORRELATION_KENDALL_P_COLUMN) - F.col(CORRELATION_KENDALL_Q_COLUMN)
246+
) / F.sqrt(
247+
(
248+
(
249+
F.col(CORRELATION_KENDALL_P_COLUMN)
250+
+ F.col(CORRELATION_KENDALL_Q_COLUMN)
251+
+ (F.col(CORRELATION_KENDALL_T_COLUMN))
252+
)
253+
)
254+
* (
255+
(
256+
F.col(CORRELATION_KENDALL_P_COLUMN)
257+
+ F.col(CORRELATION_KENDALL_Q_COLUMN)
258+
+ (F.col(CORRELATION_KENDALL_U_COLUMN))
259+
)
260+
)
261+
)
262+
222263
sdf = (
223264
sdf.groupby(groupKeys)
224265
.agg(
@@ -232,26 +273,7 @@ def compute(sdf: SparkDataFrame, groupKeys: List[str], method: str) -> SparkData
232273
).otherwise(F.lit(0))
233274
).alias(CORRELATION_COUNT_OUTPUT_COLUMN),
234275
)
235-
.withColumn(
236-
CORRELATION_CORR_OUTPUT_COLUMN,
237-
(F.col(CORRELATION_KENDALL_P_COLUMN) - F.col(CORRELATION_KENDALL_Q_COLUMN))
238-
/ F.sqrt(
239-
(
240-
(
241-
F.col(CORRELATION_KENDALL_P_COLUMN)
242-
+ F.col(CORRELATION_KENDALL_Q_COLUMN)
243-
+ (F.col(CORRELATION_KENDALL_T_COLUMN))
244-
)
245-
)
246-
* (
247-
(
248-
F.col(CORRELATION_KENDALL_P_COLUMN)
249-
+ F.col(CORRELATION_KENDALL_Q_COLUMN)
250-
+ (F.col(CORRELATION_KENDALL_U_COLUMN))
251-
)
252-
)
253-
),
254-
)
276+
.withColumn(CORRELATION_CORR_OUTPUT_COLUMN, corr_expr)
255277
)
256278

257279
sdf = sdf.select(

python/pyspark/pandas/tests/computation/test_corr.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,9 @@
2222
from pyspark import pandas as ps
2323
from pyspark.testing.pandasutils import PandasOnSparkTestCase, SPARK_CONF_ARROW_ENABLED
2424
from pyspark.testing.sqlutils import SQLTestUtils
25-
from pyspark.testing.utils import is_ansi_mode_test, ansi_mode_not_supported_message
2625

2726

2827
class FrameCorrMixin:
29-
@unittest.skipIf(is_ansi_mode_test, ansi_mode_not_supported_message)
3028
def test_dataframe_corr(self):
3129
pdf = pd.DataFrame(
3230
index=[

python/pyspark/pandas/tests/diff_frames_ops/test_corrwith.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,6 @@ def tearDownClass(cls):
8282
reset_option("compute.ops_on_diff_frames")
8383
super().tearDownClass()
8484

85-
@unittest.skipIf(is_ansi_mode_test, ansi_mode_not_supported_message)
8685
def test_corrwith(self):
8786
df1 = ps.DataFrame({"A": [1, np.nan, 7, 8], "X": [5, 8, np.nan, 3], "C": [10, 4, 9, 3]})
8887
df2 = ps.DataFrame({"A": [5, 3, 6, 4], "B": [11, 2, 4, 3], "C": [4, 3, 8, np.nan]})

python/pyspark/pandas/tests/groupby/test_corr.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ def test_corr(self):
4848
almost=True,
4949
)
5050

51-
@unittest.skipIf(is_ansi_mode_test, ansi_mode_not_supported_message)
5251
def test_method(self):
5352
for m in ["pearson", "spearman", "kendall"]:
5453
self.assert_eq(

0 commit comments

Comments
 (0)