forked from triton-inference-server/python_backend
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpb_memory.cc
More file actions
482 lines (424 loc) · 14.9 KB
/
Copy pathpb_memory.cc
File metadata and controls
482 lines (424 loc) · 14.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
// Copyright 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "pb_memory.h"
namespace triton { namespace backend { namespace python {
std::unique_ptr<PbMemory>
PbMemory::Create(
std::unique_ptr<SharedMemoryManager>& shm_pool,
TRITONSERVER_MemoryType memory_type, int64_t memory_type_id,
uint64_t byte_size, char* data, bool copy_gpu)
{
size_t requested_byte_size = sizeof(MemoryShm);
if (memory_type == TRITONSERVER_MEMORY_GPU) {
#ifdef TRITON_ENABLE_GPU
requested_byte_size += sizeof(cudaIpcMemHandle_t);
#endif
} else {
requested_byte_size += byte_size;
}
AllocatedSharedMemory<char> memory_shm =
shm_pool->Construct<char>(requested_byte_size);
PbMemory::FillShmData(
shm_pool->GetCUDAMemoryPoolManager(), memory_type, memory_type_id,
byte_size, data, memory_shm.data_.get(), memory_shm.handle_, copy_gpu);
if (memory_type == TRITONSERVER_MEMORY_CPU) {
data = memory_shm.data_.get() + sizeof(MemoryShm);
}
std::unique_ptr<PbMemory> pb_memory(
new PbMemory(memory_shm, data, false /* opened_cuda_ipc_handle */));
#ifdef TRITON_ENABLE_GPU
if (memory_type == TRITONSERVER_MEMORY_GPU) {
pb_memory->memory_shm_ptr_->gpu_pointer_offset =
pb_memory->GetGPUPointerOffset();
}
#endif
return pb_memory;
}
#ifndef TRITON_PB_STUB
std::unique_ptr<PbMemory>
PbMemory::Create(
std::unique_ptr<SharedMemoryManager>& shm_pool,
std::unique_ptr<BackendMemory>&& backend_memory, bool copy_gpu)
{
std::unique_ptr<PbMemory> pb_memory = PbMemory::Create(
shm_pool, backend_memory->MemoryType(), backend_memory->MemoryTypeId(),
backend_memory->ByteSize(), backend_memory->MemoryPtr(), copy_gpu);
pb_memory->backend_memory_ = std::move(backend_memory);
return pb_memory;
}
#endif
std::unique_ptr<PbMemory>
PbMemory::Create(
std::unique_ptr<SharedMemoryManager>& shm_pool,
TRITONSERVER_MemoryType memory_type, int64_t memory_type_id,
uint64_t byte_size, char* data, char* data_shm,
bi::managed_external_buffer::handle_t handle, bool copy_gpu)
{
PbMemory::FillShmData(
shm_pool->GetCUDAMemoryPoolManager(), memory_type, memory_type_id,
byte_size, data, data_shm, handle, copy_gpu);
if (memory_type == TRITONSERVER_MEMORY_CPU) {
data = data_shm + sizeof(MemoryShm);
}
std::unique_ptr<PbMemory> pb_memory(
new PbMemory(data_shm, data, handle, false /* opened_cuda_ipc_handle */));
#ifdef TRITON_ENABLE_GPU
if (memory_type == TRITONSERVER_MEMORY_GPU) {
pb_memory->memory_shm_ptr_->gpu_pointer_offset =
pb_memory->GetGPUPointerOffset();
}
#endif
return pb_memory;
}
void
PbMemory::CopyBuffer(
std::unique_ptr<PbMemory>& dst, std::unique_ptr<PbMemory>& src)
{
if (src->ByteSize() != dst->ByteSize()) {
throw PythonBackendException(
"Failed to copy memory buffers. Source and destination byte size do "
"not match: " +
std::to_string(dst->ByteSize()) +
" != " + std::to_string(src->ByteSize()));
}
if (src->MemoryType() == TRITONSERVER_MEMORY_CPU &&
dst->MemoryType() == TRITONSERVER_MEMORY_CPU) {
std::memcpy(dst->DataPtr(), src->DataPtr(), dst->ByteSize());
return;
}
#ifdef TRITON_ENABLE_GPU
cudaMemcpyKind kind = cudaMemcpyHostToDevice;
if (src->MemoryType() == TRITONSERVER_MEMORY_CPU &&
dst->MemoryType() == TRITONSERVER_MEMORY_GPU) {
kind = cudaMemcpyHostToDevice;
} else if (
src->MemoryType() == TRITONSERVER_MEMORY_GPU &&
dst->MemoryType() == TRITONSERVER_MEMORY_CPU) {
kind = cudaMemcpyDeviceToHost;
} else if (
src->MemoryType() == TRITONSERVER_MEMORY_GPU &&
dst->MemoryType() == TRITONSERVER_MEMORY_GPU) {
kind = cudaMemcpyDeviceToDevice;
}
cudaError_t err;
if ((kind == cudaMemcpyDeviceToDevice) &&
(src->MemoryTypeId() != dst->MemoryTypeId())) {
err = cudaMemcpyPeer(
dst->DataPtr(), dst->MemoryTypeId(), src->DataPtr(),
src->MemoryTypeId(), src->ByteSize());
} else {
err = cudaMemcpy(dst->DataPtr(), src->DataPtr(), src->ByteSize(), kind);
}
if (err != cudaSuccess) {
throw PythonBackendException(
std::string(
"failed to copy data: " + std::string(cudaGetErrorString(err)))
.c_str());
}
if (kind == cudaMemcpyDeviceToDevice) {
// Synchronize the default stream for d2d copies.
// https://docs.nvidia.com/cuda/cuda-runtime-api/api-sync-behavior.html#api-sync-behavior__memcpy-sync
err = cudaStreamSynchronize(0);
if (err != cudaSuccess) {
throw PythonBackendException(
std::string(
"failed to synchronize the default CUDA stream. error: " +
std::string(cudaGetErrorString(err)))
.c_str());
}
}
#endif
}
void
PbMemory::FillShmData(
std::unique_ptr<CUDAMemoryPoolManager>& cuda_pool,
TRITONSERVER_MemoryType memory_type, int64_t memory_type_id,
uint64_t byte_size, char* data, char* data_shm,
bi::managed_external_buffer::handle_t handle, bool copy_gpu)
{
char* memory_data_shm = data_shm + sizeof(MemoryShm);
MemoryShm* memory_shm_ptr = reinterpret_cast<MemoryShm*>(data_shm);
memory_shm_ptr->memory_release_id = 0;
bool use_cuda_shared_pool = false;
if (memory_type == TRITONSERVER_MEMORY_GPU) {
#ifdef TRITON_ENABLE_GPU
if (data != nullptr) {
if (copy_gpu) {
ScopedSetDevice scoped_set_device(memory_type_id);
THROW_IF_CUDA_ERROR(cudaIpcGetMemHandle(
reinterpret_cast<cudaIpcMemHandle_t*>(memory_data_shm), data));
}
if (cuda_pool->UseCudaSharedPool(memory_type_id) &&
IsUsingCUDAPool(cuda_pool, memory_type_id, data)) {
use_cuda_shared_pool = true;
memory_shm_ptr->cuda_pool_offset =
data -
reinterpret_cast<char*>(cuda_pool->CUDAPoolAddress(memory_type_id));
}
}
#endif // TRITON_ENABLE_GPU
} else {
if (data != nullptr) {
std::copy(data, data + byte_size, memory_data_shm);
}
}
memory_shm_ptr->byte_size = byte_size;
memory_shm_ptr->memory_type_id = memory_type_id;
memory_shm_ptr->memory_type = memory_type;
memory_shm_ptr->use_cuda_shared_pool = use_cuda_shared_pool;
}
std::unique_ptr<PbMemory>
PbMemory::LoadFromSharedMemory(
std::unique_ptr<SharedMemoryManager>& shm_pool,
bi::managed_external_buffer::handle_t handle, char* data_shm,
bool open_cuda_handle)
{
MemoryShm* memory_shm_ptr = reinterpret_cast<MemoryShm*>(data_shm);
char* memory_data_shm = data_shm + sizeof(MemoryShm);
char* data_ptr = nullptr;
bool opened_cuda_ipc_handle = false;
if (memory_shm_ptr->memory_type == TRITONSERVER_MEMORY_GPU &&
open_cuda_handle) {
#ifdef TRITON_ENABLE_GPU
if (memory_shm_ptr->use_cuda_shared_pool) {
// When CUDA shared memory pool is used, the stub will retrieve the
// data pointer using the offset.
data_ptr =
(reinterpret_cast<char*>(
shm_pool->GetCUDAMemoryPoolManager()->CUDAPoolAddress(
memory_shm_ptr->memory_type_id)) +
memory_shm_ptr->cuda_pool_offset);
} else {
cudaIpcMemHandle_t* cuda_handle =
reinterpret_cast<cudaIpcMemHandle_t*>(memory_data_shm);
// The pointer opened by the cudaIpcOpenMemHandle will refer to the base
// address. We need to manually correct the offset.
void* data_ptr_base;
CUDAHandler& cuda_handler = CUDAHandler::getInstance();
cuda_handler.OpenCudaHandle(
memory_shm_ptr->memory_type_id, cuda_handle, &data_ptr_base);
data_ptr =
(reinterpret_cast<char*>(data_ptr_base) +
memory_shm_ptr->gpu_pointer_offset);
opened_cuda_ipc_handle = true;
}
#endif // TRITON_ENABLE_GPU
} else {
data_ptr = memory_data_shm;
}
return std::unique_ptr<PbMemory>(new PbMemory(
data_shm, data_ptr, handle,
opened_cuda_ipc_handle /* opened_cuda_ipc_handle */));
}
std::unique_ptr<PbMemory>
PbMemory::LoadFromSharedMemory(
std::unique_ptr<SharedMemoryManager>& shm_pool,
bi::managed_external_buffer::handle_t handle, bool open_cuda_handle)
{
AllocatedSharedMemory<char> memory_shm = shm_pool->Load<char>(handle);
MemoryShm* memory_shm_ptr =
reinterpret_cast<MemoryShm*>(memory_shm.data_.get());
char* memory_data_shm = memory_shm.data_.get() + sizeof(MemoryShm);
char* data_ptr = nullptr;
bool opened_cuda_ipc_handle = false;
if (memory_shm_ptr->memory_type == TRITONSERVER_MEMORY_GPU) {
if (memory_shm_ptr->byte_size > 0 && open_cuda_handle) {
#ifdef TRITON_ENABLE_GPU
if (memory_shm_ptr->use_cuda_shared_pool) {
// When CUDA shared memory pool is used, the stub will retrieve the
// data pointer using the offset.
data_ptr =
(reinterpret_cast<char*>(
shm_pool->GetCUDAMemoryPoolManager()->CUDAPoolAddress(
memory_shm_ptr->memory_type_id)) +
memory_shm_ptr->cuda_pool_offset);
} else {
cudaIpcMemHandle_t* cuda_handle =
reinterpret_cast<cudaIpcMemHandle_t*>(memory_data_shm);
// The pointer opened by the cudaIpcOpenMemHandle will refer to the base
// address. We need to manually correct the offset.
void* data_ptr_base;
CUDAHandler& cuda_handler = CUDAHandler::getInstance();
cuda_handler.OpenCudaHandle(
memory_shm_ptr->memory_type_id, cuda_handle, &data_ptr_base);
data_ptr =
(reinterpret_cast<char*>(data_ptr_base) +
memory_shm_ptr->gpu_pointer_offset);
opened_cuda_ipc_handle = true;
}
#endif
}
} else {
data_ptr = memory_data_shm;
}
return std::unique_ptr<PbMemory>(new PbMemory(
memory_shm, data_ptr,
opened_cuda_ipc_handle /* opened_cuda_ipc_handle */));
}
PbMemory::PbMemory(
AllocatedSharedMemory<char>& memory_shm, char* data,
bool opened_cuda_ipc_handle)
: memory_shm_(std::move(memory_shm)), data_ptr_(data),
opened_cuda_ipc_handle_(opened_cuda_ipc_handle)
{
memory_shm_ptr_ = reinterpret_cast<MemoryShm*>(memory_shm_.data_.get());
memory_shm_handle_ = memory_shm_.handle_;
}
PbMemory::PbMemory(
char* memory_shm, char* data, bi::managed_external_buffer::handle_t handle,
bool opened_cuda_ipc_handle)
{
memory_shm_ptr_ = reinterpret_cast<MemoryShm*>(memory_shm);
data_ptr_ = data;
opened_cuda_ipc_handle_ = opened_cuda_ipc_handle;
memory_shm_handle_ = handle;
}
bi::managed_external_buffer::handle_t
PbMemory::ShmHandle()
{
return memory_shm_handle_;
}
#ifdef TRITON_ENABLE_GPU
void*
PbMemory::GetGPUStartAddress()
{
if (memory_shm_ptr_->memory_type == TRITONSERVER_MEMORY_GPU) {
CUDAHandler& cuda_api = CUDAHandler::getInstance();
CUdeviceptr start_address = 0;
// Skip this step for empty tensor as the CUDA API 'cuPointerGetAttribute'
// we use in this function does not accept nullptr.
if (data_ptr_) {
cuda_api.PointerGetAttribute(
&start_address, CU_POINTER_ATTRIBUTE_RANGE_START_ADDR,
reinterpret_cast<CUdeviceptr>(data_ptr_));
}
return reinterpret_cast<void*>(start_address);
}
throw PythonBackendException(
"Calling GetGPUStartAddress function on CPU memory.");
}
uint64_t
PbMemory::GetGPUPointerOffset()
{
uint64_t offset;
if (memory_shm_ptr_->memory_type == TRITONSERVER_MEMORY_GPU) {
offset = data_ptr_ - reinterpret_cast<char*>(GetGPUStartAddress());
} else {
throw PythonBackendException(
"Calling GetGPUPointerOffset function on CPU tensor.");
}
return offset;
}
#endif
TRITONSERVER_MemoryType
PbMemory::MemoryType() const
{
return memory_shm_ptr_->memory_type;
}
void
PbMemory::SetMemoryReleaseId(uint64_t memory_release_id)
{
memory_shm_ptr_->memory_release_id = memory_release_id;
}
int64_t
PbMemory::MemoryTypeId() const
{
return memory_shm_ptr_->memory_type_id;
}
uint64_t
PbMemory::ByteSize() const
{
return memory_shm_ptr_->byte_size;
}
char*
PbMemory::ShmData() const
{
return reinterpret_cast<char*>(memory_shm_ptr_) + sizeof(MemoryShm);
}
char*
PbMemory::DataPtr() const
{
return data_ptr_;
}
uint64_t
PbMemory::ShmStructSize(TRITONSERVER_MemoryType memory_type, uint64_t byte_size)
{
uint64_t total_memory_size = sizeof(MemoryShm);
if (memory_type == TRITONSERVER_MEMORY_GPU) {
#ifdef TRITON_ENABLE_GPU
total_memory_size += sizeof(cudaIpcMemHandle_t);
#endif
} else {
total_memory_size += byte_size;
}
return total_memory_size;
}
#ifdef TRITON_ENABLE_GPU
void
PbMemory::SetCudaIpcHandle(cudaIpcMemHandle_t* cuda_ipc_handle)
{
*(reinterpret_cast<cudaIpcMemHandle_t*>(ShmData())) = *(cuda_ipc_handle);
}
void
PbMemory::UpdateCUDAOffset(std::unique_ptr<CUDAMemoryPoolManager>& cuda_pool)
{
if (cuda_pool->UseCudaSharedPool(MemoryTypeId()) &&
IsUsingCUDAPool(cuda_pool, MemoryTypeId(), DataPtr())) {
memory_shm_ptr_->cuda_pool_offset =
DataPtr() -
reinterpret_cast<char*>(cuda_pool->CUDAPoolAddress(MemoryTypeId()));
memory_shm_ptr_->use_cuda_shared_pool = true;
}
}
#endif
PbMemory::~PbMemory()
{
if (opened_cuda_ipc_handle_) {
#ifdef TRITON_ENABLE_GPU
CUDAHandler& cuda_handler = CUDAHandler::getInstance();
cuda_handler.CloseCudaHandle(
memory_shm_ptr_->memory_type_id, GetGPUStartAddress());
#endif
}
if (release_callback_) {
release_callback_();
}
}
void
PbMemory::SetMemoryReleaseCallback(std::function<void(void)> release_callback)
{
if (!release_callback_) {
release_callback_ = release_callback;
} else {
throw PythonBackendException("Release callback is already set.");
}
}
uint64_t
PbMemory::MemoryReleaseId()
{
return memory_shm_ptr_->memory_release_id;
}
}}} // namespace triton::backend::python