diff --git a/examples/simple_plugin/plugin/client.go b/examples/simple_plugin/plugin/client.go index 5ef3bf773d..72bb5131dc 100644 --- a/examples/simple_plugin/plugin/client.go +++ b/examples/simple_plugin/plugin/client.go @@ -69,6 +69,11 @@ func (*Client) Transform(_ context.Context, _ <-chan arrow.Record, _ chan<- arro return nil } +func (*Client) TransformSchema(_ context.Context, _ *arrow.Schema) (*arrow.Schema, error) { + // Not implemented, just used for testing destination packaging + return nil, nil +} + func Configure(_ context.Context, logger zerolog.Logger, spec []byte, opts plugin.NewClientOptions) (plugin.Client, error) { if opts.NoConnection { return &Client{ diff --git a/internal/memdb/memdb.go b/internal/memdb/memdb.go index e0c7fa410c..5874e1c9b2 100644 --- a/internal/memdb/memdb.go +++ b/internal/memdb/memdb.go @@ -311,6 +311,10 @@ func (*client) Transform(_ context.Context, _ <-chan arrow.Record, _ chan<- arro return nil } +func (*client) TransformSchema(_ context.Context, _ *arrow.Schema) (*arrow.Schema, error) { + return nil, nil +} + func evaluatePredicate(pred message.Predicate, record arrow.Record) bool { sc := record.Schema() indices := sc.FieldIndices(pred.Column) diff --git a/internal/reversertransformer/reversertransformer.go b/internal/reversertransformer/reversertransformer.go index 831443fec5..0932900796 100644 --- a/internal/reversertransformer/reversertransformer.go +++ b/internal/reversertransformer/reversertransformer.go @@ -58,6 +58,10 @@ func (c *client) Transform(ctx context.Context, recvRecords <-chan arrow.Record, } } +func (*client) TransformSchema(_ context.Context, old *arrow.Schema) (*arrow.Schema, error) { + return old, nil +} + func (*client) reverseStrings(record arrow.Record) (arrow.Record, error) { for i, column := range record.Columns() { if column.DataType().ID() != arrow.STRING { diff --git a/internal/reversertransformer/reversertransformer_test.go b/internal/reversertransformer/reversertransformer_test.go index 1f7c19225a..9750ed729c 100644 --- a/internal/reversertransformer/reversertransformer_test.go +++ b/internal/reversertransformer/reversertransformer_test.go @@ -16,8 +16,6 @@ import ( "google.golang.org/grpc/metadata" ) -var mem = memory.NewGoAllocator() - func TestReverserTransformer(t *testing.T) { p := plugin.NewPlugin("test", "development", GetNewClient()) s := internalPlugin.Server{ @@ -58,7 +56,7 @@ func makeRequestFromString(s string) *pb.Transform_Request { } func makeRecordFromString(s string) arrow.Record { - str := array.NewStringBuilder(mem) + str := array.NewStringBuilder(memory.DefaultAllocator) str.AppendString(s) arr := str.NewStringArray() schema := arrow.NewSchema([]arrow.Field{{Name: "col1", Type: arrow.BinaryTypes.String}}, nil) diff --git a/internal/servers/plugin/v3/plugin.go b/internal/servers/plugin/v3/plugin.go index 699fde1477..917fe3c58b 100644 --- a/internal/servers/plugin/v3/plugin.go +++ b/internal/servers/plugin/v3/plugin.go @@ -475,6 +475,22 @@ func (s *Server) Transform(stream pb.Plugin_TransformServer) error { return eg.Wait() } +func (s *Server) TransformSchema(ctx context.Context, req *pb.TransformSchema_Request) (*pb.TransformSchema_Response, error) { + sc, err := pb.NewSchemaFromBytes(req.Schema) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "failed to create schema from bytes: %v", err) + } + newSchema, err := s.Plugin.TransformSchema(ctx, sc) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to transform schema: %v", err) + } + encoded, err := pb.SchemaToBytes(newSchema) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to encode schema: %v", err) + } + return &pb.TransformSchema_Response{Schema: encoded}, nil +} + func (s *Server) Close(ctx context.Context, _ *pb.Close_Request) (*pb.Close_Response, error) { return &pb.Close_Response{}, s.Plugin.Close(ctx) } diff --git a/internal/servers/plugin/v3/plugin_test.go b/internal/servers/plugin/v3/plugin_test.go index dc5610994c..ac32af2091 100644 --- a/internal/servers/plugin/v3/plugin_test.go +++ b/internal/servers/plugin/v3/plugin_test.go @@ -12,6 +12,8 @@ import ( "github.com/cloudquery/plugin-sdk/v4/internal/memdb" "github.com/cloudquery/plugin-sdk/v4/plugin" "github.com/cloudquery/plugin-sdk/v4/schema" + "github.com/rs/zerolog" + "github.com/stretchr/testify/require" "google.golang.org/grpc" "google.golang.org/grpc/metadata" ) @@ -183,3 +185,66 @@ func TestPluginSync(t *testing.T) { t.Fatal(err) } } + +func TestTransformSchema(t *testing.T) { + ctx := context.Background() + s := Server{ + Plugin: plugin.NewPlugin("test", "development", getColumnAdderPlugin()), + } + + _, err := s.Init(ctx, &pb.Init_Request{}) + if err != nil { + t.Fatal(err) + } + + table := &schema.Table{ + Name: "test", + Columns: []schema.Column{ + { + Name: "test", + Type: arrow.BinaryTypes.String, + }, + }, + } + sc := table.ToArrowSchema() + + schemaBytes, err := pb.SchemaToBytes(sc) + require.NoError(t, err) + + resp, err := s.TransformSchema(ctx, &pb.TransformSchema_Request{Schema: schemaBytes}) + if err != nil { + t.Fatal(err) + } + + newSchema, err := pb.NewSchemaFromBytes(resp.Schema) + require.NoError(t, err) + + require.Len(t, newSchema.Fields(), 2) + require.Equal(t, "test", newSchema.Fields()[0].Name) + require.Equal(t, "source", newSchema.Fields()[1].Name) + require.Equal(t, "utf8", newSchema.Fields()[1].Type.(*arrow.StringType).Name()) + + if _, err := s.Close(ctx, &pb.Close_Request{}); err != nil { + t.Fatal(err) + } +} + +type mockSourceColumnAdderPluginClient struct { + plugin.UnimplementedDestination + plugin.UnimplementedSource +} + +func getColumnAdderPlugin(...plugin.Option) plugin.NewClientFunc { + c := &mockSourceColumnAdderPluginClient{} + return func(context.Context, zerolog.Logger, []byte, plugin.NewClientOptions) (plugin.Client, error) { + return c, nil + } +} + +func (*mockSourceColumnAdderPluginClient) Transform(context.Context, <-chan arrow.Record, chan<- arrow.Record) error { + return nil +} +func (*mockSourceColumnAdderPluginClient) TransformSchema(_ context.Context, old *arrow.Schema) (*arrow.Schema, error) { + return old.AddField(1, arrow.Field{Name: "source", Type: arrow.BinaryTypes.String}) +} +func (*mockSourceColumnAdderPluginClient) Close(context.Context) error { return nil } diff --git a/plugin/plugin.go b/plugin/plugin.go index b7a00671dc..7ad2824cfd 100644 --- a/plugin/plugin.go +++ b/plugin/plugin.go @@ -58,6 +58,9 @@ type UnimplementedTransformer struct{} func (UnimplementedTransformer) Transform(context.Context, <-chan arrow.Record, chan<- arrow.Record) error { return ErrNotImplemented } +func (UnimplementedTransformer) TransformSchema(context.Context, *arrow.Schema) (*arrow.Schema, error) { + return nil, ErrNotImplemented +} // Plugin is the base structure required to pass to sdk.serve // We take a declarative approach to API here similar to Cobra diff --git a/plugin/plugin_test.go b/plugin/plugin_test.go index be7af4ff81..22d20d5a25 100644 --- a/plugin/plugin_test.go +++ b/plugin/plugin_test.go @@ -59,6 +59,9 @@ func (*testPluginClient) Close(context.Context) error { func (*testPluginClient) Transform(context.Context, <-chan arrow.Record, chan<- arrow.Record) error { return nil } +func (*testPluginClient) TransformSchema(context.Context, *arrow.Schema) (*arrow.Schema, error) { + return nil, nil +} func TestPluginSuccess(t *testing.T) { ctx := context.Background() diff --git a/plugin/plugin_transformer.go b/plugin/plugin_transformer.go index a273021078..225b3e8129 100644 --- a/plugin/plugin_transformer.go +++ b/plugin/plugin_transformer.go @@ -8,8 +8,12 @@ import ( type TransformerClient interface { Transform(context.Context, <-chan arrow.Record, chan<- arrow.Record) error + TransformSchema(context.Context, *arrow.Schema) (*arrow.Schema, error) } func (p *Plugin) Transform(ctx context.Context, recvRecords <-chan arrow.Record, sendRecords chan<- arrow.Record) error { return p.client.Transform(ctx, recvRecords, sendRecords) } +func (p *Plugin) TransformSchema(ctx context.Context, old *arrow.Schema) (*arrow.Schema, error) { + return p.client.TransformSchema(ctx, old) +}