3030GS = "gs"
3131S3 = "s3"
3232S3A = "s3a"
33+ AZURE_SCHEME = "https"
3334LOCAL_FILE = "file"
3435
3536
@@ -82,7 +83,7 @@ def download_file(self, uri: ParseResult) -> IO[bytes]:
8283 pass
8384
8485 @abstractmethod
85- def list_files (self , bucket : str , path : str ) -> List [str ]:
86+ def list_files (self , uri : ParseResult ) -> List [str ]:
8687 """
8788 Lists all the files under a directory in an object store.
8889 """
@@ -157,19 +158,19 @@ def download_file(self, uri: ParseResult) -> IO[bytes]:
157158 file_obj .seek (0 )
158159 return file_obj
159160
160- def list_files (self , bucket : str , path : str ) -> List [str ]:
161+ def list_files (self , uri : ParseResult ) -> List [str ]:
161162 """
162163 Lists all the files under a directory in google cloud storage if path has wildcard(*) character.
163164
164165 Args:
165- bucket (str): google cloud storage bucket name
166- path (str): object location in google cloud storage.
166+ uri (urllib.parse.ParseResult): Parsed uri of this location
167167
168168 Returns:
169169 List[str]: A list containing the full path to the file(s) in the
170170 remote staging location.
171171 """
172172
173+ bucket , path = self ._uri_to_bucket_key (uri )
173174 gs_bucket = self .gcs_client .get_bucket (bucket )
174175
175176 if "*" in path :
@@ -184,7 +185,7 @@ def list_files(self, bucket: str, path: str) -> List[str]:
184185 if re .match (regex , file ) and file not in path
185186 ]
186187 else :
187- return [f"{ GS } ://{ bucket } /{ path . lstrip ( '/' ) } " ]
188+ return [f"{ GS } ://{ bucket } /{ path } " ]
188189
189190 def _uri_to_bucket_key (self , remote_path : ParseResult ) -> Tuple [str , str ]:
190191 assert remote_path .hostname is not None
@@ -234,25 +235,24 @@ def download_file(self, uri: ParseResult) -> IO[bytes]:
234235 Returns:
235236 TemporaryFile object
236237 """
237- url = uri .path .lstrip ("/" )
238- bucket = uri .hostname
238+ bucket , url = self ._uri_to_bucket_key (uri )
239239 file_obj = TemporaryFile ()
240240 self .s3_client .download_fileobj (bucket , url , file_obj )
241241 return file_obj
242242
243- def list_files (self , bucket : str , path : str ) -> List [str ]:
243+ def list_files (self , uri : ParseResult ) -> List [str ]:
244244 """
245245 Lists all the files under a directory in s3 if path has wildcard(*) character.
246246
247247 Args:
248- bucket (str): s3 bucket name.
249- path (str): Object location in s3.
248+ uri (urllib.parse.ParseResult): Parsed uri of this location
250249
251250 Returns:
252251 List[str]: A list containing the full path to the file(s) in the
253252 remote staging location.
254253 """
255254
255+ bucket , path = self ._uri_to_bucket_key (uri )
256256 if "*" in path :
257257 regex = re .compile (path .replace ("*" , ".*?" ).strip ("/" ))
258258 blob_list = self .s3_client .list_objects (
@@ -265,7 +265,7 @@ def list_files(self, bucket: str, path: str) -> List[str]:
265265 if re .match (regex , file ) and file not in path
266266 ]
267267 else :
268- return [f"{ self .url_scheme } ://{ bucket } /{ path . lstrip ( '/' ) } " ]
268+ return [f"{ self .url_scheme } ://{ bucket } /{ path } " ]
269269
270270 def _uri_to_bucket_key (self , remote_path : ParseResult ) -> Tuple [str , str ]:
271271 assert remote_path .hostname is not None
@@ -313,6 +313,90 @@ def upload_fileobj(
313313 return remote_uri
314314
315315
316+ class AzureBlobClient (AbstractStagingClient ):
317+ """
318+ Implementation of AbstractStagingClient for Azure Blob storage
319+ """
320+
321+ def __init__ (self , account_name : str , account_access_key : str ):
322+ try :
323+ from azure .storage .blob import BlobServiceClient
324+ except ImportError :
325+ raise ImportError (
326+ "Install package azure-storage-blob for azure blob staging support"
327+ "run ```pip install azure-storage-blob```"
328+ )
329+ self .account_url = f"https://{ account_name } .blob.core.windows.net"
330+ self .blob_service_client = BlobServiceClient (
331+ account_url = self .account_url , credential = account_access_key
332+ )
333+
334+ def download_file (self , uri : ParseResult ) -> IO [bytes ]:
335+ """
336+ Downloads a file from Azure blob storage and returns a TemporaryFile object
337+
338+ Args:
339+ uri (urllib.parse.ParseResult): Parsed uri of the file ex: urlparse("https://account_name.blob.core.windows.net/bucket/file.avro")
340+
341+ Returns:
342+ TemporaryFile object
343+ """
344+ bucket , path = self ._uri_to_bucket_key (uri )
345+ container_client = self .blob_service_client .get_container_client (bucket )
346+ return container_client .download_blob (path ).readall ()
347+
348+ def list_files (self , uri : ParseResult ) -> List [str ]:
349+ """
350+ Lists all the files under a directory in azure blob storage if path has wildcard(*) character.
351+
352+ Args:
353+ uri (urllib.parse.ParseResult): Parsed uri of this location
354+
355+ Returns:
356+ List[str]: A list containing the full path to the file(s) in the
357+ remote staging location.
358+ """
359+
360+ bucket , path = self ._uri_to_bucket_key (uri )
361+ if "*" in path :
362+ regex = re .compile (path .replace ("*" , ".*?" ).strip ("/" ))
363+ container_client = self .blob_service_client .get_container_client (bucket )
364+ blob_list = container_client .list_blobs (
365+ name_starts_with = path .strip ("/" ).split ("*" )[0 ]
366+ )
367+ # File path should not be in path (file path must be longer than path)
368+ return [
369+ f"{ self .account_url } /{ bucket } /{ file } "
370+ for file in [x .name for x in blob_list ]
371+ if re .match (regex , file ) and file not in path
372+ ]
373+ else :
374+ return [f"{ self .account_url } /{ bucket } /{ path } " ]
375+
376+ def _uri_to_bucket_key (self , uri : ParseResult ) -> Tuple [str , str ]:
377+ assert uri .hostname == urlparse (self .account_url ).hostname
378+ bucket = uri .path .lstrip ("/" ).split ("/" )[0 ]
379+ key = uri .path .lstrip ("/" ).split ("/" , 1 )[1 ]
380+ return bucket , key
381+
382+ def upload_fileobj (
383+ self ,
384+ fileobj : IO [bytes ],
385+ local_path : str ,
386+ * ,
387+ remote_uri : Optional [ParseResult ] = None ,
388+ remote_path_prefix : Optional [str ] = None ,
389+ remote_path_suffix : Optional [str ] = None ,
390+ ) -> ParseResult :
391+ remote_uri = _gen_remote_uri (
392+ fileobj , remote_uri , remote_path_prefix , remote_path_suffix , None
393+ )
394+ bucket , key = self ._uri_to_bucket_key (remote_uri )
395+ container_client = self .blob_service_client .get_container_client (bucket )
396+ container_client .upload_blob (name = key , data = fileobj )
397+ return remote_uri
398+
399+
316400class LocalFSClient (AbstractStagingClient ):
317401 """
318402 Implementation of AbstractStagingClient for local file
@@ -327,15 +411,15 @@ def download_file(self, uri: ParseResult) -> IO[bytes]:
327411 Reads a local file from the disk
328412
329413 Args:
330- uri (urllib.parse.ParseResult): Parsed uri of the file ex: urlparse("file://folder/file.avro")
414+ uri (urllib.parse.ParseResult): Parsed uri of the file ex: urlparse("file:/// folder/file.avro")
331415 Returns:
332416 TemporaryFile object
333417 """
334418 url = uri .path
335419 file_obj = open (url , "rb" )
336420 return file_obj
337421
338- def list_files (self , bucket : str , path : str ) -> List [str ]:
422+ def list_files (self , uri : ParseResult ) -> List [str ]:
339423 raise NotImplementedError ("list files not implemented for Local file" )
340424
341425 def _uri_to_path (self , uri : ParseResult ) -> str :
@@ -381,6 +465,18 @@ def _gcs_client(config: Config = None):
381465 return GCSClient ()
382466
383467
468+ def _azure_blob_client (config : Config = None ):
469+ if config is None :
470+ raise Exception ("Azure blob client requires config" )
471+ account_name = config .get (opt .AZURE_BLOB_ACCOUNT_NAME , None )
472+ account_access_key = config .get (opt .AZURE_BLOB_ACCOUNT_ACCESS_KEY , None )
473+ if account_name is None or account_access_key is None :
474+ raise Exception (
475+ f"Azure blob client requires { opt .AZURE_BLOB_ACCOUNT_NAME } and { opt .AZURE_BLOB_ACCOUNT_ACCESS_KEY } set in config"
476+ )
477+ return AzureBlobClient (account_name , account_access_key )
478+
479+
384480def _local_fs_client (config : Config = None ):
385481 return LocalFSClient ()
386482
@@ -389,6 +485,7 @@ def _local_fs_client(config: Config = None):
389485 GS : _gcs_client ,
390486 S3 : _s3_client ,
391487 S3A : _s3a_client ,
488+ AZURE_SCHEME : _azure_blob_client , # note we currently interpret all uris beginning https:// as Azure blob uris
392489 LOCAL_FILE : _local_fs_client ,
393490}
394491
@@ -408,5 +505,5 @@ def get_staging_client(scheme, config: Config = None) -> AbstractStagingClient:
408505 return storage_clients [scheme ](config )
409506 except ValueError :
410507 raise Exception (
411- f"Could not identify file scheme { scheme } . Only gs://, file:// and s3 :// are supported"
508+ f"Could not identify file scheme { scheme } . Only gs://, file://, s3:// and https :// (for Azure) are supported"
412509 )
0 commit comments