diff --git a/src/arithmetic/arithmetic_expression_construct.c b/src/arithmetic/arithmetic_expression_construct.c index 2df0833fdb..7b2cdb7028 100644 --- a/src/arithmetic/arithmetic_expression_construct.c +++ b/src/arithmetic/arithmetic_expression_construct.c @@ -666,6 +666,11 @@ static AR_ExpNode *_AR_ExpNodeFromComprehensionFunction // build a FilterTree to represent this predicate if(predicate_node) { AST_ConvertFilters(&ctx->ft, predicate_node); + // in case of list comprehension, validate that aggregation function is not used in predicate + if(strcmp(func_name, "LIST_COMPREHENSION") == 0 && !FilterTree_Valid(ctx->ft, CYPHER_AST_LIST_COMPREHENSION)) { + rm_free(ctx); + return AR_EXP_NewConstOperandNode(SI_NullVal()); + } } else if(type != CYPHER_AST_LIST_COMPREHENSION) { // Functions like any() and all() must have a predicate node. ErrorCtx_SetError("'%s' function requires a WHERE predicate", func_name); @@ -679,7 +684,14 @@ static AR_ExpNode *_AR_ExpNodeFromComprehensionFunction // in the above query, this will be an operation node representing "val * 2" // this will always be NULL for comprehensions like any() and all() const cypher_astnode_t *eval_node = cypher_ast_list_comprehension_get_eval(comp_exp); - if(eval_node) ctx->eval_exp = _AR_EXP_FromASTNode(eval_node); + if(eval_node) { + ctx->eval_exp = _AR_EXP_FromASTNode(eval_node); + // validate that aggregation function is not used in evaluation node + if(AR_EXP_ContainsAgg(ctx->eval_exp)) { + rm_free(ctx); + return AR_EXP_NewConstOperandNode(SI_NullVal()); + } + } // build an operation node to represent the list comprehension AR_ExpNode *op = AR_EXP_NewOpNode(func_name, true, 2); @@ -888,3 +900,23 @@ AR_ExpNode *AR_EXP_FromASTNode(const cypher_astnode_t *expr) { return root; } +bool AR_EXP_ContainsAgg(const AR_ExpNode *root) { + // Is this an aggregation node? + if(root->type == AR_EXP_OP && root->op.f->aggregate == true) { + ErrorCtx_SetError("Invalid use of aggregating function '%s'", root->op.f->name); + return true; + } + + if(root->type == AR_EXP_OP) { + // Scan child nodes. + for(int i = 0; i < root->op.child_count; i++) { + AR_ExpNode *child = root->op.children[i]; + if(child->type == AR_EXP_OP && child->op.f->aggregate == true) { + ErrorCtx_SetError("Invalid use of aggregating function '%s'", child->op.f->name); + return true; + } + } + } + + return false; +} diff --git a/src/arithmetic/arithmetic_expression_construct.h b/src/arithmetic/arithmetic_expression_construct.h index 3b68704e09..a67a8c278a 100644 --- a/src/arithmetic/arithmetic_expression_construct.h +++ b/src/arithmetic/arithmetic_expression_construct.h @@ -12,3 +12,7 @@ // Construct arithmetic expression from AST node AR_ExpNode *AR_EXP_FromASTNode(const cypher_astnode_t *expr); +// Check if the AR_ExpNode tree contains an aggregating function. +// if an aggregating function is found, then sets the error message and returns true +// otherwise, returns false +bool AR_EXP_ContainsAgg(const AR_ExpNode *root); diff --git a/src/ast/ast_build_filter_tree.c b/src/ast/ast_build_filter_tree.c index 2b5b3774db..c7db4c6625 100644 --- a/src/ast/ast_build_filter_tree.c +++ b/src/ast/ast_build_filter_tree.c @@ -322,7 +322,7 @@ FT_FilterNode *AST_BuildFilterTree(AST *ast) { array_free(call_clauses); } - if(!FilterTree_Valid(filter_tree)) { + if(!FilterTree_Valid(filter_tree, UINT8_MAX)) { // Invalid filter tree structure, a compile-time error has been set. FilterTree_Free(filter_tree); return NULL; @@ -361,7 +361,7 @@ FT_FilterNode *AST_BuildFilterTreeFromClauses if(predicate) AST_ConvertFilters(&filter_tree, predicate); } - if(!FilterTree_Valid(filter_tree)) { + if(!FilterTree_Valid(filter_tree, UINT8_MAX)) { // Invalid filter tree structure, a compile-time error has been set. FilterTree_Free(filter_tree); return NULL; diff --git a/src/execution_plan/execution_plan_build/build_pattern_comprehension_ops.c b/src/execution_plan/execution_plan_build/build_pattern_comprehension_ops.c index e4b3c0ab99..2cdf0860d4 100644 --- a/src/execution_plan/execution_plan_build/build_pattern_comprehension_ops.c +++ b/src/execution_plan/execution_plan_build/build_pattern_comprehension_ops.c @@ -7,6 +7,7 @@ #include "execution_plan_construct.h" #include "RG.h" #include "../ops/ops.h" +#include "../../errors.h" #include "../../query_ctx.h" #include "../../util/rax_extensions.h" #include "../../ast/ast_build_filter_tree.h" @@ -86,6 +87,11 @@ void buildPatternComprehensionOps eval_node = cypher_ast_pattern_comprehension_get_eval(pc); AR_ExpNode *eval_exp = AR_EXP_FromASTNode(eval_node); + // check that evaluation node does not contains aggregating function + if (AR_EXP_ContainsAgg(eval_exp)) { + break; + } + // collect evaluation results into an array using `collect` AR_ExpNode *collect_exp = AR_EXP_NewOpNode("collect", false, 1); collect_exp->op.children[0] = eval_exp; @@ -105,7 +111,7 @@ void buildPatternComprehensionOps FT_FilterNode *filter_tree = NULL; AST_ConvertFilters(&filter_tree, predicate); - if(!FilterTree_Valid(filter_tree)) { + if(!FilterTree_Valid(filter_tree, CYPHER_AST_PATTERN_COMPREHENSION)) { // Invalid filter tree structure, a compile-time error has been set. FilterTree_Free(filter_tree); } else { diff --git a/src/filter_tree/filter_tree.c b/src/filter_tree/filter_tree.c index 8002cf8a2f..27b5b8a10d 100644 --- a/src/filter_tree/filter_tree.c +++ b/src/filter_tree/filter_tree.c @@ -13,6 +13,7 @@ #include "../util/rmalloc.h" #include "../ast/ast_shared.h" #include "../datatypes/array.h" +#include "../arithmetic/arithmetic_expression_construct.h" // forward declarations void _FilterTree_DeMorgan @@ -666,7 +667,8 @@ static inline bool _FilterTree_ValidExpressionNode bool FilterTree_Valid ( - const FT_FilterNode *root + const FT_FilterNode *root, + cypher_astnode_type_t type ) { // An empty tree is has a valid structure. if(!root) return true; @@ -681,6 +683,12 @@ bool FilterTree_Valid ErrorCtx_SetError("Filter predicate did not compare two expressions."); return false; } + // Aggregate functions can't be used as part of filters predicate in + // either pattern comprehension node or list comprehension node + if ((type == CYPHER_AST_PATTERN_COMPREHENSION || type == CYPHER_AST_LIST_COMPREHENSION) && + (AR_EXP_ContainsAgg(root->pred.lhs) || AR_EXP_ContainsAgg(root->pred.rhs))) { + return false; + } break; case FT_N_COND: // Empty condition, invalid structure. @@ -694,8 +702,8 @@ bool FilterTree_Valid ErrorCtx_SetError("Invalid usage of 'NOT' filter."); return false; } - if(!FilterTree_Valid(root->cond.left)) return false; - if(!FilterTree_Valid(root->cond.right)) return false; + if(!FilterTree_Valid(root->cond.left, type)) return false; + if(!FilterTree_Valid(root->cond.right, type)) return false; break; default: ASSERT("Unknown filter tree node" && false); diff --git a/src/filter_tree/filter_tree.h b/src/filter_tree/filter_tree.h index fa09e33cf6..f0a49c1cd3 100644 --- a/src/filter_tree/filter_tree.h +++ b/src/filter_tree/filter_tree.h @@ -181,7 +181,8 @@ FT_FilterNode *FilterTree_Combine // a condition or predicate node can't be childless bool FilterTree_Valid ( - const FT_FilterNode *root + const FT_FilterNode *root, + cypher_astnode_type_t type ); // remove NOT nodes by applying DeMorgan laws diff --git a/tests/flow/test_comprehension_functions.py b/tests/flow/test_comprehension_functions.py index dd1dd2707c..3a982fd743 100644 --- a/tests/flow/test_comprehension_functions.py +++ b/tests/flow/test_comprehension_functions.py @@ -41,6 +41,13 @@ def populate_graph(self): redis_graph.commit() + def expect_error(self, query, expected_err_msg): + try: + redis_graph.query(query) + assert(False) + except redis.exceptions.ResponseError as e: + self.env.assertIn(expected_err_msg, str(e)) + # Test list comprehension queries with scalar inputs and a single result row def test01_list_comprehension_single_return(self): expected_result = [[[2, 6]]] @@ -462,3 +469,27 @@ def test20_pattern_comprehension_in_switch_case(self): expected_result = [[[1, 1]], [[2]], [[3]], [[]]] self.env.assertEquals(actual_result.result_set, expected_result) + def test21_pattern_comprehension_with_invalid_filters(self): + # A list of queries and errors which are expected to occur with the specified query. + queries_with_errors = { + "MATCH (x) RETURN [(x)--(z) WHERE collect(z.a) > 10 | z.b]" : "Invalid use of aggregating function 'collect'", + "MATCH (x) RETURN [(x)--(z) WHERE z.a > collect(10) | z.b]" : "Invalid use of aggregating function 'collect'", + "MATCH (x) RETURN [(x)--(z) WHERE z.a > 10 | collect(z.b)]" : "Invalid use of aggregating function 'collect'", + "MATCH (x) RETURN [(x)--(z) WHERE (min(z.a) - sum(z.b)) > z.d | z.b]" : "Invalid use of aggregating function 'min'" + } + for query, error in queries_with_errors.items(): + self.expect_error(query, error) + + def test22_list_comprehension(self): + # A list of queries and errors which are expected to occur with the specified query. + queries_with_errors = { + "RETURN [x IN range(1,10) | sum(1)]" : "Invalid use of aggregating function 'sum'", + "RETURN [x IN range(1,10) | sum(x)]" : "Invalid use of aggregating function 'sum'", + "RETURN [x IN range(1,10) | 1 + min(x)]" : "Invalid use of aggregating function 'min'", + "RETURN [x IN range(1,10) | max(x) + min(x)]" : "Invalid use of aggregating function 'max'", + "RETURN [x IN range(1,10) | collect(x)]" : "Invalid use of aggregating function 'collect'", + "RETURN [x IN range(1,10) WHERE x % 2 > avg(x) ] AS r" : "Invalid use of aggregating function 'avg'", + "RETURN [x IN range(1,10) WHERE avg(x) > x - 1 ] AS r" : "Invalid use of aggregating function 'avg'", + } + for query, error in queries_with_errors.items(): + self.expect_error(query, error)