Skip to content

Commit 2be80cd

Browse files
authored
Initialize cluster provider / jedis pool only once (#157)
Signed-off-by: khorshuheng <khor.heng@gojek.com> Co-authored-by: khorshuheng <khor.heng@gojek.com>
1 parent 96b9336 commit 2be80cd

File tree

5 files changed

+117
-64
lines changed

5 files changed

+117
-64
lines changed

spark/ingestion/src/main/scala/feast/ingestion/stores/redis/ClusterPipelineProvider.scala

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
*/
1717
package feast.ingestion.stores.redis
1818

19-
import redis.clients.jedis.{ClusterPipeline, DefaultJedisClientConfig, HostAndPort}
19+
import redis.clients.jedis.commands.PipelineBinaryCommands
20+
import redis.clients.jedis.{ClusterPipeline, DefaultJedisClientConfig, HostAndPort, Response}
2021
import redis.clients.jedis.providers.ClusterConnectionProvider
2122

2223
import scala.collection.JavaConverters._
@@ -34,9 +35,14 @@ case class ClusterPipelineProvider(endpoint: RedisEndpoint) extends PipelineProv
3435
val provider = new ClusterConnectionProvider(nodes, DEFAULT_CLIENT_CONFIG)
3536

3637
/**
37-
* @return a cluster pipeline
38+
* @return execute commands within a pipeline and return the result
3839
*/
39-
override def pipeline(): UnifiedPipeline = new ClusterPipeline(provider)
40+
override def withPipeline[T](ops: PipelineBinaryCommands => T): T = {
41+
val pipeline = new ClusterPipeline(provider)
42+
val response = ops(pipeline)
43+
pipeline.close()
44+
response
45+
}
4046

4147
/**
4248
* Close client connection

spark/ingestion/src/main/scala/feast/ingestion/stores/redis/PipelineProvider.scala

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
*/
1717
package feast.ingestion.stores.redis
1818

19+
import redis.clients.jedis.Response
1920
import redis.clients.jedis.commands.PipelineBinaryCommands
2021

2122
import java.io.Closeable
@@ -25,12 +26,7 @@ import java.io.Closeable
2526
*/
2627
trait PipelineProvider {
2728

28-
type UnifiedPipeline = PipelineBinaryCommands with Closeable
29-
30-
/**
31-
* @return an interface for executing pipeline commands
32-
*/
33-
def pipeline(): UnifiedPipeline
29+
def withPipeline[T](ops: PipelineBinaryCommands => T): T
3430

3531
/**
3632
* Close client connection
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
* Copyright 2018-2022 The Feast Authors
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* https://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package feast.ingestion.stores.redis
18+
19+
import redis.clients.jedis.Jedis
20+
21+
import scala.collection.mutable
22+
import scala.util.Try
23+
24+
object PipelineProviderFactory {
25+
26+
private lazy val providers: mutable.Map[RedisEndpoint, PipelineProvider] = mutable.Map.empty
27+
28+
private def newJedisClient(endpoint: RedisEndpoint): Jedis = {
29+
val jedis = new Jedis(endpoint.host, endpoint.port)
30+
if (endpoint.password.nonEmpty) {
31+
jedis.auth(endpoint.password)
32+
}
33+
jedis
34+
}
35+
36+
private def checkIfInClusterMode(endpoint: RedisEndpoint): Boolean = {
37+
val jedis = newJedisClient(endpoint)
38+
val isCluster = Try(jedis.clusterInfo()).isSuccess
39+
jedis.close()
40+
isCluster
41+
}
42+
43+
private def clusterPipelineProvider(endpoint: RedisEndpoint): PipelineProvider = {
44+
ClusterPipelineProvider(endpoint)
45+
}
46+
47+
private def singleNodePipelineProvider(endpoint: RedisEndpoint): PipelineProvider = {
48+
SingleNodePipelineProvider(endpoint)
49+
}
50+
51+
def newProvider(endpoint: RedisEndpoint): PipelineProvider = {
52+
if (checkIfInClusterMode(endpoint)) {
53+
clusterPipelineProvider(endpoint)
54+
}
55+
singleNodePipelineProvider(endpoint)
56+
}
57+
58+
def provider(endpoint: RedisEndpoint): PipelineProvider = {
59+
providers.getOrElseUpdate(endpoint, newProvider(endpoint))
60+
}
61+
}

spark/ingestion/src/main/scala/feast/ingestion/stores/redis/RedisSinkRelation.scala

Lines changed: 28 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -65,23 +65,6 @@ class RedisSinkRelation(override val sqlContext: SQLContext, config: SparkRedisC
6565
pipelineSize = sparkConf.get("spark.redis.properties.pipelineSize").toInt
6666
)
6767

68-
lazy val isClusterMode: Boolean = checkIfInClusterMode(endpoint)
69-
70-
def newJedisClient(endpoint: RedisEndpoint): Jedis = {
71-
val jedis = new Jedis(endpoint.host, endpoint.port)
72-
if (endpoint.password.nonEmpty) {
73-
jedis.auth(endpoint.password)
74-
}
75-
jedis
76-
}
77-
78-
def checkIfInClusterMode(endpoint: RedisEndpoint): Boolean = {
79-
val jedis = newJedisClient(endpoint)
80-
val isCluster = Try(jedis.clusterInfo()).isSuccess
81-
jedis.close()
82-
isCluster
83-
}
84-
8568
override def insert(data: DataFrame, overwrite: Boolean): Unit = {
8669
// repartition for deduplication
8770
val dataToStore =
@@ -95,23 +78,19 @@ class RedisSinkRelation(override val sqlContext: SQLContext, config: SparkRedisC
9578
java.security.Security.setProperty("networkaddress.cache.ttl", "3");
9679
java.security.Security.setProperty("networkaddress.cache.negative.ttl", "0");
9780

98-
val pipelineProvider = if (isClusterMode) {
99-
ClusterPipelineProvider(endpoint)
100-
} else {
101-
SingleNodePipelineProvider(newJedisClient(endpoint))
102-
}
81+
val pipelineProvider = PipelineProviderFactory.provider(endpoint)
10382

10483
// grouped iterator to only allocate memory for a portion of rows
10584
partition.grouped(properties.pipelineSize).foreach { batch =>
10685
// group by key and keep only latest row per each key
10786
val rowsWithKey: Map[String, Row] =
10887
compactRowsToLatestTimestamp(batch.map(row => dataKeyId(row) -> row)).toMap
10988

110-
val keys = rowsWithKey.keysIterator.toList
111-
val readPipeline = pipelineProvider.pipeline()
112-
val readResponses =
113-
keys.map(key => persistence.get(readPipeline, key.getBytes()))
114-
readPipeline.close()
89+
val keys = rowsWithKey.keysIterator.toList
90+
val readResponses = pipelineProvider.withPipeline(pipeline => {
91+
keys.map(key => persistence.get(pipeline, key.getBytes()))
92+
})
93+
11594
val storedValues = readResponses.map(_.get())
11695
val timestamps = storedValues.map(persistence.storedTimestamp)
11796
val timestampByKey = keys.zip(timestamps).toMap
@@ -122,31 +101,30 @@ class RedisSinkRelation(override val sqlContext: SQLContext, config: SparkRedisC
122101
}
123102
.toMap
124103

125-
val writePipeline = pipelineProvider.pipeline()
126-
rowsWithKey.foreach { case (key, row) =>
127-
timestampByKey(key) match {
128-
case Some(t) if (t.after(row.getAs[java.sql.Timestamp](config.timestampColumn))) =>
129-
()
130-
case _ =>
131-
if (metricSource.nonEmpty) {
132-
val lag = System.currentTimeMillis() - row
133-
.getAs[java.sql.Timestamp](config.timestampColumn)
134-
.getTime
135-
136-
metricSource.get.METRIC_TOTAL_ROWS_INSERTED.inc()
137-
metricSource.get.METRIC_ROWS_LAG.update(lag)
138-
}
139-
persistence.save(
140-
writePipeline,
141-
key.getBytes(),
142-
row,
143-
expiryTimestampByKey(key)
144-
)
104+
pipelineProvider.withPipeline(pipeline => {
105+
rowsWithKey.foreach { case (key, row) =>
106+
timestampByKey(key) match {
107+
case Some(t) if (t.after(row.getAs[java.sql.Timestamp](config.timestampColumn))) =>
108+
()
109+
case _ =>
110+
if (metricSource.nonEmpty) {
111+
val lag = System.currentTimeMillis() - row
112+
.getAs[java.sql.Timestamp](config.timestampColumn)
113+
.getTime
114+
115+
metricSource.get.METRIC_TOTAL_ROWS_INSERTED.inc()
116+
metricSource.get.METRIC_ROWS_LAG.update(lag)
117+
}
118+
persistence.save(
119+
pipeline,
120+
key.getBytes(),
121+
row,
122+
expiryTimestampByKey(key)
123+
)
124+
}
145125
}
146-
}
147-
writePipeline.close()
126+
})
148127
}
149-
pipelineProvider.close()
150128
}
151129
dataToStore.unpersist()
152130
}

spark/ingestion/src/main/scala/feast/ingestion/stores/redis/SingleNodePipelineProvider.scala

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,32 @@
1616
*/
1717
package feast.ingestion.stores.redis
1818

19-
import redis.clients.jedis.Jedis
19+
import redis.clients.jedis.commands.PipelineBinaryCommands
20+
import redis.clients.jedis.{JedisPool, Response}
2021

2122
/**
2223
* Provide pipeline for single node Redis.
2324
*/
24-
case class SingleNodePipelineProvider(jedis: Jedis) extends PipelineProvider {
25+
case class SingleNodePipelineProvider(endpoint: RedisEndpoint) extends PipelineProvider {
26+
27+
val jedisPool = new JedisPool(endpoint.host, endpoint.port)
2528

2629
/**
27-
* @return a single node redis pipeline
30+
* @return execute command within a pipeline and return the result
2831
*/
29-
override def pipeline(): UnifiedPipeline = jedis.pipelined()
32+
override def withPipeline[T](ops: PipelineBinaryCommands => T): T = {
33+
val jedis = jedisPool.getResource
34+
if (endpoint.password.nonEmpty) {
35+
jedis.auth(endpoint.password)
36+
}
37+
val response = ops(jedis.pipelined())
38+
jedis.close()
39+
response
40+
}
3041

3142
/**
3243
* Close client connection
3344
*/
34-
override def close(): Unit = jedis.close()
45+
override def close(): Unit = jedisPool.close()
46+
3547
}

0 commit comments

Comments
 (0)