Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions registry/.dockerignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
__pycache__
.env
.vscode
5 changes: 5 additions & 0 deletions registry/.env
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
AAD_INSTANCE=https://login.microsoftonline.com
API_CLIENT_ID=e1475341-ff84-4b6e-a187-349e35551cff
API_CLIENT_SECRET=cMZ8Q~d.gI6A-j-EMFbbmz.xy6jXOxUKH4dLYdhH
SWAGGER_UI_CLIENT_ID=e1335db3-a7eb-42b0-976c-82c9b5bdf0fb
AAD_TENANT_ID=72f988bf-86f1-41af-91ab-2d7cd011db47
4 changes: 4 additions & 0 deletions registry/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
__pycache__
.env
.vscode
.idea
9 changes: 9 additions & 0 deletions registry/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
FROM python:3.9

COPY ./ /usr/src

WORKDIR /usr/src
RUN pip install -r requirements.txt

# Start web server
CMD [ "uvicorn","main:app","--host", "0.0.0.0", "--port", "80" ]
4 changes: 4 additions & 0 deletions registry/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from common.database import DbConnection, connect
from common.registry_models import *
from sql_registry import registry
from access_control import *
25 changes: 25 additions & 0 deletions registry/access_control/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Access Control Notes (WIP)

A Global Switch `RBAC_ENABLED` is required to set as `True` to turn on the access control protection feature.

Please note that in this version, only access control management API and UI experience are included. Supported scenarios status are tracked below:

- General Fundation:
- [x] Access Control Abstrct Class
- [x] API Spec Contents for Access Control Management APIs
- [ ] API Sepc Contents for Registry API Access Control
- SQL Implementaion:
- [x] `userroles` table CUD through FastAPI
- [x] `userroles` table schema & test data
- [x] Enable/Disable with `RBAC_ENABLE` configuration
- [ ] Initialize default Admin role for project creator: After `create_project` API is ready
- UI Experience
- [x] Hidden page `../management` for global admin to make CUD requests to `userroles` table
- [x] Use id token in Management API Request headers to identify requestor
- [ ] Protect SQL Registry API with Access Control: After `create_project` API is ready
- Future Enhancements:
- [ ] Functional in Feathr Client
- [ ] Support Security Group scenario in Access Control
- [ ] Support AAD Groups
- [ ] Support Other OAuth Providers

8 changes: 8 additions & 0 deletions registry/access_control/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
__all__ = ["auth", "models", "interface", "db_rbac"]


from access_control.auth import *
from access_control.interface import RBAC
from access_control.models import *
from access_control.db_rbac import DbRBAC
from common.database import DbConnection, connect
31 changes: 31 additions & 0 deletions registry/access_control/access-control-gateway-sepc.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Feathr Registry Access Control Gateway Specifications

## Registry API with Access Control Gateway
**Access Control Gateway** is an access control **Plugin** component of feature registry API. It can work with different type of backend registry. When user enables this component, registry requests will be validated in a gateway as below flow chart:

```mermaid
flowchart TD
A[Get Registry API Request] --> B{Is Id Token Valid?};
B -- No --> D[Return 401];
B -- Yes --> C{Have Permission?};
C -- No --> F[Return 403];
C -- Yes --> E[Call Registry API*];
E --> G{API Service Avaiable?}
G -- No --> I[Return 503]
G -- Yes --> H[Return API Results]
```
If Access control component is NOT enabled, the flow will start from __Call Registry API*__

## Acess Control Rules
- For all **get** requests, check **read** permission for certain project.
- For all **post** request, check **write** permission for certain project.
- For all **access control management** request, check **manage** permission for certian project.
- In case of feature level query, will verify the parent project access of the feature.
- Registry API calls and returns will be transparently transfered.
- A header `x-access-control-enabled` will be added for API calls protected by access control gateway

## Management Rules
### Initialize `userroles` table
In current version, user needs to mannually initialze `userroles` table admins in SQL table.
When `create_registry` and `create_project` API is enabled, default admin role will be assgined to the creator.
Admin roles can add or delete roles in management UI page or thorugh management API.
36 changes: 36 additions & 0 deletions registry/access_control/access-control-gateway-spec.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Feathr Registry Access Control Gateway Specifications

## Registry API with Access Control Gateway

**Access Control Gateway** is an access control **Plugin** component of feature registry API. It can work with different type of backend registry. When user enables this component, registry requests will be validated in a gateway as below flow chart:

```mermaid
flowchart TD
A[Get Registry API Request] --> B{Is Id Token Valid?};
B -- No --> D[Return 401];
B -- Yes --> C{Have Permission?};
C -- No --> F[Return 403];
C -- Yes --> E[Call Registry API*];
E --> G{API Service Avaiable?}
G -- No --> I[Return 503]
G -- Yes --> H[Return API Results]
```

If Access control component is NOT enabled, the flow will start from **Call Registry API***

## Acess Control Rules

- For all **get** requests, check **read** permission for certain project.
- For all **post** request, check **write** permission for certain project.
- For all **access control management** request, check **manage** permission for certian project.
- In case of feature level query, will verify the parent project access of the feature.
- Registry API calls and returns will be transparently transfered.
- A header `x-access-control-enabled` will be added for API calls protected by access control gateway

## Management Rules

### Initialize `userroles` table

In current version, user needs to mannually initialze `userroles` table admins in SQL table.
When `create_registry` and `create_project` API is enabled, default admin role will be assgined to the creator.
Admin roles can add or delete roles in management UI page or thorugh management API.
39 changes: 39 additions & 0 deletions registry/access_control/access.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import Any
from xmlrpc.client import Boolean

from fastapi import Depends, HTTPException, status
from access_control.db_rbac import DbRBAC

from access_control.models import SUPER_ADMIN_SCOPE, AccessType, RoleType, User
from access_control.authorize import authorize

rbac = DbRBAC()

class ForbiddenAccess(HTTPException):
def __init__(self, detail: Any = None) -> None:
super().__init__(status_code=status.HTTP_403_FORBIDDEN, detail=detail, headers={"WWW-Authenticate": "Bearer"})


def get_user(user: User = Depends(authorize)) -> User:
return user

def project_read_access(project: str, user: User = Depends(authorize)) -> User:
return _project_access(project, user, AccessType.READ)

def project_write_access(project: str, user: User = Depends(authorize)) -> User:
return _project_access(project, user, AccessType.WRITE)

def project_manage_access(project: str, user: User = Depends(authorize)) -> User:
return _project_access(project, user, AccessType.MANAGE)

def _project_access(project:str, user: User, access: str):
if rbac.validate_project_access_users(project, user.preferred_username, access):
return user
else:
raise ForbiddenAccess(f"{access} privileges for project {project} required for user {user.preferred_username}")

def global_admin_access(user: User = Depends(authorize)):
if user.preferred_username in rbac.global_admin:
return user
else:
raise ForbiddenAccess('Admin privileges required')
49 changes: 49 additions & 0 deletions registry/access_control/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import json
import requests
import jwt
from jwt.algorithms import RSAAlgorithm

BEARER_TOKEN = "BEARER "


class AuthProvider():
"""This is the abstract class to decode JWT ID token.
Sample Usage with Azure ID token:
jwks_uri = "https://login.microsoftonline.com/common/discovery/v2.0/keys"
client_id = {Your Client Id}
auth = rbac.AuthProvider(jwks_uri, client_id)
"""

def __init__(self, jwks_uri, client_id):
"""Args:
- client Id: used as audience ("aud")
- jwks_uri: used to get public key pool
"""
self.client_id = client_id
self.cert_set = self.get_certs(jwks_uri)

def get_certs(self, jwks_uri: str):
"""Get certs from jwks uri"""
certs = requests.get(jwks_uri).json()
return {cert['kid']: cert for cert in certs['keys']}

def get_public_key(self, token: str):
""" Get public key based on token kid"""
header_data = jwt.get_unverified_header(token)
kid = header_data['kid']
return RSAAlgorithm.from_jwk(json.dumps(self.cert_set[kid]))

def decode_token(self, bearer_token: str):
""" Decode ID token with RA256 Algorithm
Sample Usage with Azure ID token:
decoded = auth.decode_token(token)
username = decoded.get('preferred_username').lower()
"""
# TODO: Process Bearer Token more elegantly
token = bearer_token[len(BEARER_TOKEN):]
return jwt.decode(token, self.get_public_key(token), algorithms=[
"RS256"], audience=self.client_id)

def AzureADAuth(client_id: str):
jwks_uri = "https://login.microsoftonline.com/common/discovery/v2.0/keys"
auth = AuthProvider(jwks_uri, client_id)
114 changes: 114 additions & 0 deletions registry/access_control/authorize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import base64
import logging
import requests
import rsa
from typing import Any, Dict, Mapping, Optional, Union

from fastapi import HTTPException, Request, status
from fastapi.security import OAuth2AuthorizationCodeBearer
import jwt
from jwt.exceptions import ExpiredSignatureError, PyJWKError

import access_control.config as config
from access_control.models import User


log = logging.getLogger()

log.info(config.AAD_INSTANCE)

BEARER_TOKEN = "BEARER "

class InvalidAuthorization(HTTPException):
def __init__(self, detail: Any = None) -> None:
super().__init__(status_code=status.HTTP_401_UNAUTHORIZED, detail=detail, headers={"WWW-Authenticate": "Bearer"})


class AzureADAuth(OAuth2AuthorizationCodeBearer):
# cached AAD jwt keys
aad_jwt_keys_cache: dict = {}

def __init__(self, aad_instance: str = config.AAD_INSTANCE, aad_tenant: str = config.AAD_TENANT_ID):
self.base_auth_url: str = f"{aad_instance}/{aad_tenant}"

async def __call__(self, request: Request) -> User:
# token: str = await super(AzureADAuthorization, self).__call__(request) or ''
bearer_token: str = request.headers.get("authorization")
token = bearer_token[len(BEARER_TOKEN):]
decoded_token = self._decode_token(token)
return self._get_user_from_token(decoded_token)

@staticmethod
def _get_user_from_token(decoded_token: Mapping) -> User:
try:
user_id = decoded_token['oid']
except Exception as e:
logging.debug(e)
raise InvalidAuthorization(detail='Unable to extract user details from token')

return User(
id=user_id,
name=decoded_token.get('name', ''),
preferred_username=decoded_token.get('preferred_username', ''),
roles=decoded_token.get('roles', [])
)

@staticmethod
def _get_key_id(token: str) -> Optional[str]:
headers = jwt.get_unverified_header(token)
return headers['kid'] if headers and 'kid' in headers else None

@staticmethod
def _ensure_b64padding(key: str) -> str:
"""
The base64 encoded keys are not always correctly padded, so pad with the right number of =
"""
key = key.encode('utf-8')
missing_padding = len(key) % 4
for _ in range(missing_padding):
key = key + b'='
return key

def _cache_aad_keys(self) -> None:
"""
Cache all AAD JWT keys - so we don't have to make a web call each auth request
"""
response = requests.get(f"{self.base_auth_url}/v2.0/.well-known/openid-configuration")
aad_metadata = response.json() if response.ok else None
jwks_uri = aad_metadata['jwks_uri'] if aad_metadata and 'jwks_uri' in aad_metadata else None
if jwks_uri:
response = requests.get(jwks_uri)
keys = response.json() if response.ok else None
if keys and 'keys' in keys:
for key in keys['keys']:
n = int.from_bytes(base64.urlsafe_b64decode(self._ensure_b64padding(key['n'])), "big")
e = int.from_bytes(base64.urlsafe_b64decode(self._ensure_b64padding(key['e'])), "big")
pub_key = rsa.PublicKey(n, e)
# Cache the PEM formatted public key.
AzureADAuth.aad_jwt_keys_cache[key['kid']] = pub_key.save_pkcs1()

def _get_token_key(self, key_id: str) -> str:
if key_id not in AzureADAuth.aad_jwt_keys_cache:
self._cache_aad_keys()
return AzureADAuth.aad_jwt_keys_cache[key_id]

def _decode_token(self, token: str) -> Mapping:
key_id = self._get_key_id(token)
if not key_id:
raise InvalidAuthorization('The token does not contain kid')
key = self._get_token_key(key_id)
try:
decode = jwt.decode(token, key=key, algorithms=['RS256'], audience=config.API_AUDIENCE)
return decode
except ExpiredSignatureError as e:
logging.debug(f'The token signature has expired: {e}')
raise InvalidAuthorization('The token signature has expired')
except PyJWKError as e:
logging.debug(f'Invalid token: {e}')
raise InvalidAuthorization('The token is invalid')
except Exception as e:
logging.debug(f'Unexpected error: {e}')
raise InvalidAuthorization('Unable to decode token')


authorize = AzureADAuth()
18 changes: 18 additions & 0 deletions registry/access_control/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from starlette.config import Config


config = Config(".env")

API_PREFIX: str = "/api"
VERSION: str = "0.1.0"
PROJECT_NAME: str = "FastAPI with AAD Authentication"
DEBUG: bool = config("DEBUG", cast=bool, default=False)

# Authentication
API_CLIENT_ID: str = config("API_CLIENT_ID", default="e1475341-ff84-4b6e-a187-349e35551cff")
API_CLIENT_SECRET: str = config("API_CLIENT_SECRET", default="cMZ8Q~d.gI6A-j-EMFbbmz.xy6jXOxUKH4dLYdhH")
SWAGGER_UI_CLIENT_ID: str = config("SWAGGER_UI_CLIENT_ID", default=" e1335db3-a7eb-42b0-976c-82c9b5bdf0fb")
AAD_TENANT_ID: str = config("AAD_TENANT_ID", default="72f988bf-86f1-41af-91ab-2d7cd011db47")

AAD_INSTANCE: str = config("AAD_INSTANCE", default="https://login.microsoftonline.com")
API_AUDIENCE: str = config("API_AUDIENCE", default="db8dc4b0-202e-450c-b38d-7396ad9631a5")
Loading