forked from heavyai/heavydb
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathCaseIR.cpp
More file actions
107 lines (104 loc) · 4.94 KB
/
CaseIR.cpp
File metadata and controls
107 lines (104 loc) · 4.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
/*
* Copyright 2017 MapD Technologies, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "Execute.h"
std::vector<llvm::Value*> Executor::codegen(const Analyzer::CaseExpr* case_expr, const CompilationOptions& co) {
const auto case_ti = case_expr->get_type_info();
llvm::Type* case_llvm_type = nullptr;
bool is_real_str = false;
if (case_ti.is_integer() || case_ti.is_time() || case_ti.is_decimal()) {
case_llvm_type = get_int_type(get_bit_width(case_ti), cgen_state_->context_);
} else if (case_ti.is_fp()) {
case_llvm_type = case_ti.get_type() == kFLOAT ? llvm::Type::getFloatTy(cgen_state_->context_)
: llvm::Type::getDoubleTy(cgen_state_->context_);
} else if (case_ti.is_string()) {
if (case_ti.get_compression() == kENCODING_DICT) {
case_llvm_type = get_int_type(8 * case_ti.get_logical_size(), cgen_state_->context_);
} else {
is_real_str = true;
case_llvm_type = get_int_type(64, cgen_state_->context_);
}
} else if (case_ti.is_boolean()) {
case_llvm_type = get_int_type(8 * case_ti.get_logical_size(), cgen_state_->context_);
}
CHECK(case_llvm_type);
const auto& else_ti = case_expr->get_else_expr()->get_type_info();
CHECK_EQ(else_ti.get_type(), case_ti.get_type());
llvm::Value* case_val = codegenCase(case_expr, case_llvm_type, is_real_str, co);
std::vector<llvm::Value*> ret_vals{case_val};
if (is_real_str) {
ret_vals.push_back(cgen_state_->emitCall("extract_str_ptr", {case_val}));
ret_vals.push_back(cgen_state_->emitCall("extract_str_len", {case_val}));
}
return ret_vals;
}
llvm::Value* Executor::codegenCase(const Analyzer::CaseExpr* case_expr,
llvm::Type* case_llvm_type,
const bool is_real_str,
const CompilationOptions& co) {
// Here the linear control flow will diverge and expressions cached during the
// code branch code generation (currently just column decoding) are not going
// to be available once we're done generating the case. Take a snapshot of
// the cache with FetchCacheAnchor and restore it once we're done with CASE.
FetchCacheAnchor anchor(cgen_state_.get());
const auto& expr_pair_list = case_expr->get_expr_pair_list();
std::vector<llvm::Value*> then_lvs;
std::vector<llvm::BasicBlock*> then_bbs;
const auto end_bb = llvm::BasicBlock::Create(cgen_state_->context_, "end_case", cgen_state_->row_func_);
for (const auto& expr_pair : expr_pair_list) {
FetchCacheAnchor branch_anchor(cgen_state_.get());
const auto when_lv = toBool(codegen(expr_pair.first.get(), true, co).front());
const auto cmp_bb = cgen_state_->ir_builder_.GetInsertBlock();
const auto then_bb = llvm::BasicBlock::Create(cgen_state_->context_, "then_case", cgen_state_->row_func_);
cgen_state_->ir_builder_.SetInsertPoint(then_bb);
auto then_bb_lvs = codegen(expr_pair.second.get(), true, co);
if (is_real_str) {
if (then_bb_lvs.size() == 3) {
then_lvs.push_back(cgen_state_->emitCall("string_pack", {then_bb_lvs[1], then_bb_lvs[2]}));
} else {
then_lvs.push_back(then_bb_lvs.front());
}
} else {
CHECK_EQ(size_t(1), then_bb_lvs.size());
then_lvs.push_back(then_bb_lvs.front());
}
then_bbs.push_back(cgen_state_->ir_builder_.GetInsertBlock());
cgen_state_->ir_builder_.CreateBr(end_bb);
const auto when_bb = llvm::BasicBlock::Create(cgen_state_->context_, "when_case", cgen_state_->row_func_);
cgen_state_->ir_builder_.SetInsertPoint(cmp_bb);
cgen_state_->ir_builder_.CreateCondBr(when_lv, then_bb, when_bb);
cgen_state_->ir_builder_.SetInsertPoint(when_bb);
}
const auto else_expr = case_expr->get_else_expr();
CHECK(else_expr);
auto else_lvs = codegen(else_expr, true, co);
llvm::Value* else_lv{nullptr};
if (else_lvs.size() == 3) {
else_lv = cgen_state_->emitCall("string_pack", {else_lvs[1], else_lvs[2]});
} else {
else_lv = else_lvs.front();
}
CHECK(else_lv);
auto else_bb = cgen_state_->ir_builder_.GetInsertBlock();
cgen_state_->ir_builder_.CreateBr(end_bb);
cgen_state_->ir_builder_.SetInsertPoint(end_bb);
auto then_phi = cgen_state_->ir_builder_.CreatePHI(case_llvm_type, expr_pair_list.size() + 1);
CHECK_EQ(then_bbs.size(), then_lvs.size());
for (size_t i = 0; i < then_bbs.size(); ++i) {
then_phi->addIncoming(then_lvs[i], then_bbs[i]);
}
then_phi->addIncoming(else_lv, else_bb);
return then_phi;
}