Skip to content

Commit f061d1b

Browse files
committed
feldera-types: fix serialization of SqlType
A custom deserialization was implemented for `SqlType`, but its custom serialization counterpart was not implemented. This resulted in an error when deserializing the output of its own serialization for some of its variants (notably, interval types). This commit implements the custom serialization and adds a corresponding test. This also removes the `From` trait implementation, which is not used anywhere. Signed-off-by: Simon Kassing <simon.kassing@feldera.com>
1 parent 69465b2 commit f061d1b

File tree

2 files changed

+139
-64
lines changed

2 files changed

+139
-64
lines changed

crates/feldera-types/src/program_schema.rs

Lines changed: 106 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use serde::{Deserialize, Deserializer, Serialize};
1+
use serde::{Deserialize, Deserializer, Serialize, Serializer};
22
use std::cmp::Ordering;
33
use std::collections::BTreeMap;
44
use std::fmt::Display;
@@ -309,9 +309,8 @@ impl<'de> Deserialize<'de> for Field {
309309
///
310310
/// `INTERVAL 1 DAY`, `INTERVAL 1 DAY TO HOUR`, `INTERVAL 1 DAY TO MINUTE`,
311311
/// would yield `Day`, `DayToHour`, `DayToMinute`, as the `IntervalUnit` respectively.
312-
#[derive(Serialize, Deserialize, ToSchema, Debug, Eq, PartialEq, Clone, Copy)]
312+
#[derive(ToSchema, Debug, Eq, PartialEq, Clone, Copy)]
313313
#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
314-
#[serde(rename_all = "UPPERCASE")]
315314
pub enum IntervalUnit {
316315
/// Unit for `INTERVAL ... DAY`.
317316
Day,
@@ -342,70 +341,50 @@ pub enum IntervalUnit {
342341
}
343342

344343
/// The available SQL types as specified in `CREATE` statements.
345-
#[derive(Serialize, ToSchema, Debug, Eq, PartialEq, Clone, Copy)]
344+
#[derive(ToSchema, Debug, Eq, PartialEq, Clone, Copy)]
346345
#[cfg_attr(feature = "testing", derive(proptest_derive::Arbitrary))]
347346
pub enum SqlType {
348347
/// SQL `BOOLEAN` type.
349-
#[serde(rename = "BOOLEAN")]
350348
Boolean,
351349
/// SQL `TINYINT` type.
352-
#[serde(rename = "TINYINT")]
353350
TinyInt,
354351
/// SQL `SMALLINT` or `INT2` type.
355-
#[serde(rename = "SMALLINT")]
356352
SmallInt,
357353
/// SQL `INTEGER`, `INT`, `SIGNED`, `INT4` type.
358-
#[serde(rename = "INTEGER")]
359354
Int,
360355
/// SQL `BIGINT` or `INT64` type.
361-
#[serde(rename = "BIGINT")]
362356
BigInt,
363357
/// SQL `REAL` or `FLOAT4` or `FLOAT32` type.
364-
#[serde(rename = "REAL")]
365358
Real,
366359
/// SQL `DOUBLE` or `FLOAT8` or `FLOAT64` type.
367-
#[serde(rename = "DOUBLE")]
368360
Double,
369361
/// SQL `DECIMAL` or `DEC` or `NUMERIC` type.
370-
#[serde(rename = "DECIMAL")]
371362
Decimal,
372363
/// SQL `CHAR(n)` or `CHARACTER(n)` type.
373-
#[serde(rename = "CHAR")]
374364
Char,
375365
/// SQL `VARCHAR`, `CHARACTER VARYING`, `TEXT`, or `STRING` type.
376-
#[serde(rename = "VARCHAR")]
377366
Varchar,
378367
/// SQL `BINARY(n)` type.
379-
#[serde(rename = "BINARY")]
380368
Binary,
381369
/// SQL `VARBINARY` or `BYTEA` type.
382-
#[serde(rename = "VARBINARY")]
383370
Varbinary,
384371
/// SQL `TIME` type.
385-
#[serde(rename = "TIME")]
386372
Time,
387373
/// SQL `DATE` type.
388-
#[serde(rename = "DATE")]
389374
Date,
390375
/// SQL `TIMESTAMP` type.
391-
#[serde(rename = "TIMESTAMP")]
392376
Timestamp,
393377
/// SQL `INTERVAL ... X` type where `X` is a unit.
394378
Interval(IntervalUnit),
395379
/// SQL `ARRAY` type.
396-
#[serde(rename = "ARRAY")]
397380
Array,
398381
/// A complex SQL struct type (`CREATE TYPE x ...`).
399-
#[serde(rename = "STRUCT")]
400382
Struct,
401383
/// SQL `MAP` type.
402-
#[serde(rename = "MAP")]
403384
Map,
404385
/// SQL `NULL` type.
405-
#[serde(rename = "NULL")]
406386
Null,
407387
/// SQL `VARIANT` type.
408-
#[serde(rename = "VARIANT")]
409388
Variant,
410389
}
411390

@@ -463,9 +442,12 @@ impl<'de> Deserialize<'de> for SqlType {
463442
}
464443
}
465444

466-
impl From<SqlType> for &'static str {
467-
fn from(value: SqlType) -> &'static str {
468-
match value {
445+
impl Serialize for SqlType {
446+
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
447+
where
448+
S: Serializer,
449+
{
450+
let type_str = match self {
469451
SqlType::Boolean => "BOOLEAN",
470452
SqlType::TinyInt => "TINYINT",
471453
SqlType::SmallInt => "SMALLINT",
@@ -481,13 +463,28 @@ impl From<SqlType> for &'static str {
481463
SqlType::Time => "TIME",
482464
SqlType::Date => "DATE",
483465
SqlType::Timestamp => "TIMESTAMP",
484-
SqlType::Interval(_) => "INTERVAL",
466+
SqlType::Interval(interval_unit) => match interval_unit {
467+
IntervalUnit::Day => "INTERVAL_DAY",
468+
IntervalUnit::DayToHour => "INTERVAL_DAY_HOUR",
469+
IntervalUnit::DayToMinute => "INTERVAL_DAY_MINUTE",
470+
IntervalUnit::DayToSecond => "INTERVAL_DAY_SECOND",
471+
IntervalUnit::Hour => "INTERVAL_HOUR",
472+
IntervalUnit::HourToMinute => "INTERVAL_HOUR_MINUTE",
473+
IntervalUnit::HourToSecond => "INTERVAL_HOUR_SECOND",
474+
IntervalUnit::Minute => "INTERVAL_MINUTE",
475+
IntervalUnit::MinuteToSecond => "INTERVAL_MINUTE_SECOND",
476+
IntervalUnit::Month => "INTERVAL_MONTH",
477+
IntervalUnit::Second => "INTERVAL_SECOND",
478+
IntervalUnit::Year => "INTERVAL_YEAR",
479+
IntervalUnit::YearToMonth => "INTERVAL_YEAR_MONTH",
480+
},
485481
SqlType::Array => "ARRAY",
486482
SqlType::Struct => "STRUCT",
487483
SqlType::Map => "MAP",
488-
SqlType::Variant => "VARIANT",
489484
SqlType::Null => "NULL",
490-
}
485+
SqlType::Variant => "VARIANT",
486+
};
487+
serializer.serialize_str(type_str)
491488
}
492489
}
493490

@@ -651,9 +648,87 @@ impl ColumnType {
651648

652649
#[cfg(test)]
653650
mod tests {
654-
use super::SqlIdentifier;
651+
use super::{IntervalUnit, SqlIdentifier};
655652
use crate::program_schema::SqlType;
656653

654+
#[test]
655+
fn serde_sql_type() {
656+
for (sql_str_base, expected_value) in [
657+
("Boolean", SqlType::Boolean),
658+
("TinyInt", SqlType::TinyInt),
659+
("SmallInt", SqlType::SmallInt),
660+
("Integer", SqlType::Int),
661+
("BigInt", SqlType::BigInt),
662+
("Real", SqlType::Real),
663+
("Double", SqlType::Double),
664+
("Decimal", SqlType::Decimal),
665+
("Char", SqlType::Char),
666+
("Varchar", SqlType::Varchar),
667+
("Binary", SqlType::Binary),
668+
("Varbinary", SqlType::Varbinary),
669+
("Time", SqlType::Time),
670+
("Date", SqlType::Date),
671+
("Timestamp", SqlType::Timestamp),
672+
("Interval_Day", SqlType::Interval(IntervalUnit::Day)),
673+
(
674+
"Interval_Day_Hour",
675+
SqlType::Interval(IntervalUnit::DayToHour),
676+
),
677+
(
678+
"Interval_Day_Minute",
679+
SqlType::Interval(IntervalUnit::DayToMinute),
680+
),
681+
(
682+
"Interval_Day_Second",
683+
SqlType::Interval(IntervalUnit::DayToSecond),
684+
),
685+
("Interval_Hour", SqlType::Interval(IntervalUnit::Hour)),
686+
(
687+
"Interval_Hour_Minute",
688+
SqlType::Interval(IntervalUnit::HourToMinute),
689+
),
690+
(
691+
"Interval_Hour_Second",
692+
SqlType::Interval(IntervalUnit::HourToSecond),
693+
),
694+
("Interval_Minute", SqlType::Interval(IntervalUnit::Minute)),
695+
(
696+
"Interval_Minute_Second",
697+
SqlType::Interval(IntervalUnit::MinuteToSecond),
698+
),
699+
("Interval_Month", SqlType::Interval(IntervalUnit::Month)),
700+
("Interval_Second", SqlType::Interval(IntervalUnit::Second)),
701+
("Interval_Year", SqlType::Interval(IntervalUnit::Year)),
702+
(
703+
"Interval_Year_Month",
704+
SqlType::Interval(IntervalUnit::YearToMonth),
705+
),
706+
("Array", SqlType::Array),
707+
("Struct", SqlType::Struct),
708+
("Map", SqlType::Map),
709+
("Null", SqlType::Null),
710+
("Variant", SqlType::Variant),
711+
] {
712+
for sql_str in [
713+
sql_str_base, // Capitalized
714+
&sql_str_base.to_lowercase(), // lowercase
715+
&sql_str_base.to_uppercase(), // UPPERCASE
716+
] {
717+
let value1: SqlType = serde_json::from_str(&format!("\"{}\"", sql_str)).expect(
718+
&format!("\"{sql_str}\" should deserialize into its SQL type"),
719+
);
720+
assert_eq!(value1, expected_value);
721+
let serialized_str =
722+
serde_json::to_string(&value1).expect("Value should serialize into JSON");
723+
let value2: SqlType = serde_json::from_str(&serialized_str).expect(&format!(
724+
"{} should deserialize back into its SQL type",
725+
serialized_str
726+
));
727+
assert_eq!(value1, value2);
728+
}
729+
}
730+
}
731+
657732
#[test]
658733
fn deserialize_interval_types() {
659734
use super::IntervalUnit::*;

0 commit comments

Comments
 (0)