diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 26d95965..ee86e419 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -14,7 +14,7 @@ jobs: strategy: matrix: os: [ubuntu-latest, macos-latest, windows-latest, macos-13, ubuntu-24.04-arm] - go: [1.23, 1.24] + go: [1.24] fail-fast: false steps: - uses: actions/checkout@v4 @@ -31,7 +31,7 @@ jobs: strategy: matrix: os: [ubuntu-latest, macos-latest] - go: [1.23, 1.24] + go: [1.24] defaults: run: shell: bash @@ -51,7 +51,7 @@ jobs: strategy: matrix: os: [ubuntu-latest, macos-latest, windows-latest] - go: [1.23, 1.24] + go: [1.24] fail-fast: false steps: - uses: actions/checkout@v4 @@ -68,7 +68,7 @@ jobs: shell: bash strategy: matrix: - go: [1.23, 1.24] + go: [1.24] os: [macos-13, macos-14, ubuntu-latest, ubuntu-24.04-arm] include: - os: "macos-13" @@ -101,7 +101,7 @@ jobs: shell: bash strategy: matrix: - go: [1.23, 1.24] + go: [1.24] fail-fast: false steps: - uses: actions/checkout@v4 diff --git a/.golangci.yml b/.golangci.yml index 96c0915e..f3528695 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -21,6 +21,7 @@ linters: - noinlineerr - paralleltest - testpackage + - unqueryvet - varnamelen - wrapcheck - wsl diff --git a/README.md b/README.md index 00ceb2b6..b251cc30 100644 --- a/README.md +++ b/README.md @@ -215,6 +215,14 @@ When passing a `time.Time` to go-duckdb, go-duckdb transforms it to an instant w even when using `TIMESTAMP_TZ`. Later, scanning either type of value returns an instant, as SQL types do not model time zone information for individual values. +**Connection lifetime** + +Temporary objects and state, such as temporary tables, are scoped to connections. +When closing a connection, Go's `database.sql` pooling logic might cache it as an idle connection, +instead of invoking its clean-up code by closing the connection. +That behavior can lead to, e.g., temporary tables persisting longer than expected. +To disable keeping idle connections alive, use `db.SetMaxIdleConns(0)`. + ## Memory Allocation DuckDB lives in process. diff --git a/appender.go b/appender.go index 835aa4e6..7e88fbb5 100644 --- a/appender.go +++ b/appender.go @@ -7,13 +7,15 @@ import ( "github.com/marcboeker/go-duckdb/mapping" ) -// Appender holds the DuckDB appender. It allows efficient bulk loading into a DuckDB database. +// Appender wraps functionality around the DuckDB appender. +// It enables efficient bulk transformations. type Appender struct { - conn *Conn - schema string - table string + // The raw sql.Conn's driver connection. + conn *Conn + // The DuckDB appender. appender mapping.Appender - closed bool + // True, if the appender has been closed. + closed bool // The chunk to append to. chunk DataChunk @@ -23,63 +25,99 @@ type Appender struct { rowCount int } -// NewAppenderFromConn returns a new Appender for the default catalog from a DuckDB driver connection. +// NewAppenderFromConn returns a new Appender for the default catalog. +// The Appender batches rows via AppendRow. Upon reaching the auto-flush threshold or +// upon calling Flush or Close, it appends these rows to the table. +// Thus, it can be used instead of INSERT INTO statements to enable bulk insertions. +// `driverConn` is the raw sql.Conn's driver connection. +// `schema` and `table` specify the table (`schema.table`) to append to. func NewAppenderFromConn(driverConn driver.Conn, schema, table string) (*Appender, error) { return NewAppender(driverConn, "", schema, table) } -// NewAppender returns a new Appender from a DuckDB driver connection. +// NewAppender returns a new Appender. +// The Appender batches rows via AppendRow. Upon reaching the auto-flush threshold or +// upon calling Flush or Close, it appends these rows to the table. +// Thus, it can be used instead of INSERT INTO statements to enable bulk insertions. +// `driverConn` is the raw sql.Conn's driver connection. +// `catalog`, `schema` and `table` specify the table (`catalog.schema.table`) to append to. func NewAppender(driverConn driver.Conn, catalog, schema, table string) (*Appender, error) { - conn, ok := driverConn.(*Conn) - if !ok { - return nil, getError(errInvalidCon, nil) - } - if conn.closed { - return nil, getError(errClosedCon, nil) + var a Appender + err := a.appenderConn(driverConn) + if err != nil { + return nil, err } - var appender mapping.Appender - state := mapping.AppenderCreateExt(conn.conn, catalog, schema, table, &appender) + state := mapping.AppenderCreateExt(a.conn.conn, catalog, schema, table, &a.appender) if state == mapping.StateError { - err := getDuckDBError(mapping.AppenderError(appender)) - mapping.AppenderDestroy(&appender) + err = errorDataError(mapping.AppenderErrorData(a.appender)) + mapping.AppenderDestroy(&a.appender) return nil, getError(errAppenderCreation, err) } - a := &Appender{ - conn: conn, - schema: schema, - table: table, - appender: appender, - rowCount: 0, - } - // Get the column types. - columnCount := mapping.AppenderColumnCount(appender) + columnCount := mapping.AppenderColumnCount(a.appender) for i := mapping.IdxT(0); i < columnCount; i++ { - colType := mapping.AppenderColumnType(appender, i) + colType := mapping.AppenderColumnType(a.appender, i) a.types = append(a.types, colType) // Ensure that we only create an appender for supported column types. t := mapping.GetTypeId(colType) name, found := unsupportedTypeToStringMap[t] if found { - err := addIndexToError(unsupportedTypeError(name), int(i)+1) - destroyTypeSlice(a.types) - mapping.AppenderDestroy(&appender) + err = addIndexToError(unsupportedTypeError(name), int(i)+1) + destroyLogicalTypes(a.types) + mapping.AppenderDestroy(&a.appender) return nil, getError(errAppenderCreation, err) } } - // Initialize the data chunk. - if err := a.chunk.initFromTypes(a.types, true); err != nil { - a.chunk.close() - destroyTypeSlice(a.types) - mapping.AppenderDestroy(&appender) + return a.initAppenderChunk() +} + +// NewQueryAppender returns a new query Appender. +// The Appender batches rows via AppendRow. Upon reaching the auto-flush threshold or +// upon calling Flush or Close, it executes the query, treating the batched rows as a temporary table. +// `driverConn` is the raw sql.Conn's driver connection. +// `query` is the query to execute. It can be a INSERT, DELETE, UPDATE or MERGE INTO statement. +// `table` is the (optional) table name of the temporary table containing the batched rows. +// It defaults to `appended_data`. +// `colTypes` are the column types of the temporary table. +// `colNames` are the (optional) names of the columns of the temporary table containing the batched rows. +// They default to `col1`, `col2`, ... +func NewQueryAppender(driverConn driver.Conn, query, table string, colTypes []TypeInfo, colNames []string) (*Appender, error) { + var a Appender + err := a.appenderConn(driverConn) + if err != nil { + return nil, err + } + + if query == "" { + return nil, getError(errAppenderEmptyQuery, nil) + } + if len(colTypes) == 0 { + return nil, getError(errAppenderEmptyColumnTypes, nil) + } + if len(colNames) != 0 && len(colTypes) != 0 { + if len(colNames) != len(colTypes) { + return nil, getError(errAppenderColumnMismatch, nil) + } + } + + // Get the logical types via the type infos. + for _, ct := range colTypes { + a.types = append(a.types, ct.logicalType()) + } + + state := mapping.AppenderCreateQuery(a.conn.conn, query, a.types, table, colNames, &a.appender) + if state == mapping.StateError { + destroyLogicalTypes(a.types) + err = errorDataError(mapping.AppenderErrorData(a.appender)) + mapping.AppenderDestroy(&a.appender) return nil, getError(errAppenderCreation, err) } - return a, nil + return a.initAppenderChunk() } // Flush the data chunks to the underlying table and clear the internal cache. @@ -117,7 +155,7 @@ func (a *Appender) Close() error { } // Destroy all appender data and the appender. - destroyTypeSlice(a.types) + destroyLogicalTypes(a.types) var errClose error if mapping.AppenderDestroy(&a.appender) == mapping.StateError { errClose = errAppenderClose @@ -145,6 +183,30 @@ func (a *Appender) AppendRow(args ...driver.Value) error { return nil } +func (a *Appender) appenderConn(driverConn driver.Conn) error { + var ok bool + a.conn, ok = driverConn.(*Conn) + if !ok { + return getError(errInvalidCon, nil) + } + if a.conn.closed { + return getError(errClosedCon, nil) + } + + return nil +} + +func (a *Appender) initAppenderChunk() (*Appender, error) { + if err := a.chunk.initFromTypes(a.types, true); err != nil { + a.chunk.close() + destroyLogicalTypes(a.types) + mapping.AppenderDestroy(&a.appender) + return nil, getError(errAppenderCreation, err) + } + + return a, nil +} + func (a *Appender) appendRowSlice(args []driver.Value) error { // Early-out, if the number of args does not match the column count. if len(args) != len(a.types) { @@ -187,9 +249,3 @@ func (a *Appender) appendDataChunk() error { return nil } - -func destroyTypeSlice(slice []mapping.LogicalType) { - for _, t := range slice { - mapping.DestroyLogicalType(&t) - } -} diff --git a/appender_test.go b/appender_test.go index f8b1a9d6..07e301d1 100644 --- a/appender_test.go +++ b/appender_test.go @@ -963,6 +963,106 @@ func TestAppenderAppendDataChunk(t *testing.T) { require.NoError(t, a.Flush()) } +func TestAppenderUpsert(t *testing.T) { + c := newConnectorWrapper(t, ``, nil) + defer closeConnectorWrapper(t, c) + + // Create a table with a PK for UPSERT. + db := sql.OpenDB(c) + defer closeDbWrapper(t, db) + _, err := db.Exec(` + CREATE TABLE test ( + id INT PRIMARY KEY, + u UNION(num INT, str VARCHAR) + )`) + require.NoError(t, err) + + conn := openDriverConnWrapper(t, c) + defer closeDriverConnWrapper(t, &conn) + + // Create the types. + intType, err := NewTypeInfo(TYPE_INTEGER) + require.NoError(t, err) + varcharType, err := NewTypeInfo(TYPE_VARCHAR) + require.NoError(t, err) + + memberTypes := []TypeInfo{intType, varcharType} + memberNames := []string{"num", "str"} + unionType, err := NewUnionInfo(memberTypes, memberNames) + require.NoError(t, err) + + // Create the INSERT query appender. + query := `INSERT INTO test SELECT col1, col2 FROM appended_data` + colTypes := []TypeInfo{intType, unionType} + aInsert := newQueryAppenderWrapper(t, &conn, query, "", colTypes, []string{}) + + // Close without appending anything. + closeAppenderWrapper(t, aInsert) + + // Create again and try to append with mismatching column names. + aInsert = newQueryAppenderWrapper(t, &conn, query, "", colTypes, []string{"a", "b"}) + require.NoError(t, aInsert.AppendRow(0, Union{Value: "str1", Tag: "str"})) + require.ErrorContains(t, aInsert.Close(), "Referenced column \"col1\" not found in FROM clause!") + + // Now re-create and test "normally". + aInsert = newQueryAppenderWrapper(t, &conn, query, "", colTypes, []string{}) + + // Append and insert (flush) two rows. + require.NoError(t, aInsert.AppendRow(0, Union{Value: "str1", Tag: "str"})) + require.NoError(t, aInsert.AppendRow(1, Union{Value: 42, Tag: "num"})) + require.NoError(t, aInsert.Flush()) + + // Create another INSERT appender selecting only some columns. + query = `INSERT INTO test SELECT id + 10, u FROM appended_data` + colTypes = []TypeInfo{intType, unionType, intType} + colNames := []string{"id", "u", "other"} + aInsertOther := newQueryAppenderWrapper(t, &conn, query, "", colTypes, colNames) + defer closeAppenderWrapper(t, aInsertOther) + + // Append and insert (flush) two rows. + require.NoError(t, aInsertOther.AppendRow(10, Union{Value: "str10", Tag: "str"}, 101)) + require.NoError(t, aInsertOther.AppendRow(11, Union{Value: 50, Tag: "num"}, 102)) + require.NoError(t, aInsertOther.Flush()) + + // Create the UPSERT query appender. + query = `INSERT INTO test SELECT * FROM my_append_tbl ON CONFLICT DO UPDATE SET u = EXCLUDED.u;` + colTypes = []TypeInfo{intType, unionType} + aUpsert := newQueryAppenderWrapper(t, &conn, query, "my_append_tbl", colTypes, []string{}) + defer closeAppenderWrapper(t, aUpsert) + + // Append and upsert (flush) two rows. + require.NoError(t, aUpsert.AppendRow(2, Union{Value: "str2", Tag: "str"})) + require.NoError(t, aUpsert.AppendRow(0, Union{Value: 43, Tag: "num"})) + require.NoError(t, aUpsert.Flush()) + + // Verify results. + res, err := db.QueryContext(context.Background(), `SELECT id, u FROM test ORDER BY id`) + require.NoError(t, err) + defer closeRowsWrapper(t, res) + + testCases := []struct { + id int32 + u Union + }{ + {0, Union{Value: int32(43), Tag: "num"}}, + {1, Union{Value: int32(42), Tag: "num"}}, + {2, Union{Value: "str2", Tag: "str"}}, + {20, Union{Value: "str10", Tag: "str"}}, + {21, Union{Value: int32(50), Tag: "num"}}, + } + + i := 0 + for res.Next() { + var id int32 + var u Union + require.NoError(t, res.Scan(&id, &u)) + require.Equal(t, testCases[i].id, id) + require.Equal(t, testCases[i].u, u) + i++ + } + require.Equal(t, len(testCases), i) +} + func BenchmarkAppenderNested(b *testing.B) { c, db, conn, a := prepareAppender(b, createNestedDataTableSQL) defer cleanupAppender(b, c, db, conn, a) diff --git a/arrowmapping/go.mod b/arrowmapping/go.mod index 0040d92c..aa55007f 100644 --- a/arrowmapping/go.mod +++ b/arrowmapping/go.mod @@ -3,10 +3,10 @@ module github.com/marcboeker/go-duckdb/arrowmapping go 1.24 require ( - github.com/duckdb/duckdb-go-bindings v0.1.19 - github.com/duckdb/duckdb-go-bindings/darwin-amd64 v0.1.19 - github.com/duckdb/duckdb-go-bindings/darwin-arm64 v0.1.19 - github.com/duckdb/duckdb-go-bindings/linux-amd64 v0.1.19 - github.com/duckdb/duckdb-go-bindings/linux-arm64 v0.1.19 - github.com/duckdb/duckdb-go-bindings/windows-amd64 v0.1.19 + github.com/duckdb/duckdb-go-bindings v0.1.20 + github.com/duckdb/duckdb-go-bindings/darwin-amd64 v0.1.20 + github.com/duckdb/duckdb-go-bindings/darwin-arm64 v0.1.20 + github.com/duckdb/duckdb-go-bindings/linux-amd64 v0.1.20 + github.com/duckdb/duckdb-go-bindings/linux-arm64 v0.1.20 + github.com/duckdb/duckdb-go-bindings/windows-amd64 v0.1.20 ) diff --git a/arrowmapping/go.sum b/arrowmapping/go.sum index b96bda22..669cb112 100644 --- a/arrowmapping/go.sum +++ b/arrowmapping/go.sum @@ -1,17 +1,17 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/duckdb/duckdb-go-bindings v0.1.19 h1:t8fwgKlr/5BEa5TJzvo3Vdr3yAgoYiR7L/TqyMuUQ2k= -github.com/duckdb/duckdb-go-bindings v0.1.19/go.mod h1:pBnfviMzANT/9hi4bg+zW4ykRZZPCXlVuvBWEcZofkc= -github.com/duckdb/duckdb-go-bindings/darwin-amd64 v0.1.19 h1:CdNZfRcFUFxI4Q+1Tu4TBFln9tkIn6bDwVwh9LeEsoo= -github.com/duckdb/duckdb-go-bindings/darwin-amd64 v0.1.19/go.mod h1:Ezo7IbAfB8NP7CqPIN8XEHKUg5xdRRQhcPPlCXImXYA= -github.com/duckdb/duckdb-go-bindings/darwin-arm64 v0.1.19 h1:mVijr3WFz3TXZLtAm5Hb6qEnstacZdFI5QQNuE9R2QQ= -github.com/duckdb/duckdb-go-bindings/darwin-arm64 v0.1.19/go.mod h1:eS7m/mLnPQgVF4za1+xTyorKRBuK0/BA44Oy6DgrGXI= -github.com/duckdb/duckdb-go-bindings/linux-amd64 v0.1.19 h1:jhchUY24T5bQLOwGyK0BzB6+HQmsRjAbgUZDKWo4ajs= -github.com/duckdb/duckdb-go-bindings/linux-amd64 v0.1.19/go.mod h1:1GOuk1PixiESxLaCGFhag+oFi7aP+9W8byymRAvunBk= -github.com/duckdb/duckdb-go-bindings/linux-arm64 v0.1.19 h1:CFcH+Bze2OgTaTLM94P3gJ554alnCCDnt1BH/nO8RJ8= -github.com/duckdb/duckdb-go-bindings/linux-arm64 v0.1.19/go.mod h1:o7crKMpT2eOIi5/FY6HPqaXcvieeLSqdXXaXbruGX7w= -github.com/duckdb/duckdb-go-bindings/windows-amd64 v0.1.19 h1:x/8t04sgCVU8JL0XLUZWmC1FAX13ZjM58EmsyPjvrvY= -github.com/duckdb/duckdb-go-bindings/windows-amd64 v0.1.19/go.mod h1:IlOhJdVKUJCAPj3QsDszUo8DVdvp1nBFp4TUJVdw99s= +github.com/duckdb/duckdb-go-bindings v0.1.20 h1:k9TOW6/oMSrGG+j7TVbi2CeTBuN0BEYFZ9IgI9zKXFE= +github.com/duckdb/duckdb-go-bindings v0.1.20/go.mod h1:pBnfviMzANT/9hi4bg+zW4ykRZZPCXlVuvBWEcZofkc= +github.com/duckdb/duckdb-go-bindings/darwin-amd64 v0.1.20 h1:4q5OfbXLoJZ2lbb74ttohBj8Lhz8CbyqVSCTH907VhI= +github.com/duckdb/duckdb-go-bindings/darwin-amd64 v0.1.20/go.mod h1:Ezo7IbAfB8NP7CqPIN8XEHKUg5xdRRQhcPPlCXImXYA= +github.com/duckdb/duckdb-go-bindings/darwin-arm64 v0.1.20 h1:Wb6Um5ocqcXQCL4bZ1263jUnPeVgjeR/s3JvUBKwaN0= +github.com/duckdb/duckdb-go-bindings/darwin-arm64 v0.1.20/go.mod h1:eS7m/mLnPQgVF4za1+xTyorKRBuK0/BA44Oy6DgrGXI= +github.com/duckdb/duckdb-go-bindings/linux-amd64 v0.1.20 h1:U8osykHQE1CWhJO+cZpQW4GSy2++gUgWEJ9UD1phEc0= +github.com/duckdb/duckdb-go-bindings/linux-amd64 v0.1.20/go.mod h1:1GOuk1PixiESxLaCGFhag+oFi7aP+9W8byymRAvunBk= +github.com/duckdb/duckdb-go-bindings/linux-arm64 v0.1.20 h1:SahHKRsqg3Nr5vCyXosbN1gtv26jw4JJUGShNJI3T/I= +github.com/duckdb/duckdb-go-bindings/linux-arm64 v0.1.20/go.mod h1:o7crKMpT2eOIi5/FY6HPqaXcvieeLSqdXXaXbruGX7w= +github.com/duckdb/duckdb-go-bindings/windows-amd64 v0.1.20 h1:0IJWGjcRYAX7aSNtIytGsR+X9C3MUGuku1KkV4cif3g= +github.com/duckdb/duckdb-go-bindings/windows-amd64 v0.1.20/go.mod h1:IlOhJdVKUJCAPj3QsDszUo8DVdvp1nBFp4TUJVdw99s= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= diff --git a/duckdb_test.go b/duckdb_test.go index cd63cf6c..8e3f3be2 100644 --- a/duckdb_test.go +++ b/duckdb_test.go @@ -92,6 +92,12 @@ func newAppenderWrapper[T require.TestingT](t T, conn *driver.Conn, schema, tabl return a } +func newQueryAppenderWrapper[T require.TestingT](t T, conn *driver.Conn, query, table string, colTypes []TypeInfo, colNames []string) *Appender { + a, err := NewQueryAppender(*conn, query, table, colTypes, colNames) + require.NoError(t, err) + return a +} + func closeAppenderWrapper[T require.TestingT](t T, a *Appender) { if a == nil { return diff --git a/errors.go b/errors.go index 065f5670..c79af327 100644 --- a/errors.go +++ b/errors.go @@ -4,6 +4,8 @@ import ( "errors" "fmt" "strings" + + "github.com/marcboeker/go-duckdb/mapping" ) func getError(errDriver, err error) error { @@ -114,6 +116,9 @@ var ( errAppenderAppendRow = errors.New("could not append row") errAppenderAppendAfterClose = fmt.Errorf("%w: appender already closed", errAppenderAppendRow) errAppenderFlush = errors.New("could not flush appender") + errAppenderEmptyQuery = errors.New("empty query") + errAppenderEmptyColumnTypes = errors.New("empty column types") + errAppenderColumnMismatch = errors.New("mismatch between the number of column types and names") errUnsupportedMapKeyType = errors.New("MAP key type not supported") errEmptyName = errors.New("empty name") @@ -144,49 +149,49 @@ var ( type ErrorType int const ( - ErrorTypeInvalid ErrorType = iota // invalid type - ErrorTypeOutOfRange // value out of range error - ErrorTypeConversion // conversion/casting error - ErrorTypeUnknownType // unknown type error - ErrorTypeDecimal // decimal related - ErrorTypeMismatchType // type mismatch - ErrorTypeDivideByZero // divide by 0 - ErrorTypeObjectSize // object size exceeded - ErrorTypeInvalidType // incompatible for operation - ErrorTypeSerialization // serialization - ErrorTypeTransaction // transaction management - ErrorTypeNotImplemented // method not implemented - ErrorTypeExpression // expression parsing - ErrorTypeCatalog // catalog related - ErrorTypeParser // parser related - ErrorTypePlanner // planner related - ErrorTypeScheduler // scheduler related - ErrorTypeExecutor // executor related - ErrorTypeConstraint // constraint related - ErrorTypeIndex // index related - ErrorTypeStat // stat related - ErrorTypeConnection // connection related - ErrorTypeSyntax // syntax related - ErrorTypeSettings // settings related - ErrorTypeBinder // binder related - ErrorTypeNetwork // network related - ErrorTypeOptimizer // optimizer related - ErrorTypeNullPointer // nullptr exception - ErrorTypeIO // IO exception - ErrorTypeInterrupt // interrupt - ErrorTypeFatal // Fatal exceptions are non-recoverable, and render the entire DB in an unusable state - ErrorTypeInternal // Internal exceptions indicate something went wrong internally (i.e. bug in the code base) - ErrorTypeInvalidInput // Input or arguments error - ErrorTypeOutOfMemory // out of memory - ErrorTypePermission // insufficient permissions - ErrorTypeParameterNotResolved // parameter types could not be resolved - ErrorTypeParameterNotAllowed // parameter types not allowed - ErrorTypeDependency // dependency - ErrorTypeHTTP - ErrorTypeMissingExtension // Thrown when an extension is used but not loaded - ErrorTypeAutoLoad // Thrown when an extension is used but not loaded - ErrorTypeSequence - ErrorTypeInvalidConfiguration // An invalid configuration was detected (e.g. a Secret param was missing, or a required setting not found) + ErrorTypeInvalid = ErrorType(mapping.ErrorTypeInvalid) // Invalid type. + ErrorTypeOutOfRange = ErrorType(mapping.ErrorTypeOutOfRange) // The type's value is out of range. + ErrorTypeConversion = ErrorType(mapping.ErrorTypeConversion) // Conversion/casting error. + ErrorTypeUnknownType = ErrorType(mapping.ErrorTypeUnknownType) // The type is unknown. + ErrorTypeDecimal = ErrorType(mapping.TypeDecimal) // Decimal-related error. + ErrorTypeMismatchType = ErrorType(mapping.ErrorTypeMismatchType) // Types don't match. + ErrorTypeDivideByZero = ErrorType(mapping.ErrorTypeDivideByZero) // Division by zero. + ErrorTypeObjectSize = ErrorType(mapping.ErrorTypeObjectSize) // Exceeds object size. + ErrorTypeInvalidType = ErrorType(mapping.ErrorTypeInvalidType) // Incompatible types. + ErrorTypeSerialization = ErrorType(mapping.ErrorTypeSerialization) // Type serialization error. + ErrorTypeTransaction = ErrorType(mapping.ErrorTypeTransaction) // Transaction conflict. + ErrorTypeNotImplemented = ErrorType(mapping.ErrorTypeNotImplemented) // Missing functionality. + ErrorTypeExpression = ErrorType(mapping.ErrorTypeExpression) // Expression error. + ErrorTypeCatalog = ErrorType(mapping.ErrorTypeCatalog) // Catalog error. + ErrorTypeParser = ErrorType(mapping.ErrorTypeParser) // Error during parsing. + ErrorTypePlanner = ErrorType(mapping.ErrorTypePlanner) // Error during planning. + ErrorTypeScheduler = ErrorType(mapping.ErrorTypeScheduler) // Scheduling error. + ErrorTypeExecutor = ErrorType(mapping.ErrorTypeExecutor) // Executor error. + ErrorTypeConstraint = ErrorType(mapping.ErrorTypeConstraint) // Constraint violation. + ErrorTypeIndex = ErrorType(mapping.ErrorTypeIndex) // Index error. + ErrorTypeStat = ErrorType(mapping.ErrorTypeStat) // Statistics error. + ErrorTypeConnection = ErrorType(mapping.ErrorTypeConnection) // Connection error. + ErrorTypeSyntax = ErrorType(mapping.ErrorTypeSyntax) // Invalid syntax. + ErrorTypeSettings = ErrorType(mapping.ErrorTypeSettings) // Settings-related error. + ErrorTypeBinder = ErrorType(mapping.ErrorTypeBinder) // Binding error. + ErrorTypeNetwork = ErrorType(mapping.ErrorTypeNetwork) // Network error. + ErrorTypeOptimizer = ErrorType(mapping.ErrorTypeOptimizer) // Optimizer error. + ErrorTypeNullPointer = ErrorType(mapping.ErrorTypeNullPointer) // Null-pointer exception. + ErrorTypeIO = ErrorType(mapping.ErrorTypeErrorIO) // IO exception. + ErrorTypeInterrupt = ErrorType(mapping.ErrorTypeInterrupt) // Query interruption. + ErrorTypeFatal = ErrorType(mapping.ErrorTypeFatal) // Fatal exception. Non-recoverable. The DB enters an invalid state and must be restarted. + ErrorTypeInternal = ErrorType(mapping.ErrorTypeInternal) // Internal exception. Indicates a bug, and should be reported. + ErrorTypeInvalidInput = ErrorType(mapping.ErrorTypeInvalidInput) // Invalid input. + ErrorTypeOutOfMemory = ErrorType(mapping.ErrorTypeOutOfMemory) // Out-of-memory error. + ErrorTypePermission = ErrorType(mapping.ErrorTypePermission) // Invalid permissions. + ErrorTypeParameterNotResolved = ErrorType(mapping.ErrorTypeParameterNotResolved) // Error when resolving types. + ErrorTypeParameterNotAllowed = ErrorType(mapping.ErrorTypeParameterNotAllowed) // Invalid parameter. + ErrorTypeDependency = ErrorType(mapping.ErrorTypeDependency) // Dependency error. + ErrorTypeHTTP = ErrorType(mapping.ErrorTypeHTTP) // HTTP error. + ErrorTypeMissingExtension = ErrorType(mapping.ErrorTypeMissingExtension) // Usage of a non-loaded extension. + ErrorTypeAutoLoad = ErrorType(mapping.ErrorTypeAutoload) // Usage of a non-loaded extension that cannot be loaded automatically. + ErrorTypeSequence = ErrorType(mapping.ErrorTypeSequence) // Sequence error. + ErrorTypeInvalidConfiguration = ErrorType(mapping.ErrorTypeInvalidConfiguration) // Indicates an invalid configuration, e.g., a missing Secret parameter, or a mandatory setting is not provided. ) var errorPrefixMap = map[string]ErrorType{ @@ -251,6 +256,18 @@ func (e *Error) Is(err error) bool { return false } +func errorDataError(errorData mapping.ErrorData) error { + defer mapping.DestroyErrorData(&errorData) + if !mapping.ErrorDataHasError(errorData) { + return nil + } + + t := mapping.ErrorDataErrorType(errorData) + msg := mapping.ErrorDataMessage(errorData) + + return &Error{ErrorType(t), msg} +} + func getDuckDBError(errMsg string) error { errType := ErrorTypeInvalid diff --git a/errors_test.go b/errors_test.go index 6cdff716..86aa4f6e 100644 --- a/errors_test.go +++ b/errors_test.go @@ -91,6 +91,45 @@ func TestErrAppender(t *testing.T) { testError(t, err, errAppenderCreation.Error()) }) + t.Run(errAppenderEmptyQuery.Error(), func(t *testing.T) { + c := newConnectorWrapper(t, ``, nil) + defer closeConnectorWrapper(t, c) + + conn := openDriverConnWrapper(t, c) + defer closeDriverConnWrapper(t, &conn) + + a, err := NewQueryAppender(conn, "", "", []TypeInfo{}, []string{}) + defer closeAppenderWrapper(t, a) + testError(t, err, errAppenderEmptyQuery.Error()) + }) + + t.Run(errAppenderEmptyColumnTypes.Error(), func(t *testing.T) { + c := newConnectorWrapper(t, ``, nil) + defer closeConnectorWrapper(t, c) + + conn := openDriverConnWrapper(t, c) + defer closeDriverConnWrapper(t, &conn) + + a, err := NewQueryAppender(conn, `INSERT INTO test SELECT * FROM appended_data`, "", []TypeInfo{}, []string{"c1", "c2"}) + defer closeAppenderWrapper(t, a) + testError(t, err, errAppenderEmptyColumnTypes.Error()) + }) + + t.Run(errAppenderColumnMismatch.Error(), func(t *testing.T) { + c := newConnectorWrapper(t, ``, nil) + defer closeConnectorWrapper(t, c) + + conn := openDriverConnWrapper(t, c) + defer closeDriverConnWrapper(t, &conn) + + info, err := NewTypeInfo(TYPE_INTEGER) + require.NoError(t, err) + + a, err := NewQueryAppender(conn, `INSERT INTO test SELECT * FROM appended_data`, "", []TypeInfo{info}, []string{"c1", "c2"}) + defer closeAppenderWrapper(t, a) + testError(t, err, errAppenderColumnMismatch.Error()) + }) + t.Run(errAppenderDoubleClose.Error(), func(t *testing.T) { c := newConnectorWrapper(t, ``, nil) defer closeConnectorWrapper(t, c) diff --git a/go.mod b/go.mod index 2dac3349..f670f67c 100644 --- a/go.mod +++ b/go.mod @@ -6,19 +6,19 @@ require ( github.com/apache/arrow-go/v18 v18.4.1 github.com/go-viper/mapstructure/v2 v2.4.0 github.com/google/uuid v1.6.0 - github.com/marcboeker/go-duckdb/arrowmapping v0.0.19 - github.com/marcboeker/go-duckdb/mapping v0.0.19 + github.com/marcboeker/go-duckdb/arrowmapping v0.0.20 + github.com/marcboeker/go-duckdb/mapping v0.0.20 github.com/stretchr/testify v1.11.0 ) require ( github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect - github.com/duckdb/duckdb-go-bindings v0.1.19 // indirect - github.com/duckdb/duckdb-go-bindings/darwin-amd64 v0.1.19 // indirect - github.com/duckdb/duckdb-go-bindings/darwin-arm64 v0.1.19 // indirect - github.com/duckdb/duckdb-go-bindings/linux-amd64 v0.1.19 // indirect - github.com/duckdb/duckdb-go-bindings/linux-arm64 v0.1.19 // indirect - github.com/duckdb/duckdb-go-bindings/windows-amd64 v0.1.19 // indirect + github.com/duckdb/duckdb-go-bindings v0.1.20 // indirect + github.com/duckdb/duckdb-go-bindings/darwin-amd64 v0.1.20 // indirect + github.com/duckdb/duckdb-go-bindings/darwin-arm64 v0.1.20 // indirect + github.com/duckdb/duckdb-go-bindings/linux-amd64 v0.1.20 // indirect + github.com/duckdb/duckdb-go-bindings/linux-arm64 v0.1.20 // indirect + github.com/duckdb/duckdb-go-bindings/windows-amd64 v0.1.20 // indirect github.com/goccy/go-json v0.10.5 // indirect github.com/google/flatbuffers v25.2.10+incompatible // indirect github.com/klauspost/compress v1.18.0 // indirect diff --git a/go.sum b/go.sum index d2d21686..71df26a9 100644 --- a/go.sum +++ b/go.sum @@ -7,18 +7,18 @@ github.com/apache/thrift v0.22.0/go.mod h1:1e7J/O1Ae6ZQMTYdy9xa3w9k+XHWPfRvdPyJe github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/duckdb/duckdb-go-bindings v0.1.19 h1:t8fwgKlr/5BEa5TJzvo3Vdr3yAgoYiR7L/TqyMuUQ2k= -github.com/duckdb/duckdb-go-bindings v0.1.19/go.mod h1:pBnfviMzANT/9hi4bg+zW4ykRZZPCXlVuvBWEcZofkc= -github.com/duckdb/duckdb-go-bindings/darwin-amd64 v0.1.19 h1:CdNZfRcFUFxI4Q+1Tu4TBFln9tkIn6bDwVwh9LeEsoo= -github.com/duckdb/duckdb-go-bindings/darwin-amd64 v0.1.19/go.mod h1:Ezo7IbAfB8NP7CqPIN8XEHKUg5xdRRQhcPPlCXImXYA= -github.com/duckdb/duckdb-go-bindings/darwin-arm64 v0.1.19 h1:mVijr3WFz3TXZLtAm5Hb6qEnstacZdFI5QQNuE9R2QQ= -github.com/duckdb/duckdb-go-bindings/darwin-arm64 v0.1.19/go.mod h1:eS7m/mLnPQgVF4za1+xTyorKRBuK0/BA44Oy6DgrGXI= -github.com/duckdb/duckdb-go-bindings/linux-amd64 v0.1.19 h1:jhchUY24T5bQLOwGyK0BzB6+HQmsRjAbgUZDKWo4ajs= -github.com/duckdb/duckdb-go-bindings/linux-amd64 v0.1.19/go.mod h1:1GOuk1PixiESxLaCGFhag+oFi7aP+9W8byymRAvunBk= -github.com/duckdb/duckdb-go-bindings/linux-arm64 v0.1.19 h1:CFcH+Bze2OgTaTLM94P3gJ554alnCCDnt1BH/nO8RJ8= -github.com/duckdb/duckdb-go-bindings/linux-arm64 v0.1.19/go.mod h1:o7crKMpT2eOIi5/FY6HPqaXcvieeLSqdXXaXbruGX7w= -github.com/duckdb/duckdb-go-bindings/windows-amd64 v0.1.19 h1:x/8t04sgCVU8JL0XLUZWmC1FAX13ZjM58EmsyPjvrvY= -github.com/duckdb/duckdb-go-bindings/windows-amd64 v0.1.19/go.mod h1:IlOhJdVKUJCAPj3QsDszUo8DVdvp1nBFp4TUJVdw99s= +github.com/duckdb/duckdb-go-bindings v0.1.20 h1:k9TOW6/oMSrGG+j7TVbi2CeTBuN0BEYFZ9IgI9zKXFE= +github.com/duckdb/duckdb-go-bindings v0.1.20/go.mod h1:pBnfviMzANT/9hi4bg+zW4ykRZZPCXlVuvBWEcZofkc= +github.com/duckdb/duckdb-go-bindings/darwin-amd64 v0.1.20 h1:4q5OfbXLoJZ2lbb74ttohBj8Lhz8CbyqVSCTH907VhI= +github.com/duckdb/duckdb-go-bindings/darwin-amd64 v0.1.20/go.mod h1:Ezo7IbAfB8NP7CqPIN8XEHKUg5xdRRQhcPPlCXImXYA= +github.com/duckdb/duckdb-go-bindings/darwin-arm64 v0.1.20 h1:Wb6Um5ocqcXQCL4bZ1263jUnPeVgjeR/s3JvUBKwaN0= +github.com/duckdb/duckdb-go-bindings/darwin-arm64 v0.1.20/go.mod h1:eS7m/mLnPQgVF4za1+xTyorKRBuK0/BA44Oy6DgrGXI= +github.com/duckdb/duckdb-go-bindings/linux-amd64 v0.1.20 h1:U8osykHQE1CWhJO+cZpQW4GSy2++gUgWEJ9UD1phEc0= +github.com/duckdb/duckdb-go-bindings/linux-amd64 v0.1.20/go.mod h1:1GOuk1PixiESxLaCGFhag+oFi7aP+9W8byymRAvunBk= +github.com/duckdb/duckdb-go-bindings/linux-arm64 v0.1.20 h1:SahHKRsqg3Nr5vCyXosbN1gtv26jw4JJUGShNJI3T/I= +github.com/duckdb/duckdb-go-bindings/linux-arm64 v0.1.20/go.mod h1:o7crKMpT2eOIi5/FY6HPqaXcvieeLSqdXXaXbruGX7w= +github.com/duckdb/duckdb-go-bindings/windows-amd64 v0.1.20 h1:0IJWGjcRYAX7aSNtIytGsR+X9C3MUGuku1KkV4cif3g= +github.com/duckdb/duckdb-go-bindings/windows-amd64 v0.1.20/go.mod h1:IlOhJdVKUJCAPj3QsDszUo8DVdvp1nBFp4TUJVdw99s= github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs= github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= @@ -44,10 +44,10 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/marcboeker/go-duckdb/arrowmapping v0.0.19 h1:kMxJBauR2+jwRoSFjiL/DysQtKRBCkNSLZz7GUvEG8A= -github.com/marcboeker/go-duckdb/arrowmapping v0.0.19/go.mod h1:19JWoch6I++gIrWUz1MLImIoFGri9yL54JaWn/Ujvbo= -github.com/marcboeker/go-duckdb/mapping v0.0.19 h1:xZ7LCyFZZm/4X631lOZY74p3QHINMnWJ+OakKw5d3Ao= -github.com/marcboeker/go-duckdb/mapping v0.0.19/go.mod h1:Kz9xYOkhhkgCaGgAg34ciKaks9ED2V7BzHzG6dnVo/o= +github.com/marcboeker/go-duckdb/arrowmapping v0.0.20 h1:VJ3wcHr1rFC/XgCQ2FpS/aHim4dnFrhlXq9ahBgE1es= +github.com/marcboeker/go-duckdb/arrowmapping v0.0.20/go.mod h1:I/7TU+sG5WePL3yASm1MA7A6qDQM9PjcRFrSenelTeg= +github.com/marcboeker/go-duckdb/mapping v0.0.20 h1:USRyGoWOGEBS1owldvBvHw5/t+Sgb+d/QHWs6/56I1A= +github.com/marcboeker/go-duckdb/mapping v0.0.20/go.mod h1:NrLL1bOu0XMiaUs1CGiB9z7/HAO8+LPgoK1KmWYGeyk= github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8 h1:AMFGa4R4MiIpspGNG7Z948v4n35fFGB3RR3G/ry4FWs= github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8/go.mod h1:mC1jAcsrzbxHt8iiaC+zU4b1ylILSosueou12R++wfY= github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3 h1:+n/aFZefKZp7spd8DFdX7uMikMLXX4oubIzJF4kv/wI= diff --git a/mapping/go.mod b/mapping/go.mod index 5b0b2dc7..e8b74780 100644 --- a/mapping/go.mod +++ b/mapping/go.mod @@ -3,10 +3,10 @@ module github.com/marcboeker/go-duckdb/mapping go 1.24 require ( - github.com/duckdb/duckdb-go-bindings v0.1.19 - github.com/duckdb/duckdb-go-bindings/darwin-amd64 v0.1.19 - github.com/duckdb/duckdb-go-bindings/darwin-arm64 v0.1.19 - github.com/duckdb/duckdb-go-bindings/linux-amd64 v0.1.19 - github.com/duckdb/duckdb-go-bindings/linux-arm64 v0.1.19 - github.com/duckdb/duckdb-go-bindings/windows-amd64 v0.1.19 + github.com/duckdb/duckdb-go-bindings v0.1.20 + github.com/duckdb/duckdb-go-bindings/darwin-amd64 v0.1.20 + github.com/duckdb/duckdb-go-bindings/darwin-arm64 v0.1.20 + github.com/duckdb/duckdb-go-bindings/linux-amd64 v0.1.20 + github.com/duckdb/duckdb-go-bindings/linux-arm64 v0.1.20 + github.com/duckdb/duckdb-go-bindings/windows-amd64 v0.1.20 ) diff --git a/mapping/go.sum b/mapping/go.sum index b96bda22..669cb112 100644 --- a/mapping/go.sum +++ b/mapping/go.sum @@ -1,17 +1,17 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/duckdb/duckdb-go-bindings v0.1.19 h1:t8fwgKlr/5BEa5TJzvo3Vdr3yAgoYiR7L/TqyMuUQ2k= -github.com/duckdb/duckdb-go-bindings v0.1.19/go.mod h1:pBnfviMzANT/9hi4bg+zW4ykRZZPCXlVuvBWEcZofkc= -github.com/duckdb/duckdb-go-bindings/darwin-amd64 v0.1.19 h1:CdNZfRcFUFxI4Q+1Tu4TBFln9tkIn6bDwVwh9LeEsoo= -github.com/duckdb/duckdb-go-bindings/darwin-amd64 v0.1.19/go.mod h1:Ezo7IbAfB8NP7CqPIN8XEHKUg5xdRRQhcPPlCXImXYA= -github.com/duckdb/duckdb-go-bindings/darwin-arm64 v0.1.19 h1:mVijr3WFz3TXZLtAm5Hb6qEnstacZdFI5QQNuE9R2QQ= -github.com/duckdb/duckdb-go-bindings/darwin-arm64 v0.1.19/go.mod h1:eS7m/mLnPQgVF4za1+xTyorKRBuK0/BA44Oy6DgrGXI= -github.com/duckdb/duckdb-go-bindings/linux-amd64 v0.1.19 h1:jhchUY24T5bQLOwGyK0BzB6+HQmsRjAbgUZDKWo4ajs= -github.com/duckdb/duckdb-go-bindings/linux-amd64 v0.1.19/go.mod h1:1GOuk1PixiESxLaCGFhag+oFi7aP+9W8byymRAvunBk= -github.com/duckdb/duckdb-go-bindings/linux-arm64 v0.1.19 h1:CFcH+Bze2OgTaTLM94P3gJ554alnCCDnt1BH/nO8RJ8= -github.com/duckdb/duckdb-go-bindings/linux-arm64 v0.1.19/go.mod h1:o7crKMpT2eOIi5/FY6HPqaXcvieeLSqdXXaXbruGX7w= -github.com/duckdb/duckdb-go-bindings/windows-amd64 v0.1.19 h1:x/8t04sgCVU8JL0XLUZWmC1FAX13ZjM58EmsyPjvrvY= -github.com/duckdb/duckdb-go-bindings/windows-amd64 v0.1.19/go.mod h1:IlOhJdVKUJCAPj3QsDszUo8DVdvp1nBFp4TUJVdw99s= +github.com/duckdb/duckdb-go-bindings v0.1.20 h1:k9TOW6/oMSrGG+j7TVbi2CeTBuN0BEYFZ9IgI9zKXFE= +github.com/duckdb/duckdb-go-bindings v0.1.20/go.mod h1:pBnfviMzANT/9hi4bg+zW4ykRZZPCXlVuvBWEcZofkc= +github.com/duckdb/duckdb-go-bindings/darwin-amd64 v0.1.20 h1:4q5OfbXLoJZ2lbb74ttohBj8Lhz8CbyqVSCTH907VhI= +github.com/duckdb/duckdb-go-bindings/darwin-amd64 v0.1.20/go.mod h1:Ezo7IbAfB8NP7CqPIN8XEHKUg5xdRRQhcPPlCXImXYA= +github.com/duckdb/duckdb-go-bindings/darwin-arm64 v0.1.20 h1:Wb6Um5ocqcXQCL4bZ1263jUnPeVgjeR/s3JvUBKwaN0= +github.com/duckdb/duckdb-go-bindings/darwin-arm64 v0.1.20/go.mod h1:eS7m/mLnPQgVF4za1+xTyorKRBuK0/BA44Oy6DgrGXI= +github.com/duckdb/duckdb-go-bindings/linux-amd64 v0.1.20 h1:U8osykHQE1CWhJO+cZpQW4GSy2++gUgWEJ9UD1phEc0= +github.com/duckdb/duckdb-go-bindings/linux-amd64 v0.1.20/go.mod h1:1GOuk1PixiESxLaCGFhag+oFi7aP+9W8byymRAvunBk= +github.com/duckdb/duckdb-go-bindings/linux-arm64 v0.1.20 h1:SahHKRsqg3Nr5vCyXosbN1gtv26jw4JJUGShNJI3T/I= +github.com/duckdb/duckdb-go-bindings/linux-arm64 v0.1.20/go.mod h1:o7crKMpT2eOIi5/FY6HPqaXcvieeLSqdXXaXbruGX7w= +github.com/duckdb/duckdb-go-bindings/windows-amd64 v0.1.20 h1:0IJWGjcRYAX7aSNtIytGsR+X9C3MUGuku1KkV4cif3g= +github.com/duckdb/duckdb-go-bindings/windows-amd64 v0.1.20/go.mod h1:IlOhJdVKUJCAPj3QsDszUo8DVdvp1nBFp4TUJVdw99s= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= diff --git a/profiling_test.go b/profiling_test.go index 08048a44..59c1c7f0 100644 --- a/profiling_test.go +++ b/profiling_test.go @@ -13,7 +13,10 @@ func TestProfiling(t *testing.T) { conn := openConnWrapper(t, db, context.Background()) defer closeConnWrapper(t, conn) - _, err := conn.ExecContext(context.Background(), `PRAGMA enable_profiling = 'no_output'`) + _, err := GetProfilingInfo(conn) + require.ErrorContains(t, err, errProfilingInfoEmpty.Error()) + + _, err = conn.ExecContext(context.Background(), `PRAGMA enable_profiling = 'no_output'`) require.NoError(t, err) _, err = conn.ExecContext(context.Background(), `PRAGMA profiling_mode = 'detailed'`) require.NoError(t, err) @@ -25,13 +28,16 @@ func TestProfiling(t *testing.T) { info, err := GetProfilingInfo(conn) require.NoError(t, err) - _, err = conn.ExecContext(context.Background(), `PRAGMA disable_profiling`) - require.NoError(t, err) - // Verify the metrics. require.NotEmpty(t, info.Metrics, "metrics must not be empty") require.NotEmpty(t, info.Children, "children must not be empty") require.NotEmpty(t, info.Children[0].Metrics, "child metrics must not be empty") + + _, err = conn.ExecContext(context.Background(), `PRAGMA disable_profiling`) + require.NoError(t, err) + + info, err = GetProfilingInfo(conn) + require.ErrorContains(t, err, errProfilingInfoEmpty.Error()) } func TestErrProfiling(t *testing.T) { diff --git a/rows.go b/rows.go index 6aa09bd6..25b236ce 100644 --- a/rows.go +++ b/rows.go @@ -4,10 +4,8 @@ import ( "database/sql/driver" "fmt" "io" - "math/big" "reflect" "strings" - "time" "github.com/marcboeker/go-duckdb/mapping" ) @@ -28,22 +26,33 @@ type rows struct { chunkIdx mapping.IdxT // rowCount is the number of scanned rows. rowCount int + // cached column metadata to avoid repeated CGO calls + scanTypes []reflect.Type + dbTypeNames []string } func newRowsWithStmt(res mapping.Result, stmt *Stmt) *rows { columnCount := mapping.ColumnCount(&res) r := rows{ - res: res, - stmt: stmt, - chunk: DataChunk{}, - chunkCount: mapping.ResultChunkCount(res), - chunkIdx: 0, - rowCount: 0, + res: res, + stmt: stmt, + chunk: DataChunk{}, + chunkCount: mapping.ResultChunkCount(res), + chunkIdx: 0, + rowCount: 0, + scanTypes: make([]reflect.Type, columnCount), + dbTypeNames: make([]string, columnCount), } for i := mapping.IdxT(0); i < columnCount; i++ { columnName := mapping.ColumnName(&res, i) r.chunk.columnNames = append(r.chunk.columnNames, columnName) + + // Cache column metadata + logicalType := mapping.ColumnLogicalType(&res, i) + r.scanTypes[i] = r.getScanType(logicalType, i) + r.dbTypeNames[i] = logicalTypeString(logicalType) + mapping.DestroyLogicalType(&logicalType) } return &r @@ -86,64 +95,65 @@ func (r *rows) Next(dst []driver.Value) error { // ColumnTypeScanType implements driver.RowsColumnTypeScanType. func (r *rows) ColumnTypeScanType(index int) reflect.Type { - logicalType := mapping.ColumnLogicalType(&r.res, mapping.IdxT(index)) - defer mapping.DestroyLogicalType(&logicalType) + return r.scanTypes[index] +} +func (r *rows) getScanType(logicalType mapping.LogicalType, index mapping.IdxT) reflect.Type { alias := mapping.LogicalTypeGetAlias(logicalType) if alias == aliasJSON { - return reflect.TypeFor[any]() + return reflectTypeAny } - t := mapping.ColumnType(&r.res, mapping.IdxT(index)) + t := mapping.ColumnType(&r.res, index) switch t { case TYPE_INVALID: return nil case TYPE_BOOLEAN: - return reflect.TypeOf(true) + return reflectTypeBool case TYPE_TINYINT: - return reflect.TypeOf(int8(0)) + return reflectTypeInt8 case TYPE_SMALLINT: - return reflect.TypeOf(int16(0)) + return reflectTypeInt16 case TYPE_INTEGER: - return reflect.TypeOf(int32(0)) + return reflectTypeInt32 case TYPE_BIGINT: - return reflect.TypeOf(int64(0)) + return reflectTypeInt64 case TYPE_UTINYINT: - return reflect.TypeOf(uint8(0)) + return reflectTypeUint8 case TYPE_USMALLINT: - return reflect.TypeOf(uint16(0)) + return reflectTypeUint16 case TYPE_UINTEGER: - return reflect.TypeOf(uint32(0)) + return reflectTypeUint32 case TYPE_UBIGINT: - return reflect.TypeOf(uint64(0)) + return reflectTypeUint64 case TYPE_FLOAT: - return reflect.TypeOf(float32(0)) + return reflectTypeFloat32 case TYPE_DOUBLE: - return reflect.TypeOf(float64(0)) + return reflectTypeFloat64 case TYPE_TIMESTAMP, TYPE_TIMESTAMP_S, TYPE_TIMESTAMP_MS, TYPE_TIMESTAMP_NS, TYPE_DATE, TYPE_TIME, TYPE_TIME_TZ, TYPE_TIMESTAMP_TZ: - return reflect.TypeOf(time.Time{}) + return reflectTypeTime case TYPE_INTERVAL: - return reflect.TypeOf(Interval{}) + return reflectTypeInterval case TYPE_HUGEINT: - return reflect.TypeOf(big.NewInt(0)) + return reflectTypeBigInt case TYPE_VARCHAR, TYPE_ENUM: - return reflect.TypeOf("") + return reflectTypeString case TYPE_BLOB: - return reflect.TypeOf([]byte{}) + return reflectTypeBytes case TYPE_DECIMAL: - return reflect.TypeOf(Decimal{}) + return reflectTypeDecimal case TYPE_LIST: - return reflect.TypeOf([]any{}) + return reflectTypeSliceAny case TYPE_STRUCT: - return reflect.TypeOf(map[string]any{}) + return reflectTypeMapString case TYPE_MAP: - return reflect.TypeOf(Map{}) + return reflectTypeMap case TYPE_ARRAY: - return reflect.TypeOf([]any{}) + return reflectTypeSliceAny case TYPE_UNION: - return reflect.TypeOf(Union{}) + return reflectTypeUnion case TYPE_UUID: - return reflect.TypeOf([]byte{}) + return reflectTypeBytes default: return nil } @@ -151,21 +161,7 @@ func (r *rows) ColumnTypeScanType(index int) reflect.Type { // ColumnTypeDatabaseTypeName implements driver.RowsColumnTypeScanType. func (r *rows) ColumnTypeDatabaseTypeName(index int) string { - logicalType := mapping.ColumnLogicalType(&r.res, mapping.IdxT(index)) - defer mapping.DestroyLogicalType(&logicalType) - - alias := mapping.LogicalTypeGetAlias(logicalType) - if alias == aliasJSON { - return aliasJSON - } - - t := mapping.ColumnType(&r.res, mapping.IdxT(index)) - switch t { - case TYPE_DECIMAL, TYPE_ENUM, TYPE_LIST, TYPE_STRUCT, TYPE_MAP, TYPE_ARRAY, TYPE_UNION: - return logicalTypeName(logicalType) - default: - return typeToStringMap[t] - } + return r.dbTypeNames[index] } func (r *rows) Close() error { @@ -186,6 +182,22 @@ func (r *rows) Close() error { return err } +// logicalTypeString converts a LogicalType to its string representation. +func logicalTypeString(logicalType mapping.LogicalType) string { + alias := mapping.LogicalTypeGetAlias(logicalType) + if alias == aliasJSON { + return aliasJSON + } + + t := mapping.GetTypeId(logicalType) + switch t { + case TYPE_DECIMAL, TYPE_ENUM, TYPE_LIST, TYPE_STRUCT, TYPE_MAP, TYPE_ARRAY, TYPE_UNION: + return logicalTypeName(logicalType) + default: + return typeToStringMap[t] + } +} + func logicalTypeName(logicalType mapping.LogicalType) string { t := mapping.GetTypeId(logicalType) switch t { @@ -223,22 +235,29 @@ func logicalTypeNameList(logicalType mapping.LogicalType) string { } func logicalTypeNameStruct(logicalType mapping.LogicalType) string { + var sb strings.Builder + sb.WriteString("STRUCT(") + count := mapping.StructTypeChildCount(logicalType) - name := "STRUCT(" for i := mapping.IdxT(0); i < count; i++ { + if i > 0 { + sb.WriteString(", ") + } + childName := mapping.StructTypeChildName(logicalType, i) childType := mapping.StructTypeChildType(logicalType, i) - // Add comma if not at the end of the list. - name += escapeStructFieldName(childName) + " " + logicalTypeName(childType) - if i != count-1 { - name += ", " - } + sb.WriteString(escapeStructFieldName(childName)) + sb.WriteByte(' ') + sb.WriteString(logicalTypeName(childType)) + mapping.DestroyLogicalType(&childType) } - return name + ")" + sb.WriteByte(')') + + return sb.String() } func logicalTypeNameMap(logicalType mapping.LogicalType) string { @@ -261,22 +280,28 @@ func logicalTypeNameArray(logicalType mapping.LogicalType) string { } func logicalTypeNameUnion(logicalType mapping.LogicalType) string { + var sb strings.Builder + sb.WriteString("UNION(") + count := int(mapping.UnionTypeMemberCount(logicalType)) - name := "UNION(" for i := range count { + if i > 0 { + sb.WriteString(", ") + } + memberName := mapping.UnionTypeMemberName(logicalType, mapping.IdxT(i)) memberType := mapping.UnionTypeMemberType(logicalType, mapping.IdxT(i)) - // Add comma if not at the end of the list - name += escapeStructFieldName(memberName) + " " + logicalTypeName(memberType) - if i != count-1 { - name += ", " - } + sb.WriteString(escapeStructFieldName(memberName)) + sb.WriteByte(' ') + sb.WriteString(logicalTypeName(memberType)) mapping.DestroyLogicalType(&memberType) } - return name + ")" + + sb.WriteByte(')') + return sb.String() } func escapeStructFieldName(s string) string { diff --git a/scalar_udf.go b/scalar_udf.go index fca81a3c..79304ae6 100644 --- a/scalar_udf.go +++ b/scalar_udf.go @@ -7,6 +7,9 @@ typedef void (*scalar_udf_callback_t)(void *, void *, void *); void scalar_udf_delete_callback(void *); typedef void (*scalar_udf_delete_callback_t)(void *); +void *scalar_udf_bind_copy_callback(void *); +typedef void *(*scalar_udf_bind_copy_callback_t)(void *); + void scalar_udf_bind_callback(void *); typedef void (*scalar_udf_bind_callback_t)(void *); */ @@ -189,9 +192,9 @@ func scalar_udf_callback(functionInfoPtr, inputPtr, outputPtr unsafe.Pointer) { nullInNullOut := !function.Config().SpecialNullHandling bindDataPtr := mapping.ScalarFunctionGetBindData(functionInfo) - info := getPinned[bindInfo](bindDataPtr) + info := getPinned[*bindInfo](bindDataPtr) - f := function.RowExecutor(&info) + f := function.RowExecutor(info) values := make([]driver.Value, len(inputChunk.columns)) // Execute the user-defined scalar function for each row. @@ -241,6 +244,22 @@ func scalar_udf_delete_callback(info unsafe.Pointer) { h.Delete() } +//export scalar_udf_bind_copy_callback +func scalar_udf_bind_copy_callback(dataPtr unsafe.Pointer) unsafe.Pointer { + // Copy and pin the bind data. + data := getPinned[*bindInfo](dataPtr) + dataCopy := *data + + value := pinnedValue[*bindInfo]{ + pinner: &runtime.Pinner{}, + value: &dataCopy, + } + h := cgo.NewHandle(value) + value.pinner.Pin(&h) + + return unsafe.Pointer(&h) +} + //export scalar_udf_bind_callback func scalar_udf_bind_callback(bindInfoPtr unsafe.Pointer) { info := mapping.BindInfo{Ptr: bindInfoPtr} @@ -256,10 +275,14 @@ func scalar_udf_bind_callback(bindInfoPtr unsafe.Pointer) { id := mapping.ClientContextGetConnectionId(ctx) data := bindInfo{connId: uint64(id)} + // Set the copy callback of the bind info. + copyPtr := unsafe.Pointer(C.scalar_udf_bind_copy_callback_t(C.scalar_udf_bind_copy_callback)) + mapping.ScalarFunctionSetBindDataCopy(info, copyPtr) + // Pin the bind data. - value := pinnedValue[bindInfo]{ + value := pinnedValue[*bindInfo]{ pinner: &runtime.Pinner{}, - value: data, + value: &data, } h := cgo.NewHandle(value) value.pinner.Pin(&h) diff --git a/scalar_udf_test.go b/scalar_udf_test.go index d76c84ab..a61f5d33 100644 --- a/scalar_udf_test.go +++ b/scalar_udf_test.go @@ -141,7 +141,7 @@ func (*anyTypeSUDF) Executor() ScalarFuncExecutor { } func (*getConnIdUDF) Config() ScalarFuncConfig { - return ScalarFuncConfig{[]TypeInfo{}, currentInfo, nil, false, false} + return ScalarFuncConfig{[]TypeInfo{}, currentInfo, nil, true, false} } func (*getConnIdUDF) Executor() ScalarFuncExecutor { @@ -481,6 +481,11 @@ func TestGetConnIdScalarUDF(t *testing.T) { row = conn2.QueryRowContext(ctx, `SELECT get_conn_id() AS connId`) require.NoError(t, row.Scan(&connId)) require.Equal(t, conn2Id, connId) + + var res bool + row = conn2.QueryRowContext(ctx, fmt.Sprintf(`SELECT true AS res WHERE get_conn_id() = %d`, conn2Id)) + require.NoError(t, row.Scan(&res)) + require.True(t, res) } func TestErrScalarUDF(t *testing.T) { diff --git a/statement.go b/statement.go index 48999d18..0834667f 100644 --- a/statement.go +++ b/statement.go @@ -170,19 +170,19 @@ func (s *Stmt) bindTimestamp(val driver.NamedValue, t Type, n int) (mapping.Stat var state mapping.State switch t { case TYPE_TIMESTAMP: - v, err := getMappedTimestamp(t, val.Value) + v, err := inferTimestamp(t, val.Value) if err != nil { return mapping.StateError, err } state = mapping.BindTimestamp(*s.preparedStmt, mapping.IdxT(n+1), v) case TYPE_TIMESTAMP_TZ: - v, err := getMappedTimestamp(t, val.Value) + v, err := inferTimestamp(t, val.Value) if err != nil { return mapping.StateError, err } state = mapping.BindTimestampTZ(*s.preparedStmt, mapping.IdxT(n+1), v) case TYPE_TIMESTAMP_S: - v, err := getMappedTimestampS(val.Value) + v, err := inferTimestampS(val.Value) if err != nil { return mapping.StateError, err } @@ -190,7 +190,7 @@ func (s *Stmt) bindTimestamp(val driver.NamedValue, t Type, n int) (mapping.Stat state = mapping.BindValue(*s.preparedStmt, mapping.IdxT(n+1), tS) mapping.DestroyValue(&tS) case TYPE_TIMESTAMP_MS: - v, err := getMappedTimestampMS(val.Value) + v, err := inferTimestampMS(val.Value) if err != nil { return mapping.StateError, err } @@ -198,7 +198,7 @@ func (s *Stmt) bindTimestamp(val driver.NamedValue, t Type, n int) (mapping.Stat state = mapping.BindValue(*s.preparedStmt, mapping.IdxT(n+1), tMS) mapping.DestroyValue(&tMS) case TYPE_TIMESTAMP_NS: - v, err := getMappedTimestampNS(val.Value) + v, err := inferTimestampNS(val.Value) if err != nil { return mapping.StateError, err } @@ -210,7 +210,7 @@ func (s *Stmt) bindTimestamp(val driver.NamedValue, t Type, n int) (mapping.Stat } func (s *Stmt) bindDate(val driver.NamedValue, n int) (mapping.State, error) { - date, err := getMappedDate(val.Value) + date, err := inferDate(val.Value) if err != nil { return mapping.StateError, err } @@ -282,7 +282,7 @@ func (s *Stmt) bindCompositeValue(val driver.NamedValue, n int) (mapping.State, } func (s *Stmt) tryBindComplexValue(val driver.NamedValue, n int) (mapping.State, error) { - lt, mappedVal, err := createValueByReflection(val.Value) + lt, mappedVal, err := inferLogicalTypeAndValue(val.Value) defer mapping.DestroyLogicalType(<) defer mapping.DestroyValue(&mappedVal) if err != nil { @@ -393,7 +393,11 @@ func (s *Stmt) bindValue(val driver.NamedValue, n int) (mapping.State, error) { case []byte: return mapping.BindBlob(*s.preparedStmt, mapping.IdxT(n+1), v), nil case Interval: - return mapping.BindInterval(*s.preparedStmt, mapping.IdxT(n+1), v.getMappedInterval()), nil + i, inferErr := inferInterval(v) + if inferErr != nil { + return mapping.StateError, inferErr + } + return mapping.BindInterval(*s.preparedStmt, mapping.IdxT(n+1), i), nil case nil: return mapping.BindNull(*s.preparedStmt, mapping.IdxT(n+1)), nil } diff --git a/statement_test.go b/statement_test.go index 41c4b3f0..95645385 100644 --- a/statement_test.go +++ b/statement_test.go @@ -842,5 +842,31 @@ func TestDriverValuer(t *testing.T) { // Expected to fail - no driver.Valuer implementation _, err = db.Exec(`INSERT INTO valuer_test (ids) VALUES (?)`, []uuid.UUID{uuid.MustParse("123e4567-e89b-12d3-a456-426614174000"), uuid.MustParse("3a92e387-4b7d-4098-b273-967d48f6925f")}) require.Error(t, err, "[]uuid.UUID should fail without driver.Valuer") - require.Contains(t, err.Error(), "unsupported data type: UUID") + require.Contains(t, err.Error(), castErrMsg) +} + +func TestMixedTypeSliceBinding(t *testing.T) { + db := openDbWrapper(t, ``) + defer closeDbWrapper(t, db) + + createTable(t, db, `CREATE TABLE mixed_slice_test (foo TEXT)`) + _, err := db.Exec(`INSERT INTO mixed_slice_test VALUES ('hello this is text'), ('other')`) + require.NoError(t, err) + + sameTypeSlice := []any{"hello this is text", "other"} + _, err = db.Query("FROM mixed_slice_test WHERE foo IN ?", sameTypeSlice) + require.NoError(t, err) + + mixedSlice := []any{"hello this is text", 2} + + _, err = db.Query("FROM mixed_slice_test WHERE foo IN ?", mixedSlice) + require.ErrorContains(t, err, "mixed types in slice: cannot bind VARCHAR (index 0) and BIGINT (index 1)") + + // Same test with named parameters + _, err = db.Query("FROM mixed_slice_test WHERE foo IN $foo", sql.Named("foo", mixedSlice)) + require.ErrorContains(t, err, "mixed types in slice: cannot bind VARCHAR (index 0) and BIGINT (index 1)") + + nestedMixedSlice := [][]any{{"hello this is text"}, {2}} + _, err = db.Query("FROM mixed_slice_test WHERE foo IN ?", nestedMixedSlice) + require.ErrorContains(t, err, "mixed types in slice: cannot bind VARCHAR[] (index 0) and BIGINT[] (index 1)") } diff --git a/table_source.go b/table_source.go index 1609e7d2..611dc611 100644 --- a/table_source.go +++ b/table_source.go @@ -93,6 +93,7 @@ type ( ) // ParallelRow wrapper + func (s parallelRowTSWrapper) ColumnInfos() []ColumnInfo { return s.s.ColumnInfos() } @@ -117,6 +118,7 @@ func (s parallelRowTSWrapper) FillRow(ls any, chunk Row) (bool, error) { } // ParallelChunk wrapper + func (s parallelChunkTSWrapper) ColumnInfos() []ColumnInfo { return s.s.ColumnInfos() } diff --git a/type.go b/type.go index 64d9e4d4..b7cd351e 100644 --- a/type.go +++ b/type.go @@ -45,6 +45,7 @@ const ( TYPE_ANY = mapping.TypeAny TYPE_BIGNUM = mapping.TypeBigNum TYPE_SQLNULL = mapping.TypeSQLNull + // TODO: add TYPE_TIME_NS here, or support it. ) // FIXME: Implement support for these types. diff --git a/type_info.go b/type_info.go index 14d8b919..3b8033fb 100644 --- a/type_info.go +++ b/type_info.go @@ -324,7 +324,7 @@ func (info *typeInfo) logicalListType() mapping.LogicalType { func (info *typeInfo) logicalStructType() mapping.LogicalType { var types []mapping.LogicalType - defer destroyLogicalTypes(&types) + defer destroyLogicalTypes(types) var names []string for _, entry := range info.structEntries { @@ -350,7 +350,7 @@ func (info *typeInfo) logicalArrayType() mapping.LogicalType { func (info *typeInfo) logicalUnionType() mapping.LogicalType { var types []mapping.LogicalType - defer destroyLogicalTypes(&types) + defer destroyLogicalTypes(types) for _, t := range info.types { types = append(types, t.logicalType()) } @@ -361,8 +361,8 @@ func funcName(i interface{}) string { return runtime.FuncForPC(reflect.ValueOf(i).Pointer()).Name() } -func destroyLogicalTypes(types *[]mapping.LogicalType) { - for _, t := range *types { +func destroyLogicalTypes(types []mapping.LogicalType) { + for _, t := range types { mapping.DestroyLogicalType(&t) } } diff --git a/types.go b/types.go index 8795b41d..4b775103 100644 --- a/types.go +++ b/types.go @@ -16,6 +16,35 @@ import ( "github.com/marcboeker/go-duckdb/mapping" ) +// go-duckdb exports the following type wrappers: +// UUID, Map, Interval, Decimal, Union, Composite (optional, used to scan LIST and STRUCT). + +// Pre-computed reflect type values to avoid repeated allocations. +var ( + reflectTypeBool = reflect.TypeOf(true) + reflectTypeInt8 = reflect.TypeOf(int8(0)) + reflectTypeInt16 = reflect.TypeOf(int16(0)) + reflectTypeInt32 = reflect.TypeOf(int32(0)) + reflectTypeInt64 = reflect.TypeOf(int64(0)) + reflectTypeUint8 = reflect.TypeOf(uint8(0)) + reflectTypeUint16 = reflect.TypeOf(uint16(0)) + reflectTypeUint32 = reflect.TypeOf(uint32(0)) + reflectTypeUint64 = reflect.TypeOf(uint64(0)) + reflectTypeFloat32 = reflect.TypeOf(float32(0)) + reflectTypeFloat64 = reflect.TypeOf(float64(0)) + reflectTypeTime = reflect.TypeOf(time.Time{}) + reflectTypeInterval = reflect.TypeOf(Interval{}) + reflectTypeBigInt = reflect.TypeOf(big.NewInt(0)) + reflectTypeString = reflect.TypeOf("") + reflectTypeBytes = reflect.TypeOf([]byte{}) + reflectTypeDecimal = reflect.TypeOf(Decimal{}) + reflectTypeSliceAny = reflect.TypeOf([]any{}) + reflectTypeMapString = reflect.TypeOf(map[string]any{}) + reflectTypeMap = reflect.TypeOf(Map{}) + reflectTypeUnion = reflect.TypeOf(Union{}) + reflectTypeAny = reflect.TypeFor[any]() +) + type numericType interface { int | int8 | int16 | int32 | int64 | uint | uint8 | uint16 | uint32 | uint64 | float32 | float64 } @@ -66,6 +95,27 @@ func (u *UUID) Value() (driver.Value, error) { return u.String(), nil } +func inferUUID(val any) (mapping.HugeInt, error) { + var id UUID + switch v := val.(type) { + case UUID: + id = v + case *UUID: + id = *v + case []uint8: + if len(v) != uuidLength { + return mapping.HugeInt{}, castError(reflect.TypeOf(val).String(), reflect.TypeOf(id).String()) + } + for i := range uuidLength { + id[i] = v[i] + } + default: + return mapping.HugeInt{}, castError(reflect.TypeOf(val).String(), reflect.TypeOf(id).String()) + } + hi := uuidToHugeInt(id) + return hi, nil +} + // duckdb_hugeint is composed of (lower, upper) components. // The value is computed as: upper * 2^64 + lower @@ -108,6 +158,67 @@ func hugeIntFromNative(i *big.Int) (mapping.HugeInt, error) { return mapping.NewHugeInt(r.Uint64(), q.Int64()), nil } +func inferHugeInt(val any) (mapping.HugeInt, error) { + var err error + var hi mapping.HugeInt + switch v := val.(type) { + case uint8: + hi = mapping.NewHugeInt(uint64(v), 0) + case int8: + hi = mapping.NewHugeInt(uint64(v), 0) + case uint16: + hi = mapping.NewHugeInt(uint64(v), 0) + case int16: + hi = mapping.NewHugeInt(uint64(v), 0) + case uint32: + hi = mapping.NewHugeInt(uint64(v), 0) + case int32: + hi = mapping.NewHugeInt(uint64(v), 0) + case uint64: + hi = mapping.NewHugeInt(v, 0) + case int64: + hi, err = hugeIntFromNative(big.NewInt(v)) + if err != nil { + return mapping.HugeInt{}, err + } + case uint: + hi = mapping.NewHugeInt(uint64(v), 0) + case int: + hi, err = hugeIntFromNative(big.NewInt(int64(v))) + if err != nil { + return mapping.HugeInt{}, err + } + case float32: + hi, err = hugeIntFromNative(big.NewInt(int64(v))) + if err != nil { + return mapping.HugeInt{}, err + } + case float64: + hi, err = hugeIntFromNative(big.NewInt(int64(v))) + if err != nil { + return mapping.HugeInt{}, err + } + case *big.Int: + if v == nil { + return mapping.HugeInt{}, castError(reflect.TypeOf(val).String(), reflect.TypeOf(hi).String()) + } + if hi, err = hugeIntFromNative(v); err != nil { + return mapping.HugeInt{}, err + } + case Decimal: + if v.Value == nil { + return mapping.HugeInt{}, castError(reflect.TypeOf(val).String(), reflect.TypeOf(hi).String()) + } + if hi, err = hugeIntFromNative(v.Value); err != nil { + return mapping.HugeInt{}, err + } + default: + return mapping.HugeInt{}, castError(reflect.TypeOf(val).String(), reflect.TypeOf(hi).String()) + } + + return hi, nil +} + type Map map[any]any func (m *Map) Scan(v any) error { @@ -134,11 +245,18 @@ type Interval struct { Micros int64 `json:"micros"` } -func (i *Interval) getMappedInterval() mapping.Interval { - return mapping.NewInterval(i.Months, i.Days, i.Micros) +func inferInterval(val any) (mapping.Interval, error) { + var i Interval + switch v := val.(type) { + case Interval: + i = v + default: + return mapping.Interval{}, castError(reflect.TypeOf(val).String(), reflect.TypeOf(i).String()) + } + return mapping.NewInterval(i.Months, i.Days, i.Micros), nil } -// Use as the `Scanner` type for any composite types (maps, lists, structs) +// Composite can be used as the `Scanner` type for any composite types (maps, lists, structs). type Composite[T any] struct { t T } @@ -242,27 +360,27 @@ func getTSTicks(t Type, val any) (int64, error) { return ti.UnixNano(), nil } -func getMappedTimestamp(t Type, val any) (mapping.Timestamp, error) { +func inferTimestamp(t Type, val any) (mapping.Timestamp, error) { ticks, err := getTSTicks(t, val) return mapping.NewTimestamp(ticks), err } -func getMappedTimestampS(val any) (mapping.TimestampS, error) { +func inferTimestampS(val any) (mapping.TimestampS, error) { ticks, err := getTSTicks(TYPE_TIMESTAMP_S, val) return mapping.NewTimestampS(ticks), err } -func getMappedTimestampMS(val any) (mapping.TimestampMS, error) { +func inferTimestampMS(val any) (mapping.TimestampMS, error) { ticks, err := getTSTicks(TYPE_TIMESTAMP_MS, val) return mapping.NewTimestampMS(ticks), err } -func getMappedTimestampNS(val any) (mapping.TimestampNS, error) { +func inferTimestampNS(val any) (mapping.TimestampNS, error) { ticks, err := getTSTicks(TYPE_TIMESTAMP_NS, val) return mapping.NewTimestampNS(ticks), err } -func getMappedDate[T any](val T) (mapping.Date, error) { +func inferDate[T any](val T) (mapping.Date, error) { ti, err := castToTime(val) if err != nil { return mapping.Date{}, err @@ -272,6 +390,23 @@ func getMappedDate[T any](val T) (mapping.Date, error) { return date, err } +func inferTime(val any) (mapping.Time, error) { + ticks, err := getTimeTicks(val) + if err != nil { + return mapping.Time{}, err + } + return mapping.NewTime(ticks), nil +} + +func inferTimeTZ(val any) (mapping.TimeTZ, error) { + ticks, err := getTimeTicks(val) + if err != nil { + return mapping.TimeTZ{}, err + } + // The UTC offset is 0. + return mapping.CreateTimeTZ(ticks, 0), nil +} + func getTimeTicks[T any](val T) (int64, error) { ti, err := castToTime(val) if err != nil { diff --git a/types_test.go b/types_test.go index d5790374..c4ed8344 100644 --- a/types_test.go +++ b/types_test.go @@ -1098,3 +1098,50 @@ func TestUnionTypes(t *testing.T) { require.Equal(t, int32(123), val.Value) }) } + +func TestInferPrimitiveType(t *testing.T) { + db := openDbWrapper(t, ``) + defer closeDbWrapper(t, db) + + testCases := []struct { + input any + }{ + {[]Map{nil}}, + {[]bool{true, false}}, + {[]int8{-7}}, + {[]int16{-42}}, + {[]int32{-4}}, + {[]int64{-6}}, + {[]int{-22}}, + {[]uint8{7}}, + {[]uint16{42}}, + {[]uint32{4}}, + {[]uint64{6}}, + {[]uint{22}}, + {[]float32{7.8}}, + {[]float64{22.3}}, + {[]string{"Hello from Amsterdam!"}}, + {[][]byte{{71, 111}}}, + {[]time.Time{time.Now()}}, + {[]Interval{{22, 10, 7}}}, + {[]*big.Int{big.NewInt(22)}}, + {[]Decimal{{2, 2, big.NewInt(7)}}}, + {[]UUID{UUID(uuid.New())}}, + } + for _, tc := range testCases { + _, err := db.Exec(`SELECT a FROM (VALUES (?)) t(a)`, tc.input) + require.NoError(t, err) + } + + // Not yet supported. + testCases = []struct { + input any + }{ + {[]Union{{42, "n"}}}, + {[]Map{map[any]any{"hello": "world", "beautiful": "day"}}}, + } + for _, tc := range testCases { + _, err := db.Exec(`SELECT a FROM (VALUES (?)) t(a)`, tc.input) + require.ErrorContains(t, err, unsupportedTypeErrMsg) + } +} diff --git a/value.go b/value.go index 9f3d3d13..66a680ea 100644 --- a/value.go +++ b/value.go @@ -2,6 +2,7 @@ package duckdb import ( "fmt" + "math/big" "reflect" "time" @@ -67,28 +68,26 @@ func getValue(info TypeInfo, v mapping.Value) (any, error) { } } -func createValue(lt mapping.LogicalType, v any) (mapping.Value, error) { +func createValue(lt mapping.LogicalType, val any) (mapping.Value, error) { t := mapping.GetTypeId(lt) - r, err := tryCreateValueByTypeId(t, v) - if err != nil { - return r, err - } - if r.Ptr != nil { - return r, nil + if isPrimitiveType(t) { + return createPrimitiveValue(t, val) } + switch t { case TYPE_ARRAY: - return getMappedSliceValue(lt, t, v) + return createSliceValue(lt, t, val) case TYPE_LIST: - return getMappedSliceValue(lt, t, v) + return createSliceValue(lt, t, val) case TYPE_STRUCT: - return getMappedStructValue(lt, v) + return createStructValue(lt, val) default: - return mapping.Value{}, unsupportedTypeError(reflect.TypeOf(v).Name()) + return mapping.Value{}, unsupportedTypeError(reflect.TypeOf(val).Name()) } } -func tryCreateValueByTypeId(t mapping.Type, v any) (mapping.Value, error) { +//nolint:gocyclo +func createPrimitiveValue(t mapping.Type, v any) (mapping.Value, error) { switch t { case TYPE_SQLNULL: return mapping.CreateNullValue(), nil @@ -117,31 +116,69 @@ func tryCreateValueByTypeId(t mapping.Type, v any) (mapping.Value, error) { case TYPE_VARCHAR: return mapping.CreateVarchar(v.(string)), nil case TYPE_TIMESTAMP, TYPE_TIMESTAMP_TZ: - vv, err := getMappedTimestamp(t, v) + vv, err := inferTimestamp(t, v) if err != nil { return mapping.Value{}, err } return mapping.CreateTimestamp(vv), nil case TYPE_TIMESTAMP_S: - vv, err := getMappedTimestampS(v) + vv, err := inferTimestampS(v) if err != nil { return mapping.Value{}, err } return mapping.CreateTimestampS(vv), nil case TYPE_TIMESTAMP_MS: - vv, err := getMappedTimestampMS(v) + vv, err := inferTimestampMS(v) if err != nil { return mapping.Value{}, err } return mapping.CreateTimestampMS(vv), nil case TYPE_TIMESTAMP_NS: - vv, err := getMappedTimestampNS(v) + vv, err := inferTimestampNS(v) if err != nil { return mapping.Value{}, err } return mapping.CreateTimestampNS(vv), nil + case TYPE_DATE: + vv, err := inferDate(v) + if err != nil { + return mapping.Value{}, err + } + return mapping.CreateDate(vv), nil + case TYPE_TIME: + vv, err := inferTime(v) + if err != nil { + return mapping.Value{}, err + } + return mapping.CreateTime(vv), nil + case TYPE_TIME_TZ: + vv, err := inferTimeTZ(v) + if err != nil { + return mapping.Value{}, err + } + return mapping.CreateTimeTZValue(vv), nil + case TYPE_INTERVAL: + vv, err := inferInterval(v) + if err != nil { + return mapping.Value{}, err + } + return mapping.CreateInterval(vv), nil + case TYPE_HUGEINT: + vv, err := inferHugeInt(v) + if err != nil { + return mapping.Value{}, err + } + return mapping.CreateHugeInt(vv), nil + case TYPE_UUID: + vv, err := inferUUID(v) + if err != nil { + return mapping.Value{}, err + } + lower, upper := mapping.HugeIntMembers(&vv) + uHugeInt := mapping.NewUHugeInt(lower, uint64(upper)) + return mapping.CreateUUID(uHugeInt), nil } - return mapping.Value{}, nil + return mapping.Value{}, unsupportedTypeError(typeToStringMap[t]) } func getPointerValue(v any) any { @@ -177,38 +214,70 @@ func isNil(i any) bool { } } -// leave the logic type clean up call to caller -func createValueByReflection(v any) (mapping.LogicalType, mapping.Value, error) { - t, vv := inferTypeId(v) - if t != TYPE_INVALID { - retVal, err := tryCreateValueByTypeId(t, vv) - return mapping.CreateLogicalType(t), retVal, err +func inferLogicalTypeAndValue(v any) (mapping.LogicalType, mapping.Value, error) { + // Try to create a primitive type. + t, vv := inferPrimitiveType(v) + if isPrimitiveType(t) { + val, err := createPrimitiveValue(t, vv) + if err != nil { + return mapping.LogicalType{}, mapping.Value{}, err + } + return mapping.CreateLogicalType(t), val, err } + + // User-provided type with a Stringer interface: + // We create a string and return a VARCHAR value. + // TYPE_DECIMAL has a Stringer interface. if ss, ok := v.(fmt.Stringer); ok { t = TYPE_VARCHAR - retVal, err := tryCreateValueByTypeId(t, ss.String()) - return mapping.CreateLogicalType(t), retVal, err + val, err := createPrimitiveValue(t, ss.String()) + if err != nil { + return mapping.LogicalType{}, mapping.Value{}, err + } + return mapping.CreateLogicalType(t), val, err } + + // SQLNULL type. if isNil(v) { t = TYPE_SQLNULL - retVal, err := tryCreateValueByTypeId(t, v) - return mapping.CreateLogicalType(t), retVal, err + val, err := createPrimitiveValue(t, v) + if err != nil { + return mapping.LogicalType{}, mapping.Value{}, err + } + return mapping.CreateLogicalType(t), val, err + } + + if t == TYPE_MAP { + // TODO. + return mapping.LogicalType{}, mapping.Value{}, unsupportedTypeError(typeToStringMap[t]) + } + if t == TYPE_UNION { + // TODO. + return mapping.LogicalType{}, mapping.Value{}, unsupportedTypeError(typeToStringMap[t]) } + + // Complex types. r := reflect.ValueOf(v) switch r.Kind() { + case reflect.Struct, reflect.Map: + // TODO. + return mapping.LogicalType{}, mapping.Value{}, unsupportedTypeError(typeToStringMap[TYPE_STRUCT]) case reflect.Ptr: - return createValueByReflection(getPointerValue(v)) - case reflect.Slice: - return tryGetMappedSliceValue(r.Interface(), false, r.Len()) - case reflect.Array: - return tryGetMappedSliceValue(r.Interface(), true, r.Len()) + // Extract pointer and recurse. + return inferLogicalTypeAndValue(getPointerValue(v)) + case reflect.Array, reflect.Slice: + return inferSliceLogicalTypeAndValue(r.Interface(), r.Kind() == reflect.Array, r.Len()) default: return mapping.LogicalType{}, mapping.Value{}, unsupportedTypeError(reflect.TypeOf(v).Name()) } } -func inferTypeId(v any) (Type, any) { +func inferPrimitiveType(v any) (Type, any) { + // Return TYPE_INVALID for + // TYPE_ENUM, TYPE_LIST, TYPE_STRUCT, TYPE_ARRAY, + // and for the unsupported types. t := TYPE_INVALID + switch vv := v.(type) { case nil: t = TYPE_SQLNULL @@ -243,56 +312,106 @@ func inferTypeId(v any) (Type, any) { case string: t = TYPE_VARCHAR case []byte: + // No support for TYPE_BLOB. t = TYPE_VARCHAR v = string(vv) case time.Time: + // There is no way to distinguish between + // TYPE_DATE, TYPE_TIME, TYPE_TIMESTAMP_S, TYPE_TIMESTAMP_MS, TYPE_TIMESTAMP_NS, + // TYPE_TIME_TZ, TYPE_TIMESTAMP_TZ. t = TYPE_TIMESTAMP + case Interval: + t = TYPE_INTERVAL + case *big.Int: + t = TYPE_HUGEINT + case Decimal: + t = TYPE_DECIMAL + case UUID: + t = TYPE_UUID + case Map: + // We special-case TYPE_MAP to disambiguate with structs passed as map[string]any. + t = TYPE_MAP + case Union: + t = TYPE_UNION } + return t, v } -func tryGetMappedSliceValue[T any](val T, isArray bool, sliceLength int) (mapping.LogicalType, mapping.Value, error) { +func isPrimitiveType(t Type) bool { + switch t { + case TYPE_DECIMAL, TYPE_ENUM, TYPE_LIST, TYPE_STRUCT, TYPE_MAP, TYPE_ARRAY, TYPE_UNION: + // Complex type. + return false + case TYPE_INVALID, TYPE_UHUGEINT, TYPE_BIT, TYPE_ANY, TYPE_BIGNUM: + // Invalid or unsupported. + return false + } + return true +} + +func inferSliceLogicalTypeAndValue[T any](val T, array bool, length int) (mapping.LogicalType, mapping.Value, error) { createFunc := mapping.CreateListValue typeFunc := mapping.CreateListType - if isArray { + if array { createFunc = mapping.CreateArrayValue typeFunc = func(child mapping.LogicalType) mapping.LogicalType { - return mapping.CreateArrayType(child, mapping.IdxT(sliceLength)) + return mapping.CreateArrayType(child, mapping.IdxT(length)) } } - vSlice, err := extractSlice(val) + slice, err := extractSlice(val) if err != nil { - return mapping.LogicalType{}, mapping.Value{}, fmt.Errorf("could not cast %T to []any: %w", val, err) + return mapping.LogicalType{}, mapping.Value{}, err } - childValues := make([]mapping.Value, 0, sliceLength) - defer destroyValueSlice(childValues) - childLogicTypes := make([]mapping.LogicalType, 0, sliceLength) - defer destroyLogicalTypes(&childLogicTypes) - if len(vSlice) == 0 { + + values := make([]mapping.Value, 0, length) + defer destroyValueSlice(values) + + if len(slice) == 0 { lt := mapping.CreateLogicalType(TYPE_SQLNULL) defer mapping.DestroyLogicalType(<) - return typeFunc(lt), createFunc(lt, childValues), nil + return typeFunc(lt), createFunc(lt, values), nil } - elementLogicType := mapping.LogicalType{} - for _, v := range vSlice { - et, vv, err := createValueByReflection(v) - if err != nil { - return mapping.LogicalType{}, mapping.Value{}, err + + logicalTypes := make([]mapping.LogicalType, 0, length) + defer destroyLogicalTypes(logicalTypes) + + var elemLogicalType mapping.LogicalType + expectedIndex := -1 + expectedTypeStr := "" + + for i, v := range slice { + et, vv, errInfer := inferLogicalTypeAndValue(v) + if errInfer != nil { + return mapping.LogicalType{}, mapping.Value{}, errInfer } + values = append(values, vv) + logicalTypes = append(logicalTypes, et) + if et.Ptr != nil { - elementLogicType = et + if elemLogicalType.Ptr == nil { + elemLogicalType = et + expectedIndex = i + expectedTypeStr = logicalTypeString(et) + continue + } + // Check if this element's type matches the first non-null element's type + if currentTypeStr := logicalTypeString(et); currentTypeStr != expectedTypeStr { + return mapping.LogicalType{}, mapping.Value{}, + fmt.Errorf("mixed types in slice: cannot bind %s (index %d) and %s (index %d)", + expectedTypeStr, expectedIndex, currentTypeStr, i) + } } - childValues = append(childValues, vv) - childLogicTypes = append(childLogicTypes, et) } - if elementLogicType.Ptr == nil { - return elementLogicType, mapping.Value{}, unsupportedTypeError(reflect.TypeOf(val).Name()) + + if elemLogicalType.Ptr == nil { + return elemLogicalType, mapping.Value{}, unsupportedTypeError(reflect.TypeOf(val).Name()) } - return typeFunc(elementLogicType), createFunc(elementLogicType, childValues), nil + return typeFunc(elemLogicalType), createFunc(elemLogicalType, values), nil } -func getMappedSliceValue[T any](lt mapping.LogicalType, t Type, val T) (mapping.Value, error) { +func createSliceValue[T any](lt mapping.LogicalType, t Type, val T) (mapping.Value, error) { var childType mapping.LogicalType switch t { case TYPE_ARRAY: @@ -302,36 +421,37 @@ func getMappedSliceValue[T any](lt mapping.LogicalType, t Type, val T) (mapping. } defer mapping.DestroyLogicalType(&childType) - vSlice, err := extractSlice(val) + slice, err := extractSlice(val) if err != nil { - return mapping.Value{}, fmt.Errorf("could not cast %T to []any: %w", val, err) + return mapping.Value{}, err } - var childValues []mapping.Value - defer destroyValueSlice(childValues) + var values []mapping.Value + defer destroyValueSlice(values) - for _, v := range vSlice { + for _, v := range slice { vv, err := createValue(childType, v) if err != nil { - return mapping.Value{}, fmt.Errorf("could not create value %w", err) + return mapping.Value{}, err } - childValues = append(childValues, vv) + values = append(values, vv) } var v mapping.Value switch t { case TYPE_ARRAY: - v = mapping.CreateArrayValue(childType, childValues) + v = mapping.CreateArrayValue(childType, values) case TYPE_LIST: - v = mapping.CreateListValue(childType, childValues) + v = mapping.CreateListValue(childType, values) } + return v, nil } -func getMappedStructValue(lt mapping.LogicalType, val any) (mapping.Value, error) { - vMap, ok := val.(map[string]any) +func createStructValue(lt mapping.LogicalType, val any) (mapping.Value, error) { + m, ok := val.(map[string]any) if !ok { - return mapping.Value{}, fmt.Errorf("could not cast %T to map[string]any", val) + return mapping.Value{}, castError(reflect.TypeOf(val).Name(), reflect.TypeOf(map[string]any{}).Name()) } var values []mapping.Value @@ -339,21 +459,22 @@ func getMappedStructValue(lt mapping.LogicalType, val any) (mapping.Value, error childCount := mapping.StructTypeChildCount(lt) for i := mapping.IdxT(0); i < childCount; i++ { - childName := mapping.StructTypeChildName(lt, i) - childType := mapping.StructTypeChildType(lt, i) - defer mapping.DestroyLogicalType(&childType) + name := mapping.StructTypeChildName(lt, i) + t := mapping.StructTypeChildType(lt, i) + defer mapping.DestroyLogicalType(&t) - v, exists := vMap[childName] + v, exists := m[name] if exists { - vv, err := createValue(childType, v) + vv, err := createValue(t, v) if err != nil { - return mapping.Value{}, fmt.Errorf("could not create value %w", err) + return mapping.Value{}, err } values = append(values, vv) } else { values = append(values, mapping.CreateNullValue()) } } + return mapping.CreateStructValue(lt, values), nil } @@ -383,17 +504,16 @@ func extractSlice[S any](val S) ([]any, error) { if kind != reflect.Array && kind != reflect.Slice { return nil, castError(reflect.TypeOf(val).String(), reflect.TypeOf(s).String()) } - // Insert the values into the child vector. + + // Insert the values into the slice. rv := reflect.ValueOf(val) s = make([]any, rv.Len()) - for i := range rv.Len() { idx := rv.Index(i) if canNil(idx) && idx.IsNil() { s[i] = nil continue } - s[i] = idx.Interface() } } diff --git a/vector_setters.go b/vector_setters.go index a981f61e..11aeb9ba 100644 --- a/vector_setters.go +++ b/vector_setters.go @@ -2,7 +2,6 @@ package duckdb import ( "encoding/json" - "math/big" "reflect" "strconv" "unsafe" @@ -90,25 +89,25 @@ func setBool[S any](vec *vector, rowIdx mapping.IdxT, val S) error { func setTS(vec *vector, rowIdx mapping.IdxT, val any) error { switch vec.Type { case TYPE_TIMESTAMP, TYPE_TIMESTAMP_TZ: - ts, err := getMappedTimestamp(vec.Type, val) + ts, err := inferTimestamp(vec.Type, val) if err != nil { return err } setPrimitive(vec, rowIdx, ts) case TYPE_TIMESTAMP_S: - ts, err := getMappedTimestampS(val) + ts, err := inferTimestampS(val) if err != nil { return err } setPrimitive(vec, rowIdx, ts) case TYPE_TIMESTAMP_MS: - ts, err := getMappedTimestampMS(val) + ts, err := inferTimestampMS(val) if err != nil { return err } setPrimitive(vec, rowIdx, ts) case TYPE_TIMESTAMP_NS: - ts, err := getMappedTimestampNS(val) + ts, err := inferTimestampNS(val) if err != nil { return err } @@ -120,7 +119,7 @@ func setTS(vec *vector, rowIdx mapping.IdxT, val any) error { } func setDate[S any](vec *vector, rowIdx mapping.IdxT, val S) error { - date, err := getMappedDate(val) + date, err := inferDate(val) if err != nil { return err } @@ -129,94 +128,39 @@ func setDate[S any](vec *vector, rowIdx mapping.IdxT, val S) error { } func setTime[S any](vec *vector, rowIdx mapping.IdxT, val S) error { - ticks, err := getTimeTicks(val) - if err != nil { - return err - } - switch vec.Type { case TYPE_TIME: - ti := mapping.NewTime(ticks) + ti, err := inferTime(val) + if err != nil { + return err + } setPrimitive(vec, rowIdx, ti) case TYPE_TIME_TZ: // The UTC offset is 0. - ti := mapping.CreateTimeTZ(ticks, 0) + ti, err := inferTimeTZ(val) + if err != nil { + return err + } setPrimitive(vec, rowIdx, ti) } return nil } func setInterval[S any](vec *vector, rowIdx mapping.IdxT, val S) error { - var i Interval - switch v := any(val).(type) { - case Interval: - i = v - default: - return castError(reflect.TypeOf(val).String(), reflect.TypeOf(i).String()) + i, err := inferInterval(val) + if err != nil { + return err } - interval := mapping.NewInterval(i.Months, i.Days, i.Micros) - setPrimitive(vec, rowIdx, interval) + setPrimitive(vec, rowIdx, i) return nil } func setHugeint[S any](vec *vector, rowIdx mapping.IdxT, val S) error { - var err error - var fv mapping.HugeInt - switch v := any(val).(type) { - case uint8: - fv = mapping.NewHugeInt(uint64(v), 0) - case int8: - fv = mapping.NewHugeInt(uint64(v), 0) - case uint16: - fv = mapping.NewHugeInt(uint64(v), 0) - case int16: - fv = mapping.NewHugeInt(uint64(v), 0) - case uint32: - fv = mapping.NewHugeInt(uint64(v), 0) - case int32: - fv = mapping.NewHugeInt(uint64(v), 0) - case uint64: - fv = mapping.NewHugeInt(v, 0) - case int64: - fv, err = hugeIntFromNative(big.NewInt(v)) - if err != nil { - return err - } - case uint: - fv = mapping.NewHugeInt(uint64(v), 0) - case int: - fv, err = hugeIntFromNative(big.NewInt(int64(v))) - if err != nil { - return err - } - case float32: - fv, err = hugeIntFromNative(big.NewInt(int64(v))) - if err != nil { - return err - } - case float64: - fv, err = hugeIntFromNative(big.NewInt(int64(v))) - if err != nil { - return err - } - case *big.Int: - if v == nil { - return castError(reflect.TypeOf(val).String(), reflect.TypeOf(fv).String()) - } - if fv, err = hugeIntFromNative(v); err != nil { - return err - } - case Decimal: - if v.Value == nil { - return castError(reflect.TypeOf(val).String(), reflect.TypeOf(fv).String()) - } - if fv, err = hugeIntFromNative(v.Value); err != nil { - return err - } - default: - return castError(reflect.TypeOf(val).String(), reflect.TypeOf(fv).String()) + hi, err := inferHugeInt(val) + if err != nil { + return err } - setPrimitive(vec, rowIdx, fv) + setPrimitive(vec, rowIdx, hi) return nil } @@ -388,24 +332,11 @@ func setSliceChildren(vec *vector, s []any, offset mapping.IdxT) error { } func setUUID[S any](vec *vector, rowIdx mapping.IdxT, val S) error { - var uuid UUID - switch v := any(val).(type) { - case UUID: - uuid = v - case *UUID: - uuid = *v - case []uint8: - if len(v) != uuidLength { - return castError(reflect.TypeOf(val).String(), reflect.TypeOf(uuid).String()) - } - for i := range uuidLength { - uuid[i] = v[i] - } - default: - return castError(reflect.TypeOf(val).String(), reflect.TypeOf(uuid).String()) + id, err := inferUUID(val) + if err != nil { + return err } - hi := uuidToHugeInt(uuid) - setPrimitive(vec, rowIdx, hi) + setPrimitive(vec, rowIdx, id) return nil }