forked from arrayfire/arrayfire
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathkernel_cache.cpp
More file actions
122 lines (101 loc) · 3.83 KB
/
kernel_cache.cpp
File metadata and controls
122 lines (101 loc) · 3.83 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
/*******************************************************
* Copyright (c) 2020, ArrayFire
* All rights reserved.
*
* This file is distributed under 3-clause BSD license.
* The complete license agreement can be obtained at:
* http://arrayfire.com/licenses/BSD-3-Clause
********************************************************/
#if !defined(AF_CPU)
#include <common/kernel_cache.hpp>
#include <common/compile_module.hpp>
#include <common/util.hpp>
#include <device_manager.hpp>
#include <platform.hpp>
#include <algorithm>
#include <shared_mutex>
#include <string>
#include <unordered_map>
#include <vector>
using detail::Kernel;
using detail::Module;
using std::back_inserter;
using std::shared_timed_mutex;
using std::string;
using std::transform;
using std::unordered_map;
using std::vector;
namespace common {
using ModuleMap = unordered_map<string, Module>;
shared_timed_mutex& getCacheMutex(const int device) {
static shared_timed_mutex mutexes[detail::DeviceManager::MAX_DEVICES];
return mutexes[device];
}
ModuleMap& getCache(const int device) {
static ModuleMap* caches =
new ModuleMap[detail::DeviceManager::MAX_DEVICES];
return caches[device];
}
Module findModule(const int device, const string& key) {
std::shared_lock<shared_timed_mutex> readLock(getCacheMutex(device));
auto& cache = getCache(device);
auto iter = cache.find(key);
if (iter != cache.end()) { return iter->second; }
return Module{};
}
Kernel getKernel(const string& kernelName, const vector<string>& sources,
const vector<TemplateArg>& targs,
const vector<string>& options, const bool sourceIsJIT) {
vector<string> args;
args.reserve(targs.size());
transform(targs.begin(), targs.end(), back_inserter(args),
[](const TemplateArg& arg) -> string { return arg._tparam; });
string tInstance = kernelName;
if (args.size() > 0) {
tInstance = kernelName + "<" + args[0];
for (size_t i = 1; i < args.size(); ++i) {
tInstance += ("," + args[i]);
}
tInstance += ">";
}
const bool notJIT = !sourceIsJIT;
vector<string> hashingVals;
hashingVals.reserve(1 + (notJIT * (sources.size() + options.size())));
hashingVals.push_back(tInstance);
if (notJIT) {
// This code path is only used for regular kernel compilation
// since, jit funcName(kernelName) is unique to use it's hash
// for caching the relevant compiled/linked module
hashingVals.insert(hashingVals.end(), sources.begin(), sources.end());
hashingVals.insert(hashingVals.end(), options.begin(), options.end());
}
const string moduleKey = std::to_string(deterministicHash(hashingVals));
const int device = detail::getActiveDeviceId();
Module currModule = findModule(device, moduleKey);
if (!currModule) {
currModule = loadModuleFromDisk(device, moduleKey, sourceIsJIT);
if (!currModule) {
currModule = compileModule(moduleKey, sources, options, {tInstance},
sourceIsJIT);
}
std::unique_lock<shared_timed_mutex> writeLock(getCacheMutex(device));
auto& cache = getCache(device);
auto iter = cache.find(moduleKey);
if (iter == cache.end()) {
// If not found, this thread is the first one to compile this
// kernel. Keep the generated module.
Module mod = currModule;
getCache(device).emplace(moduleKey, mod);
} else {
currModule.unload(); // dump the current threads extra compilation
currModule = iter->second;
}
}
#if defined(AF_CUDA)
return getKernel(currModule, tInstance, sourceIsJIT);
#elif defined(AF_OPENCL)
return getKernel(currModule, kernelName, sourceIsJIT);
#endif
}
} // namespace common
#endif