bedrock.cache.websockets
1import json # pragma: unit 2import time # pragma: unit 3 4from jwt import ExpiredSignatureError # pragma: unit 5from valkey import Valkey # pragma: unit 6 7from bedrock.cache import get_cache, get_cache_topics # pragma: unit 8from bedrock.cache.websockets.websocket_connection_data import WebsocketConnectionData # pragma: unit 9from bedrock.config import get_config_params # pragma: unit 10from bedrock.endpoints.decorators.protected import AuthTypes, get_auth_header_value # pragma: unit 11from bedrock.log import log_config # pragma: unit 12from bedrock.websockets.authentication import authenticate_token, authorise_token # pragma: unit 13 14logger = log_config("websocket_connections") # pragma: unit 15 16 17def save_connection(connection_data: WebsocketConnectionData) -> dict: # pragma: no cover 18 return _save_connection(connection_data, get_cache()) 19 20 21def _save_connection(connection_data: WebsocketConnectionData, cache: Valkey) -> dict: # pragma: unit 22 """ 23 Store a new websocket connection in the cache. 24 """ 25 data_to_store = connection_data.as_json() 26 cache.set(connection_data.connection_id, json.dumps(data_to_store), ex=connection_data.token_expiry) 27 return data_to_store 28 29 30def remove_connection(connection_id: str) -> str | None: # pragma: no cover 31 """ 32 Remove a websocket connection from the cache. 33 """ 34 return _remove_connection(connection_id, get_cache()) 35 36 37def _remove_connection(connection_id: str, cache: Valkey) -> str | None: # pragma: unit 38 logger.debug(f"Removing connection with id: {connection_id}") 39 if cache.exists(connection_id): 40 cache.delete(connection_id) 41 return connection_id 42 return None 43 44 45def subscribe_connection_to_topic(connection_id: str, topic: str, filters: dict = None) -> bool: # pragma: no cover 46 """ 47 Add a connection to a topic in the cache. 48 """ 49 config = get_config_params() 50 return _add_connection_to_topic(connection_id, topic, get_cache(), config["cache"]["topic_prefix"], filters) 51 52 53def _add_connection_to_topic(connection_id: str, topic: str, cache: Valkey, topic_prefix: str, 54 filters: dict = None) -> bool: # pragma: unit 55 _topic = f"{topic_prefix}-{topic}" 56 logger.info(f"Subscribing connection id: {connection_id} to topic: {_topic}") 57 try: 58 cache.sadd(_topic, connection_id) 59 if filters: 60 _add_connection_topic_filters(connection_id, _topic, cache, filters) 61 else: 62 _remove_connection_topic_filters(connection_id, _topic, cache) 63 return True 64 except Exception as e: 65 logger.error(f"Failed to add connection {connection_id} to topic {topic}: {e}") 66 return False 67 68 69def unsubscribe_connection(connection_id: str, topics: list[str] = None) -> None: # pragma: no cover 70 """ 71 Remove a connection from a topic in the cache. 72 """ 73 config = get_config_params() 74 _topics = topics if topics else get_cache_topics() 75 prefix = config["cache"]["topic_prefix"] if topics else None 76 for topic in _topics: 77 _remove_connection_from_topic(connection_id, topic, get_cache(), prefix) 78 79 80def _remove_connection_from_topic(connection_id: str, topic: str, cache: Valkey, 81 topic_prefix: str) -> bool: # pragma: unit 82 _topic = f"{topic_prefix}-{topic}" if topic_prefix else topic 83 logger.debug(f"Unsubscribing connection id: {connection_id} from topic: {_topic}") 84 try: 85 cache.srem(_topic, connection_id) 86 _remove_connection_topic_filters(connection_id, _topic, cache) 87 return True 88 except Exception as e: 89 logger.error(f"Failed to remove connection {connection_id} from topic {topic}: {e}") 90 return False 91 92 93def is_connection_token_valid(connection_id: str) -> bool: # pragma: no cover 94 """ 95 Check if the token for a connection is valid and recheck it if necessary. 96 :param connection_id: The ID of the websocket connection. 97 """ 98 try: 99 cache = get_cache() 100 ping_interval = int(get_config_params()["websockets"]["ping_interval"]) 101 if not _token_needs_recheck(connection_id, ping_interval, cache): 102 return True 103 else: 104 # Load the connection data from the cache 105 connection_data = json.loads(cache.get(connection_id)) 106 connection_data = WebsocketConnectionData.from_dict(connection_data) 107 108 # Recheck the token 109 _, last_token_check, ttl = authenticate_token(connection_data.token) 110 111 # Update the connection data with the new token expiry and last check time 112 connection_data.token_expiry = ttl 113 connection_data.last_token_check = last_token_check 114 cache.set(connection_id, json.dumps(connection_data.as_json()), ex=ttl) 115 return True 116 except ExpiredSignatureError: 117 logger.warning(f"Token for connection {connection_id} has expired.") 118 return False 119 except Exception as e: 120 logger.error(f"Failed to ping connection {connection_id}: {e}") 121 return False 122 123 124def _token_needs_recheck(connection_id: str, ping_interval: int, cache: Valkey) -> bool: # pragma: unit 125 """ 126 Check if the token for a connection needs to be rechecked based on the last check time. 127 :param connection_id: The ID of the websocket connection. 128 :param ping_interval: The interval in seconds to check if the token needs rechecking. 129 """ 130 if cache.exists(connection_id): 131 data = json.loads(cache.get(connection_id)) 132 last_check = data.get("lastTokenCheck", 0) 133 current_time = int(time.time()) 134 if current_time - last_check < ping_interval: 135 return False 136 return True 137 138 139def create_websocket_connection_data_from_connect_event(event: dict) -> WebsocketConnectionData: # pragma: unit 140 """ 141 Create unauthorised websocket connection data from a connect event. 142 :param event: The event data containing connection information. 143 :return: A dictionary representing the WebsocketConnectionData. 144 """ 145 connection_id = event["requestContext"]["connectionId"] 146 source_ip = event["requestContext"]["identity"]["sourceIp"] 147 return WebsocketConnectionData(connection_id=connection_id, source_ip=source_ip) 148 149 150def create_websocket_connection_data_from_authenticate_event(event: dict) -> WebsocketConnectionData: # pragma: unit 151 """ 152 Create a WebsocketConnectionData object from an authorise event. 153 :param event: The event data authorisation information. 154 :return: A WebsocketConnectionData instance. 155 """ 156 token = get_auth_header_value(event, AuthTypes.BEARER) 157 _, last_token_check, ttl = authenticate_token(token) 158 connection_id = event["requestContext"]["connectionId"] 159 source_ip = event["requestContext"]["identity"]["sourceIp"] 160 broadcast_filters = authorise_token(token) 161 162 return WebsocketConnectionData(connection_id=connection_id, 163 source_ip=source_ip, 164 token=token, 165 token_expiry=ttl, 166 broadcast_filters=broadcast_filters, 167 last_token_check=last_token_check, 168 is_authorised=True) 169 170 171def _get_connection_data(connection_id: str, cache: Valkey): # pragma: unit 172 if cache.exists(connection_id): 173 data = json.loads(cache.get(connection_id)) 174 return WebsocketConnectionData.from_dict(data) 175 return None 176 177 178def get_connection_data( 179 connection_id: str) -> WebsocketConnectionData | None: # pragma: no cover - The logic is already covered by _get_connection_data 180 return _get_connection_data(connection_id, get_cache()) 181 182 183def _add_connection_topic_filters(connection_id: str, topic: str, cache: Valkey, 184 filters: dict) -> dict: # pragma: unit 185 key = f"{connection_id}-{topic}" 186 cache.set(key, json.dumps(filters)) 187 return filters 188 189 190def _remove_connection_topic_filters(connection_id: str, topic: str, cache: Valkey): # pragma: unit 191 key = f"{connection_id}-{topic}" 192 cache.delete(key) 193 194 195def get_connection_topic_filters(connection_id: str, topic: str) -> dict: # pragma: unit 196 return _get_connection_topic_filters(connection_id, topic, get_cache()) 197 198 199def _get_connection_topic_filters(connection_id: str, topic: str, cache: Valkey) -> dict | None: # pragma: unit 200 key = f"{connection_id}-{topic}" 201 if cache.exists(key): 202 filters = json.loads(cache.get(key)) 203 return filters 204 return None
logger =
<MyLogger BEDROCK-websocket_connections (INFO)>
def
save_connection( connection_data: bedrock.cache.websockets.websocket_connection_data.WebsocketConnectionData) -> dict:
def
remove_connection(connection_id: str) -> str | None:
31def remove_connection(connection_id: str) -> str | None: # pragma: no cover 32 """ 33 Remove a websocket connection from the cache. 34 """ 35 return _remove_connection(connection_id, get_cache())
Remove a websocket connection from the cache.
def
subscribe_connection_to_topic(connection_id: str, topic: str, filters: dict = None) -> bool:
46def subscribe_connection_to_topic(connection_id: str, topic: str, filters: dict = None) -> bool: # pragma: no cover 47 """ 48 Add a connection to a topic in the cache. 49 """ 50 config = get_config_params() 51 return _add_connection_to_topic(connection_id, topic, get_cache(), config["cache"]["topic_prefix"], filters)
Add a connection to a topic in the cache.
def
unsubscribe_connection(connection_id: str, topics: list[str] = None) -> None:
70def unsubscribe_connection(connection_id: str, topics: list[str] = None) -> None: # pragma: no cover 71 """ 72 Remove a connection from a topic in the cache. 73 """ 74 config = get_config_params() 75 _topics = topics if topics else get_cache_topics() 76 prefix = config["cache"]["topic_prefix"] if topics else None 77 for topic in _topics: 78 _remove_connection_from_topic(connection_id, topic, get_cache(), prefix)
Remove a connection from a topic in the cache.
def
is_connection_token_valid(connection_id: str) -> bool:
94def is_connection_token_valid(connection_id: str) -> bool: # pragma: no cover 95 """ 96 Check if the token for a connection is valid and recheck it if necessary. 97 :param connection_id: The ID of the websocket connection. 98 """ 99 try: 100 cache = get_cache() 101 ping_interval = int(get_config_params()["websockets"]["ping_interval"]) 102 if not _token_needs_recheck(connection_id, ping_interval, cache): 103 return True 104 else: 105 # Load the connection data from the cache 106 connection_data = json.loads(cache.get(connection_id)) 107 connection_data = WebsocketConnectionData.from_dict(connection_data) 108 109 # Recheck the token 110 _, last_token_check, ttl = authenticate_token(connection_data.token) 111 112 # Update the connection data with the new token expiry and last check time 113 connection_data.token_expiry = ttl 114 connection_data.last_token_check = last_token_check 115 cache.set(connection_id, json.dumps(connection_data.as_json()), ex=ttl) 116 return True 117 except ExpiredSignatureError: 118 logger.warning(f"Token for connection {connection_id} has expired.") 119 return False 120 except Exception as e: 121 logger.error(f"Failed to ping connection {connection_id}: {e}") 122 return False
Check if the token for a connection is valid and recheck it if necessary.
Parameters
- connection_id: The ID of the websocket connection.
def
create_websocket_connection_data_from_connect_event( event: dict) -> bedrock.cache.websockets.websocket_connection_data.WebsocketConnectionData:
140def create_websocket_connection_data_from_connect_event(event: dict) -> WebsocketConnectionData: # pragma: unit 141 """ 142 Create unauthorised websocket connection data from a connect event. 143 :param event: The event data containing connection information. 144 :return: A dictionary representing the WebsocketConnectionData. 145 """ 146 connection_id = event["requestContext"]["connectionId"] 147 source_ip = event["requestContext"]["identity"]["sourceIp"] 148 return WebsocketConnectionData(connection_id=connection_id, source_ip=source_ip)
Create unauthorised websocket connection data from a connect event.
Parameters
- event: The event data containing connection information.
Returns
A dictionary representing the WebsocketConnectionData.
def
create_websocket_connection_data_from_authenticate_event( event: dict) -> bedrock.cache.websockets.websocket_connection_data.WebsocketConnectionData:
151def create_websocket_connection_data_from_authenticate_event(event: dict) -> WebsocketConnectionData: # pragma: unit 152 """ 153 Create a WebsocketConnectionData object from an authorise event. 154 :param event: The event data authorisation information. 155 :return: A WebsocketConnectionData instance. 156 """ 157 token = get_auth_header_value(event, AuthTypes.BEARER) 158 _, last_token_check, ttl = authenticate_token(token) 159 connection_id = event["requestContext"]["connectionId"] 160 source_ip = event["requestContext"]["identity"]["sourceIp"] 161 broadcast_filters = authorise_token(token) 162 163 return WebsocketConnectionData(connection_id=connection_id, 164 source_ip=source_ip, 165 token=token, 166 token_expiry=ttl, 167 broadcast_filters=broadcast_filters, 168 last_token_check=last_token_check, 169 is_authorised=True)
Create a WebsocketConnectionData object from an authorise event.
Parameters
- event: The event data authorisation information.
Returns
A WebsocketConnectionData instance.
def
get_connection_data( connection_id: str) -> bedrock.cache.websockets.websocket_connection_data.WebsocketConnectionData | None:
def
get_connection_topic_filters(connection_id: str, topic: str) -> dict: