diff --git a/plugins/destination/snowflake/client/client.go b/plugins/destination/snowflake/client/client.go index c843d56672b82a..789b653110ea7b 100644 --- a/plugins/destination/snowflake/client/client.go +++ b/plugins/destination/snowflake/client/client.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "encoding/json" + "errors" "fmt" "sync" @@ -14,6 +15,8 @@ import ( _ "github.com/snowflakedb/gosnowflake" // "snowflake" database/sql driver. ) +var errInvalidSpec = errors.New("invalid spec") + type Client struct { plugin.UnimplementedSource batchwriter.UnimplementedDeleteRecord @@ -32,20 +35,21 @@ func New(_ context.Context, logger zerolog.Logger, spec []byte, _ plugin.NewClie setupWriteOnce: &sync.Once{}, } if err := json.Unmarshal(spec, &c.spec); err != nil { - return nil, fmt.Errorf("failed to unmarshal snowflake spec: %w", err) + return nil, errors.Join(errInvalidSpec, err) } c.spec.SetDefaults() c.writer, err = batchwriter.New(c, batchwriter.WithLogger(c.logger), batchwriter.WithBatchSize(c.spec.BatchSize), batchwriter.WithBatchSizeBytes(c.spec.BatchSizeBytes)) if err != nil { - return nil, err + return nil, errors.Join(errInvalidSpec, err) } dsn, err := c.spec.DSN() if err != nil { - return nil, err + return nil, errors.Join(errInvalidSpec, err) } + db, err := sql.Open("snowflake", dsn+"&BINARY_INPUT_FORMAT=BASE64&BINARY_OUTPUT_FORMAT=BASE64&timezone=UTC") if err != nil { - return nil, err + return nil, errors.Join(errInvalidSpec, err) } err = db.Ping() diff --git a/plugins/destination/snowflake/client/test_connection.go b/plugins/destination/snowflake/client/test_connection.go new file mode 100644 index 00000000000000..de1c9dd75c4786 --- /dev/null +++ b/plugins/destination/snowflake/client/test_connection.go @@ -0,0 +1,46 @@ +package client + +import ( + "context" + "errors" + + "github.com/cloudquery/plugin-sdk/v4/plugin" + "github.com/rs/zerolog" + "github.com/snowflakedb/gosnowflake" +) + +type NewClientFn func(context.Context, zerolog.Logger, []byte, plugin.NewClientOptions) (plugin.Client, error) + +const ( + codeConnectionFailed string = "CONNECTION_FAILED" + codeInvalidSpec string = "INVALID_SPEC" + codeUnauthorized string = "UNAUTHORIZED" + codeUnreachable string = "UNREACHABLE" +) + +func NewConnectionTester(createClientFn NewClientFn) plugin.ConnectionTester { + return func(ctx context.Context, logger zerolog.Logger, spec []byte) error { + _, err := createClientFn(ctx, logger, spec, plugin.NewClientOptions{}) + if err == nil { + return nil + } + + var snowflakeErr *gosnowflake.SnowflakeError + if errors.As(err, &snowflakeErr) { + switch snowflakeErr.Number { + case gosnowflake.ErrCodeFailedToConnect, gosnowflake.ErrCodeServiceUnavailable, gosnowflake.ErrCodeIdpConnectionError: + return plugin.NewTestConnError(codeUnreachable, err) + case gosnowflake.ErrFailedToAuth, gosnowflake.ErrFailedToAuthSAML, gosnowflake.ErrFailedToAuthOKTA, gosnowflake.ErrObjectNotExistOrAuthorized: + return plugin.NewTestConnError(codeUnauthorized, err) + default: + return plugin.NewTestConnError(codeConnectionFailed, err) + } + } + + if errors.Is(err, errInvalidSpec) { + return plugin.NewTestConnError(codeInvalidSpec, err) + } + + return plugin.NewTestConnError(codeConnectionFailed, err) + } +} diff --git a/plugins/destination/snowflake/client/test_connection_test.go b/plugins/destination/snowflake/client/test_connection_test.go new file mode 100644 index 00000000000000..491bdb26453058 --- /dev/null +++ b/plugins/destination/snowflake/client/test_connection_test.go @@ -0,0 +1,75 @@ +package client + +import ( + "context" + "testing" + + "github.com/cloudquery/plugin-sdk/v4/plugin" + "github.com/rs/zerolog" + "github.com/snowflakedb/gosnowflake" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestConnectionTester(t *testing.T) { + cases := []struct { + name string + spec []byte + err *plugin.TestConnError + clientBuilder func() (plugin.Client, error) + }{ + { + name: "ok", + spec: []byte(`{}`), + clientBuilder: func() (plugin.Client, error) { return &Client{}, nil }, + }, + { + name: "error/unauthorized", + spec: []byte(`{}`), + err: plugin.NewTestConnError(codeUnauthorized, assert.AnError), + clientBuilder: func() (plugin.Client, error) { + return nil, &gosnowflake.SnowflakeError{Number: gosnowflake.ErrFailedToAuth} + }, + }, + { + name: "error/unreachable", + spec: []byte(`{}`), + err: plugin.NewTestConnError(codeUnreachable, assert.AnError), + clientBuilder: func() (plugin.Client, error) { + return nil, &gosnowflake.SnowflakeError{Number: gosnowflake.ErrCodeServiceUnavailable} + }, + }, + { + name: "error/spec", + spec: []byte(`{null}`), + err: plugin.NewTestConnError(codeInvalidSpec, assert.AnError), + clientBuilder: func() (plugin.Client, error) { return nil, errInvalidSpec }, + }, + { + name: "error/connection_failed", + spec: []byte(`{}`), + err: plugin.NewTestConnError(codeConnectionFailed, assert.AnError), + clientBuilder: func() (plugin.Client, error) { return nil, assert.AnError }, + }, + } + + for idx := range cases { + tc := cases[idx] + + t.Run(tc.name, func(t *testing.T) { + tester := NewConnectionTester(func(context.Context, zerolog.Logger, []byte, plugin.NewClientOptions) (plugin.Client, error) { + return tc.clientBuilder() + }) + + err := tester(context.Background(), zerolog.Nop(), tc.spec) + if tc.err == nil { + require.NoError(t, err) + return + } + + var e *plugin.TestConnError + require.ErrorAs(t, err, &e) + require.Equal(t, tc.err.Code, err.(*plugin.TestConnError).Code) + }) + } +} diff --git a/plugins/destination/snowflake/main.go b/plugins/destination/snowflake/main.go index 42526e826e40b4..b745fff442fb44 100644 --- a/plugins/destination/snowflake/main.go +++ b/plugins/destination/snowflake/main.go @@ -25,6 +25,7 @@ func main() { plugin.WithKind(internalPlugin.Kind), plugin.WithTeam(internalPlugin.Team), plugin.WithJSONSchema(client.JSONSchema), + plugin.WithConnectionTester(client.NewConnectionTester(client.New)), ) if err := serve.Plugin(p, serve.WithDestinationV0V1Server()).Serve(context.Background()); err != nil {