bedrock.endpoints.decorators.protected
1import functools # pragma: unit 2from enum import Enum # pragma: unit 3 4from bedrock.config.headers import ValidHeaders # pragma: unit 5from bedrock.exceptions import UnauthorisedException # pragma: unit 6from bedrock.external.auth_handler import get_token # pragma: unit 7 8 9def protected(entities: str | list[str] = None, 10 scopes: str | list[str] = None, scope_mode: str = "all", 11 keys: str | list[str] = None): # pragma: no cover 12 """ 13 Protect an endpoint by ensuring that the request contains a valid JWT. 14 15 The JWT is then converted into a token and the token's matching `entities` attributes are added to the endpoint 16 handler's kwargs. 17 18 Additionally, it adds the token to the event's `tkc_token` attribute. 19 20 :param entities: One entity or a list of entities. 21 :param scopes: Which scopes to check for. 22 :param scope_mode: Whether all or any of the scopes must be present. Defaults to `all`. 23 :param keys: For machine-to-machine integration. Looks for a `x-api-key` header and checks it against the defined key. 24 25 Example usage: for a `GET` to `/countries/` which only requires a valid JWT (i.e. a logged-in user): 26 ```python 27 class Countries(Endpoint): 28 # ... 29 @protected() 30 def get_global(self, event): 31 # ... 32 ``` 33 34 Example usage: for a `POST` to `/countries/` which requires a valid JWT and exposes the user's authorised `country_code`s: 35 ```python 36 class Countries(Endpoint): 37 # ... 38 @protected("country_code") 39 def post_global(self, event, country_code): 40 # ... 41 # `country_code` is now available to be used when filtering data 42 ``` 43 44 Example usage: for a `POST` to `/countries/` which requires a valid JWT with the `country:write` scope and exposes the user's authorised `country_code`s: 45 ```python 46 class Countries(Endpoint): 47 # ... 48 @protected(entities="country_code", scopes="country:write") 49 def post_global(self, event, country_code): 50 # ... 51 # `country_code` is now available to be used when filtering data 52 ``` 53 54 Example usage: for a `POST` to `/countries/` which requires a valid JWT with either `country:write` or `country:read` scopes and exposes the user's authorised `country_code`s: 55 ```python 56 class Countries(Endpoint): 57 # ... 58 @protected(entities="country_code", scopes=["country:write", "country:read"], scope_mode="any") 59 def post_global(self, event, country_code): 60 # ... 61 # `country_code` is now available to be used when filtering data 62 ``` 63 64 Example usage: for a `GET` to `/countries/` which requires a valid key: 65 ```python 66 class Countries(Endpoint): 67 # ... 68 @protected(keys="my-api-key") 69 def get_global(self, event): 70 # ... 71 ``` 72 """ 73 entity_list = [] if not entities else [entities] if isinstance(entities, str) else entities 74 scope_list = [] if not scopes else [scopes] if isinstance(scopes, str) else scopes 75 key_list = [] if not keys else [keys] if isinstance(keys, str) else keys 76 77 def decorator(func): 78 @functools.wraps(func) 79 def wrapper(*args, **kwargs): 80 event = args[1] 81 82 auth_type = get_auth_header_type(event) 83 if not auth_type: 84 raise UnauthorisedException("Missing authentication header") 85 86 extra_kwargs = {} 87 auth_value = get_auth_header_value(event, auth_type) 88 if auth_type == AuthTypes.API_KEY: 89 if auth_value not in key_list: 90 raise UnauthorisedException("Invalid API key") 91 elif auth_type == AuthTypes.BEARER: 92 token = get_token(auth_value) 93 scope_matches = [token.has_permission(scope) for scope in scope_list] 94 if scope_mode == "all" and scope_matches.count(True) != len(scope_matches): 95 raise UnauthorisedException("Insufficient permissions") 96 if scope_mode == "any" and scope_matches.count(True) == 0: 97 raise UnauthorisedException("Insufficient permissions") 98 extra_kwargs = dict((e, getattr(token, e)) for e in entity_list) 99 event["tkc_token"] = token 100 101 return func(*args, **kwargs, **extra_kwargs) 102 103 return wrapper 104 105 return decorator 106 107 108class AuthTypes(Enum): # pragma: unit 109 BEARER = 1 110 API_KEY = 2 111 112 113def get_auth_header_type(event): # pragma: unit 114 headers = event["headers"] 115 authorization = ValidHeaders.AUTHORIZATION.value 116 if authorization in headers \ 117 and headers[authorization] is not None \ 118 and "bearer " in headers[authorization].lower(): 119 return AuthTypes.BEARER 120 if ValidHeaders.API_KEY_HEADER.value in headers: 121 return AuthTypes.API_KEY 122 return None 123 124 125def get_auth_header_value(event, auth_type): # pragma: unit 126 headers = event["headers"] 127 if auth_type == AuthTypes.BEARER: 128 return headers[ValidHeaders.AUTHORIZATION.value].split(" ")[1] 129 if auth_type == AuthTypes.API_KEY: 130 return headers[ValidHeaders.API_KEY_HEADER.value] 131 return None
10def protected(entities: str | list[str] = None, 11 scopes: str | list[str] = None, scope_mode: str = "all", 12 keys: str | list[str] = None): # pragma: no cover 13 """ 14 Protect an endpoint by ensuring that the request contains a valid JWT. 15 16 The JWT is then converted into a token and the token's matching `entities` attributes are added to the endpoint 17 handler's kwargs. 18 19 Additionally, it adds the token to the event's `tkc_token` attribute. 20 21 :param entities: One entity or a list of entities. 22 :param scopes: Which scopes to check for. 23 :param scope_mode: Whether all or any of the scopes must be present. Defaults to `all`. 24 :param keys: For machine-to-machine integration. Looks for a `x-api-key` header and checks it against the defined key. 25 26 Example usage: for a `GET` to `/countries/` which only requires a valid JWT (i.e. a logged-in user): 27 ```python 28 class Countries(Endpoint): 29 # ... 30 @protected() 31 def get_global(self, event): 32 # ... 33 ``` 34 35 Example usage: for a `POST` to `/countries/` which requires a valid JWT and exposes the user's authorised `country_code`s: 36 ```python 37 class Countries(Endpoint): 38 # ... 39 @protected("country_code") 40 def post_global(self, event, country_code): 41 # ... 42 # `country_code` is now available to be used when filtering data 43 ``` 44 45 Example usage: for a `POST` to `/countries/` which requires a valid JWT with the `country:write` scope and exposes the user's authorised `country_code`s: 46 ```python 47 class Countries(Endpoint): 48 # ... 49 @protected(entities="country_code", scopes="country:write") 50 def post_global(self, event, country_code): 51 # ... 52 # `country_code` is now available to be used when filtering data 53 ``` 54 55 Example usage: for a `POST` to `/countries/` which requires a valid JWT with either `country:write` or `country:read` scopes and exposes the user's authorised `country_code`s: 56 ```python 57 class Countries(Endpoint): 58 # ... 59 @protected(entities="country_code", scopes=["country:write", "country:read"], scope_mode="any") 60 def post_global(self, event, country_code): 61 # ... 62 # `country_code` is now available to be used when filtering data 63 ``` 64 65 Example usage: for a `GET` to `/countries/` which requires a valid key: 66 ```python 67 class Countries(Endpoint): 68 # ... 69 @protected(keys="my-api-key") 70 def get_global(self, event): 71 # ... 72 ``` 73 """ 74 entity_list = [] if not entities else [entities] if isinstance(entities, str) else entities 75 scope_list = [] if not scopes else [scopes] if isinstance(scopes, str) else scopes 76 key_list = [] if not keys else [keys] if isinstance(keys, str) else keys 77 78 def decorator(func): 79 @functools.wraps(func) 80 def wrapper(*args, **kwargs): 81 event = args[1] 82 83 auth_type = get_auth_header_type(event) 84 if not auth_type: 85 raise UnauthorisedException("Missing authentication header") 86 87 extra_kwargs = {} 88 auth_value = get_auth_header_value(event, auth_type) 89 if auth_type == AuthTypes.API_KEY: 90 if auth_value not in key_list: 91 raise UnauthorisedException("Invalid API key") 92 elif auth_type == AuthTypes.BEARER: 93 token = get_token(auth_value) 94 scope_matches = [token.has_permission(scope) for scope in scope_list] 95 if scope_mode == "all" and scope_matches.count(True) != len(scope_matches): 96 raise UnauthorisedException("Insufficient permissions") 97 if scope_mode == "any" and scope_matches.count(True) == 0: 98 raise UnauthorisedException("Insufficient permissions") 99 extra_kwargs = dict((e, getattr(token, e)) for e in entity_list) 100 event["tkc_token"] = token 101 102 return func(*args, **kwargs, **extra_kwargs) 103 104 return wrapper 105 106 return decorator
Protect an endpoint by ensuring that the request contains a valid JWT.
The JWT is then converted into a token and the token's matching entities attributes are added to the endpoint
handler's kwargs.
Additionally, it adds the token to the event's tkc_token attribute.
Parameters
- entities: One entity or a list of entities.
- scopes: Which scopes to check for.
- scope_mode: Whether all or any of the scopes must be present. Defaults to
all. - keys: For machine-to-machine integration. Looks for a
x-api-keyheader and checks it against the defined key.
Example usage: for a GET to /countries/ which only requires a valid JWT (i.e. a logged-in user):
class Countries(Endpoint):
# ...
@protected()
def get_global(self, event):
# ...
Example usage: for a POST to /countries/ which requires a valid JWT and exposes the user's authorised country_codes:
class Countries(Endpoint):
# ...
@protected("country_code")
def post_global(self, event, country_code):
# ...
# `country_code` is now available to be used when filtering data
Example usage: for a POST to /countries/ which requires a valid JWT with the country:write scope and exposes the user's authorised country_codes:
class Countries(Endpoint):
# ...
@protected(entities="country_code", scopes="country:write")
def post_global(self, event, country_code):
# ...
# `country_code` is now available to be used when filtering data
Example usage: for a POST to /countries/ which requires a valid JWT with either country:write or country:read scopes and exposes the user's authorised country_codes:
class Countries(Endpoint):
# ...
@protected(entities="country_code", scopes=["country:write", "country:read"], scope_mode="any")
def post_global(self, event, country_code):
# ...
# `country_code` is now available to be used when filtering data
Example usage: for a GET to /countries/ which requires a valid key:
class Countries(Endpoint):
# ...
@protected(keys="my-api-key")
def get_global(self, event):
# ...
114def get_auth_header_type(event): # pragma: unit 115 headers = event["headers"] 116 authorization = ValidHeaders.AUTHORIZATION.value 117 if authorization in headers \ 118 and headers[authorization] is not None \ 119 and "bearer " in headers[authorization].lower(): 120 return AuthTypes.BEARER 121 if ValidHeaders.API_KEY_HEADER.value in headers: 122 return AuthTypes.API_KEY 123 return None
126def get_auth_header_value(event, auth_type): # pragma: unit 127 headers = event["headers"] 128 if auth_type == AuthTypes.BEARER: 129 return headers[ValidHeaders.AUTHORIZATION.value].split(" ")[1] 130 if auth_type == AuthTypes.API_KEY: 131 return headers[ValidHeaders.API_KEY_HEADER.value] 132 return None