Skip to content

Commit 76be86f

Browse files
authored
Fix large payload for Python backend (triton-inference-server#18)
1 parent b1eb046 commit 76be86f

2 files changed

Lines changed: 22 additions & 3 deletions

File tree

src/python.cc

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ namespace triton { namespace backend { namespace python {
8787
} \
8888
} while (false)
8989

90+
constexpr int MAX_GRPC_MESSAGE_SIZE = INT32_MAX;
91+
9092
class ModelState;
9193

9294
struct BackendState {
@@ -261,8 +263,11 @@ TRITONSERVER_Error*
261263
ModelInstanceState::ConnectPythonInterpreter()
262264
{
263265
grpc_init();
264-
auto grpc_channel =
265-
grpc::CreateChannel(domain_socket_, grpc::InsecureChannelCredentials());
266+
grpc::ChannelArguments arguments;
267+
arguments.SetMaxSendMessageSize(MAX_GRPC_MESSAGE_SIZE);
268+
arguments.SetMaxReceiveMessageSize(MAX_GRPC_MESSAGE_SIZE);
269+
auto grpc_channel = grpc::CreateCustomChannel(
270+
domain_socket_, grpc::InsecureChannelCredentials(), arguments);
266271

267272
stub = PythonInterpreter::NewStub(grpc_channel);
268273

@@ -441,6 +446,12 @@ ModelInstanceState::GetInputTensor(
441446
in, &input_name, &input_dtype, &input_shape, &input_dims_count,
442447
&input_byte_size, &input_buffer_count));
443448

449+
if (input_byte_size >= MAX_GRPC_MESSAGE_SIZE)
450+
return TRITONSERVER_ErrorNew(
451+
TRITONSERVER_ERROR_UNSUPPORTED,
452+
"Python backend does not support input size larger than 2GBs, consider "
453+
"parititioning your input into multiple inputs.");
454+
444455
// Update input_tensor
445456
input_tensor->set_name(input_name);
446457
input_tensor->set_dtype(static_cast<int>(input_dtype));

src/resources/startup.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@
4646
from python_host_pb2_grpc import PythonInterpreterServicer, add_PythonInterpreterServicer_to_server
4747
import grpc
4848

49+
MAX_GRPC_MESSAGE_SIZE = 2147483647
50+
4951

5052
def serialize_byte_tensor(input_tensor):
5153
"""
@@ -345,7 +347,13 @@ def watch_connections(address, event):
345347
if __name__ == "__main__":
346348
signal_received = False
347349
FLAGS = parse_startup_arguments()
348-
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1))
350+
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1),
351+
options=[
352+
('grpc.max_send_message_length',
353+
MAX_GRPC_MESSAGE_SIZE),
354+
('grpc.max_receive_message_length',
355+
MAX_GRPC_MESSAGE_SIZE),
356+
])
349357
channelz.add_channelz_servicer(server)
350358
# Create an Event to keep the GRPC server running
351359
event = threading.Event()

0 commit comments

Comments
 (0)