diff --git a/Makefile b/Makefile
index a3d83b56..55beb3c3 100644
--- a/Makefile
+++ b/Makefile
@@ -53,7 +53,7 @@ build-local-test-docker:
docker build -t feast:local -f infra/docker/tests/Dockerfile .
build-ingestion-jar-no-tests:
- cd spark/ingestion && ${MVN} --no-transfer-progress -Dmaven.javadoc.skip=true -Dgpg.skip -DskipUTs=true -DskipITs=true -Drevision=${REVISION} clean package
+ cd spark/ingestion && ${MVN} --no-transfer-progress -Dmaven.javadoc.skip=true -Dgpg.skip -DskipUTs=true -D"spotless.check.skip"=true -DskipITs=true -Drevision=${REVISION} clean package
build-jobservice-docker:
docker build -t $(REGISTRY)/feast-jobservice:$(VERSION) -f infra/docker/jobservice/Dockerfile .
@@ -68,3 +68,11 @@ push-spark-docker:
docker push $(REGISTRY)/feast-spark:$(VERSION)
install-ci-dependencies: install-python-ci-dependencies
+
+build-ingestion-jar-push:
+ docker build -t $(REGISTRY)/feast-spark:$(VERSION) --build-arg VERSION=$(VERSION) -f infra/docker/spark/Dockerfile .
+ rm -f feast-ingestion-spark-latest.jar
+ docker create -ti --name dummy $(REGISTRY)/feast-spark:latest bash
+ docker cp dummy:/opt/spark/jars/feast-ingestion-spark-latest.jar feast-ingestion-spark-latest.jar
+ docker rm -f dummy
+ python python/feast_spark/copy_to_azure_blob.py
\ No newline at end of file
diff --git a/README.md b/README.md
index 0e00aad6..44c8ff0c 100644
--- a/README.md
+++ b/README.md
@@ -57,4 +57,14 @@ client.apply(entity, ft)
# Start spark streaming ingestion job that reads from kafka and writes to the online store
feast_spark.Client(client).start_stream_to_online_ingestion(ft)
-```
\ No newline at end of file
+```
+
+Build and push to BLOB storage
+
+In order to build the Spark Ingestion jar and copy it to BLOB storage, you have to set these 3 environment variables:
+
+```bash
+export VERSION=latest
+export REGISTRY=your_registry_name
+export AZURE_STORAGE_CONNECTION_STRING="your_azure_storage_connection_string"
+```
diff --git a/pom.xml b/pom.xml
index 4eeffaac..0b7f608b 100644
--- a/pom.xml
+++ b/pom.xml
@@ -18,8 +18,8 @@
1.8
1.8
2.12
- ${scala.version}.12
- 3.0.2
+ ${scala.version}.10
+ 3.1.2
4.4.0
3.3.0
3.12.2
diff --git a/python/feast_spark/constants.py b/python/feast_spark/constants.py
index 8ea25ec7..6d34810a 100644
--- a/python/feast_spark/constants.py
+++ b/python/feast_spark/constants.py
@@ -93,6 +93,27 @@ class ConfigOptions(metaclass=ConfigMeta):
# SparkApplication resource template
SPARK_K8S_JOB_TEMPLATE_PATH = None
+ # Synapse dev url
+ AZURE_SYNAPSE_DEV_URL: Optional[str] = None
+
+ # Synapse pool name
+ AZURE_SYNAPSE_POOL_NAME: Optional[str] = None
+
+ # Datalake directory that linked to Synapse
+ AZURE_SYNAPSE_DATALAKE_DIR: Optional[str] = None
+
+ # Synapse pool executor size: Small, Medium or Large
+ AZURE_SYNAPSE_EXECUTOR_SIZE = "Small"
+
+ # Synapse pool executor count
+ AZURE_SYNAPSE_EXECUTORS = "2"
+
+ # Azure EventHub Connection String (with Kafka API). See more details here:
+ # https://docs.microsoft.com/en-us/azure/event-hubs/apache-kafka-migration-guide
+ # Code Sample is here:
+ # https://github.com/Azure/azure-event-hubs-for-kafka/blob/master/tutorials/spark/sparkConsumer.scala
+ AZURE_EVENTHUB_KAFKA_CONNECTION_STRING = ""
+
#: File format of historical retrieval features
HISTORICAL_FEATURE_OUTPUT_FORMAT: str = "parquet"
@@ -108,6 +129,9 @@ class ConfigOptions(metaclass=ConfigMeta):
#: Enable or disable TLS/SSL to Redis
REDIS_SSL: Optional[str] = "False"
+ #: Auth string for redis
+ REDIS_AUTH: str = ""
+
#: BigTable Project ID
BIGTABLE_PROJECT: Optional[str] = ""
diff --git a/python/feast_spark/copy_to_azure_blob.py b/python/feast_spark/copy_to_azure_blob.py
new file mode 100644
index 00000000..7f8f719f
--- /dev/null
+++ b/python/feast_spark/copy_to_azure_blob.py
@@ -0,0 +1,41 @@
+# coding: utf-8
+
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
+
+"""
+FILE: blob_samples_copy_blob.py
+DESCRIPTION:
+ This sample demos how to copy a blob from a URL.
+USAGE: python blob_samples_copy_blob.py
+ Set the environment variables with your own values before running the sample.
+ 1) AZURE_STORAGE_CONNECTION_STRING - the connection string to your storage account
+"""
+
+from __future__ import print_function
+import os
+import sys
+import time
+from azure.storage.blob import BlobServiceClient
+
+def main():
+ try:
+ CONNECTION_STRING = os.environ['AZURE_STORAGE_CONNECTION_STRING']
+
+ except KeyError:
+ print("AZURE_STORAGE_CONNECTION_STRING must be set.")
+ sys.exit(1)
+
+ blob_service_client = BlobServiceClient.from_connection_string(CONNECTION_STRING)
+ copied_blob = blob_service_client.get_blob_client("feastjar", 'feast-ingestion-spark-latest.jar')
+ # hard code to the current path
+ SOURCE_FILE = "./feast-ingestion-spark-latest.jar"
+
+ with open(SOURCE_FILE, "rb") as data:
+ copied_blob.upload_blob(data, blob_type="BlockBlob",overwrite=True)
+
+if __name__ == "__main__":
+ main()
diff --git a/python/feast_spark/pyspark/abc.py b/python/feast_spark/pyspark/abc.py
index db5c041d..154f8d08 100644
--- a/python/feast_spark/pyspark/abc.py
+++ b/python/feast_spark/pyspark/abc.py
@@ -340,13 +340,10 @@ def __init__(
feature_table: Dict,
source: Dict,
jar: str,
- redis_host: Optional[str] = None,
- redis_port: Optional[int] = None,
- redis_ssl: Optional[bool] = None,
- bigtable_project: Optional[str] = None,
- bigtable_instance: Optional[str] = None,
- cassandra_host: Optional[str] = None,
- cassandra_port: Optional[int] = None,
+ redis_host: str,
+ redis_port: int,
+ redis_ssl: bool,
+ redis_auth: str,
statsd_host: Optional[str] = None,
statsd_port: Optional[int] = None,
deadletter_path: Optional[str] = None,
@@ -359,10 +356,7 @@ def __init__(
self._redis_host = redis_host
self._redis_port = redis_port
self._redis_ssl = redis_ssl
- self._bigtable_project = bigtable_project
- self._bigtable_instance = bigtable_instance
- self._cassandra_host = cassandra_host
- self._cassandra_port = cassandra_port
+ self._redis_auth = redis_auth
self._statsd_host = statsd_host
self._statsd_port = statsd_port
self._deadletter_path = deadletter_path
@@ -370,15 +364,7 @@ def __init__(
self._drop_invalid_rows = drop_invalid_rows
def _get_redis_config(self):
- return dict(host=self._redis_host, port=self._redis_port, ssl=self._redis_ssl)
-
- def _get_bigtable_config(self):
- return dict(
- project_id=self._bigtable_project, instance_id=self._bigtable_instance
- )
-
- def _get_cassandra_config(self):
- return dict(host=self._cassandra_host, port=self._cassandra_port)
+ return dict(host=self._redis_host, port=self._redis_port, ssl=self._redis_ssl, auth=self._redis_auth)
def _get_statsd_config(self):
return (
@@ -405,17 +391,10 @@ def get_arguments(self) -> List[str]:
json.dumps(self._feature_table),
"--source",
json.dumps(self._source),
+ "--redis",
+ json.dumps(self._get_redis_config()),
]
- if self._redis_host and self._redis_port:
- args.extend(["--redis", json.dumps(self._get_redis_config())])
-
- if self._bigtable_project and self._bigtable_instance:
- args.extend(["--bigtable", json.dumps(self._get_bigtable_config())])
-
- if self._cassandra_host and self._cassandra_port:
- args.extend(["--cassandra", json.dumps(self._get_cassandra_config())])
-
if self._get_statsd_config():
args.extend(["--statsd", json.dumps(self._get_statsd_config())])
@@ -444,13 +423,14 @@ def __init__(
start: datetime,
end: datetime,
jar: str,
- redis_host: Optional[str],
- redis_port: Optional[int],
- redis_ssl: Optional[bool],
- bigtable_project: Optional[str],
- bigtable_instance: Optional[str],
+ redis_host: str,
+ redis_port: int,
+ redis_ssl: bool,
+ redis_auth: str,
+ bigtable_project: Optional[str] = None,
+ bigtable_instance: Optional[str] = None,
cassandra_host: Optional[str] = None,
- cassandra_port: Optional[int] = None,
+ cassandra_port: Optional[str] = None,
statsd_host: Optional[str] = None,
statsd_port: Optional[int] = None,
deadletter_path: Optional[str] = None,
@@ -463,10 +443,7 @@ def __init__(
redis_host,
redis_port,
redis_ssl,
- bigtable_project,
- bigtable_instance,
- cassandra_host,
- cassandra_port,
+ redis_auth,
statsd_host,
statsd_port,
deadletter_path,
@@ -494,7 +471,6 @@ def get_arguments(self) -> List[str]:
self._end.strftime("%Y-%m-%dT%H:%M:%S"),
]
-
class ScheduledBatchIngestionJobParameters(IngestionJobParameters):
def __init__(
self,
@@ -559,21 +535,20 @@ def __init__(
source: Dict,
jar: str,
extra_jars: List[str],
- redis_host: Optional[str],
- redis_port: Optional[int],
- redis_ssl: Optional[bool],
- bigtable_project: Optional[str],
- bigtable_instance: Optional[str],
- cassandra_host: Optional[str] = None,
- cassandra_port: Optional[int] = None,
+ redis_host: str,
+ redis_port: int,
+ redis_ssl: bool,
+ redis_auth: str,
statsd_host: Optional[str] = None,
statsd_port: Optional[int] = None,
deadletter_path: Optional[str] = None,
checkpoint_path: Optional[str] = None,
stencil_url: Optional[str] = None,
- drop_invalid_rows: bool = False,
- triggering_interval: Optional[int] = None,
+ drop_invalid_rows: Optional[bool] = False,
+ kafka_sasl_auth: Optional[str] = None,
):
+ stencil_url: Optional[str] = None,
+ drop_invalid_rows: bool = False,
super().__init__(
feature_table,
source,
@@ -581,10 +556,7 @@ def __init__(
redis_host,
redis_port,
redis_ssl,
- bigtable_project,
- bigtable_instance,
- cassandra_host,
- cassandra_port,
+ redis_auth,
statsd_host,
statsd_port,
deadletter_path,
@@ -593,7 +565,7 @@ def __init__(
)
self._extra_jars = extra_jars
self._checkpoint_path = checkpoint_path
- self._triggering_interval = triggering_interval
+ self._kafka_sasl_auth = kafka_sasl_auth
def get_name(self) -> str:
return f"{self.get_job_type().to_pascal_case()}-{self.get_feature_table_name()}"
@@ -609,8 +581,8 @@ def get_arguments(self) -> List[str]:
args.extend(["--mode", "online"])
if self._checkpoint_path:
args.extend(["--checkpoint-path", self._checkpoint_path])
- if self._triggering_interval:
- args.extend(["--triggering-interval", str(self._triggering_interval)])
+ if self._kafka_sasl_auth:
+ args.extend(["--kafka_sasl_auth", self._kafka_sasl_auth])
return args
def get_job_hash(self) -> str:
@@ -705,29 +677,6 @@ def offline_to_online_ingestion(
"""
raise NotImplementedError
- @abc.abstractmethod
- def schedule_offline_to_online_ingestion(
- self, ingestion_job_params: ScheduledBatchIngestionJobParameters
- ):
- """
- Submits a scheduled batch ingestion job to a Spark cluster.
-
- Raises:
- SparkJobFailure: The spark job submission failed, encountered error
- during execution, or timeout.
-
- Returns:
- ScheduledBatchIngestionJob: wrapper around remote job that can be used to check when job completed.
- """
- raise NotImplementedError
-
- @abc.abstractmethod
- def unschedule_offline_to_online_ingestion(self, project: str, feature_table: str):
- """
- Unschedule a scheduled batch ingestion job.
- """
- raise NotImplementedError
-
@abc.abstractmethod
def start_stream_to_online_ingestion(
self, ingestion_job_params: StreamIngestionJobParameters
diff --git a/python/feast_spark/pyspark/launcher.py b/python/feast_spark/pyspark/launcher.py
index 9f5e95ac..f8f8dc2d 100644
--- a/python/feast_spark/pyspark/launcher.py
+++ b/python/feast_spark/pyspark/launcher.py
@@ -83,11 +83,24 @@ def _k8s_launcher(config: Config) -> JobLauncher:
)
+def _synapse_launcher(config: Config) -> JobLauncher:
+ from feast_spark.pyspark.launchers import synapse
+
+ return synapse.SynapseJobLauncher(
+ synapse_dev_url=config.get(opt.AZURE_SYNAPSE_DEV_URL),
+ pool_name=config.get(opt.AZURE_SYNAPSE_POOL_NAME),
+ datalake_dir=config.get(opt.AZURE_SYNAPSE_DATALAKE_DIR),
+ executor_size=config.get(opt.AZURE_SYNAPSE_EXECUTOR_SIZE),
+ executors=int(config.get(opt.AZURE_SYNAPSE_EXECUTORS))
+ )
+
+
_launchers = {
"standalone": _standalone_launcher,
"dataproc": _dataproc_launcher,
"emr": _emr_launcher,
"k8s": _k8s_launcher,
+ 'synapse': _synapse_launcher,
}
@@ -347,6 +360,7 @@ def start_offline_to_online_ingestion(
redis_port=bool(client.config.get(opt.REDIS_HOST))
and client.config.getint(opt.REDIS_PORT),
redis_ssl=client.config.getboolean(opt.REDIS_SSL),
+ redis_auth=client.config.get(opt.REDIS_AUTH),
bigtable_project=client.config.get(opt.BIGTABLE_PROJECT),
bigtable_instance=client.config.get(opt.BIGTABLE_INSTANCE),
cassandra_host=client.config.get(opt.CASSANDRA_HOST),
@@ -423,11 +437,9 @@ def get_stream_to_online_ingestion_params(
source=_source_to_argument(feature_table.stream_source, client.config),
feature_table=_feature_table_to_argument(client, project, feature_table),
redis_host=client.config.get(opt.REDIS_HOST),
- redis_port=bool(client.config.get(opt.REDIS_HOST))
- and client.config.getint(opt.REDIS_PORT),
+ redis_port=client.config.getint(opt.REDIS_PORT),
redis_ssl=client.config.getboolean(opt.REDIS_SSL),
- bigtable_project=client.config.get(opt.BIGTABLE_PROJECT),
- bigtable_instance=client.config.get(opt.BIGTABLE_INSTANCE),
+ redis_auth=client.config.get(opt.REDIS_AUTH),
statsd_host=client.config.getboolean(opt.STATSD_ENABLED)
and client.config.get(opt.STATSD_HOST),
statsd_port=client.config.getboolean(opt.STATSD_ENABLED)
@@ -436,11 +448,9 @@ def get_stream_to_online_ingestion_params(
checkpoint_path=client.config.get(opt.CHECKPOINT_PATH),
stencil_url=client.config.get(opt.STENCIL_URL),
drop_invalid_rows=client.config.get(opt.INGESTION_DROP_INVALID_ROWS),
- triggering_interval=client.config.getint(
- opt.SPARK_STREAMING_TRIGGERING_INTERVAL, default=None
- ),
- )
+ kafka_sasl_auth=client.config.get(opt.AZURE_EVENTHUB_KAFKA_CONNECTION_STRING),
+ )
def start_stream_to_online_ingestion(
client: "Client", project: str, feature_table: FeatureTable, extra_jars: List[str]
diff --git a/python/feast_spark/pyspark/launchers/synapse/__init__.py b/python/feast_spark/pyspark/launchers/synapse/__init__.py
new file mode 100644
index 00000000..59dadfa4
--- /dev/null
+++ b/python/feast_spark/pyspark/launchers/synapse/__init__.py
@@ -0,0 +1,15 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+from .synapse import (
+ SynapseBatchIngestionJob,
+ SynapseJobLauncher,
+ SynapseRetrievalJob,
+ SynapseStreamIngestionJob,
+)
+
+__all__ = [
+ "SynapseRetrievalJob",
+ "SynapseBatchIngestionJob",
+ "SynapseStreamIngestionJob",
+ "SynapseJobLauncher",
+]
diff --git a/python/feast_spark/pyspark/launchers/synapse/synapse.py b/python/feast_spark/pyspark/launchers/synapse/synapse.py
new file mode 100644
index 00000000..3a42a95a
--- /dev/null
+++ b/python/feast_spark/pyspark/launchers/synapse/synapse.py
@@ -0,0 +1,298 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+import time
+from datetime import datetime
+from typing import List, Optional, cast
+
+from azure.synapse.spark.models import SparkBatchJob
+from azure.identity import DefaultAzureCredential, DeviceCodeCredential, ChainedTokenCredential, ManagedIdentityCredential,EnvironmentCredential
+
+from feast_spark.pyspark.abc import (
+ BatchIngestionJob,
+ BatchIngestionJobParameters,
+ JobLauncher,
+ RetrievalJob,
+ RetrievalJobParameters,
+ SparkJob,
+ SparkJobFailure,
+ SparkJobStatus,
+ StreamIngestionJob,
+ StreamIngestionJobParameters,
+)
+
+from .synapse_utils import (
+ HISTORICAL_RETRIEVAL_JOB_TYPE,
+ LABEL_JOBTYPE,
+ LABEL_FEATURE_TABLE,
+ METADATA_JOBHASH,
+ METADATA_OUTPUT_URI,
+ OFFLINE_TO_ONLINE_JOB_TYPE,
+ STREAM_TO_ONLINE_JOB_TYPE,
+ SynapseJobRunner,
+ DataLakeFiler,
+ _prepare_job_tags,
+ _job_feast_state,
+ _job_start_time,
+ _cancel_job_by_id,
+ _get_job_by_id,
+ _list_jobs,
+ _submit_job,
+)
+
+
+class SynapseJobMixin:
+ def __init__(self, api: SynapseJobRunner, job_id: int):
+ self._api = api
+ self._job_id = job_id
+
+ def get_id(self) -> str:
+ return self._job_id
+
+ def get_status(self) -> SparkJobStatus:
+ job = _get_job_by_id(self._api, self._job_id)
+ assert job is not None
+ return _job_feast_state(job)
+
+ def get_start_time(self) -> datetime:
+ job = _get_job_by_id(self._api, self._job_id)
+ assert job is not None
+ return _job_start_time(job)
+
+ def cancel(self):
+ _cancel_job_by_id(self._api, self._job_id)
+
+ def _wait_for_complete(self, timeout_seconds: Optional[float]) -> bool:
+ """ Returns true if the job completed successfully """
+ start_time = time.time()
+ while (timeout_seconds is None) or (time.time() - start_time < timeout_seconds):
+ status = self.get_status()
+ if status == SparkJobStatus.COMPLETED:
+ return True
+ elif status == SparkJobStatus.FAILED:
+ return False
+ else:
+ time.sleep(1)
+ else:
+ raise TimeoutError("Timeout waiting for job to complete")
+
+
+class SynapseRetrievalJob(SynapseJobMixin, RetrievalJob):
+ """
+ Historical feature retrieval job result for a synapse cluster
+ """
+
+ def __init__(
+ self, api: SynapseJobRunner, job_id: int, output_file_uri: str
+ ):
+ """
+ This is the job object representing the historical retrieval job, returned by SynapseClusterLauncher.
+
+ Args:
+ output_file_uri (str): Uri to the historical feature retrieval job output file.
+ """
+ super().__init__(api, job_id)
+ self._output_file_uri = output_file_uri
+
+ def get_output_file_uri(self, timeout_sec=None, block=True):
+ if not block:
+ return self._output_file_uri
+
+ if self._wait_for_complete(timeout_sec):
+ return self._output_file_uri
+ else:
+ raise SparkJobFailure("Spark job failed")
+
+
+class SynapseBatchIngestionJob(SynapseJobMixin, BatchIngestionJob):
+ """
+ Ingestion job result for a synapse cluster
+ """
+
+ def __init__(
+ self, api: SynapseJobRunner, job_id: int, feature_table: str
+ ):
+ super().__init__(api, job_id)
+ self._feature_table = feature_table
+
+ def get_feature_table(self) -> str:
+ return self._feature_table
+
+
+class SynapseStreamIngestionJob(SynapseJobMixin, StreamIngestionJob):
+ """
+ Ingestion streaming job for a synapse cluster
+ """
+
+ def __init__(
+ self,
+ api: SynapseJobRunner,
+ job_id: int,
+ job_hash: str,
+ feature_table: str,
+ ):
+ super().__init__(api, job_id)
+ self._job_hash = job_hash
+ self._feature_table = feature_table
+
+ def get_hash(self) -> str:
+ return self._job_hash
+
+ def get_feature_table(self) -> str:
+ return self._feature_table
+
+login_credential_cache = None
+
+class SynapseJobLauncher(JobLauncher):
+ """
+ Submits spark jobs to a spark cluster. Currently supports only historical feature retrieval jobs.
+ """
+
+ def __init__(
+ self,
+ synapse_dev_url: str,
+ pool_name: str,
+ datalake_dir: str,
+ executor_size: str,
+ executors: int
+ ):
+ tenant_id='72f988bf-86f1-41af-91ab-2d7cd011db47'
+ authority_host_uri = 'login.microsoftonline.com'
+ client_id = '04b07795-8ddb-461a-bbee-02f9e1bf7b46'
+
+ global login_credential_cache
+ # use a global cache to store the credential, to avoid users from multiple login
+
+ if login_credential_cache is None:
+ # use DeviceCodeCredential if EnvironmentCredential is not available
+ self.credential = ChainedTokenCredential(EnvironmentCredential(), DeviceCodeCredential(client_id, authority=authority_host_uri, tenant=tenant_id))
+ login_credential_cache = self.credential
+ else:
+ self.credential = login_credential_cache
+
+ self._api = SynapseJobRunner(synapse_dev_url, pool_name, executor_size = executor_size, executors = executors, credential=self.credential)
+ self._datalake = DataLakeFiler(datalake_dir,credential=self.credential)
+
+ def _job_from_job_info(self, job_info: SparkBatchJob) -> SparkJob:
+ job_type = job_info.tags[LABEL_JOBTYPE]
+ if job_type == HISTORICAL_RETRIEVAL_JOB_TYPE:
+ assert METADATA_OUTPUT_URI in job_info.tags
+ return SynapseRetrievalJob(
+ api=self._api,
+ job_id=job_info.id,
+ output_file_uri=job_info.tags[METADATA_OUTPUT_URI],
+ )
+ elif job_type == OFFLINE_TO_ONLINE_JOB_TYPE:
+ return SynapseBatchIngestionJob(
+ api=self._api,
+ job_id=job_info.id,
+ feature_table=job_info.tags.get(LABEL_FEATURE_TABLE, ""),
+ )
+ elif job_type == STREAM_TO_ONLINE_JOB_TYPE:
+ # job_hash must not be None for stream ingestion jobs
+ assert METADATA_JOBHASH in job_info.tags
+ return SynapseStreamIngestionJob(
+ api=self._api,
+ job_id=job_info.id,
+ job_hash=job_info.tags[METADATA_JOBHASH],
+ feature_table=job_info.tags.get(LABEL_FEATURE_TABLE, ""),
+ )
+ else:
+ # We should never get here
+ raise ValueError(f"Unknown job type {job_type}")
+
+ def historical_feature_retrieval(
+ self, job_params: RetrievalJobParameters
+ ) -> RetrievalJob:
+ """
+ Submits a historical feature retrieval job to a Spark cluster.
+
+ Raises:
+ SparkJobFailure: The spark job submission failed, encountered error
+ during execution, or timeout.
+
+ Returns:
+ RetrievalJob: wrapper around remote job that returns file uri to the result file.
+ """
+
+ main_file = self._datalake.upload_file(job_params.get_main_file_path())
+ job_info = _submit_job(self._api, "Historical-Retrieval", main_file,
+ arguments = job_params.get_arguments(),
+ tags = {LABEL_JOBTYPE: HISTORICAL_RETRIEVAL_JOB_TYPE,
+ METADATA_OUTPUT_URI: job_params.get_destination_path()})
+
+ return cast(RetrievalJob, self._job_from_job_info(job_info))
+
+ def offline_to_online_ingestion(
+ self, ingestion_job_params: BatchIngestionJobParameters
+ ) -> BatchIngestionJob:
+ """
+ Submits a batch ingestion job to a Spark cluster.
+
+ Raises:
+ SparkJobFailure: The spark job submission failed, encountered error
+ during execution, or timeout.
+
+ Returns:
+ BatchIngestionJob: wrapper around remote job that can be used to check when job completed.
+ """
+
+ main_file = self._datalake.upload_file(ingestion_job_params.get_main_file_path())
+
+ job_info = _submit_job(self._api, ingestion_job_params.get_project()+"_offline_to_online_ingestion", main_file,
+ main_class = ingestion_job_params.get_class_name(),
+ arguments = ingestion_job_params.get_arguments(),
+ reference_files=[main_file],
+ tags = _prepare_job_tags(ingestion_job_params, OFFLINE_TO_ONLINE_JOB_TYPE),configuration=None)
+
+ return cast(BatchIngestionJob, self._job_from_job_info(job_info))
+
+ def start_stream_to_online_ingestion(
+ self, ingestion_job_params: StreamIngestionJobParameters
+ ) -> StreamIngestionJob:
+ """
+ Starts a stream ingestion job to a Spark cluster.
+
+ Raises:
+ SparkJobFailure: The spark job submission failed, encountered error
+ during execution, or timeout.
+
+ Returns:
+ StreamIngestionJob: wrapper around remote job.
+ """
+
+ main_file = self._datalake.upload_file(ingestion_job_params.get_main_file_path())
+
+ extra_jar_paths: List[str] = []
+ for extra_jar in ingestion_job_params.get_extra_jar_paths():
+ extra_jar_paths.append(self._datalake.upload_file(extra_jar))
+
+ tags = _prepare_job_tags(ingestion_job_params, STREAM_TO_ONLINE_JOB_TYPE)
+ tags[METADATA_JOBHASH] = ingestion_job_params.get_job_hash()
+ job_info = _submit_job(self._api, ingestion_job_params.get_project()+"_stream_to_online_ingestion", main_file,
+ main_class = ingestion_job_params.get_class_name(),
+ arguments = ingestion_job_params.get_arguments(),
+ reference_files = extra_jar_paths,
+ configuration=None,
+ tags = tags)
+
+ return cast(StreamIngestionJob, self._job_from_job_info(job_info))
+
+ def get_job_by_id(self, job_id: int) -> SparkJob:
+ job_info = _get_job_by_id(self._api, job_id)
+ if job_info is None:
+ raise KeyError(f"Job iwth id {job_id} not found")
+ else:
+ return self._job_from_job_info(job_info)
+
+ def list_jobs(
+ self,
+ include_terminated: bool,
+ project: Optional[str] = None,
+ table_name: Optional[str] = None,
+ ) -> List[SparkJob]:
+ return [
+ self._job_from_job_info(job)
+ for job in _list_jobs(self._api, project, table_name)
+ if include_terminated
+ or _job_feast_state(job) not in (SparkJobStatus.COMPLETED, SparkJobStatus.FAILED)
+ ]
diff --git a/python/feast_spark/pyspark/launchers/synapse/synapse_utils.py b/python/feast_spark/pyspark/launchers/synapse/synapse_utils.py
new file mode 100644
index 00000000..91fba2eb
--- /dev/null
+++ b/python/feast_spark/pyspark/launchers/synapse/synapse_utils.py
@@ -0,0 +1,277 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+import os
+import re
+import hashlib
+import urllib.request
+from datetime import datetime
+from typing import Any, Dict, List, Optional
+from azure.core.configuration import Configuration
+
+from azure.identity import DefaultAzureCredential
+
+from azure.synapse.spark import SparkClient
+from azure.synapse.spark.models import SparkBatchJobOptions, SparkBatchJob
+
+from azure.storage.filedatalake import DataLakeServiceClient
+
+from feast_spark.pyspark.abc import SparkJobStatus
+
+__all__ = [
+ "_cancel_job_by_id",
+ "_prepare_job_tags",
+ "_list_jobs",
+ "_get_job_by_id",
+ "_generate_project_table_hash",
+ "STREAM_TO_ONLINE_JOB_TYPE",
+ "OFFLINE_TO_ONLINE_JOB_TYPE",
+ "HISTORICAL_RETRIEVAL_JOB_TYPE",
+ "METADATA_JOBHASH",
+ "METADATA_OUTPUT_URI",
+]
+
+STREAM_TO_ONLINE_JOB_TYPE = "STREAM_TO_ONLINE_JOB"
+OFFLINE_TO_ONLINE_JOB_TYPE = "OFFLINE_TO_ONLINE_JOB"
+HISTORICAL_RETRIEVAL_JOB_TYPE = "HISTORICAL_RETRIEVAL_JOB"
+
+LABEL_JOBID = "feast.dev/jobid"
+LABEL_JOBTYPE = "feast.dev/type"
+LABEL_FEATURE_TABLE = "feast.dev/table"
+LABEL_FEATURE_TABLE_HASH = "feast.dev/tablehash"
+LABEL_PROJECT = "feast.dev/project"
+
+# Can't store these bits of info due to 64-character limit, so we store them as
+# sparkConf
+METADATA_OUTPUT_URI = "dev.feast.outputuri"
+METADATA_JOBHASH = "dev.feast.jobhash"
+
+
+def _generate_project_table_hash(project: str, table_name: str) -> str:
+ return hashlib.md5(f"{project}:{table_name}".encode()).hexdigest()
+
+
+def _truncate_label(label: str) -> str:
+ return label[:63]
+
+
+def _prepare_job_tags(job_params, job_type: str) -> Dict[str, Any]:
+ """ Prepare Synapse job tags """
+ return {LABEL_JOBTYPE:job_type,
+ LABEL_FEATURE_TABLE: _truncate_label(
+ job_params.get_feature_table_name()
+ ),
+ LABEL_FEATURE_TABLE_HASH: _generate_project_table_hash(
+ job_params.get_project(),
+ job_params.get_feature_table_name(),
+ ),
+ LABEL_PROJECT: job_params.get_project()
+ }
+
+
+STATE_MAP = {
+ "": SparkJobStatus.STARTING,
+ "not_started": SparkJobStatus.STARTING,
+ 'starting': SparkJobStatus.STARTING,
+ "running": SparkJobStatus.IN_PROGRESS,
+ "success": SparkJobStatus.COMPLETED,
+ "dead": SparkJobStatus.FAILED,
+ "killed": SparkJobStatus.FAILED,
+ "Uncertain": SparkJobStatus.IN_PROGRESS,
+ "Succeeded": SparkJobStatus.COMPLETED,
+ "Failed": SparkJobStatus.FAILED,
+ "Cancelled": SparkJobStatus.FAILED,
+}
+
+
+def _job_feast_state(job: SparkBatchJob) -> SparkJobStatus:
+ return STATE_MAP[job.state]
+
+
+def _job_start_time(job: SparkBatchJob) -> datetime:
+ return job.scheduler.scheduled_at
+
+
+EXECUTOR_SIZE = {'Small': {'Cores': 4, 'Memory': '28g'}, 'Medium': {'Cores': 8, 'Memory': '56g'},
+ 'Large': {'Cores': 16, 'Memory': '112g'}}
+
+
+def categorized_files(reference_files):
+ if reference_files == None:
+ return None, None
+
+ files = []
+ jars = []
+ for file in reference_files:
+ file = file.strip()
+ if file.endswith(".jar"):
+ jars.append(file)
+ else:
+ files.append(file)
+ return files, jars
+
+
+class SynapseJobRunner(object):
+ def __init__(self, synapse_dev_url, spark_pool_name, credential = None, executor_size = 'Small', executors = 2):
+ if credential is None:
+ credential = DefaultAzureCredential()
+
+ self.client = SparkClient(
+ credential=credential,
+ endpoint=synapse_dev_url,
+ spark_pool_name=spark_pool_name
+ )
+
+ self._executor_size = executor_size
+ self._executors = executors
+
+ def get_spark_batch_job(self, job_id):
+
+ return self.client.spark_batch.get_spark_batch_job(job_id, detailed=True)
+
+ def get_spark_batch_jobs(self):
+
+ return self.client.spark_batch.get_spark_batch_jobs(detailed=True)
+
+ def cancel_spark_batch_job(self, job_id):
+
+ return self.client.spark_batch.cancel_spark_batch_job(job_id)
+
+ def create_spark_batch_job(self, job_name, main_definition_file, class_name = None,
+ arguments=None, reference_files=None, archives=None, configuration=None, tags=None):
+
+ file = main_definition_file
+
+ files, jars = categorized_files(reference_files)
+ driver_cores = EXECUTOR_SIZE[self._executor_size]['Cores']
+ driver_memory = EXECUTOR_SIZE[self._executor_size]['Memory']
+ executor_cores = EXECUTOR_SIZE[self._executor_size]['Cores']
+ executor_memory = EXECUTOR_SIZE[self._executor_size]['Memory']
+
+ # SDK source code is here: https://github.com/Azure/azure-sdk-for-python/tree/master/sdk/synapse/azure-synapse
+ # Exact code is here: https://github.com/Azure/azure-sdk-for-python/blob/master/sdk/synapse/azure-synapse-spark/azure/synapse/spark/operations/_spark_batch_operations.py#L114
+ # Adding spaces between brackets. This is to workaround this known YARN issue (when running Spark on YARN):
+ # https://issues.apache.org/jira/browse/SPARK-17814?focusedCommentId=15567964&page=com.atlassian.jira.plugin.system.issuetabpanels%3Acomment-tabpanel#comment-15567964
+ # print(arguments)
+ updated_arguments = []
+ for elem in arguments:
+ if type(elem) == str:
+ updated_arguments.append(elem.replace("}", " }"))
+ else:
+ updated_arguments.append(elem)
+
+
+ spark_batch_job_options = SparkBatchJobOptions(
+ tags=tags,
+ name=job_name,
+ file=file,
+ class_name=class_name,
+ arguments=updated_arguments,
+ jars=jars,
+ files=files,
+ archives=archives,
+ configuration=configuration,
+ driver_memory=driver_memory,
+ driver_cores=driver_cores,
+ executor_memory=executor_memory,
+ executor_cores=executor_cores,
+ executor_count=self._executors)
+
+ return self.client.spark_batch.create_spark_batch_job(spark_batch_job_options, detailed=True)
+
+
+class DataLakeFiler(object):
+ def __init__(self, datalake_dir, credential = None):
+ datalake = list(filter(None, re.split('/|@', datalake_dir)))
+ assert len(datalake) >= 3
+
+ if credential is None:
+ credential = DefaultAzureCredential()
+
+ account_url = "https://" + datalake[2]
+ datalake_client = DataLakeServiceClient(
+ credential=credential,
+ account_url=account_url
+ ).get_file_system_client(datalake[1])
+
+ if len(datalake) > 3:
+ datalake_client = datalake_client.get_directory_client('/'.join(datalake[3:]))
+ datalake_client.create_directory()
+
+ self.datalake_dir = datalake_dir + '/' if datalake_dir[-1] != '/' else datalake_dir
+ self.dir_client = datalake_client
+
+ def upload_file(self, local_file):
+
+ file_name = os.path.basename(local_file)
+ file_client = self.dir_client.create_file(file_name)
+
+ if local_file.startswith('http'):
+ # remote_file = local_file
+ # local_file = './' + file_name
+ # urllib.request.urlretrieve(remote_file, local_file)
+ with urllib.request.urlopen(local_file) as f:
+ data = f.read()
+ file_client.append_data(data, 0, len(data))
+ file_client.flush_data(len(data))
+ else:
+ with open(local_file, 'r') as f:
+ data = f.read()
+ file_client.append_data(data, 0, len(data))
+ file_client.flush_data(len(data))
+
+ return self.datalake_dir + file_name
+
+
+def _submit_job(
+ api: SynapseJobRunner,
+ name: str,
+ main_file: str,
+ main_class = None,
+ arguments = None,
+ reference_files = None,
+ tags = None,
+ configuration = None,
+) -> SparkBatchJob:
+ return api.create_spark_batch_job(name, main_file, class_name = main_class, arguments = arguments,
+ reference_files = reference_files, tags = tags, configuration=configuration)
+
+
+def _list_jobs(
+ api: SynapseJobRunner,
+ project: Optional[str] = None,
+ table_name: Optional[str] = None,
+) -> List[SparkBatchJob]:
+
+ job_infos = api.get_spark_batch_jobs()
+
+ # Batch, Streaming Ingestion jobs
+ if project and table_name:
+ result = []
+ table_name_hash = _generate_project_table_hash(project, table_name)
+ for job_info in job_infos:
+ if LABEL_FEATURE_TABLE_HASH in job_info.tags:
+ if table_name_hash == job_info.tags[LABEL_FEATURE_TABLE_HASH]:
+ result.append(job_info)
+ elif project:
+ result = []
+ for job_info in job_infos:
+ if LABEL_PROJECT in job_info.tags:
+ if project == job_info.tags[LABEL_PROJECT]:
+ result.append(job_info)
+ else:
+ result = job_infos
+
+ return result
+
+
+def _get_job_by_id(
+ api: SynapseJobRunner,
+ job_id: int
+) -> Optional[SparkBatchJob]:
+ return api.get_spark_batch_job(job_id)
+
+
+def _cancel_job_by_id(api: SynapseJobRunner, job_id: int):
+ api.cancel_spark_batch_job(job_id)
+
diff --git a/python/setup.py b/python/setup.py
index 49cd6eb6..195679ea 100644
--- a/python/setup.py
+++ b/python/setup.py
@@ -52,6 +52,11 @@
"grpcio-tools==1.31.0",
"mypy-protobuf==2.5",
"croniter==1.*",
+ "azure-synapse-spark",
+ "azure-synapse",
+ "azure-identity",
+ "azure-storage-file-datalake",
+ "azure-storage-blob",
]
# README file from Feast repo root directory
diff --git a/spark/ingestion/pom.xml b/spark/ingestion/pom.xml
index 1fc560c5..b8bb1612 100644
--- a/spark/ingestion/pom.xml
+++ b/spark/ingestion/pom.xml
@@ -51,6 +51,26 @@
${protobuf.version}
+
+ com.microsoft.azure
+ azure-eventhubs-spark_2.12
+ 2.3.18
+
+
+
+
+ org.glassfish
+ javax.el
+ 3.0.1-b08
+
+
+
+ com.microsoft.azure
+ azure-eventhubs-spark_2.12
+ 2.3.18
+
+
+
com.gojek
stencil
diff --git a/spark/ingestion/src/main/scala/feast/ingestion/BasePipeline.scala b/spark/ingestion/src/main/scala/feast/ingestion/BasePipeline.scala
index 98a47a9e..03e558de 100644
--- a/spark/ingestion/src/main/scala/feast/ingestion/BasePipeline.scala
+++ b/spark/ingestion/src/main/scala/feast/ingestion/BasePipeline.scala
@@ -33,11 +33,12 @@ object BasePipeline {
val conf = new SparkConf()
jobConfig.store match {
- case RedisConfig(host, port, ssl) =>
+ case RedisConfig(host, port, auth, ssl) =>
conf
.set("spark.redis.host", host)
.set("spark.redis.port", port.toString)
.set("spark.redis.ssl", ssl.toString)
+ .set("spark.redis.auth", auth.toString)
case BigTableConfig(projectId, instanceId) =>
conf
.set("spark.bigtable.projectId", projectId)
diff --git a/spark/ingestion/src/main/scala/feast/ingestion/IngestionJob.scala b/spark/ingestion/src/main/scala/feast/ingestion/IngestionJob.scala
index 352fe8c4..8b80e46c 100644
--- a/spark/ingestion/src/main/scala/feast/ingestion/IngestionJob.scala
+++ b/spark/ingestion/src/main/scala/feast/ingestion/IngestionJob.scala
@@ -22,6 +22,8 @@ import org.joda.time.{DateTime, DateTimeZone}
import org.json4s._
import org.json4s.ext.JavaEnumNameSerializer
import org.json4s.jackson.JsonMethods.{parse => parseJSON}
+import org.json4s.ext.JavaEnumNameSerializer
+import scala.collection.mutable.ArrayBuffer
object IngestionJob {
import Modes._
@@ -116,10 +118,22 @@ object IngestionJob {
opt[Int](name = "triggering-interval")
.action((x, c) => c.copy(streamingTriggeringSecs = x))
+
+ opt[String](name = "kafka_sasl_auth")
+ .action((x, c) => c.copy(kafkaSASL = Some(x)))
}
def main(args: Array[String]): Unit = {
- parser.parse(args, IngestionJobConfig()) match {
+ val args_modified = new Array[String](args.length)
+ for ( i <- 0 to (args_modified.length - 1)) {
+ // Removing spaces between brackets. This is to workaround this known YARN issue (when running Spark on YARN):
+ // https://issues.apache.org/jira/browse/SPARK-17814?focusedCommentId=15567964&page=com.atlassian.jira.plugin.system.issuetabpanels%3Acomment-tabpanel#comment-15567964
+ // Also remove the unncessary back slashes
+ args_modified(i) = args(i).replace(" }", "}");
+ args_modified(i) = args_modified(i).replace("\\", "\\\"");
+ }
+ println("arguments received:",args_modified.toList)
+ parser.parse(args_modified, IngestionJobConfig()) match {
case Some(config) =>
println(s"Starting with config $config")
config.mode match {
diff --git a/spark/ingestion/src/main/scala/feast/ingestion/IngestionJobConfig.scala b/spark/ingestion/src/main/scala/feast/ingestion/IngestionJobConfig.scala
index 87150493..8e9a8fe3 100644
--- a/spark/ingestion/src/main/scala/feast/ingestion/IngestionJobConfig.scala
+++ b/spark/ingestion/src/main/scala/feast/ingestion/IngestionJobConfig.scala
@@ -26,7 +26,7 @@ object Modes extends Enumeration {
abstract class StoreConfig
-case class RedisConfig(host: String, port: Int, ssl: Boolean) extends StoreConfig
+case class RedisConfig(host: String, port: Int, auth: String, ssl: Boolean) extends StoreConfig
case class BigTableConfig(projectId: String, instanceId: String) extends StoreConfig
case class CassandraConfig(
connection: CassandraConnection,
@@ -90,6 +90,7 @@ case class KafkaSource(
override val datePartitionColumn: Option[String] = None
) extends StreamingSource
+
case class Sources(
file: Option[FileSource] = None,
bq: Option[BQSource] = None,
@@ -119,12 +120,13 @@ case class IngestionJobConfig(
source: Source = null,
startTime: DateTime = DateTime.now(),
endTime: DateTime = DateTime.now(),
- store: StoreConfig = RedisConfig("localhost", 6379, false),
+ store: StoreConfig = RedisConfig("localhost", 6379, "", false),
metrics: Option[MetricConfig] = None,
deadLetterPath: Option[String] = None,
stencilURL: Option[String] = None,
streamingTriggeringSecs: Int = 0,
validationConfig: Option[ValidationConfig] = None,
doNotIngestInvalidRows: Boolean = false,
- checkpointPath: Option[String] = None
+ checkpointPath: Option[String] = None,
+ kafkaSASL: Option[String] = None
)
diff --git a/spark/ingestion/src/main/scala/feast/ingestion/StreamingPipeline.scala b/spark/ingestion/src/main/scala/feast/ingestion/StreamingPipeline.scala
index bebef92e..33f5746e 100644
--- a/spark/ingestion/src/main/scala/feast/ingestion/StreamingPipeline.scala
+++ b/spark/ingestion/src/main/scala/feast/ingestion/StreamingPipeline.scala
@@ -33,6 +33,9 @@ import org.apache.spark.sql.functions._
import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryListener}
import org.apache.spark.sql.types.BooleanType
import org.apache.spark.{SparkEnv, SparkFiles}
+import org.apache.spark.eventhubs._
+import org.apache.kafka.common.security.plain.PlainLoginModule
+import org.apache.kafka.common.security.JaasContext
import java.io.File
import java.sql.Timestamp
@@ -73,17 +76,38 @@ object StreamingPipeline extends BasePipeline with Serializable {
val validationUDF = createValidationUDF(sparkSession, config)
+ val EH_SASL = "org.apache.kafka.common.security.plain.PlainLoginModule required username=\"$ConnectionString\" password=\"Endpoint=sb://xxx.servicebus.windows.net/;SharedAccessKeyName=RootManageSharedAccessKey;SharedAccessKey=yyy=;EntityPath=driver_trips\";"
+
val input = config.source match {
case source: KafkaSource =>
- sparkSession.readStream
- .format("kafka")
- .option("kafka.bootstrap.servers", source.bootstrapServers)
- .option("subscribe", source.topic)
- .load()
+ if (config.kafkaSASL.nonEmpty)
+ {
+ // if we have authentication enabled
+ println("config.kafkaSASL value:", config.kafkaSASL.get)
+ sparkSession.readStream
+ .format("kafka")
+ .option("subscribe", source.topic)
+ .option("kafka.bootstrap.servers", source.bootstrapServers)
+ .option("kafka.sasl.mechanism", "PLAIN")
+ .option("kafka.security.protocol", "SASL_SSL")
+ .option("kafka.sasl.jaas.config", config.kafkaSASL.get)
+ .option("kafka.request.timeout.ms", "60000")
+ .option("kafka.session.timeout.ms", "60000")
+ .option("failOnDataLoss", "false")
+ .load()
+ }
+ else
+ {
+ println("config.kafkaSASL is empty.")
+ sparkSession.readStream
+ .format("kafka")
+ .option("kafka.bootstrap.servers", source.bootstrapServers)
+ .option("subscribe", source.topic)
+ .load()
+ }
case source: MemoryStreamingSource =>
- source.read
+ source.read
}
-
val featureStruct = config.source.asInstanceOf[StreamingSource].format match {
case ProtoFormat(classPath) =>
val parser = protoParser(sparkSession, classPath)
diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py
index d63d311c..534468cc 100644
--- a/tests/e2e/conftest.py
+++ b/tests/e2e/conftest.py
@@ -10,7 +10,7 @@ def pytest_addoption(parser):
parser.addoption("--kafka-brokers", action="store", default="localhost:9092")
parser.addoption(
- "--env", action="store", help="local|aws|gcloud|k8s", default="local"
+ "--env", action="store", help="local|aws|gcloud|k8s|synapse", default="local"
)
parser.addoption("--with-job-service", action="store_true")
parser.addoption("--staging-path", action="store")
@@ -23,6 +23,11 @@ def pytest_addoption(parser):
parser.addoption("--dataproc-executor-cores", action="store", default="2")
parser.addoption("--dataproc-executor-memory", action="store", default="2g")
parser.addoption("--k8s-namespace", action="store", default="sparkop-e2e")
+ parser.addoption("--azure-synapse-dev-url", action="store", default="")
+ parser.addoption("--azure-synapse-pool-name", action="store", default="")
+ parser.addoption("--azure-synapse-datalake-dir", action="store", default="")
+ parser.addoption("--azure-blob-account-name", action="store", default="")
+ parser.addoption("--azure-blob-account-access-key", action="store", default="")
parser.addoption("--ingestion-jar", action="store")
parser.addoption("--redis-url", action="store", default="localhost:6379")
parser.addoption("--redis-cluster", action="store_true")
diff --git a/tests/e2e/fixtures/client.py b/tests/e2e/fixtures/client.py
index 7ffa936f..3d331fd7 100644
--- a/tests/e2e/fixtures/client.py
+++ b/tests/e2e/fixtures/client.py
@@ -102,6 +102,24 @@ def feast_client(
enable_auth=pytestconfig.getoption("enable_auth"),
**job_service_env,
)
+ elif pytestconfig.getoption("env") == "synapse":
+ return Client(
+ core_url=f"{feast_core[0]}:{feast_core[1]}",
+ serving_url=f"{feast_serving[0]}:{feast_serving[1]}",
+ spark_launcher="synapse",
+ azure_synapse_dev_url = pytestconfig.getoption("azure_synapse_dev_url"),
+ azure_synapse_pool_name = pytestconfig.getoption("azure_synapse_pool_name"),
+ azure_synapse_datalake_dir = pytestconfig.getoption("azure_synapse_datalake_dir"),
+ spark_staging_location=os.path.join(local_staging_path, "synapse"),
+ azure_blob_account_name=pytestconfig.getoption("azure_blob_account_name"),
+ azure_blob_account_access_key=pytestconfig.getoption("azure_blob_account_access_key"),
+ spark_ingestion_jar=ingestion_job_jar,
+ redis_host=pytestconfig.getoption("redis_url").split(":")[0],
+ redis_port=pytestconfig.getoption("redis_url").split(":")[1],
+ historical_feature_output_location=os.path.join(
+ local_staging_path, "historical_output"
+ ),
+ )
else:
raise KeyError(f"Unknown environment {pytestconfig.getoption('env')}")
@@ -184,6 +202,21 @@ def tfrecord_feast_client(
),
**job_service_env,
)
+ elif pytestconfig.getoption("env") == "synapse":
+ return Client(
+ core_url=f"{feast_core[0]}:{feast_core[1]}",
+ spark_launcher="synapse",
+ azure_synapse_dev_url = pytestconfig.getoption("azure_synapse_dev_url"),
+ azure_synapse_pool_name = pytestconfig.getoption("azure_synapse_pool_name"),
+ azure_synapse_datalake_dir = pytestconfig.getoption("azure_synapse_datalake_dir"),
+ spark_staging_location=os.path.join(local_staging_path, "synapse"),
+ azure_blob_account_name=pytestconfig.getoption("azure_blob_account_name"),
+ azure_blob_account_access_key=pytestconfig.getoption("azure_blob_account_access_key"),
+ historical_feature_output_format="tfrecord",
+ historical_feature_output_location=os.path.join(
+ local_staging_path, "historical_output"
+ ),
+ )
else:
raise KeyError(f"Unknown environment {pytestconfig.getoption('env')}")