Skip to content

Commit 6c4b817

Browse files
authored
Update dlpack implementation for PbTensor (triton-inference-server#223)
* Update dlpack implementation for PbTensor: handle new API + bools
1 parent 894b074 commit 6c4b817

6 files changed

Lines changed: 171 additions & 23 deletions

File tree

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ FetchContent_MakeAvailable(pybind11)
8989
FetchContent_Declare(
9090
dlpack
9191
GIT_REPOSITORY "https://github.com/dmlc/dlpack"
92-
GIT_TAG "v0.7"
92+
GIT_TAG "v0.8"
9393
GIT_SHALLOW ON
9494
)
9595
FetchContent_MakeAvailable(dlpack)

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1226,6 +1226,15 @@ class TritonPythonModel:
12261226
# tensor.
12271227
input0 = pb_utils.Tensor.from_dlpack("INPUT0", to_dlpack(pytorch_tensor))
12281228
```
1229+
Python backend allows tensors implementing
1230+
[`__dlpack__`](https://data-apis.org/array-api/2022.12/API_specification/generated/array_api.array.__dlpack__.html)
1231+
and [`__dlpack_device__`](https://data-apis.org/array-api/2022.12/API_specification/generated/array_api.array.__dlpack_device__.html)
1232+
[interface](https://dmlc.github.io/dlpack/latest/python_spec.html)
1233+
to be converted to Python backend tensors. For instance:
1234+
1235+
```python
1236+
input0 = pb_utils.Tensor.from_dlpack("INPUT0", pytorch_tensor)
1237+
```
12291238

12301239
This method only supports contiguous Tensors that are in C-order. If the tensor
12311240
is not C-order contiguous an exception will be raised.

src/pb_stub.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1488,7 +1488,9 @@ PYBIND11_EMBEDDED_MODULE(c_python_backend_utils, module)
14881488
.def("to_dlpack", &PbTensor::ToDLPack)
14891489
.def("is_cpu", &PbTensor::IsCPU)
14901490
.def("shape", &PbTensor::Dims)
1491-
.def("from_dlpack", &PbTensor::FromDLPack);
1491+
.def("from_dlpack", &PbTensor::FromDLPack)
1492+
.def("__dlpack__", &PbTensor::DLPack, py::arg("stream") = py::none())
1493+
.def("__dlpack_device__", &PbTensor::DLPackDevice);
14921494

14931495
py::class_<InferResponse, std::shared_ptr<InferResponse>>(
14941496
module, "InferenceResponse")

src/pb_stub_utils.cc

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Redistribution and use in source and binary forms, with or without
44
// modification, are permitted provided that the following conditions
@@ -189,8 +189,8 @@ triton_to_dlpack_type(TRITONSERVER_DataType triton_dtype)
189189
dl_dtype.lanes = 1;
190190
switch (triton_dtype) {
191191
case TRITONSERVER_TYPE_BOOL:
192-
dl_code = DLDataTypeCode::kDLInt;
193-
dt_size = 1;
192+
dl_code = DLDataTypeCode::kDLBool;
193+
dt_size = 8;
194194
break;
195195
case TRITONSERVER_TYPE_UINT8:
196196
dl_code = DLDataTypeCode::kDLUInt;
@@ -279,8 +279,6 @@ dlpack_to_triton_type(const DLDataType& data_type)
279279
return TRITONSERVER_TYPE_INT32;
280280
} else if (data_type.bits == 64) {
281281
return TRITONSERVER_TYPE_INT64;
282-
} else if (data_type.bits == 1) {
283-
return TRITONSERVER_TYPE_BOOL;
284282
}
285283
}
286284

@@ -296,6 +294,12 @@ dlpack_to_triton_type(const DLDataType& data_type)
296294
}
297295
}
298296

297+
if (data_type.code == DLDataTypeCode::kDLBool) {
298+
if (data_type.bits == 8) {
299+
return TRITONSERVER_TYPE_BOOL;
300+
}
301+
}
302+
299303
return TRITONSERVER_TYPE_INVALID;
300304
}
301305
}}} // namespace triton::backend::python

src/pb_tensor.cc

Lines changed: 127 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,36 @@ PbTensor::FromNumpy(const std::string& name, py::array& numpy_array)
231231
return std::make_shared<PbTensor>(name, numpy_array);
232232
}
233233

234+
DLDeviceType
235+
PbTensor::DeviceType()
236+
{
237+
DLDeviceType device_type{};
238+
239+
switch (memory_type_) {
240+
case TRITONSERVER_MEMORY_GPU:
241+
device_type = DLDeviceType::kDLCUDA;
242+
break;
243+
case TRITONSERVER_MEMORY_CPU:
244+
device_type = DLDeviceType::kDLCPU;
245+
break;
246+
case TRITONSERVER_MEMORY_CPU_PINNED:
247+
device_type = DLDeviceType::kDLCUDAHost;
248+
break;
249+
}
250+
251+
return device_type;
252+
}
253+
254+
py::capsule
255+
PbTensor::DLPack(const py::object& stream)
256+
{
257+
// Here external tensor requests PbTensor's `__dlpack__` method to provide
258+
// a PyCapsule. By the design of PbTensor, in a GPU case no pending work
259+
// is scheduled to work with PbTensor's data and we can simply pass
260+
// the capsule without a synchronization.
261+
return this->ToDLPack();
262+
}
263+
234264
py::capsule
235265
PbTensor::ToDLPack()
236266
{
@@ -269,23 +299,19 @@ PbTensor::ToDLPack()
269299
tensor_handle.inc_ref();
270300

271301
dlpack_tensor->dl_tensor.device.device_id = memory_type_id_;
302+
dlpack_tensor->dl_tensor.device.device_type = this->DeviceType();
272303
dlpack_tensor->dl_tensor.dtype = triton_to_dlpack_type(dtype_);
273304

274-
switch (memory_type_) {
275-
case TRITONSERVER_MEMORY_GPU:
276-
dlpack_tensor->dl_tensor.device.device_type = DLDeviceType::kDLCUDA;
277-
break;
278-
case TRITONSERVER_MEMORY_CPU:
279-
dlpack_tensor->dl_tensor.device.device_type = DLDeviceType::kDLCPU;
280-
break;
281-
case TRITONSERVER_MEMORY_CPU_PINNED:
282-
dlpack_tensor->dl_tensor.device.device_type = DLDeviceType::kDLCUDAHost;
283-
break;
284-
}
285-
286305
return py::capsule(
287306
static_cast<void*>(dlpack_tensor), "dltensor", &delete_unused_dltensor);
288307
}
308+
309+
std::pair<int32_t, int64_t>
310+
PbTensor::DLPackDevice()
311+
{
312+
return std::pair<int32_t, int64_t>(this->DeviceType(), memory_type_id_);
313+
}
314+
289315
#endif // TRITON_PB_STUB
290316

291317
void
@@ -305,12 +331,100 @@ PbTensor::Memory()
305331

306332
#ifdef TRITON_PB_STUB
307333
std::shared_ptr<PbTensor>
308-
PbTensor::FromDLPack(const std::string& name, const py::capsule& dlpack_tensor)
334+
PbTensor::FromDLPack(const std::string& name, const py::object& tensor)
309335
{
310336
if (name == "") {
311337
throw PythonBackendException("Tensor name cannot be an empty string.");
312338
}
339+
if (py::isinstance<py::capsule>(tensor)) {
340+
return FromDLPackCapsule(name, tensor);
341+
}
342+
343+
if (!py::hasattr(tensor, "__dlpack__") ||
344+
!py::hasattr(tensor, "__dlpack_device__")) {
345+
throw PythonBackendException(
346+
"Provided tensor is not supported. Tensor must be a DLPack capsule \
347+
or have `__dlpack__` and `__dlpack_device__` attributes");
348+
}
349+
350+
auto capsule_device_info =
351+
tensor.attr("__dlpack_device__")().cast<std::pair<int32_t, int64_t>>();
352+
if (capsule_device_info.first == DLDeviceType::kDLCUDA) {
353+
#ifdef TRITON_ENABLE_GPU
354+
int current_device;
355+
cudaError_t err = cudaGetDevice(&current_device);
356+
if (err != cudaSuccess) {
357+
throw PythonBackendException("Failed to get current CUDA device id.");
358+
}
359+
360+
bool overridden = (current_device != capsule_device_info.second);
361+
err = overridden ? cudaSetDevice(capsule_device_info.second) : cudaSuccess;
362+
if (err != cudaSuccess) {
363+
throw PythonBackendException(
364+
"Failed to set CUDA device to device with id " +
365+
std::to_string(capsule_device_info.second));
366+
}
367+
// In case there is a pending job on the data, where this capsule
368+
// is pointing to, we need to wait for it before consuming.
369+
// This is important for when data is located on different
370+
// context (GPU) and work is done on the default stream.
371+
// For this scenario, __dlpack__ implementation may skip
372+
// syncronization (since the work is on the default stream)
373+
// and we will return pointer to the data on different GPU too early
374+
// (i.e. before pending work is done). Thus we sync on the default stream
375+
// only in the case we switched to a different context.
376+
err = overridden ? cudaStreamSynchronize(0) : cudaSuccess;
377+
if (err != cudaSuccess) {
378+
throw PythonBackendException(
379+
"Failed to synchronize CUDA device with id " +
380+
std::to_string(
381+
overridden ? capsule_device_info.second : current_device));
382+
}
383+
384+
// Array API requirements for the stream argument:
385+
// stream = 1 the legacy default stream (in this case should
386+
// synchronize on CUDA stream 0)
387+
// For CPU, `stream=None` is the only accepted argument
388+
// according to array API. For GPU, when `stream=None` producer
389+
// must assume the legacy default stream. Reference:
390+
// https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack__.html
391+
auto ptr_to_tensor = FromDLPackCapsule(
392+
name, tensor.attr("__dlpack__")(py::arg("stream") = py::int_(1)));
393+
394+
err = overridden ? cudaSetDevice(current_device) : cudaSuccess;
395+
if (err != cudaSuccess) {
396+
throw PythonBackendException(
397+
"Failed to set CUDA device back to initial compute device "
398+
"with id " +
399+
std::to_string(current_device));
400+
}
401+
return ptr_to_tensor;
402+
#else
403+
throw PythonBackendException(
404+
"DLPack capsule passed pointer to memory allocated on GPU device, \
405+
when GPU is not available");
406+
#endif
407+
} else if (
408+
capsule_device_info.first != DLDeviceType::kDLCPU &&
409+
capsule_device_info.first != DLDeviceType::kDLCUDAHost) {
410+
throw PythonBackendException(
411+
"DLDevice type " + std::to_string(capsule_device_info.first) +
412+
" is not support by Python backend.");
413+
}
414+
415+
// If data is located on CPU, `stream=None` is the only accepted argument
416+
// according to array API. For GPU, when `stream=None` producer must
417+
// assume the legacy default stream.
418+
// Reference:
419+
// https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack__.html
420+
return FromDLPackCapsule(
421+
name, tensor.attr("__dlpack__")(py::arg("stream") = py::none()));
422+
}
313423

424+
std::shared_ptr<PbTensor>
425+
PbTensor::FromDLPackCapsule(
426+
const std::string& name, const py::capsule& dlpack_tensor)
427+
{
314428
DLManagedTensor* dl_managed_tensor =
315429
static_cast<DLManagedTensor*>(dlpack_tensor.get_pointer());
316430

src/pb_tensor.h

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Redistribution and use in source and binary forms, with or without
44
// modification, are permitted provided that the following conditions
@@ -112,11 +112,16 @@ class PbTensor {
112112
DISALLOW_COPY_AND_ASSIGN(PbTensor);
113113

114114
#ifdef TRITON_PB_STUB
115-
/// Construct a Python backend tensor using a DLPack
116-
/// capsule.
115+
/// Construct a Python backend tensor from an
116+
/// external tensor.
117117
/// \param dlpack source dlpack tensor
118118
/// \param name name of the tensor
119119
static std::shared_ptr<PbTensor> FromDLPack(
120+
const std::string& name, const py::object& dlpack);
121+
122+
/// Construct a Python backend tensor using a DLPack
123+
/// capsule.
124+
static std::shared_ptr<PbTensor> FromDLPackCapsule(
120125
const std::string& name, const py::capsule& dlpack);
121126

122127
/// Construct a Python backend tensor using a NumPy object.
@@ -125,9 +130,23 @@ class PbTensor {
125130
static std::shared_ptr<PbTensor> FromNumpy(
126131
const std::string& name, py::array& numpy_array);
127132

133+
/// Get device type in DLPack format.
134+
DLDeviceType DeviceType();
135+
136+
/// Exports tensor for consumption by `from_dlpack()` as a DLPack capsule.
137+
/// \param stream a Python integer representing a pointer to a stream,
138+
/// on devices that support streams
139+
/// \return Capsule object containing pointer to a DLPack object.
140+
py::capsule DLPack(const py::object& stream);
141+
128142
/// Get a PyCapsule object containing the DLPack representation of the tensor.
129143
/// \return Capsule object containing pointer to a DLPack object.
130144
py::capsule ToDLPack();
145+
146+
/// Returns device type and device ID.
147+
/// Meant for use within `from_dlpack()`.
148+
/// \return a pair (device_type, device_id).
149+
std::pair<int32_t, int64_t> DLPackDevice();
131150
#endif
132151

133152
/// Get the name of the tensor

0 commit comments

Comments
 (0)