-
Notifications
You must be signed in to change notification settings - Fork 548
Expand file tree
/
Copy pathKernel.hpp
More file actions
76 lines (63 loc) · 2.45 KB
/
Kernel.hpp
File metadata and controls
76 lines (63 loc) · 2.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
/*******************************************************
* Copyright (c) 2020, ArrayFire
* All rights reserved.
*
* This file is distributed under 3-clause BSD license.
* The complete license agreement can be obtained at:
* http://arrayfire.com/licenses/BSD-3-Clause
********************************************************/
#pragma once
#include <common/KernelInterface.hpp>
#include <common/Logger.hpp>
#include <EnqueueArgs.hpp>
#include <backend.hpp>
#include <err_cuda.hpp>
#include <cstdlib>
#include <string>
namespace arrayfire {
namespace cuda {
struct Enqueuer {
static auto getLogger() {
static auto logger = common::loggerFactory("kernel");
return logger.get();
};
template<typename... Args>
void operator()(std::string name, void* ker, const EnqueueArgs& qArgs,
Args... args) {
void* params[] = {static_cast<void*>(&args)...};
for (auto& event : qArgs.mEvents) {
CU_CHECK(cuStreamWaitEvent(qArgs.mStream, event, 0));
}
AF_TRACE(
"Launching {}: Blocks: [{}, {}, {}] Threads: [{}, {}, {}] Shared "
"Memory: {}",
name, qArgs.mBlocks.x, qArgs.mBlocks.y, qArgs.mBlocks.z,
qArgs.mThreads.x, qArgs.mThreads.y, qArgs.mThreads.z,
qArgs.mSharedMemSize);
CU_CHECK(cuLaunchKernel(static_cast<CUfunction>(ker), qArgs.mBlocks.x,
qArgs.mBlocks.y, qArgs.mBlocks.z,
qArgs.mThreads.x, qArgs.mThreads.y,
qArgs.mThreads.z, qArgs.mSharedMemSize,
qArgs.mStream, params, NULL));
}
};
class Kernel
: public common::KernelInterface<CUmodule, CUfunction, Enqueuer,
CUdeviceptr> {
public:
using ModuleType = CUmodule;
using KernelType = CUfunction;
using DevPtrType = CUdeviceptr;
using BaseClass =
common::KernelInterface<ModuleType, KernelType, Enqueuer, DevPtrType>;
Kernel() : BaseClass("", nullptr, nullptr) {}
Kernel(std::string name, ModuleType mod, KernelType ker)
: BaseClass(name, mod, ker) {}
DevPtrType getDevPtr(const char* name) final;
void copyToReadOnly(DevPtrType dst, DevPtrType src, size_t bytes) final;
void setFlag(DevPtrType dst, int* scalarValPtr,
const bool syncCopy = false) final;
int getFlag(DevPtrType src) final;
};
} // namespace cuda
} // namespace arrayfire