forked from oceanbase/oceanbase
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathob_expr_case.cpp
More file actions
356 lines (336 loc) · 13.2 KB
/
Copy pathob_expr_case.cpp
File metadata and controls
356 lines (336 loc) · 13.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
/**
* Copyright (c) 2021 OceanBase
* OceanBase CE is licensed under Mulan PubL v2.
* You can use this software according to the terms and conditions of the Mulan PubL v2.
* You may obtain a copy of Mulan PubL v2 at:
* http://license.coscl.org.cn/MulanPubL-2.0
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PubL v2 for more details.
*/
#define USING_LOG_PREFIX SQL_ENG
#include "sql/engine/expr/ob_expr_case.h"
#include "sql/engine/expr/ob_expr_operator.h"
//#include "sql/engine/expr/ob_expr_promotion_util.h"
#include "sql/session/ob_sql_session_info.h"
#include "sql/engine/ob_exec_context.h"
namespace oceanbase
{
using namespace common;
namespace sql
{
typedef int (*CheckIsMatchFunc)(const ObDatum *when_datum, bool &match_when);
ObExprCase::ObExprCase(ObIAllocator &alloc)
: ObExprOperator(alloc, T_OP_CASE, N_CASE, MORE_THAN_ONE, VALID_FOR_GENERATED_COL, NOT_ROW_DIMENSION)
{
disable_operand_auto_cast();
param_lazy_eval_ = true;
}
ObExprCase::~ObExprCase()
{
}
/*
* NOTE:calc_result_typeN中param_num只涵盖了then/else表达式,未涵盖when表达式
*/
int ObExprCase::calc_result_typeN(ObExprResType &type,
ObExprResType *types_stack,
int64_t param_num,
ObExprTypeCtx &type_ctx) const
{
// case
// when 10 then expr1
// when 11 then expr2
// [else expr3]
int ret = OB_SUCCESS;
if (OB_ISNULL(types_stack)) {
LOG_WARN("null types");
ret = OB_INVALID_ARGUMENT;
} else if (OB_UNLIKELY(param_num < 3 || param_num % 2 == 0)) {
ret = OB_INVALID_ARGUMENT;
LOG_WARN("param num is not correct", K(param_num));
} else { //param_num >=3 and param_num is odd
/* in order to be compatible with mysql
* both in ob_expr_case.cpp and ob_expr_arg_case.cpp
* types_stack includes the condition exprs.
* In expr_case, there is no arg param expr compared with expr_arg_case
*/
const int64_t cond_type_count = param_num / 2;
const int64_t val_type_count = param_num - cond_type_count;
const ObLengthSemantics default_length_semantics = (OB_NOT_NULL(type_ctx.get_session()) ? type_ctx.get_session()->get_actual_nls_length_semantics() : LS_BYTE);
if (OB_FAIL(aggregate_result_type_for_case(
type,
types_stack + cond_type_count,
val_type_count,
type_ctx.get_coll_type(),
lib::is_oracle_mode(),
default_length_semantics,
type_ctx.get_session(),
true, false,
is_called_in_sql_))) {
LOG_WARN("failed to aggregate result type");
} else {
ObExprOperator::calc_result_flagN(type, types_stack + cond_type_count, val_type_count);
}
if (OB_SUCC(ret)) {
for (int64_t i = 0; i < cond_type_count; ++i) {
const ObObjType cond_type = types_stack[i].get_type();
const ObObjTypeClass cond_tc = ob_obj_type_class(cond_type);
if (ObIntTC == cond_tc || ObUIntTC == cond_tc
|| ObNumberTC == cond_tc || ObDecimalIntTC == cond_tc
|| ObNullTC == cond_tc) {
types_stack[i].set_calc_type(cond_type);
types_stack[i].set_calc_collation(types_stack[i]);
} else {
types_stack[i].set_calc_type(ObDoubleType);
types_stack[i].set_calc_collation_type(CS_TYPE_BINARY);
types_stack[i].set_calc_collation_level(CS_LEVEL_NUMERIC);
}
}
bool is_expr_integer_type = (ob_is_int_tc(type.get_type()) ||
ob_is_uint_tc(type.get_type()));
for (int64_t i = cond_type_count; OB_SUCC(ret) && i < param_num; ++i) {
bool is_arg_integer_type = (ob_is_int_tc(types_stack[i].get_type()) ||
ob_is_uint_tc(types_stack[i].get_type()));
if ((is_arg_integer_type && is_expr_integer_type) ||
ObNullType == types_stack[i].get_type()) {
// see ObExprCoalesce::calc_result_typeN
types_stack[i].set_calc_meta(types_stack[i].get_obj_meta());
} else {
types_stack[i].set_calc_meta(type.get_obj_meta());
if (ObDecimalIntType == type.get_obj_meta().get_type()) {
types_stack[i].set_calc_accuracy(type.get_accuracy());
}
}
}
}
}
return ret;
}
int ObExprCase::cg_expr(ObExprCGCtx &op_cg_ctx,
const ObRawExpr &raw_expr,
ObExpr &rt_expr) const
{
int ret = OB_SUCCESS;
UNUSED(op_cg_ctx);
const ObCaseOpRawExpr &case_expr = dynamic_cast<const ObCaseOpRawExpr&>(raw_expr);
// 新引擎下case表达式when expr一定要返回int/null,即when expr一定是布尔语义的表达式
for (int64_t i = 0; OB_SUCC(ret) && i < case_expr.get_when_expr_size(); ++i) {
const ObRawExpr *when_expr = case_expr.get_when_param_expr(i);
const ObObjType &when_expr_res_type = when_expr->get_result_type().get_type();
if (OB_UNLIKELY(ObNullType != when_expr_res_type &&
!ob_is_integer_type(when_expr_res_type))) {
ret = OB_ERR_UNEXPECTED;
LOG_WARN("when expr must return integer", K(ret), K(when_expr_res_type));
}
}
if (OB_SUCC(ret)) {
rt_expr.eval_func_ = calc_case_expr;
rt_expr.eval_batch_func_ = eval_case_batch;
}
return ret;
}
static int check_is_match(const ObDatum &when_datum, bool &match_when)
{
int ret = OB_SUCCESS;
if (when_datum.is_null()) {
match_when = false;
} else {
int64_t v = when_datum.get_int();
match_when = (v != 0) ? true : false;
}
return ret;
}
int ObExprCase::calc_case_expr(const ObExpr &expr, ObEvalCtx &ctx, ObDatum &res_datum)
{
int ret = OB_SUCCESS;
const bool has_else = (expr.arg_cnt_ % 2 != 0);
int64_t loop = (has_else) ? expr.arg_cnt_ - 1 : expr.arg_cnt_;
bool match_when = false;
ObDatum *when_datum = NULL;
ObDatum *then_datum = NULL;
bool has_result = false;
int64_t expr_idx = 0;
for ( ; OB_SUCC(ret) && !match_when && expr_idx < loop; expr_idx += 2) {
if (OB_FAIL(expr.args_[expr_idx]->eval(ctx, when_datum))) {
LOG_WARN("eval when expr failed", K(ret), K(expr_idx));
} else if (OB_FAIL(check_is_match(*when_datum, match_when))) {
LOG_WARN("check is when expr match failed", K(ret), K(expr_idx));
} else if (match_when) {
if (OB_FAIL(expr.args_[expr_idx+1]->eval(ctx, then_datum))) {
LOG_WARN("eval then expr failed", K(ret), K(expr_idx+1));
} else {
has_result = true;
}
}
}
if (OB_SUCC(ret)) {
if (!match_when) {
if (has_else) {
if (OB_FAIL(expr.args_[expr.arg_cnt_-1]->eval(ctx, then_datum))) {
LOG_WARN("eval else expr failed for case when", K(ret));
} else {
has_result = true;
}
}
}
}
if (OB_SUCC(ret)) {
if (!has_result) {
res_datum.set_null();
} else {
if (OB_ISNULL(then_datum)) {
ret = OB_ERR_UNEXPECTED;
LOG_WARN("then_datum is NULL", K(ret));
} else {
res_datum.set_datum(*then_datum);
}
}
}
return ret;
}
// Oracle模式下,在deduce type阶段需要将when/then expr类型要一致
int ObExprCase::is_same_kind_type_for_case(const ObIArray<ObExprResType> &type_arr)
{
int ret = OB_SUCCESS;
if (OB_SUCC(ret)) {
bool match = false;
int64_t first_not_null_idx = OB_INVALID_ID;
for (int64_t i = 0; OB_SUCC(ret) &&
OB_INVALID_ID == first_not_null_idx && i < type_arr.count(); ++i) {
if (!ob_is_null(type_arr.at(i).get_type())) {
first_not_null_idx = i;
}
}
first_not_null_idx = OB_INVALID_ID == first_not_null_idx ? 0 : first_not_null_idx;
const ObExprResType &res_type = type_arr.at(first_not_null_idx);
for (int64_t i = first_not_null_idx+1; OB_SUCC(ret) && i < type_arr.count(); ++i) {
if (OB_FAIL(ObExprOperator::is_same_kind_type_for_case(res_type,
type_arr.at(i), match))) {
LOG_WARN("fail to judge same type", K(i), K(res_type), K(type_arr.at(i)), K(ret));
} else if (!match) {
ret = OB_ERR_INVALID_TYPE_FOR_OP;
LOG_WARN("fail to judge same type", K(i), K(res_type), K(type_arr.at(i)), K(ret));
}
}
}
return ret;
}
int ObExprCase::eval_case_batch(const ObExpr &expr,
ObEvalCtx &ctx,
const ObBitVector &skip,
const int64_t batch_size)
{
int ret = OB_SUCCESS;
const bool has_else = (expr.arg_cnt_ % 2 != 0);
int64_t loop = (has_else) ? expr.arg_cnt_ - 1 : expr.arg_cnt_;
bool match_when = false;
ObDatum *results = expr.locate_batch_datums(ctx);
LOG_DEBUG("eval_case_batch", K(expr.arg_cnt_));
if (OB_ISNULL(results)) {
ret = OB_ERR_UNEXPECTED;
LOG_WARN("results frame is not init", K(ret));
} else {
ObBitVector &eval_flags = expr.get_evaluated_flags(ctx);
ObBitVector *case_when_match = nullptr;
ObBitVector *case_not_match = nullptr;
void * data = nullptr;
void * data1 = nullptr;
ObEvalCtx::TempAllocGuard alloc_guard(ctx);
if (OB_ISNULL(data = alloc_guard.get_allocator().alloc(ObBitVector::memory_size(batch_size)))) {
ret = OB_ALLOCATE_MEMORY_FAILED;
LOG_WARN("failed to alloc memory for case_when_match", K(ret));
} else if (OB_ISNULL(data1 = alloc_guard.get_allocator().alloc(ObBitVector::memory_size(batch_size)))) {
ret = OB_ALLOCATE_MEMORY_FAILED;
LOG_WARN("failed to alloc memory for case_when_match", K(ret));
} else {
case_when_match = to_bit_vector(data);
case_not_match = to_bit_vector(data1);
case_when_match->reset(batch_size);
case_not_match->reset(batch_size);
//case_when_match = eval_flags | skip
case_when_match->bit_calculate(skip, eval_flags, batch_size,
[](const uint64_t l, const uint64_t r) { return (l | r); });
case_not_match->bit_calculate(skip, eval_flags, batch_size,
[](const uint64_t l, const uint64_t r) { return (l | r); });
}
// E.G
// SELECT CASE WHEN expr1 THEN expr2 WHEN expr3 THEN expr4 ... ELSE exprN END
// the logic is
// 1. calc when branch, save result in when_datums and use match_when flag
// to mark which rows are matched in when branch and these rows should be
// calculated in then branch
// 2. calc then branch, put matching result(then_datums) into output datums
// (results)
// REPEAT 1. and 2.
// ...
// LAST.
// calc else branch and put matching result(then_datums) into output datums
for (int64_t expr_idx = 0; OB_SUCC(ret) && expr_idx < loop; expr_idx += 2) {
if (OB_FAIL(expr.args_[expr_idx]->eval_batch(ctx, *case_when_match, batch_size))) {
LOG_WARN("failed to eval batch", K(ret), K(expr_idx));
} else {
ObDatumVector when_datums = expr.args_[expr_idx]->locate_expr_datumvector(ctx);
//first eval when datums
for (int64_t j = 0; OB_SUCC(ret) && j < batch_size; ++j) {
if (case_when_match->at(j)) {
continue;
}
if (OB_FAIL(check_is_match(*when_datums.at(j), match_when))) {
LOG_WARN("check is when expr match failed", K(ret), K(j));
} else if (match_when) {
case_when_match->set(j);
} else {
// not match, mark case_not_match to stop calculating then branch
case_not_match->set(j);
}
}
//now eval then datums
if (OB_FAIL(ret)) {
} else if (OB_FAIL(expr.args_[expr_idx + 1]->eval_batch(ctx, *case_not_match, batch_size))) {
LOG_WARN("failed to eval batch", K(ret), K(expr_idx + 1));
} else {
ObDatumVector then_datums = expr.args_[expr_idx + 1]->locate_expr_datumvector(ctx);
for (int64_t j = 0; OB_SUCC(ret) && j < batch_size; ++j) {
if (case_not_match->at(j)) {
continue;
}
results[j].set_datum(*then_datums.at(j));
eval_flags.set(j);
}
// rows matched in this round should not match in next round, therefor,
// copy last round matched rows flag(case_when_match) into case_not_match
case_not_match->deep_copy(*case_when_match, batch_size);
}
}
}
//now set the result of the rest, skip rows already matched (case_when_match)
if (OB_SUCC(ret)) {
if (has_else) {
if (OB_FAIL(expr.args_[expr.arg_cnt_ - 1]->eval_batch(ctx, *case_when_match, batch_size))) {
LOG_WARN("failed to eval batch", K(ret));
} else {
ObDatumVector else_datums = expr.args_[expr.arg_cnt_ - 1]->locate_expr_datumvector(ctx);
for (int64_t j = 0; OB_SUCC(ret) && j < batch_size; ++j) {
if (case_when_match->at(j)) {
continue;
}
results[j].set_datum(*else_datums.at(j));
eval_flags.set(j);
}
}
} else {
for (int64_t j = 0; OB_SUCC(ret) && j < batch_size; ++j) {
if (case_when_match->at(j)) {
continue;
}
results[j].set_null();
eval_flags.set(j);
}
}
}
}
return ret;
}
}
}