@@ -47,17 +47,11 @@ public unsafe Operation _apply_op_helper(string op_type_name, string name = "",
4747 }
4848
4949 var attrs = new Dictionary < string , object > ( ) ;
50+
51+ // Perform input type inference
5052 var inputs = new List < Tensor > ( ) ;
5153 var input_types = new List < DataType > ( ) ;
52-
53- foreach ( var attr in op_def . Attr )
54- {
55- if ( keywords . ContainsKey ( attr . Name ) )
56- {
57- attrs [ attr . Name ] = keywords [ attr . Name ] ;
58- }
59- }
60-
54+
6155 foreach ( var input_arg in op_def . InputArg )
6256 {
6357 var input_name = input_arg . Name ;
@@ -70,18 +64,38 @@ public unsafe Operation _apply_op_helper(string op_type_name, string name = "",
7064 {
7165 attrs [ input_arg . TypeAttr ] = DataType . DtFloat ;
7266 }
67+
68+ if ( input_arg . IsRef )
69+ {
70+
71+ }
72+ else
73+ {
74+ input_types . Add ( ( keywords [ input_name ] as Tensor ) . dtype ) ;
75+ }
7376 }
7477
78+ // Process remaining attrs
79+ foreach ( var attr in op_def . Attr )
80+ {
81+ if ( keywords . ContainsKey ( attr . Name ) )
82+ {
83+ attrs [ attr . Name ] = keywords [ attr . Name ] ;
84+ }
85+ }
86+
87+ // Convert attr values to AttrValue protos.
7588 var attr_protos = new Dictionary < string , AttrValue > ( ) ;
7689 foreach ( var attr_def in op_def . Attr )
7790 {
7891 var key = attr_def . Name ;
92+ var value = attrs [ key ] ;
7993 var attr_value = new AttrValue ( ) ;
8094
8195 switch ( attr_def . Type )
8296 {
8397 case "type" :
84- attr_value . Type = ( DataType ) keywords [ "dtype" ] ;
98+ attr_value . Type = _MakeType ( value , attr_def ) ;
8599 break ;
86100 case "shape" :
87101 attr_value . Shape = new TensorShapeProto ( ) ;
@@ -91,6 +105,7 @@ public unsafe Operation _apply_op_helper(string op_type_name, string name = "",
91105 attr_protos [ key ] = attr_value ;
92106 }
93107
108+ // Determine output types (possibly using attrs)
94109 var output_types = new List < DataType > ( ) ;
95110
96111 foreach ( var arg in op_def . OutputArg )
@@ -105,6 +120,7 @@ public unsafe Operation _apply_op_helper(string op_type_name, string name = "",
105120 }
106121 }
107122
123+ // Add Op to graph
108124 var op = g . create_op ( op_type_name , inputs , output_types . ToArray ( ) ,
109125 name : scope ,
110126 input_types : input_types . ToArray ( ) ,
@@ -113,5 +129,10 @@ public unsafe Operation _apply_op_helper(string op_type_name, string name = "",
113129
114130 return op ;
115131 }
132+
133+ public DataType _MakeType ( Object v , AttrDef attr_def )
134+ {
135+ return DataType . DtFloat ;
136+ }
116137 }
117138}
0 commit comments