diff --git a/Makefile b/Makefile index a3d83b56..55beb3c3 100644 --- a/Makefile +++ b/Makefile @@ -53,7 +53,7 @@ build-local-test-docker: docker build -t feast:local -f infra/docker/tests/Dockerfile . build-ingestion-jar-no-tests: - cd spark/ingestion && ${MVN} --no-transfer-progress -Dmaven.javadoc.skip=true -Dgpg.skip -DskipUTs=true -DskipITs=true -Drevision=${REVISION} clean package + cd spark/ingestion && ${MVN} --no-transfer-progress -Dmaven.javadoc.skip=true -Dgpg.skip -DskipUTs=true -D"spotless.check.skip"=true -DskipITs=true -Drevision=${REVISION} clean package build-jobservice-docker: docker build -t $(REGISTRY)/feast-jobservice:$(VERSION) -f infra/docker/jobservice/Dockerfile . @@ -68,3 +68,11 @@ push-spark-docker: docker push $(REGISTRY)/feast-spark:$(VERSION) install-ci-dependencies: install-python-ci-dependencies + +build-ingestion-jar-push: + docker build -t $(REGISTRY)/feast-spark:$(VERSION) --build-arg VERSION=$(VERSION) -f infra/docker/spark/Dockerfile . + rm -f feast-ingestion-spark-latest.jar + docker create -ti --name dummy $(REGISTRY)/feast-spark:latest bash + docker cp dummy:/opt/spark/jars/feast-ingestion-spark-latest.jar feast-ingestion-spark-latest.jar + docker rm -f dummy + python python/feast_spark/copy_to_azure_blob.py \ No newline at end of file diff --git a/README.md b/README.md index 0e00aad6..44c8ff0c 100644 --- a/README.md +++ b/README.md @@ -57,4 +57,14 @@ client.apply(entity, ft) # Start spark streaming ingestion job that reads from kafka and writes to the online store feast_spark.Client(client).start_stream_to_online_ingestion(ft) -``` \ No newline at end of file +``` + +Build and push to BLOB storage + +In order to build the Spark Ingestion jar and copy it to BLOB storage, you have to set these 3 environment variables: + +```bash +export VERSION=latest +export REGISTRY=your_registry_name +export AZURE_STORAGE_CONNECTION_STRING="your_azure_storage_connection_string" +``` diff --git a/pom.xml b/pom.xml index 4eeffaac..0b7f608b 100644 --- a/pom.xml +++ b/pom.xml @@ -18,8 +18,8 @@ 1.8 1.8 2.12 - ${scala.version}.12 - 3.0.2 + ${scala.version}.10 + 3.1.2 4.4.0 3.3.0 3.12.2 diff --git a/python/feast_spark/constants.py b/python/feast_spark/constants.py index 8ea25ec7..6d34810a 100644 --- a/python/feast_spark/constants.py +++ b/python/feast_spark/constants.py @@ -93,6 +93,27 @@ class ConfigOptions(metaclass=ConfigMeta): # SparkApplication resource template SPARK_K8S_JOB_TEMPLATE_PATH = None + # Synapse dev url + AZURE_SYNAPSE_DEV_URL: Optional[str] = None + + # Synapse pool name + AZURE_SYNAPSE_POOL_NAME: Optional[str] = None + + # Datalake directory that linked to Synapse + AZURE_SYNAPSE_DATALAKE_DIR: Optional[str] = None + + # Synapse pool executor size: Small, Medium or Large + AZURE_SYNAPSE_EXECUTOR_SIZE = "Small" + + # Synapse pool executor count + AZURE_SYNAPSE_EXECUTORS = "2" + + # Azure EventHub Connection String (with Kafka API). See more details here: + # https://docs.microsoft.com/en-us/azure/event-hubs/apache-kafka-migration-guide + # Code Sample is here: + # https://github.com/Azure/azure-event-hubs-for-kafka/blob/master/tutorials/spark/sparkConsumer.scala + AZURE_EVENTHUB_KAFKA_CONNECTION_STRING = "" + #: File format of historical retrieval features HISTORICAL_FEATURE_OUTPUT_FORMAT: str = "parquet" @@ -108,6 +129,9 @@ class ConfigOptions(metaclass=ConfigMeta): #: Enable or disable TLS/SSL to Redis REDIS_SSL: Optional[str] = "False" + #: Auth string for redis + REDIS_AUTH: str = "" + #: BigTable Project ID BIGTABLE_PROJECT: Optional[str] = "" diff --git a/python/feast_spark/copy_to_azure_blob.py b/python/feast_spark/copy_to_azure_blob.py new file mode 100644 index 00000000..7f8f719f --- /dev/null +++ b/python/feast_spark/copy_to_azure_blob.py @@ -0,0 +1,41 @@ +# coding: utf-8 + +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +""" +FILE: blob_samples_copy_blob.py +DESCRIPTION: + This sample demos how to copy a blob from a URL. +USAGE: python blob_samples_copy_blob.py + Set the environment variables with your own values before running the sample. + 1) AZURE_STORAGE_CONNECTION_STRING - the connection string to your storage account +""" + +from __future__ import print_function +import os +import sys +import time +from azure.storage.blob import BlobServiceClient + +def main(): + try: + CONNECTION_STRING = os.environ['AZURE_STORAGE_CONNECTION_STRING'] + + except KeyError: + print("AZURE_STORAGE_CONNECTION_STRING must be set.") + sys.exit(1) + + blob_service_client = BlobServiceClient.from_connection_string(CONNECTION_STRING) + copied_blob = blob_service_client.get_blob_client("feastjar", 'feast-ingestion-spark-latest.jar') + # hard code to the current path + SOURCE_FILE = "./feast-ingestion-spark-latest.jar" + + with open(SOURCE_FILE, "rb") as data: + copied_blob.upload_blob(data, blob_type="BlockBlob",overwrite=True) + +if __name__ == "__main__": + main() diff --git a/python/feast_spark/pyspark/abc.py b/python/feast_spark/pyspark/abc.py index db5c041d..154f8d08 100644 --- a/python/feast_spark/pyspark/abc.py +++ b/python/feast_spark/pyspark/abc.py @@ -340,13 +340,10 @@ def __init__( feature_table: Dict, source: Dict, jar: str, - redis_host: Optional[str] = None, - redis_port: Optional[int] = None, - redis_ssl: Optional[bool] = None, - bigtable_project: Optional[str] = None, - bigtable_instance: Optional[str] = None, - cassandra_host: Optional[str] = None, - cassandra_port: Optional[int] = None, + redis_host: str, + redis_port: int, + redis_ssl: bool, + redis_auth: str, statsd_host: Optional[str] = None, statsd_port: Optional[int] = None, deadletter_path: Optional[str] = None, @@ -359,10 +356,7 @@ def __init__( self._redis_host = redis_host self._redis_port = redis_port self._redis_ssl = redis_ssl - self._bigtable_project = bigtable_project - self._bigtable_instance = bigtable_instance - self._cassandra_host = cassandra_host - self._cassandra_port = cassandra_port + self._redis_auth = redis_auth self._statsd_host = statsd_host self._statsd_port = statsd_port self._deadletter_path = deadletter_path @@ -370,15 +364,7 @@ def __init__( self._drop_invalid_rows = drop_invalid_rows def _get_redis_config(self): - return dict(host=self._redis_host, port=self._redis_port, ssl=self._redis_ssl) - - def _get_bigtable_config(self): - return dict( - project_id=self._bigtable_project, instance_id=self._bigtable_instance - ) - - def _get_cassandra_config(self): - return dict(host=self._cassandra_host, port=self._cassandra_port) + return dict(host=self._redis_host, port=self._redis_port, ssl=self._redis_ssl, auth=self._redis_auth) def _get_statsd_config(self): return ( @@ -405,17 +391,10 @@ def get_arguments(self) -> List[str]: json.dumps(self._feature_table), "--source", json.dumps(self._source), + "--redis", + json.dumps(self._get_redis_config()), ] - if self._redis_host and self._redis_port: - args.extend(["--redis", json.dumps(self._get_redis_config())]) - - if self._bigtable_project and self._bigtable_instance: - args.extend(["--bigtable", json.dumps(self._get_bigtable_config())]) - - if self._cassandra_host and self._cassandra_port: - args.extend(["--cassandra", json.dumps(self._get_cassandra_config())]) - if self._get_statsd_config(): args.extend(["--statsd", json.dumps(self._get_statsd_config())]) @@ -444,13 +423,14 @@ def __init__( start: datetime, end: datetime, jar: str, - redis_host: Optional[str], - redis_port: Optional[int], - redis_ssl: Optional[bool], - bigtable_project: Optional[str], - bigtable_instance: Optional[str], + redis_host: str, + redis_port: int, + redis_ssl: bool, + redis_auth: str, + bigtable_project: Optional[str] = None, + bigtable_instance: Optional[str] = None, cassandra_host: Optional[str] = None, - cassandra_port: Optional[int] = None, + cassandra_port: Optional[str] = None, statsd_host: Optional[str] = None, statsd_port: Optional[int] = None, deadletter_path: Optional[str] = None, @@ -463,10 +443,7 @@ def __init__( redis_host, redis_port, redis_ssl, - bigtable_project, - bigtable_instance, - cassandra_host, - cassandra_port, + redis_auth, statsd_host, statsd_port, deadletter_path, @@ -494,7 +471,6 @@ def get_arguments(self) -> List[str]: self._end.strftime("%Y-%m-%dT%H:%M:%S"), ] - class ScheduledBatchIngestionJobParameters(IngestionJobParameters): def __init__( self, @@ -559,21 +535,20 @@ def __init__( source: Dict, jar: str, extra_jars: List[str], - redis_host: Optional[str], - redis_port: Optional[int], - redis_ssl: Optional[bool], - bigtable_project: Optional[str], - bigtable_instance: Optional[str], - cassandra_host: Optional[str] = None, - cassandra_port: Optional[int] = None, + redis_host: str, + redis_port: int, + redis_ssl: bool, + redis_auth: str, statsd_host: Optional[str] = None, statsd_port: Optional[int] = None, deadletter_path: Optional[str] = None, checkpoint_path: Optional[str] = None, stencil_url: Optional[str] = None, - drop_invalid_rows: bool = False, - triggering_interval: Optional[int] = None, + drop_invalid_rows: Optional[bool] = False, + kafka_sasl_auth: Optional[str] = None, ): + stencil_url: Optional[str] = None, + drop_invalid_rows: bool = False, super().__init__( feature_table, source, @@ -581,10 +556,7 @@ def __init__( redis_host, redis_port, redis_ssl, - bigtable_project, - bigtable_instance, - cassandra_host, - cassandra_port, + redis_auth, statsd_host, statsd_port, deadletter_path, @@ -593,7 +565,7 @@ def __init__( ) self._extra_jars = extra_jars self._checkpoint_path = checkpoint_path - self._triggering_interval = triggering_interval + self._kafka_sasl_auth = kafka_sasl_auth def get_name(self) -> str: return f"{self.get_job_type().to_pascal_case()}-{self.get_feature_table_name()}" @@ -609,8 +581,8 @@ def get_arguments(self) -> List[str]: args.extend(["--mode", "online"]) if self._checkpoint_path: args.extend(["--checkpoint-path", self._checkpoint_path]) - if self._triggering_interval: - args.extend(["--triggering-interval", str(self._triggering_interval)]) + if self._kafka_sasl_auth: + args.extend(["--kafka_sasl_auth", self._kafka_sasl_auth]) return args def get_job_hash(self) -> str: @@ -705,29 +677,6 @@ def offline_to_online_ingestion( """ raise NotImplementedError - @abc.abstractmethod - def schedule_offline_to_online_ingestion( - self, ingestion_job_params: ScheduledBatchIngestionJobParameters - ): - """ - Submits a scheduled batch ingestion job to a Spark cluster. - - Raises: - SparkJobFailure: The spark job submission failed, encountered error - during execution, or timeout. - - Returns: - ScheduledBatchIngestionJob: wrapper around remote job that can be used to check when job completed. - """ - raise NotImplementedError - - @abc.abstractmethod - def unschedule_offline_to_online_ingestion(self, project: str, feature_table: str): - """ - Unschedule a scheduled batch ingestion job. - """ - raise NotImplementedError - @abc.abstractmethod def start_stream_to_online_ingestion( self, ingestion_job_params: StreamIngestionJobParameters diff --git a/python/feast_spark/pyspark/launcher.py b/python/feast_spark/pyspark/launcher.py index 9f5e95ac..f8f8dc2d 100644 --- a/python/feast_spark/pyspark/launcher.py +++ b/python/feast_spark/pyspark/launcher.py @@ -83,11 +83,24 @@ def _k8s_launcher(config: Config) -> JobLauncher: ) +def _synapse_launcher(config: Config) -> JobLauncher: + from feast_spark.pyspark.launchers import synapse + + return synapse.SynapseJobLauncher( + synapse_dev_url=config.get(opt.AZURE_SYNAPSE_DEV_URL), + pool_name=config.get(opt.AZURE_SYNAPSE_POOL_NAME), + datalake_dir=config.get(opt.AZURE_SYNAPSE_DATALAKE_DIR), + executor_size=config.get(opt.AZURE_SYNAPSE_EXECUTOR_SIZE), + executors=int(config.get(opt.AZURE_SYNAPSE_EXECUTORS)) + ) + + _launchers = { "standalone": _standalone_launcher, "dataproc": _dataproc_launcher, "emr": _emr_launcher, "k8s": _k8s_launcher, + 'synapse': _synapse_launcher, } @@ -347,6 +360,7 @@ def start_offline_to_online_ingestion( redis_port=bool(client.config.get(opt.REDIS_HOST)) and client.config.getint(opt.REDIS_PORT), redis_ssl=client.config.getboolean(opt.REDIS_SSL), + redis_auth=client.config.get(opt.REDIS_AUTH), bigtable_project=client.config.get(opt.BIGTABLE_PROJECT), bigtable_instance=client.config.get(opt.BIGTABLE_INSTANCE), cassandra_host=client.config.get(opt.CASSANDRA_HOST), @@ -423,11 +437,9 @@ def get_stream_to_online_ingestion_params( source=_source_to_argument(feature_table.stream_source, client.config), feature_table=_feature_table_to_argument(client, project, feature_table), redis_host=client.config.get(opt.REDIS_HOST), - redis_port=bool(client.config.get(opt.REDIS_HOST)) - and client.config.getint(opt.REDIS_PORT), + redis_port=client.config.getint(opt.REDIS_PORT), redis_ssl=client.config.getboolean(opt.REDIS_SSL), - bigtable_project=client.config.get(opt.BIGTABLE_PROJECT), - bigtable_instance=client.config.get(opt.BIGTABLE_INSTANCE), + redis_auth=client.config.get(opt.REDIS_AUTH), statsd_host=client.config.getboolean(opt.STATSD_ENABLED) and client.config.get(opt.STATSD_HOST), statsd_port=client.config.getboolean(opt.STATSD_ENABLED) @@ -436,11 +448,9 @@ def get_stream_to_online_ingestion_params( checkpoint_path=client.config.get(opt.CHECKPOINT_PATH), stencil_url=client.config.get(opt.STENCIL_URL), drop_invalid_rows=client.config.get(opt.INGESTION_DROP_INVALID_ROWS), - triggering_interval=client.config.getint( - opt.SPARK_STREAMING_TRIGGERING_INTERVAL, default=None - ), - ) + kafka_sasl_auth=client.config.get(opt.AZURE_EVENTHUB_KAFKA_CONNECTION_STRING), + ) def start_stream_to_online_ingestion( client: "Client", project: str, feature_table: FeatureTable, extra_jars: List[str] diff --git a/python/feast_spark/pyspark/launchers/synapse/__init__.py b/python/feast_spark/pyspark/launchers/synapse/__init__.py new file mode 100644 index 00000000..59dadfa4 --- /dev/null +++ b/python/feast_spark/pyspark/launchers/synapse/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +from .synapse import ( + SynapseBatchIngestionJob, + SynapseJobLauncher, + SynapseRetrievalJob, + SynapseStreamIngestionJob, +) + +__all__ = [ + "SynapseRetrievalJob", + "SynapseBatchIngestionJob", + "SynapseStreamIngestionJob", + "SynapseJobLauncher", +] diff --git a/python/feast_spark/pyspark/launchers/synapse/synapse.py b/python/feast_spark/pyspark/launchers/synapse/synapse.py new file mode 100644 index 00000000..3a42a95a --- /dev/null +++ b/python/feast_spark/pyspark/launchers/synapse/synapse.py @@ -0,0 +1,298 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +import time +from datetime import datetime +from typing import List, Optional, cast + +from azure.synapse.spark.models import SparkBatchJob +from azure.identity import DefaultAzureCredential, DeviceCodeCredential, ChainedTokenCredential, ManagedIdentityCredential,EnvironmentCredential + +from feast_spark.pyspark.abc import ( + BatchIngestionJob, + BatchIngestionJobParameters, + JobLauncher, + RetrievalJob, + RetrievalJobParameters, + SparkJob, + SparkJobFailure, + SparkJobStatus, + StreamIngestionJob, + StreamIngestionJobParameters, +) + +from .synapse_utils import ( + HISTORICAL_RETRIEVAL_JOB_TYPE, + LABEL_JOBTYPE, + LABEL_FEATURE_TABLE, + METADATA_JOBHASH, + METADATA_OUTPUT_URI, + OFFLINE_TO_ONLINE_JOB_TYPE, + STREAM_TO_ONLINE_JOB_TYPE, + SynapseJobRunner, + DataLakeFiler, + _prepare_job_tags, + _job_feast_state, + _job_start_time, + _cancel_job_by_id, + _get_job_by_id, + _list_jobs, + _submit_job, +) + + +class SynapseJobMixin: + def __init__(self, api: SynapseJobRunner, job_id: int): + self._api = api + self._job_id = job_id + + def get_id(self) -> str: + return self._job_id + + def get_status(self) -> SparkJobStatus: + job = _get_job_by_id(self._api, self._job_id) + assert job is not None + return _job_feast_state(job) + + def get_start_time(self) -> datetime: + job = _get_job_by_id(self._api, self._job_id) + assert job is not None + return _job_start_time(job) + + def cancel(self): + _cancel_job_by_id(self._api, self._job_id) + + def _wait_for_complete(self, timeout_seconds: Optional[float]) -> bool: + """ Returns true if the job completed successfully """ + start_time = time.time() + while (timeout_seconds is None) or (time.time() - start_time < timeout_seconds): + status = self.get_status() + if status == SparkJobStatus.COMPLETED: + return True + elif status == SparkJobStatus.FAILED: + return False + else: + time.sleep(1) + else: + raise TimeoutError("Timeout waiting for job to complete") + + +class SynapseRetrievalJob(SynapseJobMixin, RetrievalJob): + """ + Historical feature retrieval job result for a synapse cluster + """ + + def __init__( + self, api: SynapseJobRunner, job_id: int, output_file_uri: str + ): + """ + This is the job object representing the historical retrieval job, returned by SynapseClusterLauncher. + + Args: + output_file_uri (str): Uri to the historical feature retrieval job output file. + """ + super().__init__(api, job_id) + self._output_file_uri = output_file_uri + + def get_output_file_uri(self, timeout_sec=None, block=True): + if not block: + return self._output_file_uri + + if self._wait_for_complete(timeout_sec): + return self._output_file_uri + else: + raise SparkJobFailure("Spark job failed") + + +class SynapseBatchIngestionJob(SynapseJobMixin, BatchIngestionJob): + """ + Ingestion job result for a synapse cluster + """ + + def __init__( + self, api: SynapseJobRunner, job_id: int, feature_table: str + ): + super().__init__(api, job_id) + self._feature_table = feature_table + + def get_feature_table(self) -> str: + return self._feature_table + + +class SynapseStreamIngestionJob(SynapseJobMixin, StreamIngestionJob): + """ + Ingestion streaming job for a synapse cluster + """ + + def __init__( + self, + api: SynapseJobRunner, + job_id: int, + job_hash: str, + feature_table: str, + ): + super().__init__(api, job_id) + self._job_hash = job_hash + self._feature_table = feature_table + + def get_hash(self) -> str: + return self._job_hash + + def get_feature_table(self) -> str: + return self._feature_table + +login_credential_cache = None + +class SynapseJobLauncher(JobLauncher): + """ + Submits spark jobs to a spark cluster. Currently supports only historical feature retrieval jobs. + """ + + def __init__( + self, + synapse_dev_url: str, + pool_name: str, + datalake_dir: str, + executor_size: str, + executors: int + ): + tenant_id='72f988bf-86f1-41af-91ab-2d7cd011db47' + authority_host_uri = 'login.microsoftonline.com' + client_id = '04b07795-8ddb-461a-bbee-02f9e1bf7b46' + + global login_credential_cache + # use a global cache to store the credential, to avoid users from multiple login + + if login_credential_cache is None: + # use DeviceCodeCredential if EnvironmentCredential is not available + self.credential = ChainedTokenCredential(EnvironmentCredential(), DeviceCodeCredential(client_id, authority=authority_host_uri, tenant=tenant_id)) + login_credential_cache = self.credential + else: + self.credential = login_credential_cache + + self._api = SynapseJobRunner(synapse_dev_url, pool_name, executor_size = executor_size, executors = executors, credential=self.credential) + self._datalake = DataLakeFiler(datalake_dir,credential=self.credential) + + def _job_from_job_info(self, job_info: SparkBatchJob) -> SparkJob: + job_type = job_info.tags[LABEL_JOBTYPE] + if job_type == HISTORICAL_RETRIEVAL_JOB_TYPE: + assert METADATA_OUTPUT_URI in job_info.tags + return SynapseRetrievalJob( + api=self._api, + job_id=job_info.id, + output_file_uri=job_info.tags[METADATA_OUTPUT_URI], + ) + elif job_type == OFFLINE_TO_ONLINE_JOB_TYPE: + return SynapseBatchIngestionJob( + api=self._api, + job_id=job_info.id, + feature_table=job_info.tags.get(LABEL_FEATURE_TABLE, ""), + ) + elif job_type == STREAM_TO_ONLINE_JOB_TYPE: + # job_hash must not be None for stream ingestion jobs + assert METADATA_JOBHASH in job_info.tags + return SynapseStreamIngestionJob( + api=self._api, + job_id=job_info.id, + job_hash=job_info.tags[METADATA_JOBHASH], + feature_table=job_info.tags.get(LABEL_FEATURE_TABLE, ""), + ) + else: + # We should never get here + raise ValueError(f"Unknown job type {job_type}") + + def historical_feature_retrieval( + self, job_params: RetrievalJobParameters + ) -> RetrievalJob: + """ + Submits a historical feature retrieval job to a Spark cluster. + + Raises: + SparkJobFailure: The spark job submission failed, encountered error + during execution, or timeout. + + Returns: + RetrievalJob: wrapper around remote job that returns file uri to the result file. + """ + + main_file = self._datalake.upload_file(job_params.get_main_file_path()) + job_info = _submit_job(self._api, "Historical-Retrieval", main_file, + arguments = job_params.get_arguments(), + tags = {LABEL_JOBTYPE: HISTORICAL_RETRIEVAL_JOB_TYPE, + METADATA_OUTPUT_URI: job_params.get_destination_path()}) + + return cast(RetrievalJob, self._job_from_job_info(job_info)) + + def offline_to_online_ingestion( + self, ingestion_job_params: BatchIngestionJobParameters + ) -> BatchIngestionJob: + """ + Submits a batch ingestion job to a Spark cluster. + + Raises: + SparkJobFailure: The spark job submission failed, encountered error + during execution, or timeout. + + Returns: + BatchIngestionJob: wrapper around remote job that can be used to check when job completed. + """ + + main_file = self._datalake.upload_file(ingestion_job_params.get_main_file_path()) + + job_info = _submit_job(self._api, ingestion_job_params.get_project()+"_offline_to_online_ingestion", main_file, + main_class = ingestion_job_params.get_class_name(), + arguments = ingestion_job_params.get_arguments(), + reference_files=[main_file], + tags = _prepare_job_tags(ingestion_job_params, OFFLINE_TO_ONLINE_JOB_TYPE),configuration=None) + + return cast(BatchIngestionJob, self._job_from_job_info(job_info)) + + def start_stream_to_online_ingestion( + self, ingestion_job_params: StreamIngestionJobParameters + ) -> StreamIngestionJob: + """ + Starts a stream ingestion job to a Spark cluster. + + Raises: + SparkJobFailure: The spark job submission failed, encountered error + during execution, or timeout. + + Returns: + StreamIngestionJob: wrapper around remote job. + """ + + main_file = self._datalake.upload_file(ingestion_job_params.get_main_file_path()) + + extra_jar_paths: List[str] = [] + for extra_jar in ingestion_job_params.get_extra_jar_paths(): + extra_jar_paths.append(self._datalake.upload_file(extra_jar)) + + tags = _prepare_job_tags(ingestion_job_params, STREAM_TO_ONLINE_JOB_TYPE) + tags[METADATA_JOBHASH] = ingestion_job_params.get_job_hash() + job_info = _submit_job(self._api, ingestion_job_params.get_project()+"_stream_to_online_ingestion", main_file, + main_class = ingestion_job_params.get_class_name(), + arguments = ingestion_job_params.get_arguments(), + reference_files = extra_jar_paths, + configuration=None, + tags = tags) + + return cast(StreamIngestionJob, self._job_from_job_info(job_info)) + + def get_job_by_id(self, job_id: int) -> SparkJob: + job_info = _get_job_by_id(self._api, job_id) + if job_info is None: + raise KeyError(f"Job iwth id {job_id} not found") + else: + return self._job_from_job_info(job_info) + + def list_jobs( + self, + include_terminated: bool, + project: Optional[str] = None, + table_name: Optional[str] = None, + ) -> List[SparkJob]: + return [ + self._job_from_job_info(job) + for job in _list_jobs(self._api, project, table_name) + if include_terminated + or _job_feast_state(job) not in (SparkJobStatus.COMPLETED, SparkJobStatus.FAILED) + ] diff --git a/python/feast_spark/pyspark/launchers/synapse/synapse_utils.py b/python/feast_spark/pyspark/launchers/synapse/synapse_utils.py new file mode 100644 index 00000000..91fba2eb --- /dev/null +++ b/python/feast_spark/pyspark/launchers/synapse/synapse_utils.py @@ -0,0 +1,277 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import os +import re +import hashlib +import urllib.request +from datetime import datetime +from typing import Any, Dict, List, Optional +from azure.core.configuration import Configuration + +from azure.identity import DefaultAzureCredential + +from azure.synapse.spark import SparkClient +from azure.synapse.spark.models import SparkBatchJobOptions, SparkBatchJob + +from azure.storage.filedatalake import DataLakeServiceClient + +from feast_spark.pyspark.abc import SparkJobStatus + +__all__ = [ + "_cancel_job_by_id", + "_prepare_job_tags", + "_list_jobs", + "_get_job_by_id", + "_generate_project_table_hash", + "STREAM_TO_ONLINE_JOB_TYPE", + "OFFLINE_TO_ONLINE_JOB_TYPE", + "HISTORICAL_RETRIEVAL_JOB_TYPE", + "METADATA_JOBHASH", + "METADATA_OUTPUT_URI", +] + +STREAM_TO_ONLINE_JOB_TYPE = "STREAM_TO_ONLINE_JOB" +OFFLINE_TO_ONLINE_JOB_TYPE = "OFFLINE_TO_ONLINE_JOB" +HISTORICAL_RETRIEVAL_JOB_TYPE = "HISTORICAL_RETRIEVAL_JOB" + +LABEL_JOBID = "feast.dev/jobid" +LABEL_JOBTYPE = "feast.dev/type" +LABEL_FEATURE_TABLE = "feast.dev/table" +LABEL_FEATURE_TABLE_HASH = "feast.dev/tablehash" +LABEL_PROJECT = "feast.dev/project" + +# Can't store these bits of info due to 64-character limit, so we store them as +# sparkConf +METADATA_OUTPUT_URI = "dev.feast.outputuri" +METADATA_JOBHASH = "dev.feast.jobhash" + + +def _generate_project_table_hash(project: str, table_name: str) -> str: + return hashlib.md5(f"{project}:{table_name}".encode()).hexdigest() + + +def _truncate_label(label: str) -> str: + return label[:63] + + +def _prepare_job_tags(job_params, job_type: str) -> Dict[str, Any]: + """ Prepare Synapse job tags """ + return {LABEL_JOBTYPE:job_type, + LABEL_FEATURE_TABLE: _truncate_label( + job_params.get_feature_table_name() + ), + LABEL_FEATURE_TABLE_HASH: _generate_project_table_hash( + job_params.get_project(), + job_params.get_feature_table_name(), + ), + LABEL_PROJECT: job_params.get_project() + } + + +STATE_MAP = { + "": SparkJobStatus.STARTING, + "not_started": SparkJobStatus.STARTING, + 'starting': SparkJobStatus.STARTING, + "running": SparkJobStatus.IN_PROGRESS, + "success": SparkJobStatus.COMPLETED, + "dead": SparkJobStatus.FAILED, + "killed": SparkJobStatus.FAILED, + "Uncertain": SparkJobStatus.IN_PROGRESS, + "Succeeded": SparkJobStatus.COMPLETED, + "Failed": SparkJobStatus.FAILED, + "Cancelled": SparkJobStatus.FAILED, +} + + +def _job_feast_state(job: SparkBatchJob) -> SparkJobStatus: + return STATE_MAP[job.state] + + +def _job_start_time(job: SparkBatchJob) -> datetime: + return job.scheduler.scheduled_at + + +EXECUTOR_SIZE = {'Small': {'Cores': 4, 'Memory': '28g'}, 'Medium': {'Cores': 8, 'Memory': '56g'}, + 'Large': {'Cores': 16, 'Memory': '112g'}} + + +def categorized_files(reference_files): + if reference_files == None: + return None, None + + files = [] + jars = [] + for file in reference_files: + file = file.strip() + if file.endswith(".jar"): + jars.append(file) + else: + files.append(file) + return files, jars + + +class SynapseJobRunner(object): + def __init__(self, synapse_dev_url, spark_pool_name, credential = None, executor_size = 'Small', executors = 2): + if credential is None: + credential = DefaultAzureCredential() + + self.client = SparkClient( + credential=credential, + endpoint=synapse_dev_url, + spark_pool_name=spark_pool_name + ) + + self._executor_size = executor_size + self._executors = executors + + def get_spark_batch_job(self, job_id): + + return self.client.spark_batch.get_spark_batch_job(job_id, detailed=True) + + def get_spark_batch_jobs(self): + + return self.client.spark_batch.get_spark_batch_jobs(detailed=True) + + def cancel_spark_batch_job(self, job_id): + + return self.client.spark_batch.cancel_spark_batch_job(job_id) + + def create_spark_batch_job(self, job_name, main_definition_file, class_name = None, + arguments=None, reference_files=None, archives=None, configuration=None, tags=None): + + file = main_definition_file + + files, jars = categorized_files(reference_files) + driver_cores = EXECUTOR_SIZE[self._executor_size]['Cores'] + driver_memory = EXECUTOR_SIZE[self._executor_size]['Memory'] + executor_cores = EXECUTOR_SIZE[self._executor_size]['Cores'] + executor_memory = EXECUTOR_SIZE[self._executor_size]['Memory'] + + # SDK source code is here: https://github.com/Azure/azure-sdk-for-python/tree/master/sdk/synapse/azure-synapse + # Exact code is here: https://github.com/Azure/azure-sdk-for-python/blob/master/sdk/synapse/azure-synapse-spark/azure/synapse/spark/operations/_spark_batch_operations.py#L114 + # Adding spaces between brackets. This is to workaround this known YARN issue (when running Spark on YARN): + # https://issues.apache.org/jira/browse/SPARK-17814?focusedCommentId=15567964&page=com.atlassian.jira.plugin.system.issuetabpanels%3Acomment-tabpanel#comment-15567964 + # print(arguments) + updated_arguments = [] + for elem in arguments: + if type(elem) == str: + updated_arguments.append(elem.replace("}", " }")) + else: + updated_arguments.append(elem) + + + spark_batch_job_options = SparkBatchJobOptions( + tags=tags, + name=job_name, + file=file, + class_name=class_name, + arguments=updated_arguments, + jars=jars, + files=files, + archives=archives, + configuration=configuration, + driver_memory=driver_memory, + driver_cores=driver_cores, + executor_memory=executor_memory, + executor_cores=executor_cores, + executor_count=self._executors) + + return self.client.spark_batch.create_spark_batch_job(spark_batch_job_options, detailed=True) + + +class DataLakeFiler(object): + def __init__(self, datalake_dir, credential = None): + datalake = list(filter(None, re.split('/|@', datalake_dir))) + assert len(datalake) >= 3 + + if credential is None: + credential = DefaultAzureCredential() + + account_url = "https://" + datalake[2] + datalake_client = DataLakeServiceClient( + credential=credential, + account_url=account_url + ).get_file_system_client(datalake[1]) + + if len(datalake) > 3: + datalake_client = datalake_client.get_directory_client('/'.join(datalake[3:])) + datalake_client.create_directory() + + self.datalake_dir = datalake_dir + '/' if datalake_dir[-1] != '/' else datalake_dir + self.dir_client = datalake_client + + def upload_file(self, local_file): + + file_name = os.path.basename(local_file) + file_client = self.dir_client.create_file(file_name) + + if local_file.startswith('http'): + # remote_file = local_file + # local_file = './' + file_name + # urllib.request.urlretrieve(remote_file, local_file) + with urllib.request.urlopen(local_file) as f: + data = f.read() + file_client.append_data(data, 0, len(data)) + file_client.flush_data(len(data)) + else: + with open(local_file, 'r') as f: + data = f.read() + file_client.append_data(data, 0, len(data)) + file_client.flush_data(len(data)) + + return self.datalake_dir + file_name + + +def _submit_job( + api: SynapseJobRunner, + name: str, + main_file: str, + main_class = None, + arguments = None, + reference_files = None, + tags = None, + configuration = None, +) -> SparkBatchJob: + return api.create_spark_batch_job(name, main_file, class_name = main_class, arguments = arguments, + reference_files = reference_files, tags = tags, configuration=configuration) + + +def _list_jobs( + api: SynapseJobRunner, + project: Optional[str] = None, + table_name: Optional[str] = None, +) -> List[SparkBatchJob]: + + job_infos = api.get_spark_batch_jobs() + + # Batch, Streaming Ingestion jobs + if project and table_name: + result = [] + table_name_hash = _generate_project_table_hash(project, table_name) + for job_info in job_infos: + if LABEL_FEATURE_TABLE_HASH in job_info.tags: + if table_name_hash == job_info.tags[LABEL_FEATURE_TABLE_HASH]: + result.append(job_info) + elif project: + result = [] + for job_info in job_infos: + if LABEL_PROJECT in job_info.tags: + if project == job_info.tags[LABEL_PROJECT]: + result.append(job_info) + else: + result = job_infos + + return result + + +def _get_job_by_id( + api: SynapseJobRunner, + job_id: int +) -> Optional[SparkBatchJob]: + return api.get_spark_batch_job(job_id) + + +def _cancel_job_by_id(api: SynapseJobRunner, job_id: int): + api.cancel_spark_batch_job(job_id) + diff --git a/python/setup.py b/python/setup.py index 49cd6eb6..195679ea 100644 --- a/python/setup.py +++ b/python/setup.py @@ -52,6 +52,11 @@ "grpcio-tools==1.31.0", "mypy-protobuf==2.5", "croniter==1.*", + "azure-synapse-spark", + "azure-synapse", + "azure-identity", + "azure-storage-file-datalake", + "azure-storage-blob", ] # README file from Feast repo root directory diff --git a/spark/ingestion/pom.xml b/spark/ingestion/pom.xml index 1fc560c5..b8bb1612 100644 --- a/spark/ingestion/pom.xml +++ b/spark/ingestion/pom.xml @@ -51,6 +51,26 @@ ${protobuf.version} + + com.microsoft.azure + azure-eventhubs-spark_2.12 + 2.3.18 + + + + + org.glassfish + javax.el + 3.0.1-b08 + + + + com.microsoft.azure + azure-eventhubs-spark_2.12 + 2.3.18 + + + com.gojek stencil diff --git a/spark/ingestion/src/main/scala/feast/ingestion/BasePipeline.scala b/spark/ingestion/src/main/scala/feast/ingestion/BasePipeline.scala index 98a47a9e..03e558de 100644 --- a/spark/ingestion/src/main/scala/feast/ingestion/BasePipeline.scala +++ b/spark/ingestion/src/main/scala/feast/ingestion/BasePipeline.scala @@ -33,11 +33,12 @@ object BasePipeline { val conf = new SparkConf() jobConfig.store match { - case RedisConfig(host, port, ssl) => + case RedisConfig(host, port, auth, ssl) => conf .set("spark.redis.host", host) .set("spark.redis.port", port.toString) .set("spark.redis.ssl", ssl.toString) + .set("spark.redis.auth", auth.toString) case BigTableConfig(projectId, instanceId) => conf .set("spark.bigtable.projectId", projectId) diff --git a/spark/ingestion/src/main/scala/feast/ingestion/IngestionJob.scala b/spark/ingestion/src/main/scala/feast/ingestion/IngestionJob.scala index 352fe8c4..8b80e46c 100644 --- a/spark/ingestion/src/main/scala/feast/ingestion/IngestionJob.scala +++ b/spark/ingestion/src/main/scala/feast/ingestion/IngestionJob.scala @@ -22,6 +22,8 @@ import org.joda.time.{DateTime, DateTimeZone} import org.json4s._ import org.json4s.ext.JavaEnumNameSerializer import org.json4s.jackson.JsonMethods.{parse => parseJSON} +import org.json4s.ext.JavaEnumNameSerializer +import scala.collection.mutable.ArrayBuffer object IngestionJob { import Modes._ @@ -116,10 +118,22 @@ object IngestionJob { opt[Int](name = "triggering-interval") .action((x, c) => c.copy(streamingTriggeringSecs = x)) + + opt[String](name = "kafka_sasl_auth") + .action((x, c) => c.copy(kafkaSASL = Some(x))) } def main(args: Array[String]): Unit = { - parser.parse(args, IngestionJobConfig()) match { + val args_modified = new Array[String](args.length) + for ( i <- 0 to (args_modified.length - 1)) { + // Removing spaces between brackets. This is to workaround this known YARN issue (when running Spark on YARN): + // https://issues.apache.org/jira/browse/SPARK-17814?focusedCommentId=15567964&page=com.atlassian.jira.plugin.system.issuetabpanels%3Acomment-tabpanel#comment-15567964 + // Also remove the unncessary back slashes + args_modified(i) = args(i).replace(" }", "}"); + args_modified(i) = args_modified(i).replace("\\", "\\\""); + } + println("arguments received:",args_modified.toList) + parser.parse(args_modified, IngestionJobConfig()) match { case Some(config) => println(s"Starting with config $config") config.mode match { diff --git a/spark/ingestion/src/main/scala/feast/ingestion/IngestionJobConfig.scala b/spark/ingestion/src/main/scala/feast/ingestion/IngestionJobConfig.scala index 87150493..8e9a8fe3 100644 --- a/spark/ingestion/src/main/scala/feast/ingestion/IngestionJobConfig.scala +++ b/spark/ingestion/src/main/scala/feast/ingestion/IngestionJobConfig.scala @@ -26,7 +26,7 @@ object Modes extends Enumeration { abstract class StoreConfig -case class RedisConfig(host: String, port: Int, ssl: Boolean) extends StoreConfig +case class RedisConfig(host: String, port: Int, auth: String, ssl: Boolean) extends StoreConfig case class BigTableConfig(projectId: String, instanceId: String) extends StoreConfig case class CassandraConfig( connection: CassandraConnection, @@ -90,6 +90,7 @@ case class KafkaSource( override val datePartitionColumn: Option[String] = None ) extends StreamingSource + case class Sources( file: Option[FileSource] = None, bq: Option[BQSource] = None, @@ -119,12 +120,13 @@ case class IngestionJobConfig( source: Source = null, startTime: DateTime = DateTime.now(), endTime: DateTime = DateTime.now(), - store: StoreConfig = RedisConfig("localhost", 6379, false), + store: StoreConfig = RedisConfig("localhost", 6379, "", false), metrics: Option[MetricConfig] = None, deadLetterPath: Option[String] = None, stencilURL: Option[String] = None, streamingTriggeringSecs: Int = 0, validationConfig: Option[ValidationConfig] = None, doNotIngestInvalidRows: Boolean = false, - checkpointPath: Option[String] = None + checkpointPath: Option[String] = None, + kafkaSASL: Option[String] = None ) diff --git a/spark/ingestion/src/main/scala/feast/ingestion/StreamingPipeline.scala b/spark/ingestion/src/main/scala/feast/ingestion/StreamingPipeline.scala index bebef92e..33f5746e 100644 --- a/spark/ingestion/src/main/scala/feast/ingestion/StreamingPipeline.scala +++ b/spark/ingestion/src/main/scala/feast/ingestion/StreamingPipeline.scala @@ -33,6 +33,9 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryListener} import org.apache.spark.sql.types.BooleanType import org.apache.spark.{SparkEnv, SparkFiles} +import org.apache.spark.eventhubs._ +import org.apache.kafka.common.security.plain.PlainLoginModule +import org.apache.kafka.common.security.JaasContext import java.io.File import java.sql.Timestamp @@ -73,17 +76,38 @@ object StreamingPipeline extends BasePipeline with Serializable { val validationUDF = createValidationUDF(sparkSession, config) + val EH_SASL = "org.apache.kafka.common.security.plain.PlainLoginModule required username=\"$ConnectionString\" password=\"Endpoint=sb://xxx.servicebus.windows.net/;SharedAccessKeyName=RootManageSharedAccessKey;SharedAccessKey=yyy=;EntityPath=driver_trips\";" + val input = config.source match { case source: KafkaSource => - sparkSession.readStream - .format("kafka") - .option("kafka.bootstrap.servers", source.bootstrapServers) - .option("subscribe", source.topic) - .load() + if (config.kafkaSASL.nonEmpty) + { + // if we have authentication enabled + println("config.kafkaSASL value:", config.kafkaSASL.get) + sparkSession.readStream + .format("kafka") + .option("subscribe", source.topic) + .option("kafka.bootstrap.servers", source.bootstrapServers) + .option("kafka.sasl.mechanism", "PLAIN") + .option("kafka.security.protocol", "SASL_SSL") + .option("kafka.sasl.jaas.config", config.kafkaSASL.get) + .option("kafka.request.timeout.ms", "60000") + .option("kafka.session.timeout.ms", "60000") + .option("failOnDataLoss", "false") + .load() + } + else + { + println("config.kafkaSASL is empty.") + sparkSession.readStream + .format("kafka") + .option("kafka.bootstrap.servers", source.bootstrapServers) + .option("subscribe", source.topic) + .load() + } case source: MemoryStreamingSource => - source.read + source.read } - val featureStruct = config.source.asInstanceOf[StreamingSource].format match { case ProtoFormat(classPath) => val parser = protoParser(sparkSession, classPath) diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py index d63d311c..534468cc 100644 --- a/tests/e2e/conftest.py +++ b/tests/e2e/conftest.py @@ -10,7 +10,7 @@ def pytest_addoption(parser): parser.addoption("--kafka-brokers", action="store", default="localhost:9092") parser.addoption( - "--env", action="store", help="local|aws|gcloud|k8s", default="local" + "--env", action="store", help="local|aws|gcloud|k8s|synapse", default="local" ) parser.addoption("--with-job-service", action="store_true") parser.addoption("--staging-path", action="store") @@ -23,6 +23,11 @@ def pytest_addoption(parser): parser.addoption("--dataproc-executor-cores", action="store", default="2") parser.addoption("--dataproc-executor-memory", action="store", default="2g") parser.addoption("--k8s-namespace", action="store", default="sparkop-e2e") + parser.addoption("--azure-synapse-dev-url", action="store", default="") + parser.addoption("--azure-synapse-pool-name", action="store", default="") + parser.addoption("--azure-synapse-datalake-dir", action="store", default="") + parser.addoption("--azure-blob-account-name", action="store", default="") + parser.addoption("--azure-blob-account-access-key", action="store", default="") parser.addoption("--ingestion-jar", action="store") parser.addoption("--redis-url", action="store", default="localhost:6379") parser.addoption("--redis-cluster", action="store_true") diff --git a/tests/e2e/fixtures/client.py b/tests/e2e/fixtures/client.py index 7ffa936f..3d331fd7 100644 --- a/tests/e2e/fixtures/client.py +++ b/tests/e2e/fixtures/client.py @@ -102,6 +102,24 @@ def feast_client( enable_auth=pytestconfig.getoption("enable_auth"), **job_service_env, ) + elif pytestconfig.getoption("env") == "synapse": + return Client( + core_url=f"{feast_core[0]}:{feast_core[1]}", + serving_url=f"{feast_serving[0]}:{feast_serving[1]}", + spark_launcher="synapse", + azure_synapse_dev_url = pytestconfig.getoption("azure_synapse_dev_url"), + azure_synapse_pool_name = pytestconfig.getoption("azure_synapse_pool_name"), + azure_synapse_datalake_dir = pytestconfig.getoption("azure_synapse_datalake_dir"), + spark_staging_location=os.path.join(local_staging_path, "synapse"), + azure_blob_account_name=pytestconfig.getoption("azure_blob_account_name"), + azure_blob_account_access_key=pytestconfig.getoption("azure_blob_account_access_key"), + spark_ingestion_jar=ingestion_job_jar, + redis_host=pytestconfig.getoption("redis_url").split(":")[0], + redis_port=pytestconfig.getoption("redis_url").split(":")[1], + historical_feature_output_location=os.path.join( + local_staging_path, "historical_output" + ), + ) else: raise KeyError(f"Unknown environment {pytestconfig.getoption('env')}") @@ -184,6 +202,21 @@ def tfrecord_feast_client( ), **job_service_env, ) + elif pytestconfig.getoption("env") == "synapse": + return Client( + core_url=f"{feast_core[0]}:{feast_core[1]}", + spark_launcher="synapse", + azure_synapse_dev_url = pytestconfig.getoption("azure_synapse_dev_url"), + azure_synapse_pool_name = pytestconfig.getoption("azure_synapse_pool_name"), + azure_synapse_datalake_dir = pytestconfig.getoption("azure_synapse_datalake_dir"), + spark_staging_location=os.path.join(local_staging_path, "synapse"), + azure_blob_account_name=pytestconfig.getoption("azure_blob_account_name"), + azure_blob_account_access_key=pytestconfig.getoption("azure_blob_account_access_key"), + historical_feature_output_format="tfrecord", + historical_feature_output_location=os.path.join( + local_staging_path, "historical_output" + ), + ) else: raise KeyError(f"Unknown environment {pytestconfig.getoption('env')}")