Skip to content

Commit a0c76ef

Browse files
committed
Add script for databricks talk.
1 parent 542a82f commit a0c76ef

1 file changed

Lines changed: 199 additions & 0 deletions

File tree

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
// Databricks notebook source exported at Mon, 9 Nov 2015 04:31:10 UTC
2+
// Ensure you have included the table smsData
3+
4+
// COMMAND ----------
5+
6+
// Representation of a training message
7+
import org.apache.spark.mllib.linalg.Vector
8+
case class SMS(target: String, fv: Vector)
9+
10+
// COMMAND ----------
11+
12+
// Define tokenizer function
13+
def tokenize(data: RDD[String]): RDD[Seq[String]] = {
14+
val ignoredWords = Seq("the", "a", "", "in", "on", "at", "as", "not", "for")
15+
val ignoredChars = Seq(',', ':', ';', '/', '<', '>', '"', '.', '(', ')', '?', '-', '\'','!','0', '1')
16+
17+
val texts = data.map( r=> {
18+
var smsText = r.toLowerCase
19+
for( c <- ignoredChars) {
20+
smsText = smsText.replace(c, ' ')
21+
}
22+
23+
val words =smsText.split(" ").filter(w => !ignoredWords.contains(w) && w.length>2).distinct
24+
25+
words.toSeq
26+
})
27+
texts
28+
}
29+
30+
// COMMAND ----------
31+
32+
// Define function which builds an IDF model
33+
import org.apache.spark.mllib.feature._
34+
35+
def buildIDFModel(tokens: RDD[Seq[String]],
36+
minDocFreq:Int = 4,
37+
hashSpaceSize:Int = 1 << 10): (HashingTF, IDFModel, RDD[Vector]) = {
38+
// Hash strings into the given space
39+
val hashingTF = new HashingTF(hashSpaceSize)
40+
val tf = hashingTF.transform(tokens)
41+
// Build term frequency-inverse document frequency
42+
val idfModel = new IDF(minDocFreq = minDocFreq).fit(tf)
43+
val expandedText = idfModel.transform(tf)
44+
(hashingTF, idfModel, expandedText)
45+
}
46+
47+
// COMMAND ----------
48+
49+
// Define function which builds a DL model
50+
import org.apache.spark.h2o._
51+
import water.Key
52+
import _root_.hex.deeplearning.DeepLearning
53+
import _root_.hex.deeplearning.DeepLearningParameters
54+
import _root_.hex.deeplearning.DeepLearningModel
55+
56+
def buildDLModel(train: Frame, valid: Frame,
57+
epochs: Int = 10, l1: Double = 0.001, l2: Double = 0.0,
58+
hidden: Array[Int] = Array[Int](200, 200))
59+
(implicit h2oContext: H2OContext): DeepLearningModel = {
60+
import h2oContext._
61+
// Build a model
62+
63+
val dlParams = new DeepLearningParameters()
64+
dlParams._model_id = Key.make("dlModel.hex")
65+
dlParams._train = train
66+
dlParams._valid = valid
67+
dlParams._response_column = 'target
68+
dlParams._epochs = epochs
69+
dlParams._l1 = l1
70+
dlParams._hidden = hidden
71+
72+
// Create a job
73+
val dl = new DeepLearning(dlParams)
74+
val dlModel = dl.trainModel.get
75+
76+
// Compute metrics on both datasets
77+
dlModel.score(train).delete()
78+
dlModel.score(valid).delete()
79+
80+
dlModel
81+
}
82+
83+
// COMMAND ----------
84+
85+
// Create SQL support
86+
import org.apache.spark.sql._
87+
implicit val sqlContext = SQLContext.getOrCreate(sc)
88+
import sqlContext.implicits._
89+
90+
// Start H2O services
91+
import org.apache.spark.h2o._
92+
@transient val h2oContext = new H2OContext(sc).start()
93+
94+
95+
96+
// COMMAND ----------
97+
98+
// Open H2O UI
99+
h2oContext.openFlow
100+
101+
// COMMAND ----------
102+
103+
// Build the application
104+
105+
import org.apache.spark.rdd.RDD
106+
import org.apache.spark.examples.h2o.DemoUtils._
107+
import scala.io.Source
108+
109+
// load both columns from the table
110+
val data = sqlContext.sql("SELECT * FROM smsData")
111+
// Extract response spam or ham
112+
val hamSpam = data.map( r => r(0).toString)
113+
val message = data.map( r => r(1).toString)
114+
// Tokenize message content
115+
val tokens = tokenize(message)
116+
// Build IDF model
117+
var (hashingTF, idfModel, tfidf) = buildIDFModel(tokens)
118+
119+
// Merge response with extracted vectors
120+
val resultRDD: DataFrame = hamSpam.zip(tfidf).map(v => SMS(v._1, v._2)).toDF
121+
122+
// Publish Spark DataFrame as H2OFrame
123+
// This H2OFrame has to be transient because we do not want it to be serialized. When calling for example sc.parallelize(..) the object which we are trying to parallelize takes with itself all variables in its surroundings scope - apart from those marked as serialized.
124+
//
125+
@transient val table = h2oContext.asH2OFrame(resultRDD)
126+
println(sc.parallelize(Array(1,2)))
127+
// Transform target column into categorical
128+
table.replace(table.find("target"), table.vec("target").toCategoricalVec()).remove()
129+
table.update(null)
130+
131+
// Split table
132+
val keys = Array[String]("train.hex", "valid.hex")
133+
val ratios = Array[Double](0.8)
134+
@transient val frs = split(table, keys, ratios)
135+
@transient val train = frs(0)
136+
@transient val valid = frs(1)
137+
table.delete()
138+
139+
// Build a model
140+
@transient val dlModel = buildDLModel(train, valid)(h2oContext)
141+
142+
143+
144+
// COMMAND ----------
145+
146+
dlModel
147+
148+
// COMMAND ----------
149+
150+
// Evaluate model equality
151+
152+
// Collect model metrics and evaluate model quality
153+
import water.app.ModelMetricsSupport
154+
val trainMetrics = ModelMetricsSupport.binomialMM(dlModel, train)
155+
println(trainMetrics.auc._auc)
156+
157+
// COMMAND ----------
158+
159+
// Collect model metrics and evaluate model quality
160+
import water.app.ModelMetricsSupport
161+
val validMetrics = ModelMetricsSupport.binomialMM(dlModel, valid)
162+
println(validMetrics.auc._auc)
163+
164+
// COMMAND ----------
165+
166+
// Create a spam detector - a method which will return SPAM or HAM for given text message
167+
import water.DKV._
168+
// Spam detector
169+
def isSpam(msg: String,
170+
modelId: String,
171+
hashingTF: HashingTF,
172+
idfModel: IDFModel,
173+
h2oContext: H2OContext,
174+
hamThreshold: Double = 0.5):String = {
175+
val dlModel: DeepLearningModel = water.DKV.getGet(modelId)
176+
val msgRdd = sc.parallelize(Seq(msg))
177+
val msgVector: DataFrame = idfModel.transform(
178+
hashingTF.transform (
179+
tokenize (msgRdd))).map(v => SMS("?", v)).toDF
180+
val msgTable: H2OFrame = h2oContext.asH2OFrame(msgVector)
181+
msgTable.remove(0) // remove first column
182+
val prediction = dlModel.score(msgTable)
183+
//println(prediction)
184+
if (prediction.vecs()(1).at(0) < hamThreshold) "SPAM DETECTED!" else "HAM"
185+
}
186+
187+
// COMMAND ----------
188+
189+
// Try do detect spam
190+
191+
isSpam("Michal, h2oworld party tonight in MV?", dlModel._key.toString, hashingTF, idfModel, h2oContext)
192+
193+
// COMMAND ----------
194+
195+
isSpam("We tried to contact you re your reply to our offer of a Video Handset? 750 anytime any networks mins? UNLIMITED TEXT?", dlModel._key.toString, hashingTF, idfModel, h2oContext)
196+
197+
// COMMAND ----------
198+
199+

0 commit comments

Comments
 (0)