diff --git a/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs b/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs index c523b4a752a8..7cda6a42dcac 100644 --- a/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs +++ b/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs @@ -36,10 +36,17 @@ use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; use datafusion_physical_plan::{ ExecutionPlan, aggregates::{AggregateExec, AggregateMode}, - collect, + collect, displayable, limit::{GlobalLimitExec, LocalLimitExec}, }; +#[derive(Debug)] +struct AggregateRuntimeMetric { + mode: AggregateMode, + limit: Option, + output_rows: usize, +} + async fn run_plan_and_format(plan: Arc) -> Result { let cfg = SessionConfig::new().with_target_partitions(1); let ctx = SessionContext::new_with_config(cfg); @@ -48,6 +55,36 @@ async fn run_plan_and_format(plan: Arc) -> Result { Ok(actual) } +fn collect_aggregate_runtime_metrics( + plan: &Arc, + metrics: &mut Vec, +) { + if let Some(agg) = plan.downcast_ref::() { + let output_rows = agg + .metrics() + .and_then(|metrics| metrics.aggregate_by_name().output_rows()) + .expect("AggregateExec should record output_rows after execution"); + + metrics.push(AggregateRuntimeMetric { + mode: *agg.mode(), + limit: agg.limit_options().map(|config| config.limit()), + output_rows, + }); + } + + for child in plan.children() { + collect_aggregate_runtime_metrics(child, metrics); + } +} + +fn aggregate_runtime_metrics( + plan: &Arc, +) -> Vec { + let mut metrics = vec![]; + collect_aggregate_runtime_metrics(plan, &mut metrics); + metrics +} + #[tokio::test] async fn test_partial_final() -> Result<()> { let source = mock_data()?; @@ -104,6 +141,70 @@ async fn test_partial_final() -> Result<()> { Ok(()) } +// Ensure operator respect the the soft limit and stops early: `AggregateExec`'s +// `output_rows` metric should be smaller than then total distinct group count. +#[tokio::test] +async fn test_sql_partial_final_soft_limit_runtime_metrics() -> Result<()> { + let cfg = SessionConfig::new() + .with_target_partitions(2) + .with_batch_size(10) + .set_bool("datafusion.execution.enable_migration_aggregate", true); + let ctx = SessionContext::new_with_config(cfg); + + let dataframe = ctx + .sql( + "SELECT DISTINCT value % 100000 AS v \ + FROM generate_series(1000000) \ + LIMIT 10", + ) + .await?; + let plan = dataframe.create_physical_plan().await?; + let formatted_plan = displayable(plan.as_ref()).indent(false).to_string(); + assert!( + formatted_plan.contains("AggregateExec: mode=Partial"), + "expected a partial aggregate in plan:\n{formatted_plan}" + ); + assert!( + formatted_plan.contains("AggregateExec: mode=FinalPartitioned"), + "expected a final partitioned aggregate in plan:\n{formatted_plan}" + ); + + let batches = collect(Arc::clone(&plan), ctx.task_ctx()).await?; + assert_eq!( + batches.iter().map(|batch| batch.num_rows()).sum::(), + 10 + ); + + let metrics = aggregate_runtime_metrics(&plan); + let partial = metrics + .iter() + .find(|metric| metric.mode == AggregateMode::Partial) + .expect("expected partial aggregate metrics"); + let final_aggregate = metrics + .iter() + .find(|metric| { + matches!( + metric.mode, + AggregateMode::Final | AggregateMode::FinalPartitioned + ) + }) + .expect("expected final aggregate metrics"); + + assert_eq!(partial.limit, Some(10)); + assert_eq!(final_aggregate.limit, Some(10)); + + assert!( + partial.output_rows <= 100, + "partial aggregate should stop before emitting all distinct groups: {metrics:?}" + ); + assert!( + final_aggregate.output_rows <= 100, + "final aggregate should stop before emitting all distinct groups: {metrics:?}" + ); + + Ok(()) +} + #[tokio::test] async fn test_single_local() -> Result<()> { let source = mock_data()?; diff --git a/datafusion/physical-plan/src/aggregates/hash_aggregate.rs b/datafusion/physical-plan/src/aggregates/hash_aggregate.rs index f25299631a92..0c8593efd05b 100644 --- a/datafusion/physical-plan/src/aggregates/hash_aggregate.rs +++ b/datafusion/physical-plan/src/aggregates/hash_aggregate.rs @@ -60,6 +60,33 @@ use crate::{InputOrderMode, RecordBatchStream, SendableRecordBatchStream, metric /// ## Final Stage Behavior /// Input: partial states /// Output: results for all groups (e.g. for avg(x), it's avg(x) calculated from the state) +/// +/// # Optimization: DISTINCT LIMIT Soft Limit +/// +/// This optimization applies to both [`PartialHashAggregateStream`] and [`FinalHashAggregateStream`] +/// +/// Unordered distinct queries such as: +/// +/// ```sql +/// SELECT DISTINCT x FROM t LIMIT 10; +/// ``` +/// +/// are optimized into a two-stage aggregate like: +/// +/// ```txt +/// LimitExec, limit=10 +/// --AggregateExec(Final), group_by=[x], aggr=[], soft_limit=10 +/// ---- RepartitionExec, partitioning=hash(x) +/// ------ AggregateExec(Partial), group_by=[x], aggr=[], soft_limit=10 +/// -------- Scan(t) +/// ``` +/// +/// After each input batch, the stream checks whether the soft limit has been +/// reached. If so, it emits the accumulated groups and stops reading input. +/// +/// This operator does not guarantee an exact limit because a single batch can +/// cross the threshold. The downstream limit operator enforces the exact result +/// size. pub(crate) struct PartialHashAggregateStream { /// Output schema: group columns followed by partial aggregate state columns. schema: SchemaRef, @@ -78,6 +105,12 @@ pub(crate) struct PartialHashAggregateStream { /// Tracks partial aggregation row reduction, matching `GroupedHashAggregateStream`. reduction_factor: metrics::RatioMetrics, + + /// Optional soft limit on the number of groups to accumulate before output. + /// + /// Invariant: when this is `Some(..)`, the accumulators inside `hash_table` must + /// be empty. See struct comments for details. + group_values_soft_limit: Option, } /// Hash aggregation uses a 2-stage (partial and final) hash aggregation, this stream @@ -99,6 +132,9 @@ pub(crate) struct FinalHashAggregateStream { /// Memory reservation for group keys and accumulators. reservation: MemoryReservation, + + /// See comments for the same variable in [`PartialHashAggregateStream`] + group_values_soft_limit: Option, } impl PartialHashAggregateStream { @@ -139,8 +175,21 @@ impl PartialHashAggregateStream { baseline_metrics, reservation, reduction_factor, + group_values_soft_limit: agg.limit_options().map(|config| config.limit()), }) } + + /// See comments in [`Self::group_values_soft_limit`] for details. + fn hit_soft_group_limit(&self) -> bool { + self.group_values_soft_limit + .is_some_and(|limit| limit <= self.hash_table.building_group_count()) + } + + fn start_output(&mut self) -> Result<()> { + let input_schema = self.input.schema(); + self.input = Box::pin(EmptyRecordBatchStream::new(input_schema)); + self.hash_table.start_output() + } } impl Stream for PartialHashAggregateStream { @@ -169,6 +218,18 @@ impl Stream for PartialHashAggregateStream { return Poll::Ready(Some(Err(e))); } + if self.hit_soft_group_limit() { + let timer = elapsed_compute.timer(); + let result = self.start_output(); + timer.done(); + + if let Err(e) = result { + return Poll::Ready(Some(Err(e))); + } + + continue; + } + // TODO: impl memory-limited aggr, when OOM directly send // partial state to final aggregate stage if let Err(e) = @@ -181,11 +242,8 @@ impl Stream for PartialHashAggregateStream { return Poll::Ready(Some(Err(e))); } Poll::Ready(None) => { - let input_schema = self.input.schema(); - self.input = Box::pin(EmptyRecordBatchStream::new(input_schema)); - let timer = elapsed_compute.timer(); - let result = self.hash_table.start_output(); + let result = self.start_output(); timer.done(); if let Err(e) = result { @@ -262,8 +320,21 @@ impl FinalHashAggregateStream { hash_table, baseline_metrics, reservation, + group_values_soft_limit: agg.limit_options().map(|config| config.limit()), }) } + + /// See comments in [`Self::group_values_soft_limit`] for details. + fn hit_soft_group_limit(&self) -> bool { + self.group_values_soft_limit + .is_some_and(|limit| limit <= self.hash_table.building_group_count()) + } + + fn start_output(&mut self) -> Result<()> { + let input_schema = self.input.schema(); + self.input = Box::pin(EmptyRecordBatchStream::new(input_schema)); + self.hash_table.start_output() + } } impl Stream for FinalHashAggregateStream { @@ -291,6 +362,18 @@ impl Stream for FinalHashAggregateStream { return Poll::Ready(Some(Err(e))); } + if self.hit_soft_group_limit() { + let timer = elapsed_compute.timer(); + let result = self.start_output(); + timer.done(); + + if let Err(e) = result { + return Poll::Ready(Some(Err(e))); + } + + continue; + } + if let Err(e) = self.reservation.try_resize(self.hash_table.memory_size()) { @@ -301,11 +384,8 @@ impl Stream for FinalHashAggregateStream { return Poll::Ready(Some(Err(e))); } Poll::Ready(None) => { - let input_schema = self.input.schema(); - self.input = Box::pin(EmptyRecordBatchStream::new(input_schema)); - let timer = elapsed_compute.timer(); - let result = self.hash_table.start_output(); + let result = self.start_output(); timer.done(); if let Err(e) = result { diff --git a/datafusion/physical-plan/src/aggregates/hash_table.rs b/datafusion/physical-plan/src/aggregates/hash_table.rs index 278689d23f26..87f16d0eebe6 100644 --- a/datafusion/physical-plan/src/aggregates/hash_table.rs +++ b/datafusion/physical-plan/src/aggregates/hash_table.rs @@ -342,6 +342,10 @@ impl AggregateHashTable { } } + pub(super) fn building_group_count(&self) -> usize { + self.state.building().group_values.len() + } + pub(super) fn is_building(&self) -> bool { matches!(self.state, AggregateHashTableState::Building(_)) } diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 67327abea360..b8d1f8f6db50 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -1032,10 +1032,10 @@ impl AggregateExec { } self.mode == AggregateMode::Partial - && self.limit_options.is_none() && self.input_order_mode == InputOrderMode::Linear && !self.group_by.is_true_no_grouping() && self.group_by.is_single() + && self.limit_options_supported_by_hash_stream() } fn should_use_final_hash_stream(&self, context: &TaskContext) -> bool { @@ -1047,12 +1047,17 @@ impl AggregateExec { matches!( self.mode, AggregateMode::Final | AggregateMode::FinalPartitioned - ) && self.limit_options.is_none() + ) && self.limit_options_supported_by_hash_stream() && self.input_order_mode == InputOrderMode::Linear && !self.group_by.is_true_no_grouping() && self.group_by.is_single() } + /// See comments in `PartialHashAggregateStream` limit optimization section + fn limit_options_supported_by_hash_stream(&self) -> bool { + self.limit_options.is_none() || self.is_unordered_unfiltered_group_by_distinct() + } + /// Finds the DataType and SortDirection for this Aggregate, if there is one pub fn get_minmax_desc(&self) -> Option<(FieldRef, bool)> { let agg_expr = self.aggr_expr.iter().exactly_one().ok()?; @@ -3151,6 +3156,103 @@ mod tests { Ok(()) } + #[tokio::test] + async fn limited_distinct_aggregate_uses_migrated_hash_streams() -> Result<()> { + let schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::UInt32, false)])); + let input_batches = vec![ + RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(UInt32Array::from(vec![1, 2, 1]))], + )?, + RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(UInt32Array::from(vec![3, 4]))], + )?, + ]; + let group_by = + PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]); + let task_ctx = Arc::new( + TaskContext::default().with_session_config( + SessionConfig::new() + .set_bool("datafusion.execution.enable_migration_aggregate", true), + ), + ); + + let partial_input = TestMemoryExec::try_new_exec( + std::slice::from_ref(&input_batches), + Arc::clone(&schema), + None, + )?; + let partial_aggregate = Arc::new( + AggregateExec::try_new( + AggregateMode::Partial, + group_by.clone(), + vec![], + vec![], + partial_input, + Arc::clone(&schema), + )? + .with_limit_options(Some(LimitOptions::new(2))), + ); + + let partial_stream = partial_aggregate.execute_typed(0, &task_ctx)?; + assert!(matches!(partial_stream, StreamType::PartialHash(_))); + let stream: SendableRecordBatchStream = partial_stream.into(); + let partial_output = collect(stream).await?; + assert_eq!( + partial_output + .iter() + .map(RecordBatch::num_rows) + .sum::(), + 2 + ); + assert_snapshot!(batches_to_sort_string(&partial_output), @r" ++---+ +| a | ++---+ +| 1 | +| 2 | ++---+ +"); + + let final_input = + TestMemoryExec::try_new_exec(&[input_batches], Arc::clone(&schema), None)?; + let final_aggregate = Arc::new( + AggregateExec::try_new( + AggregateMode::Final, + group_by.as_final(), + vec![], + vec![], + final_input, + Arc::clone(&schema), + )? + .with_limit_options(Some(LimitOptions::new(2))), + ); + + let final_stream = final_aggregate.execute_typed(0, &task_ctx)?; + assert!(matches!(final_stream, StreamType::FinalHash(_))); + let stream: SendableRecordBatchStream = final_stream.into(); + let final_output = collect(stream).await?; + assert_eq!( + final_output + .iter() + .map(RecordBatch::num_rows) + .sum::(), + 2 + ); + assert_snapshot!(batches_to_sort_string(&final_output), @r" ++---+ +| a | ++---+ +| 1 | +| 2 | ++---+ +"); + + Ok(()) + } + #[tokio::test] async fn test_drop_cancel_without_groups() -> Result<()> { let task_ctx = Arc::new(TaskContext::default());