Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions include/af/memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ typedef af_err (*af_memory_manager_shutdown_fn)(af_memory_manager handle);

\param[in] handle a pointer to the active \ref af_memory_manager handle
\param[out] ptr pointer to the allocated buffer
\param[in] bytes number of bytes to allocate
\param[in] user_lock a truthy value corresponding to whether or not the
memory should have a user lock associated with it
\param[in] ndims the number of dimensions associated with the allocated
Expand Down Expand Up @@ -118,9 +117,9 @@ typedef af_err (*af_memory_manager_signal_memory_cleanup_fn)(
enforced and can include any information that could be useful to the user.
This function is only called by \ref af_print_mem_info.

\param[in] handle a pointer to the active \ref af_memory_manager handle
\param[out] a buffer to which a message will be populated
\param[in] the device id for which to print memory
\param[in] handle a pointer to the active \ref af_memory_manager handle
\param[out] buffer a buffer to which a message will be populated
\param[in] id the device id for which to print memory
\returns AF_SUCCESS

\ingroup memory_manager_api
Expand Down Expand Up @@ -174,8 +173,8 @@ typedef af_err (*af_memory_manager_is_user_locked_fn)(af_memory_manager handle,

\ingroup memory_manager_api
*/
typedef af_err (*af_memory_manager_get_memory_pressure_fn)(af_memory_manager,
float* pressure);
typedef af_err (*af_memory_manager_get_memory_pressure_fn)(
af_memory_manager handle, float* pressure);

/**
\brief Called to query if additions to the JIT tree would exert too much
Expand Down Expand Up @@ -225,8 +224,8 @@ typedef void (*af_memory_manager_add_memory_management_fn)(

\ingroup memory_manager_api
*/
typedef void (*af_memory_manager_remove_memory_management_fn)(af_memory_manager,
int id);
typedef void (*af_memory_manager_remove_memory_management_fn)(
af_memory_manager handle, int id);

/**
\brief Creates an \ref af_memory_manager handle
Expand Down
4 changes: 0 additions & 4 deletions include/af/signal.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,6 @@ AFAPI void fftInPlace(array& in, const double norm_factor = 1);

\param[inout] in is the input array on entry and the output of 2D forward fourier transform on exit
\param[in] norm_factor is the normalization factor with which the input is scaled after the transformation is applied
\return the transformed array

\note The input \p in must be complex

Expand All @@ -199,7 +198,6 @@ AFAPI void fft2InPlace(array& in, const double norm_factor = 1);

\param[inout] in is the input array on entry and the output of 3D forward fourier transform on exit
\param[in] norm_factor is the normalization factor with which the input is scaled after the transformation is applied
\return the transformed array

\note The input \p in must be complex

Expand Down Expand Up @@ -351,7 +349,6 @@ AFAPI void ifftInPlace(array& in, const double norm_factor = 1);

\param[inout] in is the input array on entry and the output of 2D inverse fourier transform on exit
\param[in] norm_factor is the normalization factor with which the input is scaled after the transformation is applied
\return the transformed array

\note The input \p in must be complex

Expand All @@ -366,7 +363,6 @@ AFAPI void ifft2InPlace(array& in, const double norm_factor = 1);

\param[inout] in is the input array on entry and the output of 3D inverse fourier transform on exit
\param[in] norm_factor is the normalization factor with which the input is scaled after the transformation is applied
\return the transformed array

\note The input \p in must be complex

Expand Down
9 changes: 5 additions & 4 deletions src/backend/common/KernelInterface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#pragma once

#include <cstddef>
#include <utility>
#include <string>

namespace common {

Expand All @@ -21,10 +21,11 @@ class KernelInterface {
private:
ModuleType mModuleHandle;
KernelType mKernelHandle;
std::string mName;

public:
KernelInterface(ModuleType mod, KernelType ker)
: mModuleHandle(mod), mKernelHandle(ker) {}
KernelInterface(std::string name, ModuleType mod, KernelType ker)
: mModuleHandle(mod), mKernelHandle(ker), mName(name) {}

/// \brief Set kernel
///
Expand Down Expand Up @@ -95,7 +96,7 @@ class KernelInterface {
template<typename EnqueueArgsType, typename... Args>
void operator()(const EnqueueArgsType& qArgs, Args... args) {
EnqueuerType launch;
launch(mKernelHandle, qArgs, std::forward<Args>(args)...);
launch(mName, mKernelHandle, qArgs, std::forward<Args>(args)...);
}
};

Expand Down
1 change: 1 addition & 0 deletions src/backend/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ if(UNIX)
if(CUDA_VERSION VERSION_GREATER 10.0)
target_link_libraries(af_cuda_static_cuda_library
PRIVATE
spdlog

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is added in common interface target, dont we use common interface target with this static lib ?

. also it will fail when we looj for system/vcpkg provided spdlog target spdlog::spdlog_header_only -- perhaps I can handle this in my PR #3139

${CUDA_cublasLt_static_LIBRARY})
endif()
if(CUDA_VERSION VERSION_GREATER 9.5)
Expand Down
22 changes: 19 additions & 3 deletions src/backend/cuda/Kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,35 @@
#pragma once

#include <common/KernelInterface.hpp>
#include <common/Logger.hpp>

#include <EnqueueArgs.hpp>
#include <backend.hpp>
#include <cu_check_macro.hpp>
#include <cstdlib>
#include <string>

namespace cuda {

struct Enqueuer {
static auto getLogger() {
static auto logger = common::loggerFactory("kernel");
return logger.get();
};

template<typename... Args>
void operator()(void* ker, const EnqueueArgs& qArgs, Args... args) {
void operator()(std::string name, void* ker, const EnqueueArgs& qArgs,
Args... args) {
void* params[] = {reinterpret_cast<void*>(&args)...};
for (auto& event : qArgs.mEvents) {
CU_CHECK(cuStreamWaitEvent(qArgs.mStream, event, 0));
}
AF_TRACE(
"Launching {}: Blocks: [{}, {}, {}] Threads: [{}, {}, {}] Shared "
"Memory: {}",
name, qArgs.mBlocks.x, qArgs.mBlocks.y, qArgs.mBlocks.z,
qArgs.mThreads.x, qArgs.mThreads.y, qArgs.mThreads.z,
qArgs.mSharedMemSize);
CU_CHECK(cuLaunchKernel(static_cast<CUfunction>(ker), qArgs.mBlocks.x,
qArgs.mBlocks.y, qArgs.mBlocks.z,
qArgs.mThreads.x, qArgs.mThreads.y,
Expand All @@ -42,8 +57,9 @@ class Kernel
using BaseClass =
common::KernelInterface<ModuleType, KernelType, Enqueuer, DevPtrType>;

Kernel() : BaseClass(nullptr, nullptr) {}
Kernel(ModuleType mod, KernelType ker) : BaseClass(mod, ker) {}
Kernel() : BaseClass("", nullptr, nullptr) {}
Kernel(std::string name, ModuleType mod, KernelType ker)
: BaseClass(name, mod, ker) {}

DevPtrType getDevPtr(const char* name) final;

Expand Down
8 changes: 5 additions & 3 deletions src/backend/cuda/compile_module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,17 @@
#include <algorithm>
#include <array>
#include <chrono>
#include <cmath>
#include <cstdlib>
#include <cstring>
#include <fstream>
#include <iterator>
#include <map>
#include <memory>
#include <numeric>
#include <string>
#include <type_traits>
#include <utility>
#include <vector>

using namespace cuda;

Expand All @@ -69,7 +72,6 @@ using std::end;
using std::extent;
using std::find_if;
using std::make_pair;
using std::map;
using std::ofstream;
using std::pair;
using std::string;
Expand Down Expand Up @@ -479,7 +481,7 @@ Kernel getKernel(const Module &mod, const string &nameExpr,
std::string name = (sourceWasJIT ? nameExpr : mod.mangledName(nameExpr));
CUfunction kernel = nullptr;
CU_CHECK(cuModuleGetFunction(&kernel, mod.get(), name.c_str()));
return {mod.get(), kernel};
return {nameExpr, mod.get(), kernel};
}

} // namespace common
35 changes: 33 additions & 2 deletions src/backend/cuda/debug_cuda.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,42 @@
********************************************************/

#pragma once
#include <common/Logger.hpp>
#include <err_cuda.hpp>
#include <platform.hpp>
#include <string>

#define CUDA_LAUNCH_SMEM(fn, blks, thrds, smem_size, ...) \
fn<<<blks, thrds, smem_size, cuda::getActiveStream()>>>(__VA_ARGS__)
namespace cuda {
namespace kernel_logger {

inline auto getLogger() {
static auto logger = common::loggerFactory("kernel");
return logger;
}
} // namespace kernel_logger
} // namespace cuda

template<>
struct fmt::formatter<dim3> : fmt::formatter<std::string> {
// parse is inherited from formatter<string_view>.
template<typename FormatContext>
auto format(dim3 c, FormatContext& ctx) {
std::string name = fmt::format("{} {} {}", c.x, c.y, c.z);
return formatter<std::string>::format(name, ctx);
}
};

#define CUDA_LAUNCH_SMEM(fn, blks, thrds, smem_size, ...) \
do { \
{ \
using namespace cuda::kernel_logger; \
AF_TRACE( \
"Launching {}: Blocks: [{}] Threads: [{}] " \
"Shared Memory: {}", \
#fn, blks, thrds, smem_size); \
} \
fn<<<blks, thrds, smem_size, cuda::getActiveStream()>>>(__VA_ARGS__); \
} while (false)

#define CUDA_LAUNCH(fn, blks, thrds, ...) \
CUDA_LAUNCH_SMEM(fn, blks, thrds, 0, __VA_ARGS__)
Expand Down
9 changes: 8 additions & 1 deletion src/backend/cuda/jit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
#include <platform.hpp>
Comment thread
9prady9 marked this conversation as resolved.
#include <af/dim4.hpp>

#include <cstdio>
#include <cstdlib>
#include <sstream>
#include <stdexcept>
#include <string>
#include <thread>
Expand Down Expand Up @@ -299,6 +300,12 @@ void evalNodes(vector<Param<T>> &outputs, const vector<Node *> &output_nodes) {
args.push_back(static_cast<void *>(&blocks_x_total));
args.push_back(static_cast<void *>(&num_odims));

{
using namespace cuda::kernel_logger;
AF_TRACE("Launching : Blocks: [{}] Threads: [{}] ",
dim3(blocks_x, blocks_y, blocks_z),
dim3(threads_x, threads_y));
}
CU_CHECK(cuLaunchKernel(ker, blocks_x, blocks_y, blocks_z, threads_x,
threads_y, 1, 0, getActiveStream(), args.data(),
NULL));
Expand Down
19 changes: 15 additions & 4 deletions src/backend/opencl/Kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,27 @@
#pragma once

#include <common/KernelInterface.hpp>
#include <common/Logger.hpp>

#include <backend.hpp>
#include <cl2hpp.hpp>
#include <string>

namespace opencl {
namespace kernel_logger {
inline auto getLogger() -> spdlog::logger* {
static auto logger = common::loggerFactory("kernel");
return logger.get();
}
} // namespace kernel_logger

struct Enqueuer {
template<typename... Args>
void operator()(cl::Kernel ker, const cl::EnqueueArgs& qArgs,
Args&&... args) {
void operator()(std::string name, cl::Kernel ker,
const cl::EnqueueArgs& qArgs, Args&&... args) {
auto launchOp = cl::KernelFunctor<Args...>(ker);
using namespace kernel_logger;
AF_TRACE("Launching {}", name);
launchOp(qArgs, std::forward<Args>(args)...);
}
};
Expand All @@ -35,8 +45,9 @@ class Kernel
using BaseClass =
common::KernelInterface<ModuleType, KernelType, Enqueuer, DevPtrType>;

Kernel() : BaseClass(nullptr, cl::Kernel{nullptr, false}) {}
Kernel(ModuleType mod, KernelType ker) : BaseClass(mod, ker) {}
Kernel() : BaseClass("", nullptr, cl::Kernel{nullptr, false}) {}
Kernel(std::string name, ModuleType mod, KernelType ker)
: BaseClass(name, mod, ker) {}

// clang-format off
[[deprecated("OpenCL backend doesn't need Kernel::getDevPtr method")]]
Expand Down
2 changes: 1 addition & 1 deletion src/backend/opencl/compile_module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ Module loadModuleFromDisk(const int device, const string &moduleKey,
Kernel getKernel(const Module &mod, const string &nameExpr,
const bool sourceWasJIT) {
UNUSED(sourceWasJIT);
return {&mod.get(), cl::Kernel(mod.get(), nameExpr.c_str())};
return {nameExpr, &mod.get(), cl::Kernel(mod.get(), nameExpr.c_str())};
}

} // namespace common
1 change: 1 addition & 0 deletions src/backend/opencl/kernel/scan_by_key/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ foreach(SBK_BINARY_OP ${SBK_BINARY_OPS})
../common
../../../include
${CMAKE_CURRENT_BINARY_DIR}
$<TARGET_PROPERTY:spdlog,INTERFACE_INCLUDE_DIRECTORIES>
$<TARGET_PROPERTY:OpenCL::OpenCL,INTERFACE_INCLUDE_DIRECTORIES>
$<TARGET_PROPERTY:OpenCL::cl2hpp,INTERFACE_INCLUDE_DIRECTORIES>
$<TARGET_PROPERTY:Boost::boost,INTERFACE_INCLUDE_DIRECTORIES>
Expand Down