2323from google .api_core import exceptions
2424from google .auth .credentials import AnonymousCredentials
2525from google .cloud import automl_v1beta1
26- from google .cloud .automl_v1beta1 .proto import data_types_pb2
26+ from google .cloud .automl_v1beta1 .proto import data_types_pb2 , data_items_pb2
27+ from google .protobuf import struct_pb2
2728
2829PROJECT = "project"
2930REGION = "region"
@@ -1116,9 +1117,10 @@ def test_predict_from_array(self):
11161117 model .configure_mock (tables_model_metadata = model_metadata , name = "my_model" )
11171118 client = self .tables_client ({"get_model.return_value" : model }, {})
11181119 client .predict (["1" ], model_name = "my_model" )
1119- client . prediction_client . predict . assert_called_with (
1120- "my_model" , { " row" : { " values" : [{ " string_value" : "1" }]}}, None
1120+ payload = data_items_pb2 . ExamplePayload (
1121+ row = data_items_pb2 . Row ( values = [ struct_pb2 . Value ( string_value = "1" )])
11211122 )
1123+ client .prediction_client .predict .assert_called_with ("my_model" , payload , None )
11221124
11231125 def test_predict_from_dict (self ):
11241126 data_type = mock .Mock (type_code = data_types_pb2 .CATEGORY )
@@ -1131,10 +1133,16 @@ def test_predict_from_dict(self):
11311133 model .configure_mock (tables_model_metadata = model_metadata , name = "my_model" )
11321134 client = self .tables_client ({"get_model.return_value" : model }, {})
11331135 client .predict ({"a" : "1" , "b" : "2" }, model_name = "my_model" )
1136+ payload = data_items_pb2 .ExamplePayload (
1137+ row = data_items_pb2 .Row (
1138+ values = [
1139+ struct_pb2 .Value (string_value = "1" ),
1140+ struct_pb2 .Value (string_value = "2" ),
1141+ ]
1142+ )
1143+ )
11341144 client .prediction_client .predict .assert_called_with (
1135- "my_model" ,
1136- {"row" : {"values" : [{"string_value" : "1" }, {"string_value" : "2" }]}},
1137- None ,
1145+ "my_model" , payload , None ,
11381146 )
11391147
11401148 def test_predict_from_dict_with_feature_importance (self ):
@@ -1150,10 +1158,16 @@ def test_predict_from_dict_with_feature_importance(self):
11501158 client .predict (
11511159 {"a" : "1" , "b" : "2" }, model_name = "my_model" , feature_importance = True
11521160 )
1161+ payload = data_items_pb2 .ExamplePayload (
1162+ row = data_items_pb2 .Row (
1163+ values = [
1164+ struct_pb2 .Value (string_value = "1" ),
1165+ struct_pb2 .Value (string_value = "2" ),
1166+ ]
1167+ )
1168+ )
11531169 client .prediction_client .predict .assert_called_with (
1154- "my_model" ,
1155- {"row" : {"values" : [{"string_value" : "1" }, {"string_value" : "2" }]}},
1156- {"feature_importance" : "true" },
1170+ "my_model" , payload , {"feature_importance" : "true" },
11571171 )
11581172
11591173 def test_predict_from_dict_missing (self ):
@@ -1167,18 +1181,31 @@ def test_predict_from_dict_missing(self):
11671181 model .configure_mock (tables_model_metadata = model_metadata , name = "my_model" )
11681182 client = self .tables_client ({"get_model.return_value" : model }, {})
11691183 client .predict ({"a" : "1" }, model_name = "my_model" )
1184+ payload = data_items_pb2 .ExamplePayload (
1185+ row = data_items_pb2 .Row (
1186+ values = [struct_pb2 .Value (string_value = "1" ), struct_pb2 .Value (null_value = struct_pb2 .NullValue .NULL_VALUE )]
1187+ )
1188+ )
11701189 client .prediction_client .predict .assert_called_with (
1171- "my_model" ,
1172- {"row" : {"values" : [{"string_value" : "1" }, {"null_value" : 0 }]}},
1173- None ,
1190+ "my_model" , payload , None ,
11741191 )
11751192
11761193 def test_predict_all_types (self ):
11771194 float_type = mock .Mock (type_code = data_types_pb2 .FLOAT64 )
11781195 timestamp_type = mock .Mock (type_code = data_types_pb2 .TIMESTAMP )
11791196 string_type = mock .Mock (type_code = data_types_pb2 .STRING )
1180- array_type = mock .Mock (type_code = data_types_pb2 .ARRAY )
1181- struct_type = mock .Mock (type_code = data_types_pb2 .STRUCT )
1197+ array_type = mock .Mock (
1198+ type_code = data_types_pb2 .ARRAY ,
1199+ list_element_type = mock .Mock (type_code = data_types_pb2 .FLOAT64 ),
1200+ )
1201+ struct = data_types_pb2 .StructType ()
1202+ struct .fields ["a" ].CopyFrom (
1203+ data_types_pb2 .DataType (type_code = data_types_pb2 .CATEGORY )
1204+ )
1205+ struct .fields ["b" ].CopyFrom (
1206+ data_types_pb2 .DataType (type_code = data_types_pb2 .CATEGORY )
1207+ )
1208+ struct_type = mock .Mock (type_code = data_types_pb2 .STRUCT , struct_type = struct )
11821209 category_type = mock .Mock (type_code = data_types_pb2 .CATEGORY )
11831210 column_spec_float = mock .Mock (display_name = "float" , data_type = float_type )
11841211 column_spec_timestamp = mock .Mock (
@@ -1211,28 +1238,34 @@ def test_predict_all_types(self):
12111238 "timestamp" : "EST" ,
12121239 "string" : "text" ,
12131240 "array" : [1 ],
1214- "struct" : {"a" : "b " },
1241+ "struct" : {"a" : "label_a" , "b" : "label_b " },
12151242 "category" : "a" ,
12161243 "null" : None ,
12171244 },
12181245 model_name = "my_model" ,
12191246 )
1247+ struct = struct_pb2 .Struct ()
1248+ struct .fields ["a" ].CopyFrom (struct_pb2 .Value (string_value = "label_a" ))
1249+ struct .fields ["b" ].CopyFrom (struct_pb2 .Value (string_value = "label_b" ))
1250+ payload = data_items_pb2 .ExamplePayload (
1251+ row = data_items_pb2 .Row (
1252+ values = [
1253+ struct_pb2 .Value (number_value = 1.0 ),
1254+ struct_pb2 .Value (string_value = "EST" ),
1255+ struct_pb2 .Value (string_value = "text" ),
1256+ struct_pb2 .Value (
1257+ list_value = struct_pb2 .ListValue (
1258+ values = [struct_pb2 .Value (number_value = 1.0 )]
1259+ )
1260+ ),
1261+ struct_pb2 .Value (struct_value = struct ),
1262+ struct_pb2 .Value (string_value = "a" ),
1263+ struct_pb2 .Value (null_value = struct_pb2 .NullValue .NULL_VALUE ),
1264+ ]
1265+ )
1266+ )
12201267 client .prediction_client .predict .assert_called_with (
1221- "my_model" ,
1222- {
1223- "row" : {
1224- "values" : [
1225- {"number_value" : 1.0 },
1226- {"string_value" : "EST" },
1227- {"string_value" : "text" },
1228- {"list_value" : [1 ]},
1229- {"struct_value" : {"a" : "b" }},
1230- {"string_value" : "a" },
1231- {"null_value" : 0 },
1232- ]
1233- }
1234- },
1235- None ,
1268+ "my_model" , payload , None ,
12361269 )
12371270
12381271 def test_predict_from_array_missing (self ):
0 commit comments