forked from triton-inference-server/python_backend
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpb_memory.h
More file actions
166 lines (133 loc) · 5.73 KB
/
Copy pathpb_memory.h
File metadata and controls
166 lines (133 loc) · 5.73 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
// Copyright 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include "pb_utils.h"
#include "shm_manager.h"
#include "triton/backend/backend_common.h"
#include "triton/backend/backend_memory.h"
#ifdef TRITON_ENABLE_GPU
#include <cuda_runtime_api.h>
#endif // TRITON_ENABLE_GPU
namespace triton { namespace backend { namespace python {
//
// Represents a memory object in shared memory.
//
struct MemoryShm {
// If the memory type is a GPU pointer, the offset of the GPU pointer from the
// base address. For CPU memory type this field contains garbage data.
uint64_t gpu_pointer_offset;
TRITONSERVER_MemoryType memory_type;
int64_t memory_type_id;
uint64_t byte_size;
bool is_cuda_handle_set;
uint64_t memory_release_id;
};
class PbMemory {
public:
static std::unique_ptr<PbMemory> Create(
std::unique_ptr<SharedMemoryManager>& shm_pool,
TRITONSERVER_MemoryType memory_type, int64_t memory_type_id,
uint64_t byte_size, char* data, bool copy_gpu = true);
static std::unique_ptr<PbMemory> Create(
TRITONSERVER_MemoryType memory_type, int64_t memory_type_id,
uint64_t byte_size, char* data, char* data_shm,
bi::managed_external_buffer::handle_t handle, bool copy_gpu = true);
#ifndef TRITON_PB_STUB
static std::unique_ptr<PbMemory> Create(
std::unique_ptr<SharedMemoryManager>& shm_pool,
std::unique_ptr<BackendMemory>&& backend_memory, bool copy_gpu = true);
#endif
#ifdef TRITON_ENABLE_GPU
void SetCudaIpcHandle(cudaIpcMemHandle_t* cuda_ipc_handle);
#endif
// Copy the destination buffer to the source buffer.
static void CopyBuffer(
std::unique_ptr<PbMemory>& dst, std::unique_ptr<PbMemory>& src);
static std::unique_ptr<PbMemory> LoadFromSharedMemory(
std::unique_ptr<SharedMemoryManager>& shm_pool,
bi::managed_external_buffer::handle_t memory_handle,
bool open_cuda_handle);
static std::unique_ptr<PbMemory> LoadFromSharedMemory(
bi::managed_external_buffer::handle_t handle, char* data_shm,
bool open_cuda_handle);
static uint64_t ShmStructSize(
TRITONSERVER_MemoryType memory_type, uint64_t byte_size);
bi::managed_external_buffer::handle_t ShmHandle();
/// Get the total byte size of the tensor.
uint64_t ByteSize() const;
/// Get the triton memory type.
/// \return the memory type of the tensor.
TRITONSERVER_MemoryType MemoryType() const;
/// Get the pointer.
/// \return The location to the memory where the data is stored.
char* DataPtr() const;
/// Get the memory type id.
/// \return The memory type id of the tensor.
int64_t MemoryTypeId() const;
/// Get the shm data
/// \return The memory type id of the tensor.
char* ShmData() const;
/// Set the memory release id
void SetMemoryReleaseId(uint64_t memory_release_id);
/// Memory Release ID
uint64_t MemoryReleaseId();
void SetMemoryReleaseCallback(std::function<void(void)> release_callback);
~PbMemory();
private:
AllocatedSharedMemory<char> memory_shm_;
MemoryShm* memory_shm_ptr_;
#ifndef TRITON_PB_STUB
std::unique_ptr<BackendMemory> backend_memory_;
#endif
std::function<void()> release_callback_;
// Refers to the pointer that can hold the data. For CPU pointers this will be
// the same as memory_data_shm_ptr_.
char* data_ptr_;
bi::managed_external_buffer::handle_t memory_shm_handle_;
bool opened_cuda_ipc_handle_;
#ifdef TRITON_ENABLE_GPU
/// Calculate the pointer offest from the base address.
/// \return The offset of a device pointer.
/// \throws PythonBackendException if the tensor is stored in CPU.
uint64_t GetGPUPointerOffset();
/// Get the GPU start address.
/// \return The start address of a device pointer.
/// \throws PythonBackendException if the tensor is stored in CPU.
void* GetGPUStartAddress();
#endif
static void FillShmData(
TRITONSERVER_MemoryType memory_type, int64_t memory_type_id,
uint64_t byte_size, char* data, char* data_shm,
bi::managed_external_buffer::handle_t handle, bool copy_gpu = true);
PbMemory(
AllocatedSharedMemory<char>& memory_shm, char* data,
bool opened_cuda_ipc_handle);
PbMemory(
char* memory_shm, char* data,
bi::managed_external_buffer::handle_t handle,
bool opened_cuda_ipc_handle);
};
}}} // namespace triton::backend::python