11import tensorflow as tf
22
33class DNNClassifier (tf .keras .Model ):
4- def __init__ (self , feature_columns , hidden_units = [10 ,10 ], n_classes = 2 ):
4+ def __init__ (self , feature_columns = None , hidden_units = [10 ,10 ], n_classes = 3 ):
55 """DNNClassifier
66 :param feature_columns: feature columns.
77 :type feature_columns: list[tf.feature_column].
@@ -11,16 +11,20 @@ def __init__(self, feature_columns, hidden_units=[10,10], n_classes=2):
1111 :type n_classes: int.
1212 """
1313 super (DNNClassifier , self ).__init__ ()
14-
15- # combines all the data as a dense tensor
16- self .feature_layer = tf .keras .layers .DenseFeatures (feature_columns )
14+ self .feature_layer = None
15+ if feature_columns is not None :
16+ # combines all the data as a dense tensor
17+ self .feature_layer = tf .keras .layers .DenseFeatures (feature_columns )
1718 self .hidden_layers = []
1819 for hidden_unit in hidden_units :
1920 self .hidden_layers .append (tf .keras .layers .Dense (hidden_unit ))
2021 self .prediction_layer = tf .keras .layers .Dense (n_classes , activation = 'softmax' )
2122
22- def call (self , inputs ):
23- x = self .feature_layer (inputs )
23+ def call (self , inputs , training = True ):
24+ if self .feature_layer is not None :
25+ x = self .feature_layer (inputs )
26+ else :
27+ x = tf .keras .layers .Flatten ()(inputs )
2428 for hidden_layer in self .hidden_layers :
2529 x = hidden_layer (x )
2630 return self .prediction_layer (x )
@@ -29,10 +33,74 @@ def optimizer(learning_rate=0.1):
2933 """Default optimizer name. Used in model.compile."""
3034 return tf .keras .optimizers .Adagrad (lr = learning_rate )
3135
32- def loss ():
36+ def loss (output , labels ):
3337 """Default loss function. Used in model.compile."""
34- return 'sparse_categorical_crossentropy'
38+ # return 'sparse_categorical_crossentropy'
39+ return tf .reduce_mean (
40+ tf .keras .losses .sparse_categorical_crossentropy (labels , output ))
41+
42+ # FIXME(typhoonzero): use the name loss once ElasticDL has updated.
43+ def loss_new (y_true , y_pred ):
44+ return tf .reduce_mean (
45+ tf .keras .losses .sparse_categorical_crossentropy (y_true , y_pred ))
3546
3647def prepare_prediction_column (prediction ):
3748 """Return the class label of highest probability."""
38- return prediction .argmax (axis = - 1 )
49+ return prediction .argmax (axis = - 1 )
50+
51+ def eval_metrics_fn ():
52+ return {
53+ "accuracy" : lambda labels , predictions : tf .equal (
54+ tf .argmax (predictions , 1 , output_type = tf .int32 ),
55+ tf .cast (tf .reshape (labels , [- 1 ]), tf .int32 ),
56+ )
57+ }
58+
59+ # dataset_fn is only used to test using this model in ElasticDL.
60+ # TODO(typhoonzero): remove dataset_fn once https://github.com/sql-machine-learning/elasticdl/issues/1482 is done.
61+ def dataset_fn (dataset , mode , metadata ):
62+ from elasticdl .python .common .constants import Mode
63+ def _parse_data (record ):
64+ label_col_name = "class"
65+ record = tf .strings .to_number (record , tf .float32 )
66+
67+ def _get_features_without_labels (
68+ record , label_col_ind , features_shape
69+ ):
70+ features = [
71+ record [:label_col_ind ],
72+ record [label_col_ind + 1 :], # noqa: E203
73+ ]
74+ features = tf .concat (features , - 1 )
75+ return tf .reshape (features , features_shape )
76+
77+ features_shape = (4 , 1 )
78+ labels_shape = (1 ,)
79+ if mode != Mode .PREDICTION :
80+ if label_col_name not in metadata .column_names :
81+ raise ValueError (
82+ "Missing the label column '%s' in the retrieved "
83+ "ODPS table." % label_col_name
84+ )
85+ label_col_ind = metadata .column_names .index (label_col_name )
86+ labels = tf .reshape (record [label_col_ind ], labels_shape )
87+ return (
88+ _get_features_without_labels (
89+ record , label_col_ind , features_shape
90+ ),
91+ labels ,
92+ )
93+ else :
94+ if label_col_name in metadata .column_names :
95+ label_col_ind = metadata .column_names .index (label_col_name )
96+ return _get_features_without_labels (
97+ record , label_col_ind , features_shape
98+ )
99+ else :
100+ return tf .reshape (record , features_shape )
101+
102+ dataset = dataset .map (_parse_data )
103+
104+ if mode == Mode .TRAINING :
105+ dataset = dataset .shuffle (buffer_size = 200 )
106+ return dataset
0 commit comments