@@ -488,9 +488,21 @@ Stub::Execute(ExecuteArgs* execute_args, ResponseBatch* response_batch)
488488void
489489Stub::Initialize (InitializeArgs* initialize_args)
490490{
491- py::module sys = py::module ::import (" sys" );
491+ py::module sys = py::module_ ::import (" sys" );
492492
493- std::string model_name = model_path_.substr (model_path_.find_last_of (" /" ) + 1 );
493+ std::string model_name =
494+ model_path_.substr (model_path_.find_last_of (" /" ) + 1 );
495+
496+ // Model name without the .py extension
497+ auto dotpy_pos = model_name.find_last_of (" .py" );
498+ if (dotpy_pos == std::string::npos || dotpy_pos != model_name.size () - 1 ) {
499+ throw PythonBackendException (
500+ " Model name must end with '.py'. Model name is \" " + model_name + " \" ." );
501+ }
502+
503+ // The position of last character of the string that is searched for is
504+ // returned by 'find_last_of'. Need to manually adjust the position.
505+ std::string model_name_trimmed = model_name.substr (0 , dotpy_pos - 2 );
494506 std::string model_path_parent =
495507 model_path_.substr (0 , model_path_.find_last_of (" /" ));
496508 std::string model_path_parent_parent =
@@ -501,9 +513,9 @@ Stub::Initialize(InitializeArgs* initialize_args)
501513 sys.attr (" path" ).attr (" append" )(python_backend_folder);
502514
503515 py::module python_backend_utils =
504- py::module ::import (" triton_python_backend_utils" );
516+ py::module_ ::import (" triton_python_backend_utils" );
505517 py::module c_python_backend_utils =
506- py::module ::import (" c_python_backend_utils" );
518+ py::module_ ::import (" c_python_backend_utils" );
507519 py::setattr (
508520 python_backend_utils, " Tensor" , c_python_backend_utils.attr (" Tensor" ));
509521 py::setattr (
@@ -520,7 +532,8 @@ Stub::Initialize(InitializeArgs* initialize_args)
520532 c_python_backend_utils.attr (" TritonModelException" ));
521533
522534 py::object TritonPythonModel =
523- py::module::import ((model_version_ + std::string (" .model" )).c_str ())
535+ py::module_::import (
536+ (std::string (model_version_) + " ." + model_name_trimmed).c_str ())
524537 .attr (" TritonPythonModel" );
525538 deserialize_bytes_ = python_backend_utils.attr (" deserialize_bytes_tensor" );
526539 serialize_bytes_ = python_backend_utils.attr (" serialize_byte_tensor" );
0 commit comments