Skip to content

Commit 5730cf4

Browse files
authored
Reuse stencil client between Spark Tasks (#58)
* reuse stencil client between tasks Signed-off-by: Oleksii Moskalenko <moskalenko.alexey@gmail.com> * format Signed-off-by: Oleksii Moskalenko <moskalenko.alexey@gmail.com>
1 parent 4b4ade3 commit 5730cf4

File tree

2 files changed

+16
-14
lines changed

2 files changed

+16
-14
lines changed

spark/ingestion/src/main/scala/feast/ingestion/registry/proto/StencilProtoRegistry.scala

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,21 @@ import com.gojek.de.stencil.StencilClientFactory
2222
import com.gojek.de.stencil.client.StencilClient
2323

2424
class StencilProtoRegistry(val url: String) extends ProtoRegistry {
25+
import StencilProtoRegistry.stencilClient
26+
27+
override def getProtoDescriptor(className: String): Descriptors.Descriptor = {
28+
stencilClient(url).get(className)
29+
}
30+
}
31+
32+
object StencilProtoRegistry {
2533
@transient
2634
private var _stencilClient: StencilClient = _
2735

28-
def stencilClient: StencilClient = {
36+
def stencilClient(url: String): StencilClient = {
2937
if (_stencilClient == null) {
3038
_stencilClient = StencilClientFactory.getClient(url, Collections.emptyMap[String, String])
3139
}
3240
_stencilClient
3341
}
34-
35-
override def getProtoDescriptor(className: String): Descriptors.Descriptor = {
36-
stencilClient.get(className)
37-
}
3842
}

spark/ingestion/src/main/scala/feast/ingestion/utils/ProtoReflection.scala

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -124,15 +124,13 @@ object ProtoReflection {
124124
}
125125

126126
def createMessageParser(protoRegistry: ProtoRegistry, className: String): Array[Byte] => Row = {
127-
// perform request to registry in driver, so serialized protoRegistry will have cached descriptor
128-
protoRegistry.getProtoDescriptor(className)
127+
bytes =>
128+
{
129+
val protoDescriptor = protoRegistry.getProtoDescriptor(className)
129130

130-
bytes => {
131-
val protoDescriptor = protoRegistry.getProtoDescriptor(className)
132-
133-
Try { DynamicMessage.parseFrom(protoDescriptor, bytes) }
134-
.map(messageToRow(protoDescriptor, _))
135-
.getOrElse(null)
136-
}
131+
Try { DynamicMessage.parseFrom(protoDescriptor, bytes) }
132+
.map(messageToRow(protoDescriptor, _))
133+
.getOrElse(null)
134+
}
137135
}
138136
}

0 commit comments

Comments
 (0)