Skip to content

Commit cdaaef7

Browse files
authored
Introduce Async User Defined Functions (apache#14837)
* introduce async udf for projection * refactor for filter * coalesce_batches for AsyncFuncExec * project filter to exclude the filter expression * coalesce the input batch of AsyncFuncExec * simple example * enhance comment * enhance doc and fix test * fix clippy and fmt * add missing dependency * fix clippy * rename the symbol * cargo fmt * fix fmt and rebase * add return_field_from_args for async scalar udf * modified into_scalar_udf method * add the async scalar udf in udfs doc * pretty doc * fix doc test * fix merge conflict
1 parent b6c8cc5 commit cdaaef7

File tree

16 files changed

+1350
-16
lines changed

16 files changed

+1350
-16
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 254 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,254 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::array::{ArrayIter, ArrayRef, AsArray, Int64Array, RecordBatch, StringArray};
19+
use arrow::compute::kernels::cmp::eq;
20+
use arrow_schema::{DataType, Field, Schema};
21+
use async_trait::async_trait;
22+
use datafusion::common::error::Result;
23+
use datafusion::common::internal_err;
24+
use datafusion::common::types::{logical_int64, logical_string};
25+
use datafusion::common::utils::take_function_args;
26+
use datafusion::config::ConfigOptions;
27+
use datafusion::logical_expr::async_udf::{
28+
AsyncScalarFunctionArgs, AsyncScalarUDF, AsyncScalarUDFImpl,
29+
};
30+
use datafusion::logical_expr::{
31+
ColumnarValue, Signature, TypeSignature, TypeSignatureClass, Volatility,
32+
};
33+
use datafusion::logical_expr_common::signature::Coercion;
34+
use datafusion::physical_expr_common::datum::apply_cmp;
35+
use datafusion::prelude::SessionContext;
36+
use log::trace;
37+
use std::any::Any;
38+
use std::sync::Arc;
39+
40+
#[tokio::main]
41+
async fn main() -> Result<()> {
42+
let ctx: SessionContext = SessionContext::new();
43+
44+
let async_upper = AsyncUpper::new();
45+
let udf = AsyncScalarUDF::new(Arc::new(async_upper));
46+
ctx.register_udf(udf.into_scalar_udf());
47+
let async_equal = AsyncEqual::new();
48+
let udf = AsyncScalarUDF::new(Arc::new(async_equal));
49+
ctx.register_udf(udf.into_scalar_udf());
50+
ctx.register_batch("animal", animal()?)?;
51+
52+
// use Async UDF in the projection
53+
// +---------------+----------------------------------------------------------------------------------------+
54+
// | plan_type | plan |
55+
// +---------------+----------------------------------------------------------------------------------------+
56+
// | logical_plan | Projection: async_equal(a.id, Int64(1)) |
57+
// | | SubqueryAlias: a |
58+
// | | TableScan: animal projection=[id] |
59+
// | physical_plan | ProjectionExec: expr=[__async_fn_0@1 as async_equal(a.id,Int64(1))] |
60+
// | | AsyncFuncExec: async_expr=[async_expr(name=__async_fn_0, expr=async_equal(id@0, 1))] |
61+
// | | CoalesceBatchesExec: target_batch_size=8192 |
62+
// | | DataSourceExec: partitions=1, partition_sizes=[1] |
63+
// | | |
64+
// +---------------+----------------------------------------------------------------------------------------+
65+
ctx.sql("explain select async_equal(a.id, 1) from animal a")
66+
.await?
67+
.show()
68+
.await?;
69+
70+
// +----------------------------+
71+
// | async_equal(a.id,Int64(1)) |
72+
// +----------------------------+
73+
// | true |
74+
// | false |
75+
// | false |
76+
// | false |
77+
// | false |
78+
// +----------------------------+
79+
ctx.sql("select async_equal(a.id, 1) from animal a")
80+
.await?
81+
.show()
82+
.await?;
83+
84+
// use Async UDF in the filter
85+
// +---------------+--------------------------------------------------------------------------------------------+
86+
// | plan_type | plan |
87+
// +---------------+--------------------------------------------------------------------------------------------+
88+
// | logical_plan | SubqueryAlias: a |
89+
// | | Filter: async_equal(animal.id, Int64(1)) |
90+
// | | TableScan: animal projection=[id, name] |
91+
// | physical_plan | CoalesceBatchesExec: target_batch_size=8192 |
92+
// | | FilterExec: __async_fn_0@2, projection=[id@0, name@1] |
93+
// | | RepartitionExec: partitioning=RoundRobinBatch(12), input_partitions=1 |
94+
// | | AsyncFuncExec: async_expr=[async_expr(name=__async_fn_0, expr=async_equal(id@0, 1))] |
95+
// | | CoalesceBatchesExec: target_batch_size=8192 |
96+
// | | DataSourceExec: partitions=1, partition_sizes=[1] |
97+
// | | |
98+
// +---------------+--------------------------------------------------------------------------------------------+
99+
ctx.sql("explain select * from animal a where async_equal(a.id, 1)")
100+
.await?
101+
.show()
102+
.await?;
103+
104+
// +----+------+
105+
// | id | name |
106+
// +----+------+
107+
// | 1 | cat |
108+
// +----+------+
109+
ctx.sql("select * from animal a where async_equal(a.id, 1)")
110+
.await?
111+
.show()
112+
.await?;
113+
114+
Ok(())
115+
}
116+
117+
fn animal() -> Result<RecordBatch> {
118+
let schema = Arc::new(Schema::new(vec![
119+
Field::new("id", DataType::Int64, false),
120+
Field::new("name", DataType::Utf8, false),
121+
]));
122+
123+
let id_array = Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5]));
124+
let name_array = Arc::new(StringArray::from(vec![
125+
"cat", "dog", "fish", "bird", "snake",
126+
]));
127+
128+
Ok(RecordBatch::try_new(schema, vec![id_array, name_array])?)
129+
}
130+
131+
#[derive(Debug)]
132+
pub struct AsyncUpper {
133+
signature: Signature,
134+
}
135+
136+
impl Default for AsyncUpper {
137+
fn default() -> Self {
138+
Self::new()
139+
}
140+
}
141+
142+
impl AsyncUpper {
143+
pub fn new() -> Self {
144+
Self {
145+
signature: Signature::new(
146+
TypeSignature::Coercible(vec![Coercion::Exact {
147+
desired_type: TypeSignatureClass::Native(logical_string()),
148+
}]),
149+
Volatility::Volatile,
150+
),
151+
}
152+
}
153+
}
154+
155+
#[async_trait]
156+
impl AsyncScalarUDFImpl for AsyncUpper {
157+
fn as_any(&self) -> &dyn Any {
158+
self
159+
}
160+
161+
fn name(&self) -> &str {
162+
"async_upper"
163+
}
164+
165+
fn signature(&self) -> &Signature {
166+
&self.signature
167+
}
168+
169+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
170+
Ok(DataType::Utf8)
171+
}
172+
173+
fn ideal_batch_size(&self) -> Option<usize> {
174+
Some(10)
175+
}
176+
177+
async fn invoke_async_with_args(
178+
&self,
179+
args: AsyncScalarFunctionArgs,
180+
_option: &ConfigOptions,
181+
) -> Result<ArrayRef> {
182+
trace!("Invoking async_upper with args: {:?}", args);
183+
let value = &args.args[0];
184+
let result = match value {
185+
ColumnarValue::Array(array) => {
186+
let string_array = array.as_string::<i32>();
187+
let iter = ArrayIter::new(string_array);
188+
let result = iter
189+
.map(|string| string.map(|s| s.to_uppercase()))
190+
.collect::<StringArray>();
191+
Arc::new(result) as ArrayRef
192+
}
193+
_ => return internal_err!("Expected a string argument, got {:?}", value),
194+
};
195+
Ok(result)
196+
}
197+
}
198+
199+
#[derive(Debug)]
200+
struct AsyncEqual {
201+
signature: Signature,
202+
}
203+
204+
impl Default for AsyncEqual {
205+
fn default() -> Self {
206+
Self::new()
207+
}
208+
}
209+
210+
impl AsyncEqual {
211+
pub fn new() -> Self {
212+
Self {
213+
signature: Signature::new(
214+
TypeSignature::Coercible(vec![
215+
Coercion::Exact {
216+
desired_type: TypeSignatureClass::Native(logical_int64()),
217+
},
218+
Coercion::Exact {
219+
desired_type: TypeSignatureClass::Native(logical_int64()),
220+
},
221+
]),
222+
Volatility::Volatile,
223+
),
224+
}
225+
}
226+
}
227+
228+
#[async_trait]
229+
impl AsyncScalarUDFImpl for AsyncEqual {
230+
fn as_any(&self) -> &dyn Any {
231+
self
232+
}
233+
234+
fn name(&self) -> &str {
235+
"async_equal"
236+
}
237+
238+
fn signature(&self) -> &Signature {
239+
&self.signature
240+
}
241+
242+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
243+
Ok(DataType::Boolean)
244+
}
245+
246+
async fn invoke_async_with_args(
247+
&self,
248+
args: AsyncScalarFunctionArgs,
249+
_option: &ConfigOptions,
250+
) -> Result<ArrayRef> {
251+
let [arg1, arg2] = take_function_args(self.name(), &args.args)?;
252+
apply_cmp(arg1, arg2, eq)?.to_array(args.number_rows)
253+
}
254+
}

0 commit comments

Comments
 (0)