diff --git a/cmd/policy_run.go b/cmd/policy_run.go index 7dcbc4b4b5c331..595a2e3591aeb5 100644 --- a/cmd/policy_run.go +++ b/cmd/policy_run.go @@ -6,6 +6,7 @@ import ( "github.com/cloudquery/cloudquery/pkg/ui/console" "github.com/spf13/cobra" + "github.com/spf13/viper" ) const policyRunHelpMsg = "Executes a policy on CloudQuery database" @@ -47,6 +48,10 @@ func init() { flags := policyRunCmd.Flags() flags.StringVar(&outputDir, "output-dir", "", "Generates a new file for each policy at the given dir with the output") flags.BoolVar(&noResults, "no-results", false, "Do not show policies results") + flags.Bool("disable-fetch-check", false, "Disable checking if a respective fetch happened before running policies") + + _ = viper.BindPFlag("disable-fetch-check", flags.Lookup("disable-fetch-check")) + policyRunCmd.SetUsageTemplate(usageTemplateWithFlags) policyCmd.AddCommand(policyRunCmd) } diff --git a/pkg/policy/execute.go b/pkg/policy/execute.go index f014e063acd5d5..291d93534bedac 100644 --- a/pkg/policy/execute.go +++ b/pkg/policy/execute.go @@ -13,6 +13,7 @@ import ( "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-version" "github.com/spf13/afero" + "github.com/spf13/viper" ) var ErrPolicyOrQueryNotFound = errors.New("selected policy/query not found") @@ -136,8 +137,11 @@ func (e *Executor) Execute(ctx context.Context, req *ExecuteRequest, policy *Pol if err := e.checkVersions(policy.Config, req.ProviderVersions); err != nil { return nil, fmt.Errorf("%s: %w", policy.Name, err) } - if err := e.checkFetches(ctx, policy.Config); err != nil { - return nil, fmt.Errorf("%s: %w, please run `cloudquery fetch` before running policy", policy.Name, err) + + if !viper.GetBool("disable-fetch-check") { + if err := e.checkFetches(ctx, policy.Config); err != nil { + return nil, fmt.Errorf("%s: %w, please run `cloudquery fetch` before running policy", policy.Name, err) + } } for _, p := range policy.Policies { diff --git a/pkg/policy/execute_test.go b/pkg/policy/execute_test.go index 44008b94dd086b..6357c7cbabfaf5 100644 --- a/pkg/policy/execute_test.go +++ b/pkg/policy/execute_test.go @@ -14,6 +14,8 @@ import ( "github.com/cloudquery/cq-provider-sdk/provider/execution" "github.com/google/uuid" "github.com/hashicorp/go-hclog" + "github.com/hashicorp/go-version" + "github.com/spf13/viper" "github.com/stretchr/testify/assert" ) @@ -405,6 +407,80 @@ func setupCheckFetchDatabase(db execution.QueryExecer, summary *meta_storage.Fet }, nil } +func TestExecuter_DisbleFetchCheckFlag(t *testing.T) { + db, err := sdkdb.New(context.Background(), hclog.NewNullLogger(), testDBConnection) + assert.NoError(t, err) + + metaStorage := meta_storage.NewClient(db, hclog.NewNullLogger()) + + _, de, err := database.GetExecutor(hclog.NewNullLogger(), testDBConnection, &history.Config{}) + if err != nil { + t.Fatal(fmt.Errorf("getExecutor: %w", err)) + } + + err = metaStorage.MigrateCore(context.Background(), de) + assert.NoError(t, err) + + executor := NewExecutor(db, hclog.Default(), nil) + + policy := &Policy{ + Name: "test", + Policies: nil, + Checks: []*Check{{ + Query: "SELECT 1 as result;", + ExpectOutput: true, + }}, + Config: &Configuration{ + Providers: []*Provider{ + { + Type: "testProvider", + Version: ">0.0.0", + }, + }, + }, + } + + testCases := []struct { + Name string + DisableFetchCheck bool + ExpectedError error + }{{ + Name: "fetch_check_enabled", + DisableFetchCheck: false, + ExpectedError: errors.New("could not find a completed fetch for requested provider"), + }, + { + Name: "fetch_check_disabled", + DisableFetchCheck: true, + ExpectedError: nil, + }, + } + + testProviderVersion, err := version.NewVersion("0.1.0") + assert.NoError(t, err) + + executeRequest := &ExecuteRequest{ + Policy: policy, + ProviderVersions: map[string]*version.Version{"testProvider": testProviderVersion}, + } + + for _, tc := range testCases { + t.Run(tc.Name, func(t *testing.T) { + defer viper.Reset() + viper.Set("disable-fetch-check", tc.DisableFetchCheck) + + _, err = executor.Execute(context.Background(), executeRequest, policy) + + if tc.ExpectedError == nil { + assert.NoError(t, err) + } else { + assert.Contains(t, err.Error(), tc.ExpectedError.Error()) + } + }) + } + +} + func TestExecutor_CheckFetches(t *testing.T) { // create database connection db, err := sdkdb.New(context.Background(), hclog.NewNullLogger(), testDBConnection)