forked from apache/mxnet
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathndarray_function.cu
More file actions
48 lines (45 loc) · 1.7 KB
/
ndarray_function.cu
File metadata and controls
48 lines (45 loc) · 1.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
// this will be invoked by nvcc and compile GPU version
#include <dmlc/logging.h>
#include "./ndarray_function.h"
#include "./ndarray_function-inl.h"
namespace mxnet {
namespace ndarray {
template<>
void Copy<cpu, gpu>(const TBlob &from, TBlob *to,
Context from_ctx, Context to_ctx,
RunContext ctx) {
mshadow::Copy(to->FlatTo2D<gpu, real_t>(),
from.FlatTo2D<cpu, real_t>(),
static_cast<mshadow::Stream<gpu>*>(ctx.stream));
}
template<>
void Copy<gpu, cpu>(const TBlob &from, TBlob *to,
Context from_ctx, Context to_ctx,
RunContext ctx) {
mshadow::Copy(to->FlatTo2D<cpu, real_t>(),
from.FlatTo2D<gpu, real_t>(),
static_cast<mshadow::Stream<gpu>*>(ctx.stream));
}
template<>
void Copy<gpu, gpu>(const TBlob &from, TBlob *to,
Context from_ctx, Context to_ctx,
RunContext ctx) {
if (from_ctx.dev_id == to_ctx.dev_id) {
mshadow::Copy(to->FlatTo2D<gpu, real_t>(),
from.FlatTo2D<gpu, real_t>(),
static_cast<mshadow::Stream<gpu>*>(ctx.stream));
} else {
CHECK(from.CheckContiguous() && to->CheckContiguous())
<< "copy across only support continugous memory";
mshadow::Stream<gpu> *s = static_cast<mshadow::Stream<gpu>*>(ctx.stream);
CHECK(s != NULL) << "need stream in GPU context";
cudaMemcpyPeerAsync(to->dptr_,
to_ctx.dev_id,
from.dptr_,
from_ctx.dev_id,
from.shape_.Size() * sizeof(real_t),
s->stream_);
}
}
} // namespace ndarray
} // namespace mxnet