diff --git a/src/idpyoidc/message/oauth2/__init__.py b/src/idpyoidc/message/oauth2/__init__.py index b8554f75..86494e68 100644 --- a/src/idpyoidc/message/oauth2/__init__.py +++ b/src/idpyoidc/message/oauth2/__init__.py @@ -404,6 +404,25 @@ class SecurityEventToken(Message): "toe": SINGLE_OPTIONAL_INT, } +class JWTAccessToken(Message): + c_param = { + "iss": SINGLE_REQUIRED_STRING, + "exp": SINGLE_REQUIRED_INT, + "aud": REQUIRED_LIST_OF_STRINGS, + "sub": SINGLE_REQUIRED_STRING, + "client_id": SINGLE_REQUIRED_STRING, + "iat": SINGLE_REQUIRED_INT, + "jti": SINGLE_REQUIRED_STRING, + "auth_time": SINGLE_OPTIONAL_INT, + "acr": SINGLE_OPTIONAL_STRING, + "amr": OPTIONAL_LIST_OF_STRINGS, + 'scope': OPTIONAL_LIST_OF_SP_SEP_STRINGS, + 'groups': OPTIONAL_LIST_OF_STRINGS, + 'roles': OPTIONAL_LIST_OF_STRINGS, + 'entitlements': OPTIONAL_LIST_OF_STRINGS + } + + def factory(msgtype, **kwargs): """ diff --git a/src/idpyoidc/server/token/jwt_token.py b/src/idpyoidc/server/token/jwt_token.py index 9c8ab32a..e552115b 100644 --- a/src/idpyoidc/server/token/jwt_token.py +++ b/src/idpyoidc/server/token/jwt_token.py @@ -1,31 +1,36 @@ from typing import Callable from typing import Optional +from typing import Union from cryptojwt import JWT from cryptojwt.jws.exception import JWSException +from cryptojwt.utils import importer -from idpyoidc.encrypter import init_encrypter from idpyoidc.server.exception import ToOld - -from ..constant import DEFAULT_TOKEN_LIFETIME -from . import Token from . import is_expired +from . import Token from .exception import UnknownToken from .exception import WrongTokenClass +from ..constant import DEFAULT_TOKEN_LIFETIME +from ...message import Message +from ...message.oauth2 import JWTAccessToken class JWTToken(Token): + def __init__( - self, - token_class, - # keyjar: KeyJar = None, - issuer: str = None, - aud: Optional[list] = None, - alg: str = "ES256", - lifetime: int = DEFAULT_TOKEN_LIFETIME, - server_get: Callable = None, - token_type: str = "Bearer", - **kwargs + self, + token_class, + # keyjar: KeyJar = None, + issuer: str = None, + aud: Optional[list] = None, + alg: str = "ES256", + lifetime: int = DEFAULT_TOKEN_LIFETIME, + server_get: Callable = None, + token_type: str = "Bearer", + profile: Optional[Union[Message, str]] = JWTAccessToken, + with_jti: Optional[bool] = False, + **kwargs ): Token.__init__(self, token_class, **kwargs) self.token_type = token_type @@ -40,17 +45,27 @@ def __init__( self.def_aud = aud or [] self.alg = alg + if isinstance(profile, str): + self.profile = importer(profile) + else: + self.profile = profile + self.with_jti = with_jti + + if self.with_jti is False and profile == JWTAccessToken: + self.with_jti = True def load_custom_claims(self, payload: dict = None): # inherit me and do your things here return payload def __call__( - self, - session_id: Optional[str] = "", - token_class: Optional[str] = "", - usage_rules: Optional[dict] = None, - **payload + self, + session_id: Optional[str] = "", + token_class: Optional[str] = "", + usage_rules: Optional[dict] = None, + profile: Optional[Message] = None, + with_jti: Optional[bool] = None, + **payload ) -> str: """ Return a token. @@ -81,6 +96,18 @@ def __call__( lifetime=lifetime, sign_alg=self.alg, ) + if isinstance(payload, Message): # don't mess with it. + pass + else: + if profile: + payload = profile(**payload).to_dict() + elif self.profile: + payload = self.profile(**payload).to_dict() + + if with_jti: + signer.with_jti = True + elif with_jti is None: + signer.with_jti = self.with_jti return signer.pack(payload) diff --git a/tests/test_server_20e_jwt_token.py b/tests/test_server_20e_jwt_token.py index 9bd86599..b363bac2 100644 --- a/tests/test_server_20e_jwt_token.py +++ b/tests/test_server_20e_jwt_token.py @@ -517,7 +517,7 @@ def test_mint_with_scope(self): grant, session_id, code, - scope=["openid"], + scope=["openid", 'foobar'], aud=["https://audience.example.com"], ) @@ -527,7 +527,7 @@ def test_mint_with_scope(self): assert _info["token_class"] == "access_token" # assert _info["eduperson_scoped_affiliation"] == ["staff@example.org"] assert set(_info["aud"]) == {"https://audience.example.com"} - assert _info["scope"] == ["openid"] + assert _info["scope"] == "openid foobar" def test_mint_with_extra(self): _auth_req = AuthorizationRequest( diff --git a/tests/test_server_24_oauth2_token_endpoint.py b/tests/test_server_24_oauth2_token_endpoint.py index 25c479f4..68b63345 100644 --- a/tests/test_server_24_oauth2_token_endpoint.py +++ b/tests/test_server_24_oauth2_token_endpoint.py @@ -3,9 +3,20 @@ import pytest from cryptojwt import JWT +from cryptojwt import KeyJar +from cryptojwt.jws.jws import factory from cryptojwt.key_jar import build_keyjar +from idpyoidc.message import REQUIRED_LIST_OF_STRINGS +from idpyoidc.message import SINGLE_REQUIRED_INT + +from idpyoidc.message import SINGLE_REQUIRED_STRING + +from idpyoidc.message import Message + +from idpyoidc.context import OidcContext from idpyoidc.defaults import JWT_BEARER +from idpyoidc.message.oauth2 import JWTAccessToken from idpyoidc.message.oidc import AccessTokenRequest from idpyoidc.message.oidc import AuthorizationRequest from idpyoidc.message.oidc import RefreshAccessTokenRequest @@ -18,7 +29,7 @@ from idpyoidc.server.exception import InvalidToken from idpyoidc.server.oauth2.authorization import Authorization from idpyoidc.server.oauth2.token import Token -from idpyoidc.server.session import MintingNotAllowed +from idpyoidc.server.token import handler from idpyoidc.server.user_authn.authn_context import INTERNETPROTOCOLPASSWORD from idpyoidc.server.user_info import UserInfo from idpyoidc.time_util import utc_time_sans_frac @@ -162,6 +173,7 @@ def conf(): class TestEndpoint(object): + @pytest.fixture(autouse=True) def create_endpoint(self, conf): server = Server(ASConfiguration(conf=conf, base_path=BASEDIR), cwd=BASEDIR) @@ -777,3 +789,118 @@ def test_refresh_token_request_other_client(self): ) assert isinstance(_resp, TokenErrorResponse) assert _resp.to_dict() == {"error": "invalid_grant", "error_description": "Wrong client"} + + +DEFAULT_TOKEN_HANDLER_ARGS = { + "jwks_file": "private/token_jwks.json", + "code": {"lifetime": 600, "kwargs": {"crypt_conf": CRYPT_CONFIG}}, + "token": { + "class": "idpyoidc.server.token.jwt_token.JWTToken", + "kwargs": { + "lifetime": 3600, + "add_claims_by_scope": True, + "aud": ["https://example.org/appl"] + }, + }, + "refresh": { + "class": "idpyoidc.server.token.jwt_token.JWTToken", + "kwargs": { + "lifetime": 3600, + "aud": ["https://example.org/appl"], + }, + }, +} +TOKEN_HANDLER_ARGS = { + "jwks_file": "private/token_jwks.json", + "code": {"lifetime": 600, "kwargs": {"crypt_conf": CRYPT_CONFIG}}, + "token": { + "class": "idpyoidc.server.token.jwt_token.JWTToken", + "kwargs": { + "lifetime": 3600, + "add_claims_by_scope": True, + "aud": ["https://example.org/appl"], + "profile": 'idpyoidc.message.oauth2.JWTAccessToken', + "with_jti": True + }, + }, + "refresh": { + "class": "idpyoidc.server.token.jwt_token.JWTToken", + "kwargs": { + "lifetime": 3600, + "aud": ["https://example.org/appl"], + }, + }, +} + +CONTEXT = OidcContext() +CONTEXT.cwd = BASEDIR +CONTEXT.issuer = "https://op.example.com" +CONTEXT.cdb = { + "client_1": {} +} +CONTEXT.keyjar = KeyJar() +CONTEXT.keyjar.import_jwks(CLIENT_KEYJAR.export_jwks(private=True), "client_1") +CONTEXT.keyjar.import_jwks(CLIENT_KEYJAR.export_jwks(private=True), "") + +def server_get(what, *args): + if what == "endpoint_context": + if not args: + return CONTEXT + +def test_def_jwttoken(): + _handler = handler.factory(server_get=server_get, **DEFAULT_TOKEN_HANDLER_ARGS) + token_handler = _handler['access_token'] + token_payload = { + 'sub': 'subject_id', + 'aud': 'resource_1', + 'client_id': 'client_1' + } + value = token_handler(session_id='session_id', **token_payload) + + _jws = factory(value) + msg = JWTAccessToken(**_jws.jwt.payload()) + # test if all required claims are there + msg.verify() + assert True + +def test_jwttoken(): + _handler = handler.factory(server_get=server_get, **TOKEN_HANDLER_ARGS) + token_handler = _handler['access_token'] + token_payload = { + 'sub': 'subject_id', + 'aud': 'resource_1', + 'client_id': 'client_1' + } + value = token_handler(session_id='session_id', **token_payload) + + _jws = factory(value) + msg = JWTAccessToken(**_jws.jwt.payload()) + # test if all required claims are there + msg.verify() + assert True + +class MyAccessToken(Message): + c_param = { + "iss": SINGLE_REQUIRED_STRING, + "exp": SINGLE_REQUIRED_INT, + "aud": REQUIRED_LIST_OF_STRINGS, + "sub": SINGLE_REQUIRED_STRING, + "iat": SINGLE_REQUIRED_INT, + 'usage': SINGLE_REQUIRED_STRING + } + +def test_jwttoken_2(): + _handler = handler.factory(server_get=server_get, **TOKEN_HANDLER_ARGS) + token_handler = _handler['access_token'] + token_payload = { + 'sub': 'subject_id', + 'aud': 'Skiresort', + 'usage': 'skilift' + } + value = token_handler(session_id='session_id', profile=MyAccessToken, **token_payload) + + _jws = factory(value) + msg = MyAccessToken(**_jws.jwt.payload()) + # test if all required claims are there + msg.verify() + assert True \ No newline at end of file