@@ -297,6 +297,10 @@ pub(crate) async fn execute_sql(
297297
298298/// Plan and translate `sql` against `state`, applying `PREPARE`/`EXECUTE`
299299/// substitution within the scope of a single ad-hoc request.
300+ ///
301+ /// Only the final statement returns rows. Earlier statements may be
302+ /// `PREPARE`s or any non-result-producing statement (e.g. `INSERT`),
303+ /// executed for their side effect.
300304async fn execute_sql_with_state (
301305 state : SessionState ,
302306 sql : & str ,
@@ -315,9 +319,6 @@ async fn execute_sql_with_state(
315319 let mut prepared: HashMap < String , LogicalPlan > = HashMap :: new ( ) ;
316320 let sql_options = SQLOptions :: new ( ) . with_allow_ddl ( false ) ;
317321
318- // For now, only the final statement may produce a result set. All
319- // preceding statements must be PREPAREs whose inner plans are stashed
320- // for a later EXECUTE in the same request.
321322 while statements. len ( ) > 1 {
322323 let stmt = statements. pop_front ( ) . unwrap ( ) ;
323324 let plan = state. statement_to_plan ( stmt) . await ?;
@@ -326,14 +327,40 @@ async fn execute_sql_with_state(
326327 sql_options. verify_plan ( & input) ?;
327328 prepared. insert ( name, ( * input) . clone ( ) ) ;
328329 }
329- _ => {
330+ LogicalPlan :: Statement ( Statement :: Execute ( Execute { name, parameters } ) ) => {
331+ // `EXECUTE` of a previously-prepared statement, used here
332+ // for its side effects (e.g. a prepared INSERT).
333+ let prepared_plan =
334+ prepared
335+ . remove ( & name)
336+ . ok_or_else ( || PipelineError :: AdHocQueryError {
337+ error : format ! (
338+ "prepared statement '{name}' is not defined in this request"
339+ ) ,
340+ df : None ,
341+ } ) ?;
342+ let values = execute_parameters_to_scalars ( & parameters) ?;
343+ let bound = prepared_plan. replace_params_with_values ( & ParamValues :: List ( values) ) ?;
344+ sql_options. verify_plan ( & bound) ?;
345+ drain_intermediate_plan ( & state, bound) . await ?;
346+ }
347+ other if is_result_producing_plan ( & other) => {
330348 return Err ( PipelineError :: AdHocQueryError {
331- error : "only PREPARE statements may precede the final statement \
332- in a multi-statement ad-hoc query"
349+ error : "only the final statement in a multi-statement \
350+ ad-hoc query may return a result set; \
351+ move SELECTs to the end or split into \
352+ separate requests"
333353 . to_string ( ) ,
334354 df : None ,
335355 } ) ;
336356 }
357+ other => {
358+ // Non-result-producing intermediate statement (INSERT,
359+ // UPDATE, DELETE, EXPLAIN, ...). Execute it for its side
360+ // effects and discard the per-statement count row.
361+ sql_options. verify_plan ( & other) ?;
362+ drain_intermediate_plan ( & state, other) . await ?;
363+ }
337364 }
338365 }
339366
@@ -374,6 +401,26 @@ async fn execute_sql_with_state(
374401 Ok ( DataFrame :: new ( state, final_plan) )
375402}
376403
404+ /// True if executing this plan would surface rows to the caller. Used to
405+ /// reject queries like `SELECT; INSERT` where the early `SELECT` would
406+ /// otherwise be silently dropped.
407+ fn is_result_producing_plan ( plan : & LogicalPlan ) -> bool {
408+ !matches ! ( plan, LogicalPlan :: Dml ( _) | LogicalPlan :: Statement ( _) )
409+ }
410+
411+ /// Execute an intermediate statement for its side effects and drop the
412+ /// resulting batches. INSERTs produce a one-row count; we keep that
413+ /// count out of the response stream so only the request's final
414+ /// statement contributes rows.
415+ async fn drain_intermediate_plan (
416+ state : & SessionState ,
417+ plan : LogicalPlan ,
418+ ) -> Result < ( ) , PipelineError > {
419+ let df = DataFrame :: new ( state. clone ( ) , plan) ;
420+ let _ = df. collect ( ) . await ?;
421+ Ok ( ( ) )
422+ }
423+
377424/// Convert `EXECUTE` positional parameters to DataFusion's `ScalarAndMetadata`
378425/// list, rejecting anything that is not a literal value.
379426fn execute_parameters_to_scalars ( params : & [ Expr ] ) -> Result < Vec < ScalarAndMetadata > , PipelineError > {
@@ -614,15 +661,74 @@ mod tests {
614661 assert_eq ! ( total_rows, 0 ) ;
615662 }
616663
664+ /// An intermediate `SELECT` (or any other result-producing statement)
665+ /// must be rejected: only one result set comes back per request, so
666+ /// executing the earlier SELECT silently would discard its rows.
617667 #[ tokio:: test]
618- async fn non_prepare_intermediate_statement_errors ( ) {
668+ async fn intermediate_select_is_rejected ( ) {
619669 let state = test_state ( ) ;
620670 let err = execute_sql_with_state ( state, "SELECT 1; SELECT 2" )
621671 . await
622672 . unwrap_err ( ) ;
623- assert ! (
624- format!( "{err:?}" ) . contains( "PREPARE" ) ,
625- "unexpected error: {err:?}"
626- ) ;
673+ let msg = format ! ( "{err}" ) ;
674+ assert ! ( msg. contains( "final statement" ) , "unexpected error: {msg}" ) ;
675+ }
676+
677+ /// Multiple `INSERT`s followed by a `SELECT` must execute in order,
678+ /// committing each insert's side effect, and only surface the final
679+ /// `SELECT`'s rows.
680+ #[ tokio:: test]
681+ async fn intermediate_inserts_run_and_final_select_returns_rows ( ) {
682+ use datafusion:: arrow:: array:: Int64Array ;
683+ use datafusion:: arrow:: datatypes:: { DataType , Field , Schema } ;
684+ use datafusion:: datasource:: MemTable ;
685+ use std:: sync:: Arc ;
686+
687+ // Register a writable in-memory table so DML executes for real.
688+ let schema = Arc :: new ( Schema :: new ( vec ! [ Field :: new( "x" , DataType :: Int64 , false ) ] ) ) ;
689+ let mem = MemTable :: try_new ( schema. clone ( ) , vec ! [ vec![ ] ] ) . unwrap ( ) ;
690+ let ctx = SessionContext :: new_with_state ( test_state ( ) ) ;
691+ ctx. register_table ( "t" , Arc :: new ( mem) ) . unwrap ( ) ;
692+ let state = ctx. state ( ) ;
693+
694+ let batches = collect_rows (
695+ state,
696+ "INSERT INTO t VALUES (1); INSERT INTO t VALUES (2); \
697+ SELECT SUM(x) AS s FROM t",
698+ )
699+ . await ;
700+ let total_rows: usize = batches. iter ( ) . map ( |b| b. num_rows ( ) ) . sum ( ) ;
701+ assert_eq ! ( total_rows, 1 ) ;
702+ let col = batches[ 0 ]
703+ . column ( 0 )
704+ . as_any ( )
705+ . downcast_ref :: < Int64Array > ( )
706+ . expect ( "int64 column" ) ;
707+ assert_eq ! ( col. value( 0 ) , 3 ) ;
708+ }
709+
710+ /// A trailing `INSERT` (no final SELECT) must still execute, and
711+ /// the final statement's count row is surfaced as today.
712+ #[ tokio:: test]
713+ async fn final_insert_returns_count ( ) {
714+ use datafusion:: arrow:: datatypes:: { DataType , Field , Schema } ;
715+ use datafusion:: datasource:: MemTable ;
716+ use std:: sync:: Arc ;
717+
718+ let schema = Arc :: new ( Schema :: new ( vec ! [ Field :: new( "x" , DataType :: Int64 , false ) ] ) ) ;
719+ let mem = MemTable :: try_new ( schema. clone ( ) , vec ! [ vec![ ] ] ) . unwrap ( ) ;
720+ let ctx = SessionContext :: new_with_state ( test_state ( ) ) ;
721+ ctx. register_table ( "t" , Arc :: new ( mem) ) . unwrap ( ) ;
722+ let state = ctx. state ( ) ;
723+
724+ let batches = collect_rows (
725+ state,
726+ "INSERT INTO t VALUES (10); INSERT INTO t VALUES (20)" ,
727+ )
728+ . await ;
729+ // The final INSERT yields a single-row count batch; check only
730+ // that one row came back.
731+ let total_rows: usize = batches. iter ( ) . map ( |b| b. num_rows ( ) ) . sum ( ) ;
732+ assert_eq ! ( total_rows, 1 ) ;
627733 }
628734}
0 commit comments