Skip to content

Commit 446adc0

Browse files
committed
More WIP for DL R script.
1 parent b3498a8 commit 446adc0

1 file changed

Lines changed: 49 additions & 18 deletions

File tree

tutorials/deeplearning/deeplearning.R

Lines changed: 49 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -41,21 +41,19 @@ response <- "Cover_Type"
4141
predictors <- setdiff(names(df), response)
4242

4343
## ~17% 17s
44-
m <- h2o.deeplearning(model_id="dl_model_defaults", training_frame = train, validation_frame = valid,
45-
x=predictors, y=response, variable_importances=T, epochs=1)
44+
m <- h2o.deeplearning(model_id="dl_model_defaults",
45+
training_frame = train,
46+
validation_frame = valid,
47+
x=predictors, y=response,
48+
variable_importances=T,
49+
epochs=1)
4650
m
4751
summary(m)
4852
h2o.varimp(m)
4953

50-
51-
## smaller network, train longer ~15% 20s
52-
m <- h2o.deeplearning(model_id="dl_model_faster", training_frame = train, validation_frame = valid,
53-
x=predictors, y=response,
54-
hidden=c(32,32,32), epochs=20)
55-
## show convergence
56-
plot(m)
57-
58-
## early stopping as soon as misclassification doesn't improve by at least 1%
54+
## smaller network, run until convergence
55+
## (stop if misclassification on 10k validation rows does not improve by at least 1%)
56+
## ~15% in 30s
5957
m <- h2o.deeplearning(
6058
model_id="dl_model_faster",
6159
training_frame = train,
@@ -64,36 +62,69 @@ m <- h2o.deeplearning(
6462
y=response,
6563
hidden=c(32,32,32),
6664
epochs=1000000,
65+
score_validation_samples = 10000,
6766
stopping_rounds=1,
6867
stopping_metric="misclassification",
6968
stopping_tolerance=0.01
7069
)
7170
summary(m)
7271

73-
## with some tuning: ~6% in 160s
72+
## show convergence
73+
plot(m)
74+
75+
## with some tuning: ~8% in 40s
7476
m <- h2o.deeplearning(
7577
model_id="dl_model_tuned",
7678
training_frame = train,
7779
validation_frame = valid,
7880
x=predictors,
7981
y=response,
82+
overwrite_with_best_model=F,
83+
hidden=c(100,100,100), ## more hidden layers -> more complex interactions
84+
epochs=10, ## long enough to converge
85+
score_validation_samples=10000, ## downsample validation set for faster scoring
86+
score_duty_cycle=0.025, ## don't score more than 2.5% of the wall time
87+
adaptive_rate=F, ## manually tuned learning rate
88+
rate=0.02,
89+
rate_annealing=2e-6,
90+
momentum_start = 0.2, ## manually tuned momentum
91+
momentum_stable = 0.4,
92+
momentum_ramp = 1e7,
93+
l2=1e-5, ## add some L2 regularization
94+
max_w2 = 10 ## helps stability for Rectifier
95+
)
96+
97+
## Optional - continue training the previous model
98+
if (FALSE) {
99+
max_epochs <- 1000 ##Takes a few minutes
100+
} else {
101+
max_epochs <- 20 ##Takes about 30s
102+
}
103+
m <- h2o.deeplearning(
104+
model_id="dl_model_tuned_continued",
105+
checkpoint="dl_model_tuned",
106+
training_frame = train,
107+
validation_frame = valid,
108+
x=predictors,
109+
y=response,
80110
hidden=c(100,100,100), ## more hidden layers -> more complex interactions
81-
epochs=100, ## long enough to converge
111+
epochs=max_epochs, ## hopefully long enough to converge (otherwise restart again)
82112
stopping_metric="logloss",
83-
stopping_tolerance=1e-2, ## stop when logloss does not improve by >=1% for 2 scoring events
113+
stopping_tolerance=1e-2, ## stop when validation logloss does not improve by >=1% for 2 scoring events
84114
stopping_rounds=2,
85115
score_validation_samples=10000, ## downsample validation set for faster scoring
86116
score_duty_cycle=0.025, ## don't score more than 2.5% of the wall time
87117
adaptive_rate=F, ## manually tuned learning rate
88118
rate=0.02,
89-
rate_annealing=2e-6, ## manually tuned momentum
90-
momentum_start = 0.2,
119+
rate_annealing=2e-6,
120+
momentum_start = 0.2, ## manually tuned momentum
91121
momentum_stable = 0.4,
92122
momentum_ramp = 1e7,
93123
l2=1e-5, ## add some L2 regularization
94124
max_w2 = 10 ## helps stability for Rectifier
95125
)
96126

127+
## Now score on the full validation set and the test test
97128
summary(m)
98129
h2o.confusionMatrix(h2o.performance(m, train=T)) ## training
99130
h2o.confusionMatrix(h2o.performance(m, valid=T)) ## sampled validation
@@ -110,7 +141,7 @@ plot(m)
110141

111142

112143
## Grid search
113-
if (TRUE) {
144+
if (FALSE) {
114145
hyper_params <- list(
115146
hidden = list(c(64,64,64),c(128,128,128),c(512,512)),
116147
l1 = c(0, 1e-5),
@@ -122,7 +153,7 @@ if (TRUE) {
122153
momentum_stable =c(0.75,0.9,0.99),
123154
momentum_ramp = c(1e6, 1e7, 1e8)
124155
)
125-
hyper_params
156+
hyper_params # 3*2*2*2*3*3*3*3*3 = 5832 combinations
126157

127158
h2o.grid(
128159
"deeplearning",

0 commit comments

Comments
 (0)