diff --git a/native-engine/auron-planner/proto/auron.proto b/native-engine/auron-planner/proto/auron.proto index b0618b971..6c0ef7f72 100644 --- a/native-engine/auron-planner/proto/auron.proto +++ b/native-engine/auron-planner/proto/auron.proto @@ -142,6 +142,8 @@ enum AggFunction { FIRST = 7; FIRST_IGNORES_NULL = 8; BLOOM_FILTER = 9; + LAST = 10; + LAST_IGNORES_NULL = 11; BRICKHOUSE_COLLECT = 1000; BRICKHOUSE_COMBINE_UNIQUE = 1001; UDAF = 1002; diff --git a/native-engine/auron-planner/src/lib.rs b/native-engine/auron-planner/src/lib.rs index a0f7b83d2..fb862bffd 100644 --- a/native-engine/auron-planner/src/lib.rs +++ b/native-engine/auron-planner/src/lib.rs @@ -135,6 +135,8 @@ impl From for AggFunction { protobuf::AggFunction::CollectSet => AggFunction::CollectSet, protobuf::AggFunction::First => AggFunction::First, protobuf::AggFunction::FirstIgnoresNull => AggFunction::FirstIgnoresNull, + protobuf::AggFunction::Last => AggFunction::Last, + protobuf::AggFunction::LastIgnoresNull => AggFunction::LastIgnoresNull, protobuf::AggFunction::BloomFilter => AggFunction::BloomFilter, protobuf::AggFunction::BrickhouseCollect => AggFunction::BrickhouseCollect, protobuf::AggFunction::BrickhouseCombineUnique => AggFunction::BrickhouseCombineUnique, diff --git a/native-engine/auron-planner/src/planner.rs b/native-engine/auron-planner/src/planner.rs index 84a625734..6af38eb4b 100644 --- a/native-engine/auron-planner/src/planner.rs +++ b/native-engine/auron-planner/src/planner.rs @@ -657,6 +657,12 @@ impl PhysicalPlanner { protobuf::AggFunction::FirstIgnoresNull => { WindowFunction::Agg(AggFunction::FirstIgnoresNull) } + protobuf::AggFunction::Last => { + WindowFunction::Agg(AggFunction::Last) + } + protobuf::AggFunction::LastIgnoresNull => { + WindowFunction::Agg(AggFunction::LastIgnoresNull) + } protobuf::AggFunction::BloomFilter => { WindowFunction::Agg(AggFunction::BloomFilter) } diff --git a/native-engine/datafusion-ext-plans/src/agg/agg.rs b/native-engine/datafusion-ext-plans/src/agg/agg.rs index 5eb4c3dad..99adc470d 100644 --- a/native-engine/datafusion-ext-plans/src/agg/agg.rs +++ b/native-engine/datafusion-ext-plans/src/agg/agg.rs @@ -33,6 +33,8 @@ use crate::agg::{ count::AggCount, first::AggFirst, first_ignores_null::AggFirstIgnoresNull, + last::AggLast, + last_ignores_null::AggLastIgnoresNull, maxmin::{AggMax, AggMin}, spark_udaf_wrapper::SparkUDAFWrapper, sum::AggSum, @@ -212,6 +214,14 @@ pub fn create_agg( let dt = children[0].data_type(input_schema)?; Arc::new(AggFirstIgnoresNull::try_new(children[0].clone(), dt)?) } + AggFunction::Last => { + let dt = children[0].data_type(input_schema)?; + Arc::new(AggLast::try_new(children[0].clone(), dt)?) + } + AggFunction::LastIgnoresNull => { + let dt = children[0].data_type(input_schema)?; + Arc::new(AggLastIgnoresNull::try_new(children[0].clone(), dt)?) + } AggFunction::BloomFilter => { let dt = children[0].data_type(input_schema)?; let empty_batch = RecordBatch::new_empty(Arc::new(Schema::empty())); diff --git a/native-engine/datafusion-ext-plans/src/agg/last.rs b/native-engine/datafusion-ext-plans/src/agg/last.rs new file mode 100644 index 000000000..6cf753b03 --- /dev/null +++ b/native-engine/datafusion-ext-plans/src/agg/last.rs @@ -0,0 +1,235 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{ + any::Any, + fmt::{Debug, Formatter}, + sync::Arc, +}; + +use arrow::{array::*, datatypes::*}; +use datafusion::{ + common::{Result, ScalarValue}, + physical_expr::PhysicalExprRef, +}; +use datafusion_ext_commons::{downcast_any, scalar_value::compacted_scalar_value_from_array}; + +use crate::{ + agg::{ + Agg, + acc::{ + AccBooleanColumn, AccBytes, AccBytesColumn, AccColumnRef, AccPrimColumn, + AccScalarValueColumn, create_acc_generic_column, + }, + agg::IdxSelection, + }, + idx_for_zipped, +}; + +pub struct AggLast { + child: PhysicalExprRef, + data_type: DataType, + acc_array_data_types: Vec, +} + +impl AggLast { + pub fn try_new(child: PhysicalExprRef, data_type: DataType) -> Result { + let acc_array_data_types = vec![data_type.clone()]; + Ok(Self { + child, + data_type, + acc_array_data_types, + }) + } +} + +impl Debug for AggLast { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "Last({:?})", self.child) + } +} + +impl Agg for AggLast { + fn as_any(&self) -> &dyn Any { + self + } + + fn exprs(&self) -> Vec { + vec![self.child.clone()] + } + + fn with_new_exprs(&self, exprs: Vec) -> Result> { + Ok(Arc::new(Self::try_new( + exprs[0].clone(), + self.data_type.clone(), + )?)) + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn nullable(&self) -> bool { + true + } + + fn create_acc_column(&self, num_rows: usize) -> AccColumnRef { + create_acc_generic_column(self.data_type.clone(), num_rows) + } + + fn acc_array_data_types(&self) -> &[DataType] { + &self.acc_array_data_types + } + + fn partial_update( + &self, + accs: &mut AccColumnRef, + acc_idx: IdxSelection<'_>, + partial_args: &[ArrayRef], + partial_arg_idx: IdxSelection<'_>, + ) -> Result<()> { + let partial_arg = &partial_args[0]; + accs.ensure_size(acc_idx); + + macro_rules! handle_bytes { + ($array:expr) => {{ + let accs = downcast_any!(accs, mut AccBytesColumn)?; + let partial_arg = $array; + idx_for_zipped! { + ((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => { + if partial_arg.is_valid(partial_arg_idx) { + accs.set_value(acc_idx, Some(AccBytes::from(partial_arg.value(partial_arg_idx).as_ref()))); + } else { + accs.set_value(acc_idx, None); + } + } + } + }} + } + + downcast_primitive_array! { + partial_arg => { + if let Ok(accs) = downcast_any!(accs, mut AccPrimColumn<_>) { + idx_for_zipped! { + ((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => { + if partial_arg.is_valid(partial_arg_idx) { + accs.set_value(acc_idx, Some(partial_arg.value(partial_arg_idx))); + } else { + accs.set_value(acc_idx, None); + } + } + } + } + } + DataType::Boolean => { + let accs = downcast_any!(accs, mut AccBooleanColumn)?; + let partial_arg = downcast_any!(partial_arg, BooleanArray)?; + idx_for_zipped! { + ((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => { + if partial_arg.is_valid(partial_arg_idx) { + accs.set_value(acc_idx, Some(partial_arg.value(partial_arg_idx))); + } else { + accs.set_value(acc_idx, None); + } + } + } + } + DataType::Utf8 => handle_bytes!(downcast_any!(partial_arg, StringArray)?), + DataType::Binary => handle_bytes!(downcast_any!(partial_arg, BinaryArray)?), + _other => { + let accs = downcast_any!(accs, mut AccScalarValueColumn)?; + idx_for_zipped! { + ((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => { + if partial_arg.is_valid(partial_arg_idx) { + accs.set_value(acc_idx, compacted_scalar_value_from_array(partial_arg, partial_arg_idx)?); + } else { + accs.set_value(acc_idx, ScalarValue::try_from(&self.data_type)?); + } + } + } + } + } + Ok(()) + } + + fn partial_merge( + &self, + accs: &mut AccColumnRef, + acc_idx: IdxSelection<'_>, + merging_accs: &mut AccColumnRef, + merging_acc_idx: IdxSelection<'_>, + ) -> Result<()> { + accs.ensure_size(acc_idx); + + // For last, always overwrite with the merging accumulator's value + macro_rules! handle_primitive { + ($ty:ty) => {{ + type TNative = <$ty as ArrowPrimitiveType>::Native; + let accs = downcast_any!(accs, mut AccPrimColumn)?; + let merging_accs = downcast_any!(merging_accs, mut AccPrimColumn<_>)?; + idx_for_zipped! { + ((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => { + accs.set_value(acc_idx, merging_accs.value(merging_acc_idx)); + } + } + }} + } + + macro_rules! handle_boolean { + () => {{ + let accs = downcast_any!(accs, mut AccBooleanColumn)?; + let merging_accs = downcast_any!(merging_accs, mut AccBooleanColumn)?; + idx_for_zipped! { + ((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => { + accs.set_value(acc_idx, merging_accs.value(merging_acc_idx)); + } + } + }}; + } + + macro_rules! handle_bytes { + () => {{ + let accs = downcast_any!(accs, mut AccBytesColumn)?; + let merging_accs = downcast_any!(merging_accs, mut AccBytesColumn)?; + idx_for_zipped! { + ((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => { + accs.set_value(acc_idx, merging_accs.take_value(merging_acc_idx)); + } + } + }}; + } + + downcast_primitive! { + (&self.data_type) => (handle_primitive), + DataType::Boolean => handle_boolean!(), + DataType::Utf8 | DataType::Binary => handle_bytes!(), + DataType::Null => {} + _ => { + let accs = downcast_any!(accs, mut AccScalarValueColumn)?; + let merging_accs = downcast_any!(merging_accs, mut AccScalarValueColumn)?; + idx_for_zipped! { + ((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => { + accs.set_value(acc_idx, merging_accs.take_value(merging_acc_idx)); + } + } + } + } + Ok(()) + } + + fn final_merge(&self, accs: &mut AccColumnRef, acc_idx: IdxSelection<'_>) -> Result { + Ok(accs.freeze_to_arrays(acc_idx)?[0].clone()) + } +} diff --git a/native-engine/datafusion-ext-plans/src/agg/last_ignores_null.rs b/native-engine/datafusion-ext-plans/src/agg/last_ignores_null.rs new file mode 100644 index 000000000..fde2afd94 --- /dev/null +++ b/native-engine/datafusion-ext-plans/src/agg/last_ignores_null.rs @@ -0,0 +1,232 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{ + any::Any, + fmt::{Debug, Formatter}, + sync::Arc, +}; + +use arrow::{array::*, datatypes::*}; +use datafusion::{common::Result, physical_expr::PhysicalExprRef}; +use datafusion_ext_commons::{downcast_any, scalar_value::compacted_scalar_value_from_array}; + +use crate::{ + agg::{ + Agg, + acc::{ + AccBooleanColumn, AccBytes, AccBytesColumn, AccColumnRef, AccPrimColumn, + AccScalarValueColumn, create_acc_generic_column, + }, + agg::IdxSelection, + }, + idx_for_zipped, +}; + +pub struct AggLastIgnoresNull { + child: PhysicalExprRef, + data_type: DataType, + acc_array_data_types: Vec, +} + +impl AggLastIgnoresNull { + pub fn try_new(child: PhysicalExprRef, data_type: DataType) -> Result { + let acc_array_data_types = vec![data_type.clone()]; + Ok(Self { + child, + data_type, + acc_array_data_types, + }) + } +} + +impl Debug for AggLastIgnoresNull { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "LastIgnoresNull({:?})", self.child) + } +} + +impl Agg for AggLastIgnoresNull { + fn as_any(&self) -> &dyn Any { + self + } + + fn exprs(&self) -> Vec { + vec![self.child.clone()] + } + + fn with_new_exprs(&self, exprs: Vec) -> Result> { + Ok(Arc::new(Self::try_new( + exprs[0].clone(), + self.data_type.clone(), + )?)) + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn nullable(&self) -> bool { + true + } + + fn create_acc_column(&self, num_rows: usize) -> AccColumnRef { + create_acc_generic_column(self.data_type.clone(), num_rows) + } + + fn acc_array_data_types(&self) -> &[DataType] { + &self.acc_array_data_types + } + + fn partial_update( + &self, + accs: &mut AccColumnRef, + acc_idx: IdxSelection<'_>, + partial_args: &[ArrayRef], + partial_arg_idx: IdxSelection<'_>, + ) -> Result<()> { + let partial_arg = &partial_args[0]; + accs.ensure_size(acc_idx); + + macro_rules! handle_bytes { + ($array:expr) => {{ + let accs = downcast_any!(accs, mut AccBytesColumn)?; + let partial_arg = $array; + idx_for_zipped! { + ((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => { + if partial_arg.is_valid(partial_arg_idx) { + accs.set_value(acc_idx, Some(AccBytes::from(partial_arg.value(partial_arg_idx).as_ref()))); + } + } + } + }} + } + + downcast_primitive_array! { + partial_arg => { + if let Ok(accs) = downcast_any!(accs, mut AccPrimColumn<_>) { + idx_for_zipped! { + ((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => { + if partial_arg.is_valid(partial_arg_idx) { + accs.set_value(acc_idx, Some(partial_arg.value(partial_arg_idx))); + } + } + } + } + } + DataType::Boolean => { + let accs = downcast_any!(accs, mut AccBooleanColumn)?; + let partial_arg = downcast_any!(partial_arg, BooleanArray)?; + idx_for_zipped! { + ((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => { + if partial_arg.is_valid(partial_arg_idx) { + accs.set_value(acc_idx, Some(partial_arg.value(partial_arg_idx))); + } + } + } + } + DataType::Utf8 => handle_bytes!(downcast_any!(partial_arg, StringArray)?), + DataType::Binary => handle_bytes!(downcast_any!(partial_arg, BinaryArray)?), + _other => { + let accs = downcast_any!(accs, mut AccScalarValueColumn)?; + idx_for_zipped! { + ((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => { + if partial_arg.is_valid(partial_arg_idx) { + accs.set_value(acc_idx, compacted_scalar_value_from_array(partial_arg, partial_arg_idx)?); + } + } + } + } + } + Ok(()) + } + + fn partial_merge( + &self, + accs: &mut AccColumnRef, + acc_idx: IdxSelection<'_>, + merging_accs: &mut AccColumnRef, + merging_acc_idx: IdxSelection<'_>, + ) -> Result<()> { + accs.ensure_size(acc_idx); + + // primitive types + macro_rules! handle_primitive { + ($ty:ty) => {{ + type TNative = <$ty as ArrowPrimitiveType>::Native; + let accs = downcast_any!(accs, mut AccPrimColumn)?; + let merging_accs = downcast_any!(merging_accs, mut AccPrimColumn<_>)?; + idx_for_zipped! { + ((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => { + if merging_accs.value(merging_acc_idx).is_some() { + accs.set_value(acc_idx, merging_accs.value(merging_acc_idx)); + } + } + } + }} + } + + macro_rules! handle_boolean { + () => {{ + let accs = downcast_any!(accs, mut AccBooleanColumn)?; + let merging_accs = downcast_any!(merging_accs, mut AccBooleanColumn)?; + idx_for_zipped! { + ((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => { + if merging_accs.value(merging_acc_idx).is_some() { + accs.set_value(acc_idx, merging_accs.value(merging_acc_idx)); + } + } + } + }}; + } + + macro_rules! handle_bytes { + () => {{ + let accs = downcast_any!(accs, mut AccBytesColumn)?; + let merging_accs = downcast_any!(merging_accs, mut AccBytesColumn)?; + idx_for_zipped! { + ((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => { + if merging_accs.value(merging_acc_idx).is_some() { + accs.set_value(acc_idx, merging_accs.take_value(merging_acc_idx)); + } + } + } + }}; + } + + downcast_primitive! { + (&self.data_type) => (handle_primitive), + DataType::Boolean => handle_boolean!(), + DataType::Utf8 | DataType::Binary => handle_bytes!(), + DataType::Null => {} + _ => { + let accs = downcast_any!(accs, mut AccScalarValueColumn)?; + let merging_accs = downcast_any!(merging_accs, mut AccScalarValueColumn)?; + idx_for_zipped! { + ((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => { + if !merging_accs.value(merging_acc_idx).is_null() { + accs.set_value(acc_idx, merging_accs.take_value(merging_acc_idx)); + } + } + } + } + } + Ok(()) + } + + fn final_merge(&self, accs: &mut AccColumnRef, acc_idx: IdxSelection<'_>) -> Result { + Ok(accs.freeze_to_arrays(acc_idx)?[0].clone()) + } +} diff --git a/native-engine/datafusion-ext-plans/src/agg/mod.rs b/native-engine/datafusion-ext-plans/src/agg/mod.rs index 9f19b02c8..0aa579ebd 100644 --- a/native-engine/datafusion-ext-plans/src/agg/mod.rs +++ b/native-engine/datafusion-ext-plans/src/agg/mod.rs @@ -25,6 +25,8 @@ pub mod collect; pub mod count; pub mod first; pub mod first_ignores_null; +pub mod last; +pub mod last_ignores_null; pub mod maxmin; pub mod spark_udaf_wrapper; pub mod sum; @@ -69,6 +71,8 @@ pub enum AggFunction { Min, First, FirstIgnoresNull, + Last, + LastIgnoresNull, CollectList, CollectSet, BloomFilter, diff --git a/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronWindowSuite.scala b/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronWindowSuite.scala new file mode 100644 index 000000000..1c0669e96 --- /dev/null +++ b/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronWindowSuite.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.auron + +import org.apache.spark.sql.AuronQueryTest +import org.apache.spark.sql.execution.auron.plan.NativeWindowBase + +import org.apache.auron.util.AuronTestUtils + +class AuronWindowSuite extends AuronQueryTest with BaseAuronSQLSuite with AuronSQLTestHelper { + + test("first_value window function") { + withSQLConf("spark.auron.enable.window" -> "true") { + withTable("t1") { + sql("create table t1(id int, grp int, v string) using parquet") + sql("insert into t1 values (1, 1, 'a'), (2, 1, null), (3, 1, 'c'), (4, 2, 'x')") + + checkSparkAnswerAndOperator("""select + | id, + | grp, + | v, + | first_value(v) over (partition by grp order by id) as first_v + |from t1 + |""".stripMargin) + } + } + } + + test("first_value window function with ignore nulls") { + if (AuronTestUtils.isSparkV32OrGreater) { + withSQLConf("spark.auron.enable.window" -> "true") { + withTable("t1") { + sql("create table t1(id int, grp int, v string) using parquet") + sql("insert into t1 values (1, 1, null), (2, 1, 'b'), (3, 1, 'c'), (4, 2, 'x')") + + checkSparkAnswerAndOperator("""select + | id, + | grp, + | v, + | first_value(v) ignore nulls over (partition by grp order by id) as first_non_null_v + |from t1 + |""".stripMargin) + } + } + } + } + + test("last_value window function") { + withSQLConf("spark.auron.enable.window" -> "true") { + withTable("t1") { + sql("create table t1(id int, grp int, v string) using parquet") + sql("insert into t1 values (1, 1, 'a'), (2, 1, null), (3, 1, 'c'), (4, 2, 'x')") + + checkSparkAnswerAndOperator("""select + | id, + | grp, + | v, + | last_value(v) over (partition by grp order by id) as last_v + |from t1 + |""".stripMargin) + } + } + } + + test("last_value window function with ignore nulls") { + if (AuronTestUtils.isSparkV32OrGreater) { + withSQLConf("spark.auron.enable.window" -> "true") { + withTable("t1") { + sql("create table t1(id int, grp int, v string) using parquet") + sql("insert into t1 values (1, 1, 'a'), (2, 1, null), (3, 1, 'c'), (4, 2, 'x')") + + checkSparkAnswerAndOperator("""select + | id, + | grp, + | v, + | last_value(v) ignore nulls over (partition by grp order by id) as last_non_null_v + |from t1 + |""".stripMargin) + } + } + } + } +} diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala index 750aaa524..c5ad62cf6 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala @@ -37,7 +37,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.auron.util.Using import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Average, CollectList, CollectSet, Count, DeclarativeAggregate, First, Max, Min, Sum, TypedImperativeAggregate} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Average, CollectList, CollectSet, Count, DeclarativeAggregate, First, Last, Max, Min, Sum, TypedImperativeAggregate} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.expressions.codegen.ExprCode import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero @@ -1259,6 +1259,18 @@ object NativeConverters extends Logging { }) aggBuilder.addChildren(convertExpr(child)) + case Last(child, ignoresNullExpr) => + val ignoresNull = ignoresNullExpr.asInstanceOf[Any] match { + case Literal(v: Boolean, BooleanType) => v + case v: Boolean => v + } + aggBuilder.setAggFunction(if (ignoresNull) { + pb.AggFunction.LAST_IGNORES_NULL + } else { + pb.AggFunction.LAST + }) + aggBuilder.addChildren(convertExpr(child)) + case CollectList(child, _, _) => aggBuilder.setAggFunction(pb.AggFunction.COLLECT_LIST) aggBuilder.addChildren(convertExpr(child)) diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeWindowBase.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeWindowBase.scala index fad61ff09..eb16a989e 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeWindowBase.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeWindowBase.scala @@ -36,6 +36,8 @@ import org.apache.spark.sql.catalyst.expressions.SortOrder import org.apache.spark.sql.catalyst.expressions.WindowExpression import org.apache.spark.sql.catalyst.expressions.aggregate.Average import org.apache.spark.sql.catalyst.expressions.aggregate.Count +import org.apache.spark.sql.catalyst.expressions.aggregate.First +import org.apache.spark.sql.catalyst.expressions.aggregate.Last import org.apache.spark.sql.catalyst.expressions.aggregate.Max import org.apache.spark.sql.catalyst.expressions.aggregate.Min import org.apache.spark.sql.catalyst.expressions.aggregate.Sum @@ -158,6 +160,32 @@ abstract class NativeWindowBase( windowExprBuilder.setAggFunc(pb.AggFunction.COUNT) windowExprBuilder.addChildren(NativeConverters.convertExpr(child)) + case e @ First(child, ignoresNullExpr) => + assert( + spec.frameSpecification == RowNumber().frame, // only supports RowFrame(Unbounded, CurrentRow) + s"window frame not supported: ${spec.frameSpecification}") + val ignoresNull = ignoresNullExpr.asInstanceOf[Any] match { + case Literal(v: Boolean, BooleanType) => v + case v: Boolean => v + } + windowExprBuilder.setFuncType(pb.WindowFunctionType.Agg) + windowExprBuilder.setAggFunc( + if (ignoresNull) pb.AggFunction.FIRST_IGNORES_NULL else pb.AggFunction.FIRST) + windowExprBuilder.addChildren(NativeConverters.convertExpr(child)) + + case e @ Last(child, ignoresNullExpr) => + assert( + spec.frameSpecification == RowNumber().frame, // only supports RowFrame(Unbounded, CurrentRow) + s"window frame not supported: ${spec.frameSpecification}") + val ignoresNull = ignoresNullExpr.asInstanceOf[Any] match { + case Literal(v: Boolean, BooleanType) => v + case v: Boolean => v + } + windowExprBuilder.setFuncType(pb.WindowFunctionType.Agg) + windowExprBuilder.setAggFunc( + if (ignoresNull) pb.AggFunction.LAST_IGNORES_NULL else pb.AggFunction.LAST) + windowExprBuilder.addChildren(NativeConverters.convertExpr(child)) + case other => throw new NotImplementedError(s"window function not supported: $other") }