diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 447d89bb4..9f742b8f2 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -39,6 +39,8 @@ use core::{ #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; +#[cfg(feature = "visitor")] +use core::ops::ControlFlow; #[cfg(feature = "visitor")] use sqlparser_derive::{Visit, VisitMut}; @@ -242,7 +244,6 @@ impl DerefMut for Parens { /// An identifier, decomposed into its value or character data and the quote style. #[derive(Debug, Clone)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] pub struct Ident { /// The value of the identifier without quotes. pub value: String, @@ -388,6 +389,22 @@ impl fmt::Display for Ident { } } +#[cfg(feature = "visitor")] +impl Visit for Ident { + fn visit(&self, visitor: &mut V) -> ControlFlow { + visitor.pre_visit_ident(self)?; + visitor.post_visit_ident(self) + } +} + +#[cfg(feature = "visitor")] +impl VisitMut for Ident { + fn visit(&mut self, visitor: &mut V) -> ControlFlow { + visitor.pre_visit_ident(self)?; + visitor.post_visit_ident(self) + } +} + /// A name of a table, view, custom type, etc., possibly multi-part, i.e. db.schema.obj #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] diff --git a/src/ast/visitor.rs b/src/ast/visitor.rs index b14ca544a..c70e83ec6 100644 --- a/src/ast/visitor.rs +++ b/src/ast/visitor.rs @@ -21,7 +21,7 @@ use alloc::{boxed::Box, string::String, vec::Vec}; use core::ops::ControlFlow; -use crate::ast::{Expr, ObjectName, Query, Select, Statement, TableFactor, ValueWithSpan}; +use crate::ast::{Expr, Ident, ObjectName, Query, Select, Statement, TableFactor, ValueWithSpan}; /// A type that can be visited by a [`Visitor`]. See [`Visitor`] for /// recursively visiting parsed SQL statements. @@ -269,6 +269,16 @@ pub trait Visitor { fn post_visit_value(&mut self, _value: &ValueWithSpan) -> ControlFlow { ControlFlow::Continue(()) } + + /// Invoked for any identifiers that appear in the AST before visiting children + fn pre_visit_ident(&mut self, _ident: &Ident) -> ControlFlow { + ControlFlow::Continue(()) + } + + /// Invoked for any identifiers that appear in the AST after visiting children + fn post_visit_ident(&mut self, _ident: &Ident) -> ControlFlow { + ControlFlow::Continue(()) + } } /// A visitor that can be used to mutate an AST tree. @@ -397,6 +407,16 @@ pub trait VisitorMut { fn post_visit_value(&mut self, _value: &mut ValueWithSpan) -> ControlFlow { ControlFlow::Continue(()) } + + /// Invoked for any identifiers that appear in the AST before visiting children + fn pre_visit_ident(&mut self, _ident: &mut Ident) -> ControlFlow { + ControlFlow::Continue(()) + } + + /// Invoked for any identifiers that appear in the AST after visiting children + fn post_visit_ident(&mut self, _ident: &mut Ident) -> ControlFlow { + ControlFlow::Continue(()) + } } struct RelationVisitor(F); @@ -1014,11 +1034,32 @@ mod tests { let flow = s.visit(&mut visitor); assert_eq!(flow, ControlFlow::Continue(())); } + + #[derive(Default)] + struct IdentVisitor { + idents: Vec, + } + + impl Visitor for IdentVisitor { + type Break = (); + + fn pre_visit_ident(&mut self, ident: &Ident) -> ControlFlow { + self.idents.push(ident.value.clone()); + ControlFlow::Continue(()) + } + } + + #[test] + fn test_pre_visit_ident() { + let mut visitor = IdentVisitor::default(); + do_visit("SELECT a, b FROM t", &mut visitor); + assert_eq!(visitor.idents, vec!["a", "b", "t"]); + } } #[cfg(test)] mod visit_mut_tests { - use crate::ast::{Statement, Value, ValueWithSpan, VisitMut, VisitorMut}; + use crate::ast::{Ident, Statement, Value, ValueWithSpan, VisitMut, VisitorMut}; use crate::dialect::GenericDialect; use crate::parser::Parser; use crate::tokenizer::Tokenizer; @@ -1079,4 +1120,23 @@ mod visit_mut_tests { assert_eq!(mutated.to_string(), expected) } } + + #[derive(Default)] + struct IdentMutator; + + impl VisitorMut for IdentMutator { + type Break = (); + + fn pre_visit_ident(&mut self, ident: &mut Ident) -> ControlFlow { + ident.value = ident.value.to_uppercase(); + ControlFlow::Continue(()) + } + } + + #[test] + fn test_pre_visit_ident_mut() { + let mut visitor = IdentMutator; + let mutated = do_visit_mut("SELECT a, b FROM t", &mut visitor); + assert_eq!(mutated.to_string(), "SELECT A, B FROM T"); + } }