Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,17 @@ use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
use datafusion_physical_plan::{
ExecutionPlan,
aggregates::{AggregateExec, AggregateMode},
collect,
collect, displayable,
limit::{GlobalLimitExec, LocalLimitExec},
};

#[derive(Debug)]
struct AggregateRuntimeMetric {
mode: AggregateMode,
limit: Option<usize>,
output_rows: usize,
}

async fn run_plan_and_format(plan: Arc<dyn ExecutionPlan>) -> Result<String> {
let cfg = SessionConfig::new().with_target_partitions(1);
let ctx = SessionContext::new_with_config(cfg);
Expand All @@ -48,6 +55,36 @@ async fn run_plan_and_format(plan: Arc<dyn ExecutionPlan>) -> Result<String> {
Ok(actual)
}

fn collect_aggregate_runtime_metrics(
plan: &Arc<dyn ExecutionPlan>,
metrics: &mut Vec<AggregateRuntimeMetric>,
) {
if let Some(agg) = plan.downcast_ref::<AggregateExec>() {
let output_rows = agg
.metrics()
.and_then(|metrics| metrics.aggregate_by_name().output_rows())
.expect("AggregateExec should record output_rows after execution");

metrics.push(AggregateRuntimeMetric {
mode: *agg.mode(),
limit: agg.limit_options().map(|config| config.limit()),
output_rows,
});
}

for child in plan.children() {
collect_aggregate_runtime_metrics(child, metrics);
}
}

fn aggregate_runtime_metrics(
plan: &Arc<dyn ExecutionPlan>,
) -> Vec<AggregateRuntimeMetric> {
let mut metrics = vec![];
collect_aggregate_runtime_metrics(plan, &mut metrics);
metrics
}

#[tokio::test]
async fn test_partial_final() -> Result<()> {
let source = mock_data()?;
Expand Down Expand Up @@ -104,6 +141,70 @@ async fn test_partial_final() -> Result<()> {
Ok(())
}

// Ensure operator respect the the soft limit and stops early: `AggregateExec`'s
// `output_rows` metric should be smaller than then total distinct group count.
#[tokio::test]
async fn test_sql_partial_final_soft_limit_runtime_metrics() -> Result<()> {
let cfg = SessionConfig::new()
.with_target_partitions(2)
.with_batch_size(10)
.set_bool("datafusion.execution.enable_migration_aggregate", true);
let ctx = SessionContext::new_with_config(cfg);

let dataframe = ctx
.sql(
"SELECT DISTINCT value % 100000 AS v \
FROM generate_series(1000000) \
LIMIT 10",
)
.await?;
let plan = dataframe.create_physical_plan().await?;
let formatted_plan = displayable(plan.as_ref()).indent(false).to_string();
assert!(
formatted_plan.contains("AggregateExec: mode=Partial"),
"expected a partial aggregate in plan:\n{formatted_plan}"
);
assert!(
formatted_plan.contains("AggregateExec: mode=FinalPartitioned"),
"expected a final partitioned aggregate in plan:\n{formatted_plan}"
);

let batches = collect(Arc::clone(&plan), ctx.task_ctx()).await?;
assert_eq!(
batches.iter().map(|batch| batch.num_rows()).sum::<usize>(),
10
);

let metrics = aggregate_runtime_metrics(&plan);
let partial = metrics
.iter()
.find(|metric| metric.mode == AggregateMode::Partial)
.expect("expected partial aggregate metrics");
let final_aggregate = metrics
.iter()
.find(|metric| {
matches!(
metric.mode,
AggregateMode::Final | AggregateMode::FinalPartitioned
)
})
.expect("expected final aggregate metrics");

assert_eq!(partial.limit, Some(10));
assert_eq!(final_aggregate.limit, Some(10));

assert!(
partial.output_rows <= 100,
"partial aggregate should stop before emitting all distinct groups: {metrics:?}"
);
assert!(
final_aggregate.output_rows <= 100,
"final aggregate should stop before emitting all distinct groups: {metrics:?}"
);

Ok(())
}

#[tokio::test]
async fn test_single_local() -> Result<()> {
let source = mock_data()?;
Expand Down
96 changes: 88 additions & 8 deletions datafusion/physical-plan/src/aggregates/hash_aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,33 @@ use crate::{InputOrderMode, RecordBatchStream, SendableRecordBatchStream, metric
/// ## Final Stage Behavior
/// Input: partial states
/// Output: results for all groups (e.g. for avg(x), it's avg(x) calculated from the state)
///
/// # Optimization: DISTINCT LIMIT Soft Limit
///
/// This optimization applies to both [`PartialHashAggregateStream`] and [`FinalHashAggregateStream`]
///
/// Unordered distinct queries such as:
///
/// ```sql
/// SELECT DISTINCT x FROM t LIMIT 10;
/// ```
///
/// are optimized into a two-stage aggregate like:
///
/// ```txt
/// LimitExec, limit=10
/// --AggregateExec(Final), group_by=[x], aggr=[], soft_limit=10
/// ---- RepartitionExec, partitioning=hash(x)
/// ------ AggregateExec(Partial), group_by=[x], aggr=[], soft_limit=10
/// -------- Scan(t)
/// ```
///
/// After each input batch, the stream checks whether the soft limit has been
/// reached. If so, it emits the accumulated groups and stops reading input.
///
/// This operator does not guarantee an exact limit because a single batch can
/// cross the threshold. The downstream limit operator enforces the exact result
/// size.
pub(crate) struct PartialHashAggregateStream {
/// Output schema: group columns followed by partial aggregate state columns.
schema: SchemaRef,
Expand All @@ -78,6 +105,12 @@ pub(crate) struct PartialHashAggregateStream {

/// Tracks partial aggregation row reduction, matching `GroupedHashAggregateStream`.
reduction_factor: metrics::RatioMetrics,

/// Optional soft limit on the number of groups to accumulate before output.
///
/// Invariant: when this is `Some(..)`, the accumulators inside `hash_table` must
/// be empty. See struct comments for details.
group_values_soft_limit: Option<usize>,
}

/// Hash aggregation uses a 2-stage (partial and final) hash aggregation, this stream
Expand All @@ -99,6 +132,9 @@ pub(crate) struct FinalHashAggregateStream {

/// Memory reservation for group keys and accumulators.
reservation: MemoryReservation,

/// See comments for the same variable in [`PartialHashAggregateStream`]
group_values_soft_limit: Option<usize>,
}

impl PartialHashAggregateStream {
Expand Down Expand Up @@ -139,8 +175,21 @@ impl PartialHashAggregateStream {
baseline_metrics,
reservation,
reduction_factor,
group_values_soft_limit: agg.limit_options().map(|config| config.limit()),
})
}

/// See comments in [`Self::group_values_soft_limit`] for details.
fn hit_soft_group_limit(&self) -> bool {
self.group_values_soft_limit
.is_some_and(|limit| limit <= self.hash_table.building_group_count())
}

fn start_output(&mut self) -> Result<()> {
let input_schema = self.input.schema();
self.input = Box::pin(EmptyRecordBatchStream::new(input_schema));
self.hash_table.start_output()
}
}

impl Stream for PartialHashAggregateStream {
Expand Down Expand Up @@ -169,6 +218,18 @@ impl Stream for PartialHashAggregateStream {
return Poll::Ready(Some(Err(e)));
}

if self.hit_soft_group_limit() {
let timer = elapsed_compute.timer();
let result = self.start_output();
timer.done();

if let Err(e) = result {
return Poll::Ready(Some(Err(e)));
}

continue;
}

// TODO: impl memory-limited aggr, when OOM directly send
// partial state to final aggregate stage
if let Err(e) =
Expand All @@ -181,11 +242,8 @@ impl Stream for PartialHashAggregateStream {
return Poll::Ready(Some(Err(e)));
}
Poll::Ready(None) => {
let input_schema = self.input.schema();
self.input = Box::pin(EmptyRecordBatchStream::new(input_schema));

let timer = elapsed_compute.timer();
let result = self.hash_table.start_output();
let result = self.start_output();
timer.done();

if let Err(e) = result {
Expand Down Expand Up @@ -262,8 +320,21 @@ impl FinalHashAggregateStream {
hash_table,
baseline_metrics,
reservation,
group_values_soft_limit: agg.limit_options().map(|config| config.limit()),
})
}

/// See comments in [`Self::group_values_soft_limit`] for details.
fn hit_soft_group_limit(&self) -> bool {
self.group_values_soft_limit
.is_some_and(|limit| limit <= self.hash_table.building_group_count())
}

fn start_output(&mut self) -> Result<()> {
let input_schema = self.input.schema();
self.input = Box::pin(EmptyRecordBatchStream::new(input_schema));
self.hash_table.start_output()
}
}

impl Stream for FinalHashAggregateStream {
Expand Down Expand Up @@ -291,6 +362,18 @@ impl Stream for FinalHashAggregateStream {
return Poll::Ready(Some(Err(e)));
}

if self.hit_soft_group_limit() {
let timer = elapsed_compute.timer();
let result = self.start_output();
timer.done();

if let Err(e) = result {
return Poll::Ready(Some(Err(e)));
}

continue;
}

if let Err(e) =
self.reservation.try_resize(self.hash_table.memory_size())
{
Expand All @@ -301,11 +384,8 @@ impl Stream for FinalHashAggregateStream {
return Poll::Ready(Some(Err(e)));
}
Poll::Ready(None) => {
let input_schema = self.input.schema();
self.input = Box::pin(EmptyRecordBatchStream::new(input_schema));

let timer = elapsed_compute.timer();
let result = self.hash_table.start_output();
let result = self.start_output();
timer.done();

if let Err(e) = result {
Expand Down
4 changes: 4 additions & 0 deletions datafusion/physical-plan/src/aggregates/hash_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,10 @@ impl<Mode> AggregateHashTable<Mode> {
}
}

pub(super) fn building_group_count(&self) -> usize {
self.state.building().group_values.len()
}

pub(super) fn is_building(&self) -> bool {
matches!(self.state, AggregateHashTableState::Building(_))
}
Expand Down
Loading
Loading