@@ -40,14 +40,16 @@ namespace tensorflow {
4040static Status BuildLaunchNode (
4141 const string& nodename, const string& function_name,
4242 const AttrValueMap& function_attr, const string& device_name,
43- const DataTypeVector& constant_dtypes, const DataTypeVector& arg_dtypes,
44- const DataTypeVector& result_dtypes, Graph* graph, Node** node) {
43+ const DataTypeVector& constant_dtypes, int num_resources,
44+ const DataTypeVector& arg_dtypes, const DataTypeVector& result_dtypes,
45+ Graph* graph, Node** node) {
4546 NodeDef def;
4647 def.set_name (graph->NewName (nodename));
4748 def.set_op (" _XlaLaunch" );
4849 def.set_device (device_name);
4950 AddNodeAttr (" Tconstants" , constant_dtypes, &def);
5051 AddNodeAttr (" Targs" , arg_dtypes, &def);
52+ AddNodeAttr (" Nresources" , num_resources, &def);
5153 AddNodeAttr (" Tresults" , result_dtypes, &def);
5254 NameAttrList function;
5355 function.set_name (function_name);
@@ -62,25 +64,32 @@ static Status BuildLaunchNode(
6264static Status ReplaceNodeWithXlaLaunch (Graph* graph, Node* node) {
6365 VLOG (2 ) << " Replacing " << node->name () << " with XlaLaunch" ;
6466
65- int num_constant_args;
67+ int num_constant_args, num_resource_args ;
6668 TF_RETURN_IF_ERROR (
6769 GetNodeAttr (node->def (), kXlaNumConstantArgsAttr , &num_constant_args));
70+ TF_RETURN_IF_ERROR (
71+ GetNodeAttr (node->def (), kXlaNumResourceArgsAttr , &num_resource_args));
6872
69- if (num_constant_args < 0 || num_constant_args > node->input_types ().size ()) {
73+ if (num_constant_args < 0 || num_resource_args < 0 ||
74+ num_constant_args + num_resource_args > node->num_inputs ()) {
7075 return errors::InvalidArgument (
71- " Invalid number of constant arguments to XLA kernel" );
76+ " Invalid number of constant/resource arguments to XLA kernel. " );
7277 }
78+ const int num_nonconst_args =
79+ node->num_inputs () - num_constant_args - num_resource_args;
80+
7381 DataTypeVector const_dtypes (node->input_types ().begin (),
7482 node->input_types ().begin () + num_constant_args);
75- DataTypeVector arg_dtypes (node->input_types ().begin () + num_constant_args,
76- node->input_types ().end ());
83+ DataTypeVector arg_dtypes (
84+ node->input_types ().begin () + num_constant_args,
85+ node->input_types ().begin () + num_constant_args + num_nonconst_args);
7786
7887 // Build a _XlaLaunch operator to execute the function body.
7988 Node* launch_node;
80- TF_RETURN_IF_ERROR (
81- BuildLaunchNode ( graph->NewName (node->name ()), node->type_string (),
82- node->def ().attr (), node-> def (). device (), const_dtypes,
83- arg_dtypes, node->output_types (), graph, &launch_node));
89+ TF_RETURN_IF_ERROR (BuildLaunchNode (
90+ graph->NewName (node->name ()), node->type_string (), node-> def (). attr (),
91+ node->def ().device (), const_dtypes, num_resource_args, arg_dtypes ,
92+ node->output_types (), graph, &launch_node));
8493 launch_node->set_assigned_device_name (node->assigned_device_name ());
8594
8695 // Copy incoming edges to the launch node.
@@ -128,6 +137,11 @@ Status BuildXlaLaunchOpsPass::Run(const GraphOptimizationPassOptions& options) {
128137 TF_RETURN_IF_ERROR (ReplaceNodeWithXlaLaunch (graph, n));
129138 }
130139 }
140+
141+ if (VLOG_IS_ON (1 )) {
142+ dump_graph::DumpGraphToFile (" build_xla_launch_ops" , *graph,
143+ options.flib_def );
144+ }
131145 return Status::OK ();
132146}
133147
@@ -179,6 +193,7 @@ Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& ndef,
179193 launch_def.set_op (" _XlaLaunch" );
180194 launch_def.set_device (flr->device ()->name ());
181195 AddNodeAttr (" Tconstants" , DataTypeVector{}, &launch_def);
196+ AddNodeAttr (" Nresources" , 0 , &launch_def);
182197 AddNodeAttr (" Targs" , fbody->arg_types , &launch_def);
183198 AddNodeAttr (" Tresults" , fbody->ret_types , &launch_def);
184199 NameAttrList func;
0 commit comments