diff --git a/cli/cmd/sync.go b/cli/cmd/sync.go index 41b82da2c48b1c..ab671203bb42e7 100644 --- a/cli/cmd/sync.go +++ b/cli/cmd/sync.go @@ -109,6 +109,8 @@ func sync(cmd *cobra.Command, args []string) error { sources := specReader.Sources destinations := specReader.Destinations + transformers := specReader.Transformers + sourcePluginClients := make(managedplugin.Clients, 0) defer func() { if err := sourcePluginClients.Terminate(); err != nil { @@ -220,6 +222,43 @@ func sync(cmd *cobra.Command, args []string) error { destinationPluginClients = append(destinationPluginClients, destPluginClient) } + transformerPluginClients := make(managedplugin.Clients, 0) + defer func() { + if err := transformerPluginClients.Terminate(); err != nil { + fmt.Println(err) + } + }() + for _, transformer := range transformers { + opts := []managedplugin.Option{ + managedplugin.WithLogger(log.Logger), + managedplugin.WithAuthToken(authToken.Value), + managedplugin.WithTeamName(teamName), + managedplugin.WithLicenseFile(licenseFile), + } + if logConsole { + opts = append(opts, managedplugin.WithNoProgress()) + } + if cqDir != "" { + opts = append(opts, managedplugin.WithDirectory(cqDir)) + } + if disableSentry { + opts = append(opts, managedplugin.WithNoSentry()) + } + + cfg := managedplugin.Config{ + Name: transformer.Name, + Registry: SpecRegistryToPlugin(transformer.Registry), + Version: transformer.Version, + Path: transformer.Path, + DockerAuth: transformer.DockerRegistryAuthToken, + } + transPluginClient, err := managedplugin.NewClient(ctx, managedplugin.PluginTransformer, cfg, opts...) + if err != nil { + return enrichClientError(managedplugin.Clients{}, []bool{transformer.RegistryInferred()}, err) + } + transformerPluginClients = append(transformerPluginClients, transPluginClient) + } + for _, source := range sources { cl := sourcePluginClients.ClientByName(source.Name) versions, err := cl.Versions(ctx) @@ -230,12 +269,26 @@ func sync(cmd *cobra.Command, args []string) error { var destinationClientsForSource []*managedplugin.Client var destinationForSourceSpec []specs.Destination + var transformerClientsForDestination = map[string][]*managedplugin.Client{} + var transformerForDestinationSpec = map[string][]specs.Transformer{} var backendClientForSource *managedplugin.Client var destinationForSourceBackendSpec *specs.Destination for _, destination := range destinations { if slices.Contains(source.Destinations, destination.Name) { destinationClientsForSource = append(destinationClientsForSource, destinationPluginClients.ClientByName(destination.Name)) destinationForSourceSpec = append(destinationForSourceSpec, *destination) + + // Each destination defines their own transformers + ts := []*managedplugin.Client{} + tsSpecs := []specs.Transformer{} + for _, transformer := range transformers { + if slices.Contains(destination.Transformers, transformer.Name) { + ts = append(ts, transformerPluginClients.ClientByName(transformer.Name)) + tsSpecs = append(tsSpecs, *transformer) + } + } + transformerClientsForDestination[destination.Name] = ts + transformerForDestinationSpec[destination.Name] = tsSpecs continue } @@ -264,6 +317,12 @@ func sync(cmd *cobra.Command, args []string) error { for field, msg := range destWarnings { log.Warn().Str("destination", destination.Name()).Str("field", field).Msg(msg) } + for _, transformer := range transformerClientsForDestination[destination.Name()] { + transformerWarnings := specReader.GetTransformerWarningsByName(source.Name) + for field, msg := range transformerWarnings { + log.Warn().Str("transformer", transformer.Name()).Str("field", field).Msg(msg) + } + } } src := v3source{ @@ -277,6 +336,16 @@ func sync(cmd *cobra.Command, args []string) error { spec: destinationForSourceSpec[i], }) } + transfs := map[string][]v3transformer{} + for destinationName, transformerClients := range transformerClientsForDestination { + for i, transformer := range transformerClients { + transfs[destinationName] = append(transfs[destinationName], v3transformer{ + client: transformer, + spec: transformerForDestinationSpec[destinationName][i], + }) + } + } + var backend *v3destination if backendClientForSource != nil && destinationForSourceBackendSpec != nil { backend = &v3destination{ @@ -290,7 +359,7 @@ func sync(cmd *cobra.Command, args []string) error { return err } - if err := syncConnectionV3(ctx, src, dests, backend, invocationUUID.String(), noMigrate, summaryLocation); err != nil { + if err := syncConnectionV3(ctx, src, dests, transfs, backend, invocationUUID.String(), noMigrate, summaryLocation); err != nil { return fmt.Errorf("failed to sync v3 source %s: %w", cl.Name(), err) } diff --git a/cli/cmd/sync_v3.go b/cli/cmd/sync_v3.go index 14f05eb5b09d8a..0af5da45efcc1b 100644 --- a/cli/cmd/sync_v3.go +++ b/cli/cmd/sync_v3.go @@ -16,6 +16,7 @@ import ( "github.com/cloudquery/cloudquery/cli/internal/api" "github.com/cloudquery/cloudquery/cli/internal/specs/v0" "github.com/cloudquery/cloudquery/cli/internal/transformer" + "github.com/cloudquery/cloudquery/cli/internal/transformerpipeline" "github.com/cloudquery/plugin-pb-go/managedplugin" "github.com/cloudquery/plugin-pb-go/metrics" "github.com/cloudquery/plugin-pb-go/pb/plugin/v3" @@ -23,6 +24,7 @@ import ( "github.com/rs/zerolog/log" "github.com/schollz/progressbar/v3" "github.com/vnteamopen/godebouncer" + "golang.org/x/sync/errgroup" "google.golang.org/protobuf/types/known/timestamppb" cloudquery_api "github.com/cloudquery/cloudquery-api-go" @@ -39,6 +41,11 @@ type v3destination struct { spec specs.Destination } +type v3transformer struct { + client *managedplugin.Client + spec specs.Transformer +} + func getProgressAPIClient() (*cloudquery_api.ClientWithResponses, error) { authClient := auth.NewTokenClient() if authClient.GetTokenType() != auth.SyncRunAPIKey { @@ -53,7 +60,7 @@ func getProgressAPIClient() (*cloudquery_api.ClientWithResponses, error) { } // nolint:dupl -func syncConnectionV3(ctx context.Context, source v3source, destinations []v3destination, backend *v3destination, uid string, noMigrate bool, summaryLocation string) (syncErr error) { +func syncConnectionV3(ctx context.Context, source v3source, destinations []v3destination, transformersByDestination map[string][]v3transformer, backend *v3destination, uid string, noMigrate bool, summaryLocation string) (syncErr error) { var mt metrics.Metrics var exitReason = ExitReasonStopped skippedFromDeleteStale := make(map[string]bool, 0) @@ -109,14 +116,35 @@ func syncConnectionV3(ctx context.Context, source v3source, destinations []v3des for i := range destinationsClients { destinationStrings[i] = destinationSpecs[i].VersionString() } - log.Info().Str("source", sourceSpec.VersionString()).Strs("destinations", destinationStrings).Time("sync_time", syncTime).Msg("Start sync") - defer log.Info().Str("source", sourceSpec.VersionString()).Strs("destinations", destinationStrings).Time("sync_time", syncTime).Msg("End sync") + + // Get all distinct transformer version strings + transformerStrings := []string{} + _transformerSet := make(map[string]struct{}) + for _, transformers := range transformersByDestination { + for _, tf := range transformers { + name := tf.spec.Name + if _, ok := _transformerSet[name]; !ok { + transformerStrings = append(transformerStrings, tf.spec.VersionString()) + _transformerSet[name] = struct{}{} + } + } + } + + log.Info().Str("source", sourceSpec.VersionString()).Strs("destinations", destinationStrings).Strs("transformers", transformerStrings).Time("sync_time", syncTime).Msg("Start sync") + defer log.Info().Str("source", sourceSpec.VersionString()).Strs("destinations", destinationStrings).Strs("transformers", transformerStrings).Time("sync_time", syncTime).Msg("End sync") variables := specs.Variables{ Plugins: make(map[string]specs.PluginVariables), } sourcePbClient := plugin.NewPluginClient(sourceClient.Conn) destinationsPbClients := make([]plugin.PluginClient, len(destinationsClients)) + transformerPbClientsByDestination := map[string][]plugin.PluginClient{} + for name, transformers := range transformersByDestination { + for _, tf := range transformers { + transformerPbClientsByDestination[name] = append(transformerPbClientsByDestination[name], plugin.NewPluginClient(tf.client.Conn)) + } + } + destinationTransformers := make([]*transformer.RecordTransformer, len(destinationsClients)) backendPbClient := plugin.PluginClient(nil) for i := range destinationsClients { @@ -160,6 +188,13 @@ func syncConnectionV3(ctx context.Context, source v3source, destinations []v3des return fmt.Errorf("failed to init backend %v: %w", backend.spec.Name, err) } } + for name, transformers := range transformersByDestination { + for i, tf := range transformers { + if err := initPlugin(ctx, transformerPbClientsByDestination[name][i], tf.spec.Spec, false, uid); err != nil { + return fmt.Errorf("failed to init transformer %v: %w", tf.spec.Name, err) + } + } + } // replace @@plugins.name.connection with the actual GRPC connection string from the client // NOTE: if this becomes a stable feature, it can move out of sync_v3 and into sync.go @@ -180,11 +215,23 @@ func syncConnectionV3(ctx context.Context, source v3source, destinations []v3des } writeClients := make([]plugin.Plugin_WriteClient, len(destinationsPbClients)) + writeClientsByName := map[string]plugin.Plugin_WriteClient{} for i := range destinationsPbClients { writeClients[i], err = destinationsPbClients[i].Write(ctx) if err != nil { return err } + writeClientsByName[destinationSpecs[i].Name] = writeClients[i] + } + transformClientsByDestination := map[string][]plugin.Plugin_TransformClient{} + for name, transformerPbClients := range transformerPbClientsByDestination { + for _, transformerPbClient := range transformerPbClients { + transformClient, err := transformerPbClient.Transform(ctx) + if err != nil { + return err + } + transformClientsByDestination[name] = append(transformClientsByDestination[name], transformClient) + } } log.Info().Str("source", sourceSpec.VersionString()).Strs("destinations", destinationStrings).Msg("Start fetching resources") @@ -283,97 +330,155 @@ func syncConnectionV3(ctx context.Context, source v3source, destinations []v3des defer remoteProgressReporter.Cancel() } - for { - r, err := syncClient.Recv() + // Note: we want to stop this errorgroup if ctx is cancelled, but we don't want to cancel ctx if gctx is cancelled. + // gctx is always cancelled when the errorgroup returns, and this isn't necessarily an error. + eg, gctx := errgroup.WithContext(ctx) + pipelineByDestinationName := map[string]*transformerpipeline.TransformerPipeline{} + + // Each destination has its own transformer pipeline + for i := range destinationsPbClients { + destinationName := destinationSpecs[i].Name + + // Start a pipeline of transformers that will receive & transform the source records + var ( + pipeline *transformerpipeline.TransformerPipeline + err error + ) + pipeline, gctx, err = transformerpipeline.New(gctx, transformClientsByDestination[destinationName]) if err != nil { - if errors.Is(err, io.EOF) { - break - } - return fmt.Errorf("unexpected error from sync client receive: %w", err) + return fmt.Errorf("failed to create transformer pipeline: %w", err) } - syncResponseMsg := r.GetMessage() - switch m := syncResponseMsg.(type) { - case *plugin.Sync_Response_Insert: - record, err := plugin.NewRecordFromBytes(m.Insert.Record) - if err != nil { - return fmt.Errorf("failed to get record from bytes: %w", err) + err = pipeline.OnOutput(func(recordBytes []byte) error { + wr := &plugin.Write_Request{ + Message: &plugin.Write_Request_Insert{ + Insert: &plugin.Write_MessageInsert{ + Record: recordBytes, + }, + }, + } + if err := writeClientsByName[destinationName].Send(wr); err != nil { + return handleSendError(err, writeClientsByName[destinationName], "insert") } + return nil + }) + if err != nil { + return fmt.Errorf("failed to create register pipeline output: %w", err) + } + eg.Go(pipeline.RunBlocking) // each transformer runs in its own goroutine + pipelineByDestinationName[destinationName] = pipeline + } - atomic.AddInt64(&newResources, record.NumRows()) - atomic.AddInt64(&totalResources, record.NumRows()) - if remoteProgressReporter != nil { - remoteProgressReporter.SendSignal() + eg.Go(func() error { + // Close all transformation pipelines when the source is done + defer func() { + for _, pipeline := range pipelineByDestinationName { + if err := pipeline.Close(); err != nil { + log.Warn().Err(err).Msg("Failed to close transformer pipeline") + } + } + }() + for { + r, err := syncClient.Recv() + if err != nil { + if errors.Is(err, io.EOF) { + break + } + return fmt.Errorf("unexpected error from sync client receive: %w", err) } - for i := range destinationsPbClients { - transformedRecord := destinationTransformers[i].Transform(record) - transformedRecordBytes, err := plugin.RecordToBytes(transformedRecord) + syncResponseMsg := r.GetMessage() + switch m := syncResponseMsg.(type) { + case *plugin.Sync_Response_Insert: + recordBytes := m.Insert.Record + record, err := plugin.NewRecordFromBytes(recordBytes) if err != nil { - return fmt.Errorf("failed to transform record bytes: %w", err) + return fmt.Errorf("failed to get record from bytes: %w", err) } - wr := &plugin.Write_Request{} - wr.Message = &plugin.Write_Request_Insert{ - Insert: &plugin.Write_MessageInsert{ - Record: transformedRecordBytes, - }, + + atomic.AddInt64(&newResources, record.NumRows()) + atomic.AddInt64(&totalResources, record.NumRows()) + if remoteProgressReporter != nil { + remoteProgressReporter.SendSignal() } - if err := writeClients[i].Send(wr); err != nil { - return handleSendError(err, writeClients[i], "insert") + for i := range destinationsPbClients { + destinationName := destinationSpecs[i].Name + transformedRecord := destinationTransformers[i].Transform(record) + transformedRecordBytes, err := plugin.RecordToBytes(transformedRecord) + if err != nil { + return fmt.Errorf("failed to transform record bytes: %w", err) + } + if err := pipelineByDestinationName[destinationName].Send(transformedRecordBytes); err != nil { + return err + } } - } - case *plugin.Sync_Response_DeleteRecord: - for i := range destinationsPbClients { - wr := &plugin.Write_Request{} - // Transformations aren't required here because DeleteRecord is only in V3 - wr.Message = &plugin.Write_Request_DeleteRecord{ - DeleteRecord: &plugin.Write_MessageDeleteRecord{ - TableName: m.DeleteRecord.TableName, - TableRelations: m.DeleteRecord.TableRelations, - WhereClause: m.DeleteRecord.WhereClause, - }, + case *plugin.Sync_Response_DeleteRecord: + for i := range destinationsPbClients { + wr := &plugin.Write_Request{} + // Transformations aren't required here because DeleteRecord is only in V3 + wr.Message = &plugin.Write_Request_DeleteRecord{ + DeleteRecord: &plugin.Write_MessageDeleteRecord{ + TableName: m.DeleteRecord.TableName, + TableRelations: m.DeleteRecord.TableRelations, + WhereClause: m.DeleteRecord.WhereClause, + }, + } + if err := writeClients[i].Send(wr); err != nil { + return handleSendError(err, writeClients[i], "delete") + } } - if err := writeClients[i].Send(wr); err != nil { - return handleSendError(err, writeClients[i], "delete") + case *plugin.Sync_Response_MigrateTable: + sc, err := plugin.NewSchemaFromBytes(m.MigrateTable.Table) + if err != nil { + return err } - } - case *plugin.Sync_Response_MigrateTable: - sc, err := plugin.NewSchemaFromBytes(m.MigrateTable.Table) - if err != nil { - return err - } - table, err := schema.NewTableFromArrowSchema(sc) - if err != nil { - return err - } - - // This works since we sync and send migrate messages for parents before children - if isStateBackendEnabled && (table.IsIncremental || (table.Parent != nil && skippedFromDeleteStale[table.Parent.Name])) { - skippedFromDeleteStale[table.Name] = true - } else { - tablesForDeleteStale[table.Name] = true - } - if noMigrate { - continue - } - for i := range destinationsPbClients { - transformedSchema := destinationTransformers[i].TransformSchema(sc) - transformedSchemaBytes, err := plugin.SchemaToBytes(transformedSchema) + table, err := schema.NewTableFromArrowSchema(sc) if err != nil { return err } - wr := &plugin.Write_Request{} - wr.Message = &plugin.Write_Request_MigrateTable{ - MigrateTable: &plugin.Write_MessageMigrateTable{ - MigrateForce: destinationSpecs[i].MigrateMode == specs.MigrateModeForced, - Table: transformedSchemaBytes, - }, + + // This works since we sync and send migrate messages for parents before children + if isStateBackendEnabled && (table.IsIncremental || (table.Parent != nil && skippedFromDeleteStale[table.Parent.Name])) { + skippedFromDeleteStale[table.Name] = true + } else { + tablesForDeleteStale[table.Name] = true + } + if noMigrate { + continue } - if err := writeClients[i].Send(wr); err != nil { - return handleSendError(err, writeClients[i], "migrate") + for i := range destinationsPbClients { + destinationName := destinationSpecs[i].Name + transformedSchema := destinationTransformers[i].TransformSchema(sc) + transformedSchemaBytes, err := plugin.SchemaToBytes(transformedSchema) + if err != nil { + return err + } + // Sequentially apply schema transformations from transformers + for _, transformerPbClient := range transformerPbClientsByDestination[destinationName] { + resp, err := transformerPbClient.TransformSchema(ctx, &plugin.TransformSchema_Request{Schema: transformedSchemaBytes}) + if err != nil { + return err + } + transformedSchemaBytes = resp.Schema + } + + wr := &plugin.Write_Request{} + wr.Message = &plugin.Write_Request_MigrateTable{ + MigrateTable: &plugin.Write_MessageMigrateTable{ + MigrateForce: destinationSpecs[i].MigrateMode == specs.MigrateModeForced, + Table: transformedSchemaBytes, + }, + } + if err := writeClients[i].Send(wr); err != nil { + return handleSendError(err, writeClients[i], "migrate") + } } + default: + return fmt.Errorf("unknown message type: %T", m) } - default: - return fmt.Errorf("unknown message type: %T", m) } + return nil + }) + if err := eg.Wait(); err != nil { // wait for source & transformers to finish. If any fails, sync fails. + return err } err = syncClient.CloseSend() diff --git a/cli/internal/specs/v0/destination.go b/cli/internal/specs/v0/destination.go index ea2666203b376b..e89fe3c7fa82eb 100644 --- a/cli/internal/specs/v0/destination.go +++ b/cli/internal/specs/v0/destination.go @@ -33,6 +33,9 @@ type Destination struct { SyncSummary bool `json:"send_sync_summary,omitempty"` + // Transformers are the names of transformer plugins to send sync data through + Transformers []string `json:"transformers,omitempty"` + // Destination plugin own (nested) spec Spec map[string]any `json:"spec,omitempty"` } diff --git a/cli/internal/specs/v0/kind.go b/cli/internal/specs/v0/kind.go index ff2560faeb5e1d..ede70c4951e775 100644 --- a/cli/internal/specs/v0/kind.go +++ b/cli/internal/specs/v0/kind.go @@ -13,12 +13,14 @@ type Kind int const ( KindSource Kind = iota KindDestination + KindTransformer ) var ( AllKinds = [...]string{ KindSource: "source", KindDestination: "destination", + KindTransformer: "transformer", } ) diff --git a/cli/internal/specs/v0/spec.go b/cli/internal/specs/v0/spec.go index 1639f3d4bb2bad..db4aeda0c7dbf9 100644 --- a/cli/internal/specs/v0/spec.go +++ b/cli/internal/specs/v0/spec.go @@ -38,6 +38,8 @@ func (s *Spec) UnmarshalJSON(data []byte) error { s.Spec = new(Source) case KindDestination: s.Spec = new(Destination) + case KindTransformer: + s.Spec = new(Transformer) default: return fmt.Errorf("unknown kind %s", s.Kind) } @@ -55,11 +57,12 @@ func (Spec) JSONSchemaExtend(sc *jsonschema.Schema) { // delete & obtain the values source, _ := sc.Properties.Delete("Source") destination, _ := sc.Properties.Delete("Destination") + transformer, _ := sc.Properties.Delete("Transformer") // update `spec` property spec := sc.Properties.Value("spec") // we can use `one_of because source & destination specs are mutually exclusive based on the kind - spec.OneOf = []*jsonschema.Schema{source, destination} + spec.OneOf = []*jsonschema.Schema{source, destination, transformer} sc.AllOf = []*jsonschema.Schema{ { @@ -102,6 +105,26 @@ func (Spec) JSONSchemaExtend(sc *jsonschema.Schema) { }(), }, }, + { + // `kind: transformer` implies transformer spec + If: &jsonschema.Schema{ + Properties: func() *orderedmap.OrderedMap[string, *jsonschema.Schema] { + properties := jsonschema.NewProperties() + kind := *sc.Properties.Value("kind") + kind.Const = "transformer" + kind.Enum = nil + properties.Set("kind", &kind) + return properties + }(), + }, + Then: &jsonschema.Schema{ + Properties: func() *orderedmap.OrderedMap[string, *jsonschema.Schema] { + properties := jsonschema.NewProperties() + properties.Set("spec", transformer) + return properties + }(), + }, + }, } } diff --git a/cli/internal/specs/v0/spec_reader.go b/cli/internal/specs/v0/spec_reader.go index bb9e83630878fb..cf52c47c387e09 100644 --- a/cli/internal/specs/v0/spec_reader.go +++ b/cli/internal/specs/v0/spec_reader.go @@ -20,12 +20,15 @@ import ( type SpecReader struct { sourcesMap map[string]*Source destinationsMap map[string]*Destination + transformersMap map[string]*Transformer sourceWarningsMap map[string]Warnings destinationWarningsMap map[string]Warnings + transformerWarningsMap map[string]Warnings Sources []*Source Destinations []*Destination + Transformers []*Transformer } var fileRegex = regexp.MustCompile(`\$\{file:([^}]+)\}`) @@ -132,6 +135,24 @@ func (r *SpecReader) loadSpecsFromFile(path string) error { } r.destinationsMap[destination.Name] = destination r.Destinations = append(r.Destinations, destination) + case KindTransformer: + transformer := s.Spec.(*Transformer) + if r.transformersMap[transformer.Name] != nil { + return fmt.Errorf("duplicate transformer name %s", transformer.Name) + } + r.transformerWarningsMap[transformer.Name] = transformer.GetWarnings() + transformer.SetDefaults() + if err := transformer.Validate(); err != nil { + return fmt.Errorf("failed to validate transformer %s: %w", transformer.Name, err) + } + if transformer.Registry == RegistryGitHub { + log.Warn(). + Str("name", transformer.Name). + Str("kind", "transformer"). + Msg("registry: github is deprecated & will be removed in future releases") + } + r.transformersMap[transformer.Name] = transformer + r.Transformers = append(r.Transformers, transformer) default: return fmt.Errorf("unknown kind %s", s.Kind) } @@ -184,6 +205,11 @@ func (r *SpecReader) validate() error { if destination.SyncGroupId != "" && destination.WriteMode == WriteModeOverwriteDeleteStale { err = errors.Join(err, fmt.Errorf("destination %s: sync_group_id is not supported with write_mode: %s", destination.Name, destination.WriteMode)) } + for _, transformer := range destination.Transformers { + if r.transformersMap[transformer] == nil { + err = errors.Join(err, fmt.Errorf("destination %s references unknown transformer %s", destination.Name, transformer)) + } + } } return err @@ -220,6 +246,10 @@ func (r *SpecReader) GetDestinationWarningsByName(name string) Warnings { return r.destinationWarningsMap[name] } +func (r *SpecReader) GetTransformerWarningsByName(name string) Warnings { + return r.transformerWarningsMap[name] +} + func (r *SpecReader) GetDestinationNamesForSource(name string) []string { var destinations []string source := r.sourcesMap[name] @@ -261,10 +291,13 @@ func newSpecReader(paths []string) (*SpecReader, error) { reader := &SpecReader{ sourcesMap: make(map[string]*Source), destinationsMap: make(map[string]*Destination), + transformersMap: make(map[string]*Transformer), Sources: make([]*Source, 0), Destinations: make([]*Destination, 0), + Transformers: make([]*Transformer, 0), sourceWarningsMap: make(map[string]Warnings), destinationWarningsMap: make(map[string]Warnings), + transformerWarningsMap: make(map[string]Warnings), } for _, path := range paths { file, err := os.Open(path) diff --git a/cli/internal/specs/v0/transformer.go b/cli/internal/specs/v0/transformer.go new file mode 100644 index 00000000000000..f328936c6fa897 --- /dev/null +++ b/cli/internal/specs/v0/transformer.go @@ -0,0 +1,34 @@ +package specs + +import ( + "bytes" + "encoding/json" +) + +// Transformer plugin spec +type Transformer struct { + Metadata + + // Transformer plugin own (nested) spec + Spec map[string]any `json:"spec,omitempty"` +} + +func (*Transformer) GetWarnings() Warnings { + warnings := make(map[string]string) + return warnings +} + +func (d *Transformer) UnmarshalSpec(out any) error { + b, err := json.Marshal(d.Spec) + if err != nil { + return err + } + dec := json.NewDecoder(bytes.NewReader(b)) + dec.UseNumber() + dec.DisallowUnknownFields() + return dec.Decode(out) +} + +func (d *Transformer) Validate() error { + return d.Metadata.Validate() +} diff --git a/cli/internal/transformerpipeline/identity.go b/cli/internal/transformerpipeline/identity.go new file mode 100644 index 00000000000000..66d1b4a77ea6b6 --- /dev/null +++ b/cli/internal/transformerpipeline/identity.go @@ -0,0 +1,47 @@ +package transformerpipeline + +import ( + "context" + "io" + + "github.com/cloudquery/plugin-pb-go/pb/plugin/v3" + "google.golang.org/grpc/metadata" +) + +// identityTransformer is a transformer mock that does nothing to the data +// it exists so that we can have at least one transformer in the pipeline +type identityTransformer struct { + ch chan []byte +} + +func newIdentityTransformer() *identityTransformer { + return &identityTransformer{ + ch: make(chan []byte), + } +} + +func (t *identityTransformer) Send(req *plugin.Transform_Request) error { + t.ch <- req.Record + return nil +} + +func (t *identityTransformer) Recv() (*plugin.Transform_Response, error) { + bs, ok := <-t.ch + if !ok { + return nil, io.EOF + } + return &plugin.Transform_Response{Record: bs}, nil +} + +// Close the channel! +func (t *identityTransformer) CloseSend() error { + close(t.ch) + return nil +} + +// Must satisfy the Plugin_TransformClient interface +func (identityTransformer) Header() (metadata.MD, error) { return metadata.MD{}, nil } +func (identityTransformer) Trailer() metadata.MD { return metadata.MD{} } +func (identityTransformer) Context() context.Context { return nil } +func (identityTransformer) SendMsg(m any) error { return nil } +func (identityTransformer) RecvMsg(m any) error { return nil } diff --git a/cli/internal/transformerpipeline/identity_test.go b/cli/internal/transformerpipeline/identity_test.go new file mode 100644 index 00000000000000..32aee2908bd4bb --- /dev/null +++ b/cli/internal/transformerpipeline/identity_test.go @@ -0,0 +1,33 @@ +package transformerpipeline + +import ( + "testing" + + "github.com/cloudquery/plugin-pb-go/pb/plugin/v3" + "github.com/stretchr/testify/require" +) + +func TestIdentityTransformer(t *testing.T) { + transformer := newIdentityTransformer() + + // Test sending and receiving a message + testRecord := []byte("test record") + req := &plugin.Transform_Request{Record: testRecord} + + go func() { + require.NoError(t, transformer.Send(req)) + }() + + resp, err := transformer.Recv() + require.NoError(t, err) + + // Since it's an identityTransformer, the record should be the same + require.Equal(t, testRecord, resp.Record, "Records should be the same") + + // Test closing the channel + require.NoError(t, transformer.CloseSend()) + + // Test channel is closed after call to CloseSend + _, ok := <-transformer.ch + require.False(t, ok, "Channel should be closed but it's not") +} diff --git a/cli/internal/transformerpipeline/pipeline.go b/cli/internal/transformerpipeline/pipeline.go new file mode 100644 index 00000000000000..c633dd9f8bf83a --- /dev/null +++ b/cli/internal/transformerpipeline/pipeline.go @@ -0,0 +1,115 @@ +package transformerpipeline + +import ( + "context" + "errors" + "io" + + "github.com/cloudquery/plugin-pb-go/pb/plugin/v3" + "golang.org/x/sync/errgroup" +) + +// TransformerPipeline runs a pipeline of transform clients. +// +// Ideally we'd just call the result of each transform to the next one, but transformations are not synchronous calls, +// so orchestration is needed. That's what this does: it hides the orchestration of the transform clients. +// +// Use it like this: +// +// - Construct a new TransformerPipeline with `New`. Give it a context and a slice of transform clients. +// - Register a callback for transformed records with `OnOutput`. +// - Start all transformers with `RunBlocking`. +// - Send records to the pipeline with `Send`. +// - When done, close the pipeline with `Close`. Otherwise, `RunBlocking` won't finish. +type TransformerPipeline struct { + clientWrappers []clientWrapper + eg *errgroup.Group +} + +func New(ctx context.Context, transformClients []plugin.Plugin_TransformClient) (*TransformerPipeline, context.Context, error) { + var ( + eg, gctx = errgroup.WithContext(ctx) + tp = &TransformerPipeline{clientWrappers: make([]clientWrapper, len(transformClients)), eg: eg} + ) + + // Make sure there's at least one transformer + if len(transformClients) == 0 { + tp.clientWrappers = append(tp.clientWrappers, clientWrapper{client: newIdentityTransformer()}) + } + + // Wrap the clients to add orchestration logic + for i, client := range transformClients { + tp.clientWrappers[i] = clientWrapper{i: i, client: client} + } + + // Connect each client to the next one + for i := 0; i < len(transformClients)-1; i++ { + tp.clientWrappers[i].nextSendFn = tp.clientWrappers[i+1].client.Send + tp.clientWrappers[i].nextClose = tp.clientWrappers[i+1].client.CloseSend + } + + // The last client in the pipeline has nothing else to close + tp.clientWrappers[len(tp.clientWrappers)-1].nextClose = func() error { return nil } + + // The last client sends to the output. This connection happens in `OnOutput`. + + return tp, gctx, nil +} + +func (lp *TransformerPipeline) RunBlocking() error { + for i := len(lp.clientWrappers) - 1; i >= 0; i-- { + lp.eg.Go(lp.clientWrappers[i].startBlocking) + } + return lp.eg.Wait() +} + +func (lp *TransformerPipeline) Send(data []byte) error { + // Constructor makes sure that there is at least one "identity" transform client + if lp.clientWrappers[len(lp.clientWrappers)-1].nextSendFn == nil { + return errors.New("OnOutput must be registered before Send is called, otherwise what do I do with the transformed data?") + } + return lp.clientWrappers[0].client.Send(&plugin.Transform_Request{Record: data}) +} + +func (lp *TransformerPipeline) OnOutput(fn func([]byte) error) error { + if fn == nil { + return errors.New("argument to OnOutput cannot be nil") + } + lp.clientWrappers[len(lp.clientWrappers)-1].nextSendFn = func(req *plugin.Transform_Request) error { + return fn(req.Record) + } + return nil +} + +func (lp *TransformerPipeline) Close() error { + // Close the first transformer. The rest will follow gracefully, otherwise records will be lost. + return lp.clientWrappers[0].client.CloseSend() +} + +type clientWrapper struct { + i int + client plugin.Plugin_TransformClient + nextSendFn func(*plugin.Transform_Request) error + nextClose func() error +} + +func (s clientWrapper) startBlocking() error { + if s.nextSendFn == nil { + return errors.New("nextSendFn is nil") + } + for { + data, err := s.client.Recv() + if err == io.EOF { + err := s.nextClose() + return err + } + if err != nil { + return err + } + if err := s.nextSendFn( + &plugin.Transform_Request{Record: data.Record}, + ); err != nil { + return err + } + } +} diff --git a/cli/internal/transformerpipeline/transformerpipeline_test.go b/cli/internal/transformerpipeline/transformerpipeline_test.go new file mode 100644 index 00000000000000..5557c8aff40371 --- /dev/null +++ b/cli/internal/transformerpipeline/transformerpipeline_test.go @@ -0,0 +1,158 @@ +package transformerpipeline + +import ( + "context" + "io" + "testing" + + "github.com/cloudquery/plugin-pb-go/pb/plugin/v3" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/metadata" +) + +func TestTransformerPipelineDoesntChangeInputsWithTwoIdentityTransformers(t *testing.T) { + var ( + inputs = []string{"test data 1", "test data 2", "test data 3"} + expectedOutputs = []string{"test data 1", "test data 2", "test data 3"} + actualOutputs = []string{} + recordOutputs = func(output []byte) error { actualOutputs = append(actualOutputs, string(output)); return nil } + ) + + transformers := []plugin.Plugin_TransformClient{newIdentityTransformer(), newIdentityTransformer()} + + pipeline, _, err := New(context.Background(), transformers) + require.NoError(t, err) + require.NoError(t, pipeline.OnOutput(recordOutputs)) + + // `Send` and `Close` affect the buffer of the initial transformer, so + // they should succeed even if `RunBlocking` hasn't run yet and will + // take a while to start the `Recv` loops. + // + // In this case, the transformer implementations block `Send` with an + // unbuffered channel, so this goroutine will block until the pipeline + // starts. + go func() { + for _, input := range inputs { + require.NoError(t, pipeline.Send([]byte(input))) + } + require.NoError(t, pipeline.Close()) + }() + + // Blocks until pipeline is closed and all messages passed through + // so above goroutine must have finished after this line. + require.NoError(t, pipeline.RunBlocking()) + + require.Equal(t, expectedOutputs, actualOutputs) +} + +func TestTransformerPipelineReversesInputs(t *testing.T) { + var ( + inputs = []string{"test data 1", "test data 2", "test data 3"} + expectedOutputs = []string{"1 atad tset", "2 atad tset", "3 atad tset"} + actualOutputs = []string{} + recordOutputs = func(output []byte) error { actualOutputs = append(actualOutputs, string(output)); return nil } + ) + + transformers := []plugin.Plugin_TransformClient{newReverserTransformer()} + + pipeline, _, err := New(context.Background(), transformers) + require.NoError(t, err) + require.NoError(t, pipeline.OnOutput(recordOutputs)) + + // `Send` and `Close` affect the buffer of the initial transformer, so + // they should succeed even if `RunBlocking` hasn't run yet and will + // take a while to start the `Recv` loops. + // + // In this case, the transformer implementations block `Send` with an + // unbuffered channel, so this goroutine will block until the pipeline + // starts. + go func() { + for _, input := range inputs { + require.NoError(t, pipeline.Send([]byte(input))) + } + require.NoError(t, pipeline.Close()) + }() + + // Blocks until pipeline is closed and all messages passed through + // so above goroutine must have finished after this line. + require.NoError(t, pipeline.RunBlocking()) + + require.Equal(t, expectedOutputs, actualOutputs) +} + +func TestTransformerPipelineDoesntChangeInputsWithTwoReversers(t *testing.T) { + var ( + inputs = []string{"test data 1", "test data 2", "test data 3"} + expectedOutputs = []string{"test data 1", "test data 2", "test data 3"} // Reversed twice! + actualOutputs = []string{} + recordOutputs = func(output []byte) error { actualOutputs = append(actualOutputs, string(output)); return nil } + ) + + transformers := []plugin.Plugin_TransformClient{newReverserTransformer(), newReverserTransformer()} + + pipeline, _, err := New(context.Background(), transformers) + require.NoError(t, err) + require.NoError(t, pipeline.OnOutput(recordOutputs)) + + // `Send` and `Close` affect the buffer of the initial transformer, so + // they should succeed even if `RunBlocking` hasn't run yet and will + // take a while to start the `Recv` loops. + // + // In this case, the transformer implementations block `Send` with an + // unbuffered channel, so this goroutine will block until the pipeline + // starts. + go func() { + for _, input := range inputs { + require.NoError(t, pipeline.Send([]byte(input))) + } + require.NoError(t, pipeline.Close()) + }() + + // Blocks until pipeline is closed and all messages passed through + // so above goroutine must have finished after this line. + require.NoError(t, pipeline.RunBlocking()) + + require.Equal(t, expectedOutputs, actualOutputs) +} + +// reverserTransformer is a transformer mock that reverses the bytes, as runes, of the data +type reverserTransformer struct { + ch chan []byte +} + +func newReverserTransformer() *reverserTransformer { + return &reverserTransformer{ + ch: make(chan []byte), + } +} + +func (t *reverserTransformer) Send(req *plugin.Transform_Request) error { + t.ch <- req.Record + return nil +} + +func (t *reverserTransformer) Recv() (*plugin.Transform_Response, error) { + bs, ok := <-t.ch + if !ok { + return nil, io.EOF + } + reversed := []rune(string(bs)) + for i, j := 0, len(reversed)-1; i < j; i, j = i+1, j-1 { + reversed[i], reversed[j] = reversed[j], reversed[i] + } + reversedBytes := []byte(string(reversed)) + return &plugin.Transform_Response{Record: reversedBytes}, nil +} + +// Close the channel! +func (t *reverserTransformer) CloseSend() error { + close(t.ch) + return nil +} + +// Must satisfy the Plugin_TransformClient interface +func (reverserTransformer) Header() (metadata.MD, error) { return metadata.MD{}, nil } +func (reverserTransformer) Trailer() metadata.MD { return metadata.MD{} } +func (reverserTransformer) Context() context.Context { return nil } +func (reverserTransformer) SendMsg(m any) error { return nil } +func (reverserTransformer) RecvMsg(m any) error { return nil }