Skip to content

Commit 76f05d1

Browse files
authored
Implement AbstractStagingClient for Azure Blob Storage (#1218)
* implement abstractstagingclient for azure blob storage Signed-off-by: Jacob Klegar <jacob@tecton.ai> * change arguments for list_files Signed-off-by: Jacob Klegar <jacob@tecton.ai> * linting Signed-off-by: Jacob Klegar <jacob@tecton.ai> * bugfix Signed-off-by: Jacob Klegar <jacob@tecton.ai> * check azure uri has the right account name Signed-off-by: Jacob Klegar <jacob@tecton.ai>
1 parent 1d53f45 commit 76f05d1

File tree

3 files changed

+122
-19
lines changed

3 files changed

+122
-19
lines changed

sdk/python/feast/constants.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,12 @@ class ConfigOptions(metaclass=ConfigMeta):
128128
#: Endpoint URL for S3 storage_client
129129
S3_ENDPOINT_URL: Optional[str] = None
130130

131+
#: Account name for Azure blob storage_client
132+
AZURE_BLOB_ACCOUNT_NAME: Optional[str] = None
133+
134+
#: Account access key for Azure blob storage_client
135+
AZURE_BLOB_ACCOUNT_ACCESS_KEY: Optional[str] = None
136+
131137
#: Authentication Provider - Google OpenID/OAuth
132138
#:
133139
#: Options: "google" / "oauth"

sdk/python/feast/loaders/file.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,18 +40,20 @@ def export_source_to_staging_location(
4040
Source of data to be staged. Can be a pandas DataFrame or a file
4141
path.
4242
43-
Only three types of source are allowed:
43+
Only four types of source are allowed:
4444
* Pandas DataFrame
4545
* Local Avro file
4646
* GCS Avro file
4747
* S3 Avro file
48+
* Azure Blob storage Avro file
4849
4950
5051
staging_location_uri (str):
5152
Remote staging location where DataFrame should be written.
5253
Examples:
5354
* gs://bucket/path/
5455
* s3://bucket/path/
56+
* https://account_name.blob.core.windows.net/bucket/path/
5557
* file:///data/subfolder/
5658
5759
Returns:
@@ -76,11 +78,9 @@ def export_source_to_staging_location(
7678
os.path.join(source_uri.netloc, source_uri.path)
7779
)
7880
else:
79-
# gs, s3 file provided as a source.
81+
# gs, s3, azure blob file provided as a source.
8082
assert source_uri.hostname is not None
81-
return get_staging_client(source_uri.scheme).list_files(
82-
bucket=source_uri.hostname, path=source_uri.path
83-
)
83+
return get_staging_client(source_uri.scheme).list_files(uri=source_uri)
8484
else:
8585
raise Exception(
8686
f"Only string and DataFrame types are allowed as a "

sdk/python/feast/staging/storage_client.py

Lines changed: 111 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
GS = "gs"
3131
S3 = "s3"
3232
S3A = "s3a"
33+
AZURE_SCHEME = "https"
3334
LOCAL_FILE = "file"
3435

3536

@@ -82,7 +83,7 @@ def download_file(self, uri: ParseResult) -> IO[bytes]:
8283
pass
8384

8485
@abstractmethod
85-
def list_files(self, bucket: str, path: str) -> List[str]:
86+
def list_files(self, uri: ParseResult) -> List[str]:
8687
"""
8788
Lists all the files under a directory in an object store.
8889
"""
@@ -157,19 +158,19 @@ def download_file(self, uri: ParseResult) -> IO[bytes]:
157158
file_obj.seek(0)
158159
return file_obj
159160

160-
def list_files(self, bucket: str, path: str) -> List[str]:
161+
def list_files(self, uri: ParseResult) -> List[str]:
161162
"""
162163
Lists all the files under a directory in google cloud storage if path has wildcard(*) character.
163164
164165
Args:
165-
bucket (str): google cloud storage bucket name
166-
path (str): object location in google cloud storage.
166+
uri (urllib.parse.ParseResult): Parsed uri of this location
167167
168168
Returns:
169169
List[str]: A list containing the full path to the file(s) in the
170170
remote staging location.
171171
"""
172172

173+
bucket, path = self._uri_to_bucket_key(uri)
173174
gs_bucket = self.gcs_client.get_bucket(bucket)
174175

175176
if "*" in path:
@@ -184,7 +185,7 @@ def list_files(self, bucket: str, path: str) -> List[str]:
184185
if re.match(regex, file) and file not in path
185186
]
186187
else:
187-
return [f"{GS}://{bucket}/{path.lstrip('/')}"]
188+
return [f"{GS}://{bucket}/{path}"]
188189

189190
def _uri_to_bucket_key(self, remote_path: ParseResult) -> Tuple[str, str]:
190191
assert remote_path.hostname is not None
@@ -234,25 +235,24 @@ def download_file(self, uri: ParseResult) -> IO[bytes]:
234235
Returns:
235236
TemporaryFile object
236237
"""
237-
url = uri.path.lstrip("/")
238-
bucket = uri.hostname
238+
bucket, url = self._uri_to_bucket_key(uri)
239239
file_obj = TemporaryFile()
240240
self.s3_client.download_fileobj(bucket, url, file_obj)
241241
return file_obj
242242

243-
def list_files(self, bucket: str, path: str) -> List[str]:
243+
def list_files(self, uri: ParseResult) -> List[str]:
244244
"""
245245
Lists all the files under a directory in s3 if path has wildcard(*) character.
246246
247247
Args:
248-
bucket (str): s3 bucket name.
249-
path (str): Object location in s3.
248+
uri (urllib.parse.ParseResult): Parsed uri of this location
250249
251250
Returns:
252251
List[str]: A list containing the full path to the file(s) in the
253252
remote staging location.
254253
"""
255254

255+
bucket, path = self._uri_to_bucket_key(uri)
256256
if "*" in path:
257257
regex = re.compile(path.replace("*", ".*?").strip("/"))
258258
blob_list = self.s3_client.list_objects(
@@ -265,7 +265,7 @@ def list_files(self, bucket: str, path: str) -> List[str]:
265265
if re.match(regex, file) and file not in path
266266
]
267267
else:
268-
return [f"{self.url_scheme}://{bucket}/{path.lstrip('/')}"]
268+
return [f"{self.url_scheme}://{bucket}/{path}"]
269269

270270
def _uri_to_bucket_key(self, remote_path: ParseResult) -> Tuple[str, str]:
271271
assert remote_path.hostname is not None
@@ -313,6 +313,90 @@ def upload_fileobj(
313313
return remote_uri
314314

315315

316+
class AzureBlobClient(AbstractStagingClient):
317+
"""
318+
Implementation of AbstractStagingClient for Azure Blob storage
319+
"""
320+
321+
def __init__(self, account_name: str, account_access_key: str):
322+
try:
323+
from azure.storage.blob import BlobServiceClient
324+
except ImportError:
325+
raise ImportError(
326+
"Install package azure-storage-blob for azure blob staging support"
327+
"run ```pip install azure-storage-blob```"
328+
)
329+
self.account_url = f"https://{account_name}.blob.core.windows.net"
330+
self.blob_service_client = BlobServiceClient(
331+
account_url=self.account_url, credential=account_access_key
332+
)
333+
334+
def download_file(self, uri: ParseResult) -> IO[bytes]:
335+
"""
336+
Downloads a file from Azure blob storage and returns a TemporaryFile object
337+
338+
Args:
339+
uri (urllib.parse.ParseResult): Parsed uri of the file ex: urlparse("https://account_name.blob.core.windows.net/bucket/file.avro")
340+
341+
Returns:
342+
TemporaryFile object
343+
"""
344+
bucket, path = self._uri_to_bucket_key(uri)
345+
container_client = self.blob_service_client.get_container_client(bucket)
346+
return container_client.download_blob(path).readall()
347+
348+
def list_files(self, uri: ParseResult) -> List[str]:
349+
"""
350+
Lists all the files under a directory in azure blob storage if path has wildcard(*) character.
351+
352+
Args:
353+
uri (urllib.parse.ParseResult): Parsed uri of this location
354+
355+
Returns:
356+
List[str]: A list containing the full path to the file(s) in the
357+
remote staging location.
358+
"""
359+
360+
bucket, path = self._uri_to_bucket_key(uri)
361+
if "*" in path:
362+
regex = re.compile(path.replace("*", ".*?").strip("/"))
363+
container_client = self.blob_service_client.get_container_client(bucket)
364+
blob_list = container_client.list_blobs(
365+
name_starts_with=path.strip("/").split("*")[0]
366+
)
367+
# File path should not be in path (file path must be longer than path)
368+
return [
369+
f"{self.account_url}/{bucket}/{file}"
370+
for file in [x.name for x in blob_list]
371+
if re.match(regex, file) and file not in path
372+
]
373+
else:
374+
return [f"{self.account_url}/{bucket}/{path}"]
375+
376+
def _uri_to_bucket_key(self, uri: ParseResult) -> Tuple[str, str]:
377+
assert uri.hostname == urlparse(self.account_url).hostname
378+
bucket = uri.path.lstrip("/").split("/")[0]
379+
key = uri.path.lstrip("/").split("/", 1)[1]
380+
return bucket, key
381+
382+
def upload_fileobj(
383+
self,
384+
fileobj: IO[bytes],
385+
local_path: str,
386+
*,
387+
remote_uri: Optional[ParseResult] = None,
388+
remote_path_prefix: Optional[str] = None,
389+
remote_path_suffix: Optional[str] = None,
390+
) -> ParseResult:
391+
remote_uri = _gen_remote_uri(
392+
fileobj, remote_uri, remote_path_prefix, remote_path_suffix, None
393+
)
394+
bucket, key = self._uri_to_bucket_key(remote_uri)
395+
container_client = self.blob_service_client.get_container_client(bucket)
396+
container_client.upload_blob(name=key, data=fileobj)
397+
return remote_uri
398+
399+
316400
class LocalFSClient(AbstractStagingClient):
317401
"""
318402
Implementation of AbstractStagingClient for local file
@@ -327,15 +411,15 @@ def download_file(self, uri: ParseResult) -> IO[bytes]:
327411
Reads a local file from the disk
328412
329413
Args:
330-
uri (urllib.parse.ParseResult): Parsed uri of the file ex: urlparse("file://folder/file.avro")
414+
uri (urllib.parse.ParseResult): Parsed uri of the file ex: urlparse("file:///folder/file.avro")
331415
Returns:
332416
TemporaryFile object
333417
"""
334418
url = uri.path
335419
file_obj = open(url, "rb")
336420
return file_obj
337421

338-
def list_files(self, bucket: str, path: str) -> List[str]:
422+
def list_files(self, uri: ParseResult) -> List[str]:
339423
raise NotImplementedError("list files not implemented for Local file")
340424

341425
def _uri_to_path(self, uri: ParseResult) -> str:
@@ -381,6 +465,18 @@ def _gcs_client(config: Config = None):
381465
return GCSClient()
382466

383467

468+
def _azure_blob_client(config: Config = None):
469+
if config is None:
470+
raise Exception("Azure blob client requires config")
471+
account_name = config.get(opt.AZURE_BLOB_ACCOUNT_NAME, None)
472+
account_access_key = config.get(opt.AZURE_BLOB_ACCOUNT_ACCESS_KEY, None)
473+
if account_name is None or account_access_key is None:
474+
raise Exception(
475+
f"Azure blob client requires {opt.AZURE_BLOB_ACCOUNT_NAME} and {opt.AZURE_BLOB_ACCOUNT_ACCESS_KEY} set in config"
476+
)
477+
return AzureBlobClient(account_name, account_access_key)
478+
479+
384480
def _local_fs_client(config: Config = None):
385481
return LocalFSClient()
386482

@@ -389,6 +485,7 @@ def _local_fs_client(config: Config = None):
389485
GS: _gcs_client,
390486
S3: _s3_client,
391487
S3A: _s3a_client,
488+
AZURE_SCHEME: _azure_blob_client, # note we currently interpret all uris beginning https:// as Azure blob uris
392489
LOCAL_FILE: _local_fs_client,
393490
}
394491

@@ -408,5 +505,5 @@ def get_staging_client(scheme, config: Config = None) -> AbstractStagingClient:
408505
return storage_clients[scheme](config)
409506
except ValueError:
410507
raise Exception(
411-
f"Could not identify file scheme {scheme}. Only gs://, file:// and s3:// are supported"
508+
f"Could not identify file scheme {scheme}. Only gs://, file://, s3:// and https:// (for Azure) are supported"
412509
)

0 commit comments

Comments
 (0)