|
1 | | -#include "private.h" |
2 | | - |
3 | 1 | #include "gpuarray/buffer.h" |
4 | 2 | #include "gpuarray/buffer_collectives.h" |
5 | 3 | #include "gpuarray/error.h" |
6 | 4 |
|
7 | | -int gpucomm_new(gpucomm** comm, gpucontext* ctx, |
8 | | - gpucommCliqueId comm_id, int ndev, int rank) { |
| 5 | +#include "private.h" |
| 6 | + |
| 7 | +int gpucomm_new(gpucomm** comm, gpucontext* ctx, gpucommCliqueId comm_id, int ndev, |
| 8 | + int rank) |
| 9 | +{ |
9 | 10 | if (ctx->comm_ops == NULL) { |
10 | 11 | *comm = NULL; |
11 | 12 | return GA_UNSUPPORTED_ERROR; |
12 | 13 | } |
13 | 14 | return ctx->comm_ops->comm_new(comm, ctx, comm_id, ndev, rank); |
14 | 15 | } |
15 | 16 |
|
16 | | -void gpucomm_free(gpucomm* comm) { |
| 17 | +void gpucomm_free(gpucomm* comm) |
| 18 | +{ |
17 | 19 | gpucontext* ctx = gpucomm_context(comm); |
18 | 20 | if (ctx->comm_ops != NULL) |
19 | 21 | ctx->comm_ops->comm_free(comm); |
20 | 22 | } |
21 | 23 |
|
22 | | -const char* gpucomm_error(gpucontext* ctx) { |
| 24 | +const char* gpucomm_error(gpucontext* ctx) |
| 25 | +{ |
23 | 26 | if (ctx->comm_ops != NULL) |
24 | | - return ctx->comm_error; |
25 | | - return "No collective ops available, API error. Is a collectives library installed?"; |
26 | | -} |
27 | | - |
28 | | -gpucontext* gpucomm_context(gpucomm* comm) { |
29 | | - return ((partial_gpucomm*) comm)->ctx; |
| 27 | + return ctx->error_msg; |
| 28 | + return "No collective ops available, API error. Is a collectives library " |
| 29 | + "installed?"; |
30 | 30 | } |
31 | 31 |
|
32 | | -int gpucomm_gen_clique_id(gpucontext* ctx, gpucommCliqueId* comm_id) { |
| 32 | +gpucontext* gpucomm_context(gpucomm* comm) { return ((partial_gpucomm*)comm)->ctx; } |
| 33 | +int gpucomm_gen_clique_id(gpucontext* ctx, gpucommCliqueId* comm_id) |
| 34 | +{ |
33 | 35 | if (ctx->comm_ops == NULL) |
34 | 36 | return GA_COMM_ERROR; |
35 | 37 | return ctx->comm_ops->generate_clique_id(ctx, comm_id); |
36 | 38 | } |
37 | 39 |
|
38 | | -int gpucomm_get_count(gpucomm* comm, int* count) { |
| 40 | +int gpucomm_get_count(gpucomm* comm, int* count) |
| 41 | +{ |
39 | 42 | gpucontext* ctx = gpucomm_context(comm); |
40 | 43 | if (ctx->comm_ops == NULL) |
41 | 44 | return GA_COMM_ERROR; |
42 | 45 | return ctx->comm_ops->get_count(comm, count); |
43 | 46 | } |
44 | 47 |
|
45 | | -int gpucomm_get_rank(gpucomm* comm, int* rank) { |
| 48 | +int gpucomm_get_rank(gpucomm* comm, int* rank) |
| 49 | +{ |
46 | 50 | gpucontext* ctx = gpucomm_context(comm); |
47 | 51 | if (ctx->comm_ops == NULL) |
48 | 52 | return GA_COMM_ERROR; |
49 | 53 | return ctx->comm_ops->get_rank(comm, rank); |
50 | 54 | } |
51 | 55 |
|
52 | | -int gpucomm_reduce(gpudata* src, size_t offsrc, |
53 | | - gpudata* dest, size_t offdest, |
54 | | - int count, int typecode, int opcode, |
55 | | - int root, gpucomm* comm) { |
| 56 | +int gpucomm_reduce(gpudata* src, size_t offsrc, gpudata* dest, size_t offdest, |
| 57 | + int count, int typecode, int opcode, int root, gpucomm* comm) |
| 58 | +{ |
56 | 59 | gpucontext* ctx = gpucomm_context(comm); |
57 | 60 | if (ctx->comm_ops == NULL) |
58 | 61 | return GA_COMM_ERROR; |
59 | | - return ctx->comm_ops->reduce(src, offsrc, dest, offdest, |
60 | | - count, typecode, opcode, root, comm); |
| 62 | + return ctx->comm_ops->reduce(src, offsrc, dest, offdest, count, typecode, opcode, |
| 63 | + root, comm); |
61 | 64 | } |
62 | 65 |
|
63 | | -int gpucomm_all_reduce(gpudata* src, size_t offsrc, |
64 | | - gpudata* dest, size_t offdest, |
65 | | - int count, int typecode, int opcode, |
66 | | - gpucomm* comm) { |
| 66 | +int gpucomm_all_reduce(gpudata* src, size_t offsrc, gpudata* dest, size_t offdest, |
| 67 | + int count, int typecode, int opcode, gpucomm* comm) |
| 68 | +{ |
67 | 69 | gpucontext* ctx = gpucomm_context(comm); |
68 | 70 | if (ctx->comm_ops == NULL) |
69 | 71 | return GA_COMM_ERROR; |
70 | | - return ctx->comm_ops->all_reduce(src, offsrc, dest, offdest, |
71 | | - count, typecode, opcode, comm); |
| 72 | + return ctx->comm_ops->all_reduce(src, offsrc, dest, offdest, count, typecode, |
| 73 | + opcode, comm); |
72 | 74 | } |
73 | 75 |
|
74 | | -int gpucomm_reduce_scatter(gpudata* src, size_t offsrc, |
75 | | - gpudata* dest, size_t offdest, |
76 | | - int count, int typecode, int opcode, |
77 | | - gpucomm* comm) { |
| 76 | +int gpucomm_reduce_scatter(gpudata* src, size_t offsrc, gpudata* dest, |
| 77 | + size_t offdest, int count, int typecode, int opcode, |
| 78 | + gpucomm* comm) |
| 79 | +{ |
78 | 80 | gpucontext* ctx = gpucomm_context(comm); |
79 | 81 | if (ctx->comm_ops == NULL) |
80 | 82 | return GA_COMM_ERROR; |
81 | | - return ctx->comm_ops->reduce_scatter(src, offsrc, dest, offdest, |
82 | | - count, typecode, opcode, comm); |
| 83 | + return ctx->comm_ops->reduce_scatter(src, offsrc, dest, offdest, count, typecode, |
| 84 | + opcode, comm); |
83 | 85 | } |
84 | 86 |
|
85 | | -int gpucomm_broadcast(gpudata* array, size_t offset, |
86 | | - int count, int typecode, |
87 | | - int root, gpucomm* comm) { |
| 87 | +int gpucomm_broadcast(gpudata* array, size_t offset, int count, int typecode, |
| 88 | + int root, gpucomm* comm) |
| 89 | +{ |
88 | 90 | gpucontext* ctx = gpucomm_context(comm); |
89 | 91 | if (ctx->comm_ops == NULL) |
90 | 92 | return GA_COMM_ERROR; |
91 | | - return ctx->comm_ops->broadcast(array, offset, |
92 | | - count, typecode, root, comm); |
| 93 | + return ctx->comm_ops->broadcast(array, offset, count, typecode, root, comm); |
93 | 94 | } |
94 | 95 |
|
95 | | -int gpucomm_all_gather(gpudata* src, size_t offsrc, |
96 | | - gpudata* dest, size_t offdest, |
97 | | - int count, int typecode, |
98 | | - gpucomm* comm) { |
| 96 | +int gpucomm_all_gather(gpudata* src, size_t offsrc, gpudata* dest, size_t offdest, |
| 97 | + int count, int typecode, gpucomm* comm) |
| 98 | +{ |
99 | 99 | gpucontext* ctx = gpucomm_context(comm); |
100 | 100 | if (ctx->comm_ops == NULL) |
101 | 101 | return GA_COMM_ERROR; |
102 | | - return ctx->comm_ops->all_gather(src, offsrc, dest, offdest, |
103 | | - count, typecode, comm); |
| 102 | + return ctx->comm_ops->all_gather(src, offsrc, dest, offdest, count, typecode, |
| 103 | + comm); |
104 | 104 | } |
0 commit comments