From ff15af87f85582e863ee9fc3bdce3de87f1cb717 Mon Sep 17 00:00:00 2001 From: shimonp21 Date: Fri, 25 Mar 2022 11:26:13 +0300 Subject: [PATCH] feat: Limit scope of views --- pkg/policy/execute.go | 47 ++++++++++++++++++++++++++++++++++---- pkg/policy/execute_test.go | 32 +++++++++++++++++++++++++- 2 files changed, 73 insertions(+), 6 deletions(-) diff --git a/pkg/policy/execute.go b/pkg/policy/execute.go index 92e0dd5919748d..f014e063acd5d5 100644 --- a/pkg/policy/execute.go +++ b/pkg/policy/execute.go @@ -139,9 +139,6 @@ func (e *Executor) Execute(ctx context.Context, req *ExecuteRequest, policy *Pol if err := e.checkFetches(ctx, policy.Config); err != nil { return nil, fmt.Errorf("%s: %w, please run `cloudquery fetch` before running policy", policy.Name, err) } - if err := e.createViews(ctx, policy); err != nil { - return nil, err - } for _, p := range policy.Policies { executor := e.with(p.Name) @@ -158,6 +155,13 @@ func (e *Executor) Execute(ctx context.Context, req *ExecuteRequest, policy *Pol } } + // TODO: A better idea here is to create a new session, create the views, run queries, and close the session. + // This will remove the need for 'deleteViews'. + if err := e.createViews(ctx, policy); err != nil { + return nil, err + } + defer e.deleteViews(ctx, policy) + for _, q := range policy.Checks { e.log = e.log.With("query", q.Name) qr, err := e.executeQuery(ctx, q) @@ -265,17 +269,50 @@ func (e *Executor) executeQuery(ctx context.Context, q *Check) (*QueryResult, er return result, nil } -// createViews creates temporary views for given config.Policy, and any views defined by sub-policies +// createViews creates temporary views for the given policy (but not for its subpolicies) func (e *Executor) createViews(ctx context.Context, policy *Policy) error { for _, v := range policy.Views { e.log.Info("creating policy view", "view", v.Name, "query", v.Query) - if err := e.conn.Exec(ctx, fmt.Sprintf("CREATE OR REPLACE TEMPORARY VIEW %s AS %s", v.Name, v.Query)); err != nil { + if err := e.conn.Exec(ctx, fmt.Sprintf("CREATE TEMPORARY VIEW %s AS %s", v.Name, v.Query)); err != nil { return fmt.Errorf("failed to create view %s/%s: %w", policy.Name, v.Name, err) } } return nil } +// deleteView deletes the temporary views for the given policy (but not for its subpolicies). +// This method should be executed in 'defer' statements, so it doesn't return an error. +func (e *Executor) deleteViews(ctx context.Context, policy *Policy) { + for _, v := range policy.Views { + + // Validate that the view is actually a temp view + data, err := e.conn.Query(ctx, fmt.Sprintf("SELECT table_name FROM INFORMATION_SCHEMA.VIEWS WHERE TABLE_NAME = '%s' and TABLE_SCHEMA LIKE 'pg_temp%%'", v.Name)) + if err != nil { + e.log.Error("Failed to check if view is temporary", "policy", policy.Name, "view", v.Name, "err", err) + continue + } + count := 0 + for data.Next() { + count += 1 + } + if data.Err() != nil { + e.log.Error("Failed to check if view is temporary", "policy", policy.Name, "view", v.Name, "err", data.Err()) + continue + } + // If count is 0 then that means that no temp views with the correct name were found + if count == 0 { + continue + } + + e.log.Info("deleting policy view", "view", v.Name) + + if err := e.conn.Exec(ctx, fmt.Sprintf("DROP VIEW %s", v.Name)); err != nil { + e.log.Error("failed to drop view", "policy", policy.Name, "view", v.Name, "err", err) + continue + } + } +} + func GenerateExecutionResultFile(result *ExecutionResult, outputDir string) error { fs := afero.NewOsFs() diff --git a/pkg/policy/execute_test.go b/pkg/policy/execute_test.go index fa0894f229320b..736bfe243277bb 100644 --- a/pkg/policy/execute_test.go +++ b/pkg/policy/execute_test.go @@ -249,6 +249,28 @@ var ( ExpectOutput: true, }}, } + // views cannot be inherited from parent policies. + multiLayerWithInheritedView = &Policy{ + Name: "test", + Views: []*View{ + { + Name: "testview", + Query: "SELECT 'something'", + }, + }, + Policies: Policies{ + { + Name: "subpolicy", + Checks: []*Check{ + { + Name: "query-with-view", + ExpectOutput: true, + Query: "SELECT * from testview", + }, + }, + }, + }, + } ) func TestExecutor_Execute(t *testing.T) { @@ -322,6 +344,11 @@ func TestExecutor_Execute(t *testing.T) { Pass: true, TotalExpectedResults: 1, }, + { + Name: "multilayer policy w/ using view inherited from parent", + Policy: multiLayerWithInheritedView, + ErrorOutput: "relation \"testview\" does not exist", + }, } conn, tearDownFunc := setupPolicyDatabase(t, t.Name()) @@ -338,7 +365,10 @@ func TestExecutor_Execute(t *testing.T) { filtered := tc.Policy.Filter(tc.Selector) res, err := executor.Execute(context.Background(), execReq, &filtered) if tc.ErrorOutput != "" { - assert.EqualError(t, err, tc.ErrorOutput) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), tc.ErrorOutput) + } + return } else { assert.NoError(t, err) }