bedrock.endpoints.decorators.protected

  1import functools  # pragma: unit
  2from enum import Enum  # pragma: unit
  3
  4from bedrock.config.headers import ValidHeaders  # pragma: unit
  5from bedrock.exceptions import UnauthorisedException  # pragma: unit
  6from bedrock.external.auth_handler import get_token  # pragma: unit
  7
  8
  9def protected(entities: str | list[str] = None,
 10              scopes: str | list[str] = None, scope_mode: str = "all",
 11              keys: str | list[str] = None):  # pragma: no cover
 12    """
 13    Protect an endpoint by ensuring that the request contains a valid JWT.
 14
 15    The JWT is then converted into a token and the token's matching `entities` attributes are added to the endpoint
 16    handler's kwargs.
 17
 18    Additionally, it adds the token to the event's `tkc_token` attribute.
 19
 20    :param entities: One entity or a list of entities.
 21    :param scopes: Which scopes to check for.
 22    :param scope_mode: Whether all or any of the scopes must be present. Defaults to `all`.
 23    :param keys: For machine-to-machine integration. Looks for a `x-api-key` header and checks it against the defined key.
 24
 25    Example usage: for a `GET` to `/countries/` which only requires a valid JWT (i.e. a logged-in user):
 26    ```python
 27    class Countries(Endpoint):
 28        # ...
 29        @protected()
 30        def get_global(self, event):
 31            # ...
 32    ```
 33
 34    Example usage: for a `POST` to `/countries/` which requires a valid JWT and exposes the user's authorised `country_code`s:
 35    ```python
 36    class Countries(Endpoint):
 37        # ...
 38        @protected("country_code")
 39        def post_global(self, event, country_code):
 40            # ...
 41            # `country_code` is now available to be used when filtering data
 42    ```
 43
 44    Example usage: for a `POST` to `/countries/` which requires a valid JWT with the `country:write` scope and exposes the user's authorised `country_code`s:
 45    ```python
 46    class Countries(Endpoint):
 47        # ...
 48        @protected(entities="country_code", scopes="country:write")
 49        def post_global(self, event, country_code):
 50            # ...
 51            # `country_code` is now available to be used when filtering data
 52    ```
 53
 54    Example usage: for a `POST` to `/countries/` which requires a valid JWT with either `country:write` or `country:read` scopes and exposes the user's authorised `country_code`s:
 55    ```python
 56    class Countries(Endpoint):
 57        # ...
 58        @protected(entities="country_code", scopes=["country:write", "country:read"], scope_mode="any")
 59        def post_global(self, event, country_code):
 60            # ...
 61            # `country_code` is now available to be used when filtering data
 62    ```
 63
 64    Example usage: for a `GET` to `/countries/` which requires a valid key:
 65    ```python
 66    class Countries(Endpoint):
 67        # ...
 68        @protected(keys="my-api-key")
 69        def get_global(self, event):
 70            # ...
 71    ```
 72    """
 73    entity_list = [] if not entities else [entities] if isinstance(entities, str) else entities
 74    scope_list = [] if not scopes else [scopes] if isinstance(scopes, str) else scopes
 75    key_list = [] if not keys else [keys] if isinstance(keys, str) else keys
 76
 77    def decorator(func):
 78        @functools.wraps(func)
 79        def wrapper(*args, **kwargs):
 80            event = args[1]
 81
 82            auth_type = get_auth_header_type(event)
 83            if not auth_type:
 84                raise UnauthorisedException("Missing authentication header")
 85
 86            extra_kwargs = {}
 87            auth_value = get_auth_header_value(event, auth_type)
 88            if auth_type == AuthTypes.API_KEY:
 89                if auth_value not in key_list:
 90                    raise UnauthorisedException("Invalid API key")
 91            elif auth_type == AuthTypes.BEARER:
 92                token = get_token(auth_value)
 93                scope_matches = [token.has_permission(scope) for scope in scope_list]
 94                if scope_mode == "all" and scope_matches.count(True) != len(scope_matches):
 95                    raise UnauthorisedException("Insufficient permissions")
 96                if scope_mode == "any" and scope_matches.count(True) == 0:
 97                    raise UnauthorisedException("Insufficient permissions")
 98                extra_kwargs = dict((e, getattr(token, e)) for e in entity_list)
 99                event["tkc_token"] = token
100
101            return func(*args, **kwargs, **extra_kwargs)
102
103        return wrapper
104
105    return decorator
106
107
108class AuthTypes(Enum):  # pragma: unit
109    BEARER = 1
110    API_KEY = 2
111
112
113def get_auth_header_type(event):  # pragma: unit
114    headers = event["headers"]
115    authorization = ValidHeaders.AUTHORIZATION.value
116    if authorization in headers \
117            and headers[authorization] is not None \
118            and "bearer " in headers[authorization].lower():
119        return AuthTypes.BEARER
120    if ValidHeaders.API_KEY_HEADER.value in headers:
121        return AuthTypes.API_KEY
122    return None
123
124
125def get_auth_header_value(event, auth_type):  # pragma: unit
126    headers = event["headers"]
127    if auth_type == AuthTypes.BEARER:
128        return headers[ValidHeaders.AUTHORIZATION.value].split(" ")[1]
129    if auth_type == AuthTypes.API_KEY:
130        return headers[ValidHeaders.API_KEY_HEADER.value]
131    return None
def protected( entities: str | list[str] = None, scopes: str | list[str] = None, scope_mode: str = 'all', keys: str | list[str] = None):
 10def protected(entities: str | list[str] = None,
 11              scopes: str | list[str] = None, scope_mode: str = "all",
 12              keys: str | list[str] = None):  # pragma: no cover
 13    """
 14    Protect an endpoint by ensuring that the request contains a valid JWT.
 15
 16    The JWT is then converted into a token and the token's matching `entities` attributes are added to the endpoint
 17    handler's kwargs.
 18
 19    Additionally, it adds the token to the event's `tkc_token` attribute.
 20
 21    :param entities: One entity or a list of entities.
 22    :param scopes: Which scopes to check for.
 23    :param scope_mode: Whether all or any of the scopes must be present. Defaults to `all`.
 24    :param keys: For machine-to-machine integration. Looks for a `x-api-key` header and checks it against the defined key.
 25
 26    Example usage: for a `GET` to `/countries/` which only requires a valid JWT (i.e. a logged-in user):
 27    ```python
 28    class Countries(Endpoint):
 29        # ...
 30        @protected()
 31        def get_global(self, event):
 32            # ...
 33    ```
 34
 35    Example usage: for a `POST` to `/countries/` which requires a valid JWT and exposes the user's authorised `country_code`s:
 36    ```python
 37    class Countries(Endpoint):
 38        # ...
 39        @protected("country_code")
 40        def post_global(self, event, country_code):
 41            # ...
 42            # `country_code` is now available to be used when filtering data
 43    ```
 44
 45    Example usage: for a `POST` to `/countries/` which requires a valid JWT with the `country:write` scope and exposes the user's authorised `country_code`s:
 46    ```python
 47    class Countries(Endpoint):
 48        # ...
 49        @protected(entities="country_code", scopes="country:write")
 50        def post_global(self, event, country_code):
 51            # ...
 52            # `country_code` is now available to be used when filtering data
 53    ```
 54
 55    Example usage: for a `POST` to `/countries/` which requires a valid JWT with either `country:write` or `country:read` scopes and exposes the user's authorised `country_code`s:
 56    ```python
 57    class Countries(Endpoint):
 58        # ...
 59        @protected(entities="country_code", scopes=["country:write", "country:read"], scope_mode="any")
 60        def post_global(self, event, country_code):
 61            # ...
 62            # `country_code` is now available to be used when filtering data
 63    ```
 64
 65    Example usage: for a `GET` to `/countries/` which requires a valid key:
 66    ```python
 67    class Countries(Endpoint):
 68        # ...
 69        @protected(keys="my-api-key")
 70        def get_global(self, event):
 71            # ...
 72    ```
 73    """
 74    entity_list = [] if not entities else [entities] if isinstance(entities, str) else entities
 75    scope_list = [] if not scopes else [scopes] if isinstance(scopes, str) else scopes
 76    key_list = [] if not keys else [keys] if isinstance(keys, str) else keys
 77
 78    def decorator(func):
 79        @functools.wraps(func)
 80        def wrapper(*args, **kwargs):
 81            event = args[1]
 82
 83            auth_type = get_auth_header_type(event)
 84            if not auth_type:
 85                raise UnauthorisedException("Missing authentication header")
 86
 87            extra_kwargs = {}
 88            auth_value = get_auth_header_value(event, auth_type)
 89            if auth_type == AuthTypes.API_KEY:
 90                if auth_value not in key_list:
 91                    raise UnauthorisedException("Invalid API key")
 92            elif auth_type == AuthTypes.BEARER:
 93                token = get_token(auth_value)
 94                scope_matches = [token.has_permission(scope) for scope in scope_list]
 95                if scope_mode == "all" and scope_matches.count(True) != len(scope_matches):
 96                    raise UnauthorisedException("Insufficient permissions")
 97                if scope_mode == "any" and scope_matches.count(True) == 0:
 98                    raise UnauthorisedException("Insufficient permissions")
 99                extra_kwargs = dict((e, getattr(token, e)) for e in entity_list)
100                event["tkc_token"] = token
101
102            return func(*args, **kwargs, **extra_kwargs)
103
104        return wrapper
105
106    return decorator

Protect an endpoint by ensuring that the request contains a valid JWT.

The JWT is then converted into a token and the token's matching entities attributes are added to the endpoint handler's kwargs.

Additionally, it adds the token to the event's tkc_token attribute.

Parameters
  • entities: One entity or a list of entities.
  • scopes: Which scopes to check for.
  • scope_mode: Whether all or any of the scopes must be present. Defaults to all.
  • keys: For machine-to-machine integration. Looks for a x-api-key header and checks it against the defined key.

Example usage: for a GET to /countries/ which only requires a valid JWT (i.e. a logged-in user):

class Countries(Endpoint):
    # ...
    @protected()
    def get_global(self, event):
        # ...

Example usage: for a POST to /countries/ which requires a valid JWT and exposes the user's authorised country_codes:

class Countries(Endpoint):
    # ...
    @protected("country_code")
    def post_global(self, event, country_code):
        # ...
        # `country_code` is now available to be used when filtering data

Example usage: for a POST to /countries/ which requires a valid JWT with the country:write scope and exposes the user's authorised country_codes:

class Countries(Endpoint):
    # ...
    @protected(entities="country_code", scopes="country:write")
    def post_global(self, event, country_code):
        # ...
        # `country_code` is now available to be used when filtering data

Example usage: for a POST to /countries/ which requires a valid JWT with either country:write or country:read scopes and exposes the user's authorised country_codes:

class Countries(Endpoint):
    # ...
    @protected(entities="country_code", scopes=["country:write", "country:read"], scope_mode="any")
    def post_global(self, event, country_code):
        # ...
        # `country_code` is now available to be used when filtering data

Example usage: for a GET to /countries/ which requires a valid key:

class Countries(Endpoint):
    # ...
    @protected(keys="my-api-key")
    def get_global(self, event):
        # ...
class AuthTypes(enum.Enum):
109class AuthTypes(Enum):  # pragma: unit
110    BEARER = 1
111    API_KEY = 2
BEARER = <AuthTypes.BEARER: 1>
API_KEY = <AuthTypes.API_KEY: 2>
def get_auth_header_type(event):
114def get_auth_header_type(event):  # pragma: unit
115    headers = event["headers"]
116    authorization = ValidHeaders.AUTHORIZATION.value
117    if authorization in headers \
118            and headers[authorization] is not None \
119            and "bearer " in headers[authorization].lower():
120        return AuthTypes.BEARER
121    if ValidHeaders.API_KEY_HEADER.value in headers:
122        return AuthTypes.API_KEY
123    return None
def get_auth_header_value(event, auth_type):
126def get_auth_header_value(event, auth_type):  # pragma: unit
127    headers = event["headers"]
128    if auth_type == AuthTypes.BEARER:
129        return headers[ValidHeaders.AUTHORIZATION.value].split(" ")[1]
130    if auth_type == AuthTypes.API_KEY:
131        return headers[ValidHeaders.API_KEY_HEADER.value]
132    return None