Skip to content

Commit b3ebd13

Browse files
sirtorrybusunkim96
authored andcommitted
fix(automl): pass params to underlying client (#9794)
1 parent b72f0f8 commit b3ebd13

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

packages/google-cloud-automl/google/cloud/automl_v1beta1/tables/tables_client.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2596,6 +2596,7 @@ def predict(
25962596
model=None,
25972597
model_name=None,
25982598
model_display_name=None,
2599+
params=None,
25992600
project=None,
26002601
region=None,
26012602
**kwargs
@@ -2642,6 +2643,9 @@ def predict(
26422643
The `model` instance you want to predict with . This must be
26432644
supplied if `model_display_name` or `model_name` are not
26442645
supplied.
2646+
params (dict[str, str]):
2647+
`feature_importance` can be set as True to enable local
2648+
explainability. The default is false.
26452649
26462650
Returns:
26472651
A :class:`~google.cloud.automl_v1beta1.types.PredictResponse`
@@ -2683,7 +2687,7 @@ def predict(
26832687

26842688
request = {"row": {"values": values}}
26852689

2686-
return self.prediction_client.predict(model.name, request, **kwargs)
2690+
return self.prediction_client.predict(model.name, request, params, **kwargs)
26872691

26882692
def batch_predict(
26892693
self,

packages/google-cloud-automl/tests/unit/gapic/v1beta1/test_tables_client_v1beta1.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1117,7 +1117,7 @@ def test_predict_from_array(self):
11171117
client = self.tables_client({"get_model.return_value": model}, {})
11181118
client.predict(["1"], model_name="my_model")
11191119
client.prediction_client.predict.assert_called_with(
1120-
"my_model", {"row": {"values": [{"string_value": "1"}]}}
1120+
"my_model", {"row": {"values": [{"string_value": "1"}]}}, None
11211121
)
11221122

11231123
def test_predict_from_dict(self):
@@ -1134,6 +1134,7 @@ def test_predict_from_dict(self):
11341134
client.prediction_client.predict.assert_called_with(
11351135
"my_model",
11361136
{"row": {"values": [{"string_value": "1"}, {"string_value": "2"}]}},
1137+
None,
11371138
)
11381139

11391140
def test_predict_from_dict_missing(self):
@@ -1148,7 +1149,9 @@ def test_predict_from_dict_missing(self):
11481149
client = self.tables_client({"get_model.return_value": model}, {})
11491150
client.predict({"a": "1"}, model_name="my_model")
11501151
client.prediction_client.predict.assert_called_with(
1151-
"my_model", {"row": {"values": [{"string_value": "1"}, {"null_value": 0}]}}
1152+
"my_model",
1153+
{"row": {"values": [{"string_value": "1"}, {"null_value": 0}]}},
1154+
None,
11521155
)
11531156

11541157
def test_predict_all_types(self):
@@ -1210,6 +1213,7 @@ def test_predict_all_types(self):
12101213
]
12111214
}
12121215
},
1216+
None,
12131217
)
12141218

12151219
def test_predict_from_array_missing(self):

0 commit comments

Comments
 (0)