Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions examples/simple_plugin/plugin/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ func (*Client) Transform(_ context.Context, _ <-chan arrow.Record, _ chan<- arro
return nil
}

func (*Client) TransformSchema(_ context.Context, _ *arrow.Schema) (*arrow.Schema, error) {
// Not implemented, just used for testing destination packaging
return nil, nil
}

func Configure(_ context.Context, logger zerolog.Logger, spec []byte, opts plugin.NewClientOptions) (plugin.Client, error) {
if opts.NoConnection {
return &Client{
Expand Down
4 changes: 4 additions & 0 deletions internal/memdb/memdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,10 @@ func (*client) Transform(_ context.Context, _ <-chan arrow.Record, _ chan<- arro
return nil
}

func (*client) TransformSchema(_ context.Context, _ *arrow.Schema) (*arrow.Schema, error) {
return nil, nil
}

func evaluatePredicate(pred message.Predicate, record arrow.Record) bool {
sc := record.Schema()
indices := sc.FieldIndices(pred.Column)
Expand Down
4 changes: 4 additions & 0 deletions internal/reversertransformer/reversertransformer.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ func (c *client) Transform(ctx context.Context, recvRecords <-chan arrow.Record,
}
}

func (*client) TransformSchema(_ context.Context, old *arrow.Schema) (*arrow.Schema, error) {
return old, nil
}

func (*client) reverseStrings(record arrow.Record) (arrow.Record, error) {
for i, column := range record.Columns() {
if column.DataType().ID() != arrow.STRING {
Expand Down
4 changes: 1 addition & 3 deletions internal/reversertransformer/reversertransformer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ import (
"google.golang.org/grpc/metadata"
)

var mem = memory.NewGoAllocator()

func TestReverserTransformer(t *testing.T) {
p := plugin.NewPlugin("test", "development", GetNewClient())
s := internalPlugin.Server{
Expand Down Expand Up @@ -58,7 +56,7 @@ func makeRequestFromString(s string) *pb.Transform_Request {
}

func makeRecordFromString(s string) arrow.Record {
str := array.NewStringBuilder(mem)
str := array.NewStringBuilder(memory.DefaultAllocator)
str.AppendString(s)
arr := str.NewStringArray()
schema := arrow.NewSchema([]arrow.Field{{Name: "col1", Type: arrow.BinaryTypes.String}}, nil)
Expand Down
16 changes: 16 additions & 0 deletions internal/servers/plugin/v3/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,22 @@ func (s *Server) Transform(stream pb.Plugin_TransformServer) error {
return eg.Wait()
}

func (s *Server) TransformSchema(ctx context.Context, req *pb.TransformSchema_Request) (*pb.TransformSchema_Response, error) {
sc, err := pb.NewSchemaFromBytes(req.Schema)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "failed to create schema from bytes: %v", err)
}
newSchema, err := s.Plugin.TransformSchema(ctx, sc)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to transform schema: %v", err)
}
encoded, err := pb.SchemaToBytes(newSchema)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to encode schema: %v", err)
}
return &pb.TransformSchema_Response{Schema: encoded}, nil
}

func (s *Server) Close(ctx context.Context, _ *pb.Close_Request) (*pb.Close_Response, error) {
return &pb.Close_Response{}, s.Plugin.Close(ctx)
}
65 changes: 65 additions & 0 deletions internal/servers/plugin/v3/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (
"github.com/cloudquery/plugin-sdk/v4/internal/memdb"
"github.com/cloudquery/plugin-sdk/v4/plugin"
"github.com/cloudquery/plugin-sdk/v4/schema"
"github.com/rs/zerolog"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)
Expand Down Expand Up @@ -183,3 +185,66 @@ func TestPluginSync(t *testing.T) {
t.Fatal(err)
}
}

func TestTransformSchema(t *testing.T) {
ctx := context.Background()
s := Server{
Plugin: plugin.NewPlugin("test", "development", getColumnAdderPlugin()),
}

_, err := s.Init(ctx, &pb.Init_Request{})
if err != nil {
t.Fatal(err)
}

table := &schema.Table{
Name: "test",
Columns: []schema.Column{
{
Name: "test",
Type: arrow.BinaryTypes.String,
},
},
}
sc := table.ToArrowSchema()

schemaBytes, err := pb.SchemaToBytes(sc)
require.NoError(t, err)

resp, err := s.TransformSchema(ctx, &pb.TransformSchema_Request{Schema: schemaBytes})
if err != nil {
t.Fatal(err)
}

newSchema, err := pb.NewSchemaFromBytes(resp.Schema)
require.NoError(t, err)

require.Len(t, newSchema.Fields(), 2)
require.Equal(t, "test", newSchema.Fields()[0].Name)
require.Equal(t, "source", newSchema.Fields()[1].Name)
require.Equal(t, "utf8", newSchema.Fields()[1].Type.(*arrow.StringType).Name())

if _, err := s.Close(ctx, &pb.Close_Request{}); err != nil {
t.Fatal(err)
}
}

type mockSourceColumnAdderPluginClient struct {
plugin.UnimplementedDestination
plugin.UnimplementedSource
}

func getColumnAdderPlugin(...plugin.Option) plugin.NewClientFunc {
c := &mockSourceColumnAdderPluginClient{}
return func(context.Context, zerolog.Logger, []byte, plugin.NewClientOptions) (plugin.Client, error) {
return c, nil
}
}

func (*mockSourceColumnAdderPluginClient) Transform(context.Context, <-chan arrow.Record, chan<- arrow.Record) error {
return nil
}
func (*mockSourceColumnAdderPluginClient) TransformSchema(_ context.Context, old *arrow.Schema) (*arrow.Schema, error) {
return old.AddField(1, arrow.Field{Name: "source", Type: arrow.BinaryTypes.String})
}
func (*mockSourceColumnAdderPluginClient) Close(context.Context) error { return nil }
3 changes: 3 additions & 0 deletions plugin/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ type UnimplementedTransformer struct{}
func (UnimplementedTransformer) Transform(context.Context, <-chan arrow.Record, chan<- arrow.Record) error {
return ErrNotImplemented
}
func (UnimplementedTransformer) TransformSchema(context.Context, *arrow.Schema) (*arrow.Schema, error) {
return nil, ErrNotImplemented
}

// Plugin is the base structure required to pass to sdk.serve
// We take a declarative approach to API here similar to Cobra
Expand Down
3 changes: 3 additions & 0 deletions plugin/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ func (*testPluginClient) Close(context.Context) error {
func (*testPluginClient) Transform(context.Context, <-chan arrow.Record, chan<- arrow.Record) error {
return nil
}
func (*testPluginClient) TransformSchema(context.Context, *arrow.Schema) (*arrow.Schema, error) {
return nil, nil
}

func TestPluginSuccess(t *testing.T) {
ctx := context.Background()
Expand Down
4 changes: 4 additions & 0 deletions plugin/plugin_transformer.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@ import (

type TransformerClient interface {
Transform(context.Context, <-chan arrow.Record, chan<- arrow.Record) error
TransformSchema(context.Context, *arrow.Schema) (*arrow.Schema, error)
}

func (p *Plugin) Transform(ctx context.Context, recvRecords <-chan arrow.Record, sendRecords chan<- arrow.Record) error {
return p.client.Transform(ctx, recvRecords, sendRecords)
}
func (p *Plugin) TransformSchema(ctx context.Context, old *arrow.Schema) (*arrow.Schema, error) {
return p.client.TransformSchema(ctx, old)
}