Skip to content

Commit 795c8e9

Browse files
committed
条件组合:限制 @combine:value 中的 value 的括号嵌套深度、key 数量、key 重复次数等
1 parent d29d079 commit 795c8e9

File tree

1 file changed

+70
-7
lines changed

1 file changed

+70
-7
lines changed

APIJSONORM/src/main/java/apijson/orm/AbstractSQLConfig.java

Lines changed: 70 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,8 @@
3737
import java.util.ArrayList;
3838
import java.util.Arrays;
3939
import java.util.Collection;
40-
import java.util.Deque;
4140
import java.util.HashMap;
42-
import java.util.HashSet;
4341
import java.util.LinkedHashMap;
44-
import java.util.LinkedList;
4542
import java.util.List;
4643
import java.util.Map;
4744
import java.util.Map.Entry;
@@ -81,7 +78,13 @@
8178
*/
8279
public abstract class AbstractSQLConfig implements SQLConfig {
8380
private static final String TAG = "AbstractSQLConfig";
84-
81+
82+
public static int MAX_COMBINE_DEPTH = 2;
83+
public static int MAX_WHERE_COUNT = 3;
84+
public static int MAX_COMBINE_COUNT = 5;
85+
public static int MAX_COMBINE_KEY_COUNT = 2;
86+
public static float MAX_COMBINE_RATIO = 1.0f;
87+
8588
public static String DEFAULT_DATABASE = DATABASE_MYSQL;
8689
public static String DEFAULT_SCHEMA = "sys";
8790
public static String PREFFIX_DISTINCT = "DISTINCT ";
@@ -102,6 +105,7 @@ public abstract class AbstractSQLConfig implements SQLConfig {
102105
// 允许调用的 SQL 函数:当 substring 为 null 时忽略;当 substring 为 "" 时整个 value 是 raw SQL;其它情况则只是 substring 这段为 raw SQL
103106
public static final Map<String, String> SQL_FUNCTION_MAP;
104107

108+
105109
static { // 凡是 SQL 边界符、分隔符、注释符 都不允许,例如 ' " ` ( ) ; # -- /**/ ,以免拼接 SQL 时被注入意外可执行指令
106110
PATTERN_RANGE = Pattern.compile("^[0-9%,!=\\<\\>/\\.\\+\\-\\*\\^]+$"); // ^[a-zA-Z0-9_*%!=<>(),"]+$ 导致 exists(select*from(Comment)) 通过!
107111
PATTERN_FUNCTION = Pattern.compile("^[A-Za-z0-9%,:_@&~`!=\\<\\>\\|\\[\\]\\{\\} /\\.\\+\\-\\*\\^\\?\\(\\)\\$]+$"); //TODO 改成更好的正则,校验前面为单词,中间为操作符,后面为值
@@ -2140,6 +2144,23 @@ public SQLConfig setCast(Map<String, String> cast) {
21402144

21412145
//WHERE <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
21422146

2147+
protected int getMaxWhereCount() {
2148+
return MAX_WHERE_COUNT;
2149+
}
2150+
protected int getMaxCombineDepth() {
2151+
return MAX_COMBINE_DEPTH;
2152+
}
2153+
protected int getMaxCombineCount() {
2154+
return MAX_COMBINE_COUNT;
2155+
}
2156+
protected int getMaxCombineKeyCount() {
2157+
return MAX_COMBINE_KEY_COUNT;
2158+
}
2159+
protected float getMaxCombineRatio() {
2160+
return MAX_COMBINE_RATIO;
2161+
}
2162+
2163+
21432164
@Override
21442165
public String getCombineExpression() {
21452166
return combineExpression;
@@ -2329,10 +2350,26 @@ public String getWhereString(boolean hasPrefix, RequestMethod method, Map<String
23292350
+ "逻辑连接符 & | 左右必须各一个相邻空格!左括号 ( 右边和右括号 ) 左边都不允许有相邻空格!");
23302351
}
23312352

2353+
if (where == null) {
2354+
where = new HashMap<>();
2355+
}
2356+
int whereSize = where.size();
2357+
2358+
int maxWhereCount = getMaxWhereCount();
2359+
if (maxWhereCount > 0 && whereSize > maxWhereCount) {
2360+
throw new IllegalArgumentException(table + ":{ key0:value0, key1:value1... } 中条件 key:value 数量 " + whereSize + " 已超过最大数量,必须在 0-" + maxWhereCount + " 内!");
2361+
}
23322362

23332363
String whereString = "";
23342364

2365+
int maxDepth = getMaxCombineDepth();
2366+
int maxCombineCount = getMaxCombineCount();
2367+
int maxCombineKeyCount = getMaxCombineKeyCount();
2368+
float maxCombineRatio = getMaxCombineRatio();
2369+
23352370
int depth = 0;
2371+
int allCount = 0;
2372+
23362373
int n = s.length();
23372374
int i = 0;
23382375

@@ -2341,7 +2378,7 @@ public String getWhereString(boolean hasPrefix, RequestMethod method, Map<String
23412378
boolean first = true;
23422379

23432380
String key = "";
2344-
Set<String> usedKeySet = new HashSet<>(where.size());
2381+
Map<String, Integer> usedKeyCountMap = new HashMap<>(whereSize);
23452382
while (i <= n) { // "date> | (contactIdList<> & (name*~ | tag&$))"
23462383
boolean isOver = i >= n;
23472384
char c = isOver ? 0 : s.charAt(i);
@@ -2359,6 +2396,17 @@ public String getWhereString(boolean hasPrefix, RequestMethod method, Map<String
23592396
throw new IllegalArgumentException(table + ":{ @combine: '" + combine + "' } 中字符 '" + s.substring(i - key.length() - (isOver ? 1 : 0))
23602397
+ "' 不合法!左边缺少 & | 其中一个逻辑连接符!");
23612398
}
2399+
2400+
allCount ++;
2401+
if (allCount > maxCombineCount && maxCombineCount > 0) {
2402+
throw new IllegalArgumentException(table + ":{ @combine: '" + combine + "' } 中字符 '" + s + "' 不合法!"
2403+
+ "其中 key 数量 " + allCount + " 已超过最大值,必须在条件键值对数量 0-" + maxCombineCount + " 内!");
2404+
}
2405+
if (1.0f*allCount/whereSize > maxCombineRatio && maxCombineRatio > 0) {
2406+
throw new IllegalArgumentException(table + ":{ @combine: '" + combine + "' } 中字符 '" + s + "' 不合法!"
2407+
+ "其中 key 数量 " + allCount + " / 条件键值对数量 " + whereSize + " = " + (1.0f*allCount/whereSize)
2408+
+ " 已超过 最大倍数,必须在条件键值对数量 0-" + maxCombineRatio + " 倍内!");
2409+
}
23622410

23632411
Object value = where.get(key);
23642412
if (value == null) {
@@ -2370,7 +2418,16 @@ public String getWhereString(boolean hasPrefix, RequestMethod method, Map<String
23702418
throw new IllegalArgumentException(table + ":{ @combine: '" + combine + "' } 中字符 '" + key + "' 对应的 " + key + ":value 不是有效条件键值对!");
23712419
}
23722420

2373-
usedKeySet.add(key);
2421+
Integer count = usedKeyCountMap.get(key);
2422+
count = count == null ? 1 : count + 1;
2423+
if (count > maxCombineKeyCount && maxCombineKeyCount > 0) {
2424+
throw new IllegalArgumentException(table + ":{ @combine: '" + combine + "' } 中字符 '" + s + "' 不合法!其中 '" + key
2425+
+ "' 重复引用,次数 " + count + " 已超过最大值,必须在 0-" + maxCombineKeyCount + " 内!");
2426+
}
2427+
2428+
usedKeyCountMap.put(key, count);
2429+
2430+
23742431
whereString += "( " + wi + " )";
23752432
first = false;
23762433
}
@@ -2422,6 +2479,10 @@ else if (c == '(') {
24222479
}
24232480

24242481
depth ++;
2482+
if (depth > maxDepth && maxDepth > 0) {
2483+
throw new IllegalArgumentException(table + ":{ @combine: '" + combine + "' } 中字符 '" + s.substring(0, i + 1) + "' 不合法!括号 (()) 嵌套层级 " + depth + " 已超过最大值,必须在 0-" + maxDepth + " 内!");
2484+
}
2485+
24252486
whereString += c;
24262487
lastLogic = 0;
24272488
first = true;
@@ -2454,7 +2515,7 @@ else if (c == ')') {
24542515

24552516
for (Entry<String, Object> entry : set) {
24562517
key = entry == null ? null : entry.getKey();
2457-
if (key == null || usedKeySet.contains(key)) {
2518+
if (key == null || usedKeyCountMap.containsKey(key)) {
24582519
continue;
24592520
}
24602521

@@ -2609,6 +2670,8 @@ else if (isSideJoin) { // ^ SIDE JOIN: ! (A & B)
26092670
return result;
26102671
}
26112672

2673+
2674+
26122675
public String getWhereString(boolean hasPrefix, RequestMethod method, Map<String, Object> where, Map<String, List<String>> combine, List<Join> joinList, boolean verifyName) throws Exception {
26132676
Set<Entry<String, List<String>>> combineSet = combine == null ? null : combine.entrySet();
26142677
if (combineSet == null || combineSet.isEmpty()) {

0 commit comments

Comments
 (0)