Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
WIP: Set error in AR_EXP_ContainsAgg()
  • Loading branch information
nafraf committed Feb 20, 2023
commit 9979fc45758f818f0ea6951e0a36743f19e42b6f
2 changes: 2 additions & 0 deletions src/arithmetic/arithmetic_expression_construct.c
Original file line number Diff line number Diff line change
Expand Up @@ -902,6 +902,7 @@ AR_ExpNode *AR_EXP_FromASTNode(const cypher_astnode_t *expr) {
bool AR_EXP_ContainsAgg(const AR_ExpNode *root) {
// Is this an aggregation node?
if(root->type == AR_EXP_OP && root->op.f->aggregate == true) {
ErrorCtx_SetError("Invalid use of aggregating function '%s'", root->op.f->name);
return true;
}

Expand All @@ -910,6 +911,7 @@ bool AR_EXP_ContainsAgg(const AR_ExpNode *root) {
for(int i = 0; i < root->op.child_count; i++) {
AR_ExpNode *child = root->op.children[i];
if(child->type == AR_EXP_OP && child->op.f->aggregate == true) {
ErrorCtx_SetError("Invalid use of aggregating function '%s'", child->op.f->name);
return true;
}
}
Expand Down
6 changes: 4 additions & 2 deletions src/arithmetic/arithmetic_expression_construct.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,7 @@
// Construct arithmetic expression from AST node
AR_ExpNode *AR_EXP_FromASTNode(const cypher_astnode_t *expr);

// Detect if expression contains aggreation function
bool AR_EXP_ContainsAgg(const AR_ExpNode *root);
// Check if the AR_ExpNode tree contains an aggregating function.
// if an aggregating function is found, then sets the error message and returns true
// otherwise, returns false
bool AR_EXP_ContainsAgg(const AR_ExpNode *root);
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,8 @@ void buildPatternComprehensionOps
eval_node = cypher_ast_pattern_comprehension_get_eval(pc);
AR_ExpNode *eval_exp = AR_EXP_FromASTNode(eval_node);

// check that evaluation node does not contains aggregating function
if (AR_EXP_ContainsAgg(eval_exp)) {
// TO DO: Determine aggregation function name
ErrorCtx_SetError("Invalid use of aggregating function");
return;
}

Expand Down
13 changes: 3 additions & 10 deletions src/filter_tree/filter_tree.c
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "../util/rmalloc.h"
#include "../ast/ast_shared.h"
#include "../datatypes/array.h"
#include "../arithmetic/arithmetic_expression_construct.h"

// forward declarations
void _FilterTree_DeMorgan
Expand Down Expand Up @@ -683,17 +684,9 @@ bool FilterTree_Valid
return false;
}
// Aggregate functions can't be used as part of filters in pattern comprehension node
if (type == CYPHER_AST_PATTERN_COMPREHENSION) {
const char *func_name = AR_EXP_GetFuncName(root->pred.lhs);
if(AR_FuncIsAggregate(func_name)) {
ErrorCtx_SetError("Invalid use of aggregating function '%s' in pattern comprehension predicate", func_name);
if (type == CYPHER_AST_PATTERN_COMPREHENSION &&
(AR_EXP_ContainsAgg(root->pred.lhs) || AR_EXP_ContainsAgg(root->pred.rhs))) {
return false;
}
func_name = AR_EXP_GetFuncName(root->pred.rhs);
if(AR_FuncIsAggregate(func_name)) {
ErrorCtx_SetError("Invalid use of aggregating function '%s' in pattern comprehension predicate", func_name);
return false;
}
}
break;
case FT_N_COND:
Expand Down
7 changes: 4 additions & 3 deletions tests/flow/test_comprehension_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,9 +472,10 @@ def test20_pattern_comprehension_in_switch_case(self):
def test21_pattern_comprehension_with_invalid_filters(self):
# A list of queries and errors which are expected to occur with the specified query.
queries_with_errors = {
"MATCH (x) RETURN [(x)--(z) WHERE collect(z.a) > 10 | z.b]": "Invalid use of aggregating function 'collect' in pattern comprehension predicate",
"MATCH (x) RETURN [(x)--(z) WHERE z.a > collect(10) | z.b]": "Invalid use of aggregating function 'collect' in pattern comprehension predicate",
"MATCH (x) RETURN [(x)--(z) WHERE z.a > 10 | collect(z.b)]": "Invalid use of aggregating function 'collect'",
"MATCH (x) RETURN [(x)--(z) WHERE collect(z.a) > 10 | z.b]" : "Invalid use of aggregating function 'collect'",
"MATCH (x) RETURN [(x)--(z) WHERE z.a > collect(10) | z.b]" : "Invalid use of aggregating function 'collect'",
"MATCH (x) RETURN [(x)--(z) WHERE z.a > 10 | collect(z.b)]" : "Invalid use of aggregating function 'collect'",
"MATCH (x) RETURN [(x)--(z) WHERE (min(z.a) - sum(z.b)) > z.d | z.b]" : "Invalid use of aggregating function 'min'"
}
for query, error in queries_with_errors.items():
self.expect_error(query, error)