3737import java .util .ArrayList ;
3838import java .util .Arrays ;
3939import java .util .Collection ;
40- import java .util .Deque ;
4140import java .util .HashMap ;
42- import java .util .HashSet ;
4341import java .util .LinkedHashMap ;
44- import java .util .LinkedList ;
4542import java .util .List ;
4643import java .util .Map ;
4744import java .util .Map .Entry ;
8178 */
8279public abstract class AbstractSQLConfig implements SQLConfig {
8380 private static final String TAG = "AbstractSQLConfig" ;
84-
81+
82+ public static int MAX_COMBINE_DEPTH = 2 ;
83+ public static int MAX_WHERE_COUNT = 3 ;
84+ public static int MAX_COMBINE_COUNT = 5 ;
85+ public static int MAX_COMBINE_KEY_COUNT = 2 ;
86+ public static float MAX_COMBINE_RATIO = 1.0f ;
87+
8588 public static String DEFAULT_DATABASE = DATABASE_MYSQL ;
8689 public static String DEFAULT_SCHEMA = "sys" ;
8790 public static String PREFFIX_DISTINCT = "DISTINCT " ;
@@ -102,6 +105,7 @@ public abstract class AbstractSQLConfig implements SQLConfig {
102105 // 允许调用的 SQL 函数:当 substring 为 null 时忽略;当 substring 为 "" 时整个 value 是 raw SQL;其它情况则只是 substring 这段为 raw SQL
103106 public static final Map <String , String > SQL_FUNCTION_MAP ;
104107
108+
105109 static { // 凡是 SQL 边界符、分隔符、注释符 都不允许,例如 ' " ` ( ) ; # -- /**/ ,以免拼接 SQL 时被注入意外可执行指令
106110 PATTERN_RANGE = Pattern .compile ("^[0-9%,!=\\ <\\ >/\\ .\\ +\\ -\\ *\\ ^]+$" ); // ^[a-zA-Z0-9_*%!=<>(),"]+$ 导致 exists(select*from(Comment)) 通过!
107111 PATTERN_FUNCTION = Pattern .compile ("^[A-Za-z0-9%,:_@&~`!=\\ <\\ >\\ |\\ [\\ ]\\ {\\ } /\\ .\\ +\\ -\\ *\\ ^\\ ?\\ (\\ )\\ $]+$" ); //TODO 改成更好的正则,校验前面为单词,中间为操作符,后面为值
@@ -2140,6 +2144,23 @@ public SQLConfig setCast(Map<String, String> cast) {
21402144
21412145 //WHERE <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
21422146
2147+ protected int getMaxWhereCount () {
2148+ return MAX_WHERE_COUNT ;
2149+ }
2150+ protected int getMaxCombineDepth () {
2151+ return MAX_COMBINE_DEPTH ;
2152+ }
2153+ protected int getMaxCombineCount () {
2154+ return MAX_COMBINE_COUNT ;
2155+ }
2156+ protected int getMaxCombineKeyCount () {
2157+ return MAX_COMBINE_KEY_COUNT ;
2158+ }
2159+ protected float getMaxCombineRatio () {
2160+ return MAX_COMBINE_RATIO ;
2161+ }
2162+
2163+
21432164 @ Override
21442165 public String getCombineExpression () {
21452166 return combineExpression ;
@@ -2329,10 +2350,26 @@ public String getWhereString(boolean hasPrefix, RequestMethod method, Map<String
23292350 + "逻辑连接符 & | 左右必须各一个相邻空格!左括号 ( 右边和右括号 ) 左边都不允许有相邻空格!" );
23302351 }
23312352
2353+ if (where == null ) {
2354+ where = new HashMap <>();
2355+ }
2356+ int whereSize = where .size ();
2357+
2358+ int maxWhereCount = getMaxWhereCount ();
2359+ if (maxWhereCount > 0 && whereSize > maxWhereCount ) {
2360+ throw new IllegalArgumentException (table + ":{ key0:value0, key1:value1... } 中条件 key:value 数量 " + whereSize + " 已超过最大数量,必须在 0-" + maxWhereCount + " 内!" );
2361+ }
23322362
23332363 String whereString = "" ;
23342364
2365+ int maxDepth = getMaxCombineDepth ();
2366+ int maxCombineCount = getMaxCombineCount ();
2367+ int maxCombineKeyCount = getMaxCombineKeyCount ();
2368+ float maxCombineRatio = getMaxCombineRatio ();
2369+
23352370 int depth = 0 ;
2371+ int allCount = 0 ;
2372+
23362373 int n = s .length ();
23372374 int i = 0 ;
23382375
@@ -2341,7 +2378,7 @@ public String getWhereString(boolean hasPrefix, RequestMethod method, Map<String
23412378 boolean first = true ;
23422379
23432380 String key = "" ;
2344- Set <String > usedKeySet = new HashSet <>(where . size () );
2381+ Map <String , Integer > usedKeyCountMap = new HashMap <>(whereSize );
23452382 while (i <= n ) { // "date> | (contactIdList<> & (name*~ | tag&$))"
23462383 boolean isOver = i >= n ;
23472384 char c = isOver ? 0 : s .charAt (i );
@@ -2359,6 +2396,17 @@ public String getWhereString(boolean hasPrefix, RequestMethod method, Map<String
23592396 throw new IllegalArgumentException (table + ":{ @combine: '" + combine + "' } 中字符 '" + s .substring (i - key .length () - (isOver ? 1 : 0 ))
23602397 + "' 不合法!左边缺少 & | 其中一个逻辑连接符!" );
23612398 }
2399+
2400+ allCount ++;
2401+ if (allCount > maxCombineCount && maxCombineCount > 0 ) {
2402+ throw new IllegalArgumentException (table + ":{ @combine: '" + combine + "' } 中字符 '" + s + "' 不合法!"
2403+ + "其中 key 数量 " + allCount + " 已超过最大值,必须在条件键值对数量 0-" + maxCombineCount + " 内!" );
2404+ }
2405+ if (1.0f *allCount /whereSize > maxCombineRatio && maxCombineRatio > 0 ) {
2406+ throw new IllegalArgumentException (table + ":{ @combine: '" + combine + "' } 中字符 '" + s + "' 不合法!"
2407+ + "其中 key 数量 " + allCount + " / 条件键值对数量 " + whereSize + " = " + (1.0f *allCount /whereSize )
2408+ + " 已超过 最大倍数,必须在条件键值对数量 0-" + maxCombineRatio + " 倍内!" );
2409+ }
23622410
23632411 Object value = where .get (key );
23642412 if (value == null ) {
@@ -2370,7 +2418,16 @@ public String getWhereString(boolean hasPrefix, RequestMethod method, Map<String
23702418 throw new IllegalArgumentException (table + ":{ @combine: '" + combine + "' } 中字符 '" + key + "' 对应的 " + key + ":value 不是有效条件键值对!" );
23712419 }
23722420
2373- usedKeySet .add (key );
2421+ Integer count = usedKeyCountMap .get (key );
2422+ count = count == null ? 1 : count + 1 ;
2423+ if (count > maxCombineKeyCount && maxCombineKeyCount > 0 ) {
2424+ throw new IllegalArgumentException (table + ":{ @combine: '" + combine + "' } 中字符 '" + s + "' 不合法!其中 '" + key
2425+ + "' 重复引用,次数 " + count + " 已超过最大值,必须在 0-" + maxCombineKeyCount + " 内!" );
2426+ }
2427+
2428+ usedKeyCountMap .put (key , count );
2429+
2430+
23742431 whereString += "( " + wi + " )" ;
23752432 first = false ;
23762433 }
@@ -2422,6 +2479,10 @@ else if (c == '(') {
24222479 }
24232480
24242481 depth ++;
2482+ if (depth > maxDepth && maxDepth > 0 ) {
2483+ throw new IllegalArgumentException (table + ":{ @combine: '" + combine + "' } 中字符 '" + s .substring (0 , i + 1 ) + "' 不合法!括号 (()) 嵌套层级 " + depth + " 已超过最大值,必须在 0-" + maxDepth + " 内!" );
2484+ }
2485+
24252486 whereString += c ;
24262487 lastLogic = 0 ;
24272488 first = true ;
@@ -2454,7 +2515,7 @@ else if (c == ')') {
24542515
24552516 for (Entry <String , Object > entry : set ) {
24562517 key = entry == null ? null : entry .getKey ();
2457- if (key == null || usedKeySet . contains (key )) {
2518+ if (key == null || usedKeyCountMap . containsKey (key )) {
24582519 continue ;
24592520 }
24602521
@@ -2609,6 +2670,8 @@ else if (isSideJoin) { // ^ SIDE JOIN: ! (A & B)
26092670 return result ;
26102671 }
26112672
2673+
2674+
26122675 public String getWhereString (boolean hasPrefix , RequestMethod method , Map <String , Object > where , Map <String , List <String >> combine , List <Join > joinList , boolean verifyName ) throws Exception {
26132676 Set <Entry <String , List <String >>> combineSet = combine == null ? null : combine .entrySet ();
26142677 if (combineSet == null || combineSet .isEmpty ()) {
0 commit comments