Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ use core::{
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};

#[cfg(feature = "visitor")]
use core::ops::ControlFlow;
#[cfg(feature = "visitor")]
use sqlparser_derive::{Visit, VisitMut};

Expand Down Expand Up @@ -242,7 +244,6 @@ impl<T> DerefMut for Parens<T> {
/// An identifier, decomposed into its value or character data and the quote style.
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct Ident {
/// The value of the identifier without quotes.
pub value: String,
Expand Down Expand Up @@ -388,6 +389,22 @@ impl fmt::Display for Ident {
}
}

#[cfg(feature = "visitor")]
impl Visit for Ident {
fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
visitor.pre_visit_ident(self)?;
visitor.post_visit_ident(self)
}
}

#[cfg(feature = "visitor")]
impl VisitMut for Ident {
fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break> {
visitor.pre_visit_ident(self)?;
visitor.post_visit_ident(self)
}
}

/// A name of a table, view, custom type, etc., possibly multi-part, i.e. db.schema.obj
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
Expand Down
64 changes: 62 additions & 2 deletions src/ast/visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
use alloc::{boxed::Box, string::String, vec::Vec};
use core::ops::ControlFlow;

use crate::ast::{Expr, ObjectName, Query, Select, Statement, TableFactor, ValueWithSpan};
use crate::ast::{Expr, Ident, ObjectName, Query, Select, Statement, TableFactor, ValueWithSpan};

/// A type that can be visited by a [`Visitor`]. See [`Visitor`] for
/// recursively visiting parsed SQL statements.
Expand Down Expand Up @@ -269,6 +269,16 @@ pub trait Visitor {
fn post_visit_value(&mut self, _value: &ValueWithSpan) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}

/// Invoked for any identifiers that appear in the AST before visiting children
fn pre_visit_ident(&mut self, _ident: &Ident) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}

/// Invoked for any identifiers that appear in the AST after visiting children
fn post_visit_ident(&mut self, _ident: &Ident) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}
}

/// A visitor that can be used to mutate an AST tree.
Expand Down Expand Up @@ -397,6 +407,16 @@ pub trait VisitorMut {
fn post_visit_value(&mut self, _value: &mut ValueWithSpan) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}

/// Invoked for any identifiers that appear in the AST before visiting children
fn pre_visit_ident(&mut self, _ident: &mut Ident) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}

/// Invoked for any identifiers that appear in the AST after visiting children
fn post_visit_ident(&mut self, _ident: &mut Ident) -> ControlFlow<Self::Break> {
ControlFlow::Continue(())
}
}

struct RelationVisitor<F>(F);
Expand Down Expand Up @@ -1014,11 +1034,32 @@ mod tests {
let flow = s.visit(&mut visitor);
assert_eq!(flow, ControlFlow::Continue(()));
}

#[derive(Default)]
struct IdentVisitor {
idents: Vec<String>,
}

impl Visitor for IdentVisitor {
type Break = ();

fn pre_visit_ident(&mut self, ident: &Ident) -> ControlFlow<Self::Break> {
self.idents.push(ident.value.clone());
ControlFlow::Continue(())
}
}

#[test]
fn test_pre_visit_ident() {
let mut visitor = IdentVisitor::default();
do_visit("SELECT a, b FROM t", &mut visitor);
assert_eq!(visitor.idents, vec!["a", "b", "t"]);
}
}

#[cfg(test)]
mod visit_mut_tests {
use crate::ast::{Statement, Value, ValueWithSpan, VisitMut, VisitorMut};
use crate::ast::{Ident, Statement, Value, ValueWithSpan, VisitMut, VisitorMut};
use crate::dialect::GenericDialect;
use crate::parser::Parser;
use crate::tokenizer::Tokenizer;
Expand Down Expand Up @@ -1079,4 +1120,23 @@ mod visit_mut_tests {
assert_eq!(mutated.to_string(), expected)
}
}

#[derive(Default)]
struct IdentMutator;

impl VisitorMut for IdentMutator {
type Break = ();

fn pre_visit_ident(&mut self, ident: &mut Ident) -> ControlFlow<Self::Break> {
ident.value = ident.value.to_uppercase();
ControlFlow::Continue(())
}
}

#[test]
fn test_pre_visit_ident_mut() {
let mut visitor = IdentMutator;
let mutated = do_visit_mut("SELECT a, b FROM t", &mut visitor);
assert_eq!(mutated.to_string(), "SELECT A, B FROM T");
}
}