1515# limitations under the License.
1616#
1717
18- from typing import Optional , Dict
18+ from typing import Optional , Dict , Union
1919
2020from google .cloud .aiplatform_v1 .types import (
2121 io as gca_io_v1 ,
3939class _SkewDetectionConfig :
4040 def __init__ (
4141 self ,
42- data_source : str ,
43- skew_thresholds : Dict [str , float ],
44- target_field : str ,
45- attribute_skew_thresholds : Dict [str , float ],
42+ data_source : Optional [ str ] = None ,
43+ skew_thresholds : Union [ Dict [str , float ], float , None ] = None ,
44+ target_field : Optional [ str ] = None ,
45+ attribute_skew_thresholds : Optional [ Dict [str , float ]] = None ,
4646 data_format : Optional [str ] = None ,
4747 ):
4848 """Base class for training-serving skew detection.
4949 Args:
5050 data_source (str):
51- Required . Path to training dataset.
51+ Optional . Path to training dataset.
5252
53- skew_thresholds ( Dict[str, float]) :
53+ skew_thresholds: Union[ Dict[str, float], float, None] :
5454 Optional. Key is the feature name and value is the
5555 threshold. If a feature needs to be monitored
5656 for skew, a value threshold must be configured
5757 for that feature. The threshold here is against
5858 feature distribution distance between the
59- training and prediction feature.
59+ training and prediction feature. If a float is passed,
60+ then all features will be monitored using the same
61+ threshold. If None is passed, all feature will be monitored
62+ using alert threshold 0.3 (Backend default).
6063
6164 target_field (str):
62- Required . The target field name the model is to
65+ Optional . The target field name the model is to
6366 predict. This field will be excluded when doing
6467 Predict and (or) Explain for the training data.
6568
@@ -93,12 +96,18 @@ def as_proto(self):
9396 """Returns _SkewDetectionConfig as a proto message."""
9497 skew_thresholds_mapping = {}
9598 attribution_score_skew_thresholds_mapping = {}
99+ default_skew_threshold = None
96100 if self .skew_thresholds is not None :
97- for key in self .skew_thresholds . keys ( ):
98- skew_threshold = gca_model_monitoring .ThresholdConfig (
99- value = self .skew_thresholds [ key ]
101+ if isinstance ( self .skew_thresholds , float ):
102+ default_skew_threshold = gca_model_monitoring .ThresholdConfig (
103+ value = self .skew_thresholds
100104 )
101- skew_thresholds_mapping [key ] = skew_threshold
105+ else :
106+ for key in self .skew_thresholds .keys ():
107+ skew_threshold = gca_model_monitoring .ThresholdConfig (
108+ value = self .skew_thresholds [key ]
109+ )
110+ skew_thresholds_mapping [key ] = skew_threshold
102111 if self .attribute_skew_thresholds is not None :
103112 for key in self .attribute_skew_thresholds .keys ():
104113 attribution_score_skew_threshold = gca_model_monitoring .ThresholdConfig (
@@ -110,6 +119,7 @@ def as_proto(self):
110119 return gca_model_monitoring .ModelMonitoringObjectiveConfig .TrainingPredictionSkewDetectionConfig (
111120 skew_thresholds = skew_thresholds_mapping ,
112121 attribution_score_skew_thresholds = attribution_score_skew_thresholds_mapping ,
122+ default_skew_threshold = default_skew_threshold ,
113123 )
114124
115125
@@ -266,30 +276,33 @@ class SkewDetectionConfig(_SkewDetectionConfig):
266276
267277 def __init__ (
268278 self ,
269- data_source : str ,
270- target_field : str ,
271- skew_thresholds : Optional [Dict [str , float ]] = None ,
279+ data_source : Optional [ str ] = None ,
280+ target_field : Optional [ str ] = None ,
281+ skew_thresholds : Union [Dict [str , float ], float , None ] = None ,
272282 attribute_skew_thresholds : Optional [Dict [str , float ]] = None ,
273283 data_format : Optional [str ] = None ,
274284 ):
275285 """Initializer for SkewDetectionConfig.
276286
277287 Args:
278288 data_source (str):
279- Required . Path to training dataset.
289+ Optional . Path to training dataset.
280290
281291 target_field (str):
282- Required . The target field name the model is to
292+ Optional . The target field name the model is to
283293 predict. This field will be excluded when doing
284294 Predict and (or) Explain for the training data.
285295
286- skew_thresholds ( Dict[str, float]) :
296+ skew_thresholds: Union[ Dict[str, float], float, None] :
287297 Optional. Key is the feature name and value is the
288298 threshold. If a feature needs to be monitored
289299 for skew, a value threshold must be configured
290300 for that feature. The threshold here is against
291301 feature distribution distance between the
292- training and prediction feature.
302+ training and prediction feature. If a float is passed,
303+ then all features will be monitored using the same
304+ threshold. If None is passed, all feature will be monitored
305+ using alert threshold 0.3 (Backend default).
293306
294307 attribute_skew_thresholds (Dict[str, float]):
295308 Optional. Key is the feature name and value is the
@@ -315,11 +328,11 @@ def __init__(
315328 ValueError for unsupported data formats.
316329 """
317330 super ().__init__ (
318- data_source ,
319- skew_thresholds ,
320- target_field ,
321- attribute_skew_thresholds ,
322- data_format ,
331+ data_source = data_source ,
332+ skew_thresholds = skew_thresholds ,
333+ target_field = target_field ,
334+ attribute_skew_thresholds = attribute_skew_thresholds ,
335+ data_format = data_format ,
323336 )
324337
325338
0 commit comments