Skip to content

Commit 9b66df1

Browse files
committed
fix
1 parent c9e1895 commit 9b66df1

File tree

3 files changed

+757
-0
lines changed

3 files changed

+757
-0
lines changed

dev/sparktestsupport/modules.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,7 @@ def __hash__(self):
542542
"pyspark.sql.tests.arrow.test_arrow_grouped_map",
543543
"pyspark.sql.tests.arrow.test_arrow_python_udf",
544544
"pyspark.sql.tests.arrow.test_arrow_udf",
545+
"pyspark.sql.tests.arrow.test_arrow_udf_grouped_agg",
545546
"pyspark.sql.tests.arrow.test_arrow_udf_scalar",
546547
"pyspark.sql.tests.pandas.test_pandas_cogrouped_map",
547548
"pyspark.sql.tests.pandas.test_pandas_grouped_map",

python/pyspark/sql/pandas/functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,7 @@ def calculate(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
456456
if kind == "arrow" and eval_type not in [
457457
PythonEvalType.SQL_SCALAR_ARROW_UDF,
458458
PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF,
459+
PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF,
459460
None,
460461
]: # None means it should infer the type from type hints.
461462
raise PySparkTypeError(

0 commit comments

Comments
 (0)