Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 26 additions & 22 deletions src/backend/cuda/ThrustArrayFirePolicy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,11 @@
#include <backend.hpp>
#include <memory.hpp>
#include <platform.hpp>
#include <thrust/execution_policy.h>
#include <thrust/system/cuda/execution_policy.h>

namespace cuda {
struct ThrustArrayFirePolicy
: thrust::device_execution_policy<ThrustArrayFirePolicy> {};

namespace {
__DH__
inline cudaStream_t get_stream(ThrustArrayFirePolicy) {
#if defined(__CUDA_ARCH__)
return 0;
#else
return getActiveStream();
#endif
}

__DH__
inline cudaError_t synchronize_stream(ThrustArrayFirePolicy) {
#if defined(__CUDA_ARCH__)
return cudaDeviceSynchronize();
#else
return cudaStreamSynchronize(getActiveStream());
#endif
}
} // namespace
: thrust::cuda::execution_policy<ThrustArrayFirePolicy> {};

template<typename T>
thrust::pair<thrust::pointer<T, ThrustArrayFirePolicy>, std::ptrdiff_t>
Expand All @@ -53,3 +33,27 @@ inline void return_temporary_buffer(ThrustArrayFirePolicy, Pointer p) {
}

} // namespace cuda

namespace thrust {
namespace cuda_cub {
template<>
__DH__ inline cudaStream_t get_stream<::cuda::ThrustArrayFirePolicy>(
execution_policy<::cuda::ThrustArrayFirePolicy> &) {
#if defined(__CUDA_ARCH__)
return 0;
#else
return ::cuda::getActiveStream();
#endif
}

__DH__
inline cudaError_t synchronize_stream(const ::cuda::ThrustArrayFirePolicy &) {
#if defined(__CUDA_ARCH__)
return cudaDeviceSynchronize();
#else
return cudaStreamSynchronize(::cuda::getActiveStream());
#endif
}

} // namespace cuda_cub
} // namespace thrust