88 ExternalAuthProvider ,
99 DatabricksOAuthProvider ,
1010)
11- from databricks .sql .auth .endpoint import infer_cloud_from_host , CloudType
12- from databricks .sql .experimental .oauth_persistence import OAuthPersistence
1311
1412
1513class AuthType (Enum ):
1614 DATABRICKS_OAUTH = "databricks-oauth"
15+ AZURE_OAUTH = "azure-oauth"
1716 # other supported types (access_token, user/pass) can be inferred
1817 # we can add more types as needed later
1918
@@ -51,7 +50,7 @@ def __init__(
5150def get_auth_provider (cfg : ClientContext ):
5251 if cfg .credentials_provider :
5352 return ExternalAuthProvider (cfg .credentials_provider )
54- if cfg .auth_type == AuthType .DATABRICKS_OAUTH .value :
53+ if cfg .auth_type in [ AuthType .DATABRICKS_OAUTH .value , AuthType . AZURE_OAUTH . value ] :
5554 assert cfg .oauth_redirect_port_range is not None
5655 assert cfg .oauth_client_id is not None
5756 assert cfg .oauth_scopes is not None
@@ -62,6 +61,7 @@ def get_auth_provider(cfg: ClientContext):
6261 cfg .oauth_redirect_port_range ,
6362 cfg .oauth_client_id ,
6463 cfg .oauth_scopes ,
64+ cfg .auth_type ,
6565 )
6666 elif cfg .access_token is not None :
6767 return AccessTokenAuthProvider (cfg .access_token )
@@ -87,20 +87,22 @@ def normalize_host_name(hostname: str):
8787 return f"{ maybe_scheme } { hostname } { maybe_trailing_slash } "
8888
8989
90- def get_client_id_and_redirect_port (hostname : str ):
91- cloud_type = infer_cloud_from_host (hostname )
90+ def get_client_id_and_redirect_port (use_azure_auth : bool ):
9291 return (
9392 (PYSQL_OAUTH_CLIENT_ID , PYSQL_OAUTH_REDIRECT_PORT_RANGE )
94- if cloud_type == CloudType . AWS or cloud_type == CloudType . GCP
93+ if not use_azure_auth
9594 else (PYSQL_OAUTH_AZURE_CLIENT_ID , PYSQL_OAUTH_AZURE_REDIRECT_PORT_RANGE )
9695 )
9796
9897
9998def get_python_sql_connector_auth_provider (hostname : str , ** kwargs ):
100- (client_id , redirect_port_range ) = get_client_id_and_redirect_port (hostname )
99+ auth_type = kwargs .get ("auth_type" )
100+ (client_id , redirect_port_range ) = get_client_id_and_redirect_port (
101+ auth_type == AuthType .AZURE_OAUTH .value
102+ )
101103 cfg = ClientContext (
102104 hostname = normalize_host_name (hostname ),
103- auth_type = kwargs . get ( " auth_type" ) ,
105+ auth_type = auth_type ,
104106 access_token = kwargs .get ("access_token" ),
105107 username = kwargs .get ("_username" ),
106108 password = kwargs .get ("_password" ),
0 commit comments