Skip to content

Commit 542c3cb

Browse files
hawkinsptensorflower-gardener
authored andcommitted
[TF:XLA] Add support for resource variables to the Tensorflow/XLA bridge.
Change: 148176223
1 parent c061d6c commit 542c3cb

30 files changed

Lines changed: 956 additions & 183 deletions

tensorflow/compiler/aot/compile.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,8 +263,8 @@ Status CreateXlaArgs(const Graph& graph,
263263
TF_RETURN_IF_ERROR(CollectArgNodes(graph, &arg_nodes));
264264
for (const Node* node : arg_nodes) {
265265
XlaCompiler::Argument arg;
266+
arg.kind = XlaCompiler::Argument::kParameter;
266267
TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "T", &arg.type));
267-
TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "index", &arg.parameter));
268268
TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), kShapeAttr, &arg.shape));
269269
TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), kDebugNameAttr, &arg.name));
270270
xla_args->push_back(arg);

tensorflow/compiler/jit/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ cc_library(
162162
"//tensorflow/core:lib",
163163
"//tensorflow/core:lib_internal",
164164
"//tensorflow/core:protos_all_cc",
165+
"//tensorflow/core/kernels:variable_ops",
165166
],
166167
)
167168

tensorflow/compiler/jit/build_xla_launch_ops_pass.cc

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,16 @@ namespace tensorflow {
4040
static Status BuildLaunchNode(
4141
const string& nodename, const string& function_name,
4242
const AttrValueMap& function_attr, const string& device_name,
43-
const DataTypeVector& constant_dtypes, const DataTypeVector& arg_dtypes,
44-
const DataTypeVector& result_dtypes, Graph* graph, Node** node) {
43+
const DataTypeVector& constant_dtypes, int num_resources,
44+
const DataTypeVector& arg_dtypes, const DataTypeVector& result_dtypes,
45+
Graph* graph, Node** node) {
4546
NodeDef def;
4647
def.set_name(graph->NewName(nodename));
4748
def.set_op("_XlaLaunch");
4849
def.set_device(device_name);
4950
AddNodeAttr("Tconstants", constant_dtypes, &def);
5051
AddNodeAttr("Targs", arg_dtypes, &def);
52+
AddNodeAttr("Nresources", num_resources, &def);
5153
AddNodeAttr("Tresults", result_dtypes, &def);
5254
NameAttrList function;
5355
function.set_name(function_name);
@@ -62,25 +64,32 @@ static Status BuildLaunchNode(
6264
static Status ReplaceNodeWithXlaLaunch(Graph* graph, Node* node) {
6365
VLOG(2) << "Replacing " << node->name() << " with XlaLaunch";
6466

65-
int num_constant_args;
67+
int num_constant_args, num_resource_args;
6668
TF_RETURN_IF_ERROR(
6769
GetNodeAttr(node->def(), kXlaNumConstantArgsAttr, &num_constant_args));
70+
TF_RETURN_IF_ERROR(
71+
GetNodeAttr(node->def(), kXlaNumResourceArgsAttr, &num_resource_args));
6872

69-
if (num_constant_args < 0 || num_constant_args > node->input_types().size()) {
73+
if (num_constant_args < 0 || num_resource_args < 0 ||
74+
num_constant_args + num_resource_args > node->num_inputs()) {
7075
return errors::InvalidArgument(
71-
"Invalid number of constant arguments to XLA kernel");
76+
"Invalid number of constant/resource arguments to XLA kernel.");
7277
}
78+
const int num_nonconst_args =
79+
node->num_inputs() - num_constant_args - num_resource_args;
80+
7381
DataTypeVector const_dtypes(node->input_types().begin(),
7482
node->input_types().begin() + num_constant_args);
75-
DataTypeVector arg_dtypes(node->input_types().begin() + num_constant_args,
76-
node->input_types().end());
83+
DataTypeVector arg_dtypes(
84+
node->input_types().begin() + num_constant_args,
85+
node->input_types().begin() + num_constant_args + num_nonconst_args);
7786

7887
// Build a _XlaLaunch operator to execute the function body.
7988
Node* launch_node;
80-
TF_RETURN_IF_ERROR(
81-
BuildLaunchNode(graph->NewName(node->name()), node->type_string(),
82-
node->def().attr(), node->def().device(), const_dtypes,
83-
arg_dtypes, node->output_types(), graph, &launch_node));
89+
TF_RETURN_IF_ERROR(BuildLaunchNode(
90+
graph->NewName(node->name()), node->type_string(), node->def().attr(),
91+
node->def().device(), const_dtypes, num_resource_args, arg_dtypes,
92+
node->output_types(), graph, &launch_node));
8493
launch_node->set_assigned_device_name(node->assigned_device_name());
8594

8695
// Copy incoming edges to the launch node.
@@ -128,6 +137,11 @@ Status BuildXlaLaunchOpsPass::Run(const GraphOptimizationPassOptions& options) {
128137
TF_RETURN_IF_ERROR(ReplaceNodeWithXlaLaunch(graph, n));
129138
}
130139
}
140+
141+
if (VLOG_IS_ON(1)) {
142+
dump_graph::DumpGraphToFile("build_xla_launch_ops", *graph,
143+
options.flib_def);
144+
}
131145
return Status::OK();
132146
}
133147

@@ -179,6 +193,7 @@ Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& ndef,
179193
launch_def.set_op("_XlaLaunch");
180194
launch_def.set_device(flr->device()->name());
181195
AddNodeAttr("Tconstants", DataTypeVector{}, &launch_def);
196+
AddNodeAttr("Nresources", 0, &launch_def);
182197
AddNodeAttr("Targs", fbody->arg_types, &launch_def);
183198
AddNodeAttr("Tresults", fbody->ret_types, &launch_def);
184199
NameAttrList func;

tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ namespace tensorflow {
4646

4747
const char* const kXlaCompiledKernelAttr = "_XlaCompiledKernel";
4848
const char* const kXlaNumConstantArgsAttr = "_XlaNumConstantArgs";
49+
const char* const kXlaNumResourceArgsAttr = "_XlaNumResourceArgs";
4950

5051
namespace {
5152

@@ -563,6 +564,21 @@ Status EncapsulateSubgraphsInFunctions(
563564
return s;
564565
}
565566

567+
// Finds the types of the _Arg nodes, indexed by position.
568+
static Status GetArgTypes(const Graph& graph, DataTypeVector* types) {
569+
for (Node* n : graph.nodes()) {
570+
if (n->type_string() == kArgOp) {
571+
int index;
572+
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index));
573+
if (index < 0 || index >= types->size()) {
574+
return errors::InvalidArgument("Invalid argument number");
575+
}
576+
(*types)[index] = n->output_type(0);
577+
}
578+
}
579+
return Status::OK();
580+
}
581+
566582
// Renumber the indices of _Arg nodes in a graph, according to
567583
// 'permutation' that maps old indices to new indices.
568584
static Status RenumberArguments(Graph* graph,
@@ -604,19 +620,40 @@ Status EncapsulateSubgraphsPass::Run(
604620
// Optimize the subgraph.
605621
OptimizeGraph(flr.get(), subgraph);
606622

607-
std::vector<bool> const_args(input_permutation->size());
623+
const int num_args = input_permutation->size();
624+
std::vector<bool> const_args(num_args);
608625
TF_RETURN_IF_ERROR(BackwardsConstAnalysis(**subgraph, &const_args));
609626

627+
DataTypeVector arg_types(num_args);
628+
TF_RETURN_IF_ERROR(GetArgTypes(**subgraph, &arg_types));
629+
610630
// Compute a permutation of the arguments such that the constant arguments
611631
// are first.
612632
const int num_consts =
613633
std::count(const_args.begin(), const_args.end(), true);
634+
635+
const int num_resources =
636+
std::count(arg_types.begin(), arg_types.end(), DT_RESOURCE);
637+
const int num_nonconsts = num_args - num_resources - num_consts;
638+
if (num_nonconsts < 0) {
639+
return errors::Internal("num_nonconsts should be >= 0, was ",
640+
num_nonconsts);
641+
}
642+
614643
int const_pos = 0;
615644
int arg_pos = num_consts;
616-
for (int i = 0; i < const_args.size(); ++i) {
645+
int resource_pos = num_consts + num_nonconsts;
646+
for (int i = 0; i < num_args; ++i) {
617647
if (const_args[i]) {
648+
if (arg_types[i] == DT_RESOURCE) {
649+
return errors::Internal(
650+
"Resource arguments cannot be constant (argument ", i, ")");
651+
}
618652
(*input_permutation)[i] = const_pos;
619653
++const_pos;
654+
} else if (arg_types[i] == DT_RESOURCE) {
655+
(*input_permutation)[i] = resource_pos;
656+
++resource_pos;
620657
} else {
621658
(*input_permutation)[i] = arg_pos;
622659
++arg_pos;
@@ -631,6 +668,7 @@ Status EncapsulateSubgraphsPass::Run(
631668

632669
AddNodeAttr(kXlaCompiledKernelAttr, true, node);
633670
AddNodeAttr(kXlaNumConstantArgsAttr, num_consts, node);
671+
AddNodeAttr(kXlaNumResourceArgsAttr, num_resources, node);
634672
return Status::OK();
635673
};
636674

tensorflow/compiler/jit/encapsulate_subgraphs_pass.h

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,22 @@ extern const char* const kXlaCompiledKernelAttr;
7070
// Does `node` have the kXlaCompiledKernelAttr attribute?
7171
bool IsXlaCompiledKernel(const Node& node);
7272

73-
// Functions produce by the EncapsulateSubgraphs pass have their arguments
74-
// ordered such that compile-time constant arguments are first in the argument
75-
// order. The functions are annotated with the following attribute giving the
76-
// number of constant arguments.
73+
// Functions produced by the EncapsulateSubgraphs pass have their arguments in
74+
// the order:
75+
// 1) compile-time constant arguments, in host memory,
76+
// 2) other arguments, in device memory.
77+
// 3) resource variable arguments, in host memory. Note that only the resource
78+
// Tensor itself is in host memory; the underlying value may be in device
79+
// memory.
80+
// The functions are annotated with the following attributes that describe how
81+
// many constant and resource arguments there are:
82+
83+
// Name of the attribute containing the number of constant arguments.
7784
extern const char* const kXlaNumConstantArgsAttr;
7885

86+
// Name of the attribute containing the number of resource variable arguments.
87+
extern const char* const kXlaNumResourceArgsAttr;
88+
7989
class EncapsulateSubgraphsPass : public GraphOptimizationPass {
8090
public:
8191
Status Run(const GraphOptimizationPassOptions& options) override;

tensorflow/compiler/jit/mark_for_compilation_pass.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,12 @@ Status DeviceTypeOfDevice(const string& device, DeviceType* device_type) {
147147
return Status::OK();
148148
}
149149

150+
// Does `node` have a DT_RESOURCE typed argument?
151+
bool HasResourceArgument(const Node& node) {
152+
return std::find(node.input_types().begin(), node.input_types().end(),
153+
DT_RESOURCE) != node.input_types().end();
154+
}
155+
150156
Status FindCompilationCandidates(
151157
const Graph& graph, FunctionLibraryDefinition* flib_def, Env* env,
152158
const std::function<bool(const Node*, const DeviceType&)>& is_compilable_fn,
@@ -174,6 +180,11 @@ Status FindCompilationCandidates(
174180
<< ": " << node->def().op();
175181
continue;
176182
}
183+
if (!registration->compile_resource_ops && HasResourceArgument(*node)) {
184+
VLOG(2) << "Compilation rejected node: resource argument " << node->name()
185+
<< ": " << node->def().op();
186+
continue;
187+
}
177188
if (node->def().op() == "While" &&
178189
!IsCompilableWhile(node->def(), jit_device_type, 0,
179190
lib_runtime.get())) {

0 commit comments

Comments
 (0)