bedrock.cache.websockets

  1import json  # pragma: unit
  2import time  # pragma: unit
  3
  4from jwt import ExpiredSignatureError  # pragma: unit
  5from valkey import Valkey  # pragma: unit
  6
  7from bedrock.cache import get_cache, get_cache_topics  # pragma: unit
  8from bedrock.cache.websockets.websocket_connection_data import WebsocketConnectionData  # pragma: unit
  9from bedrock.config import get_config_params  # pragma: unit
 10from bedrock.endpoints.decorators.protected import AuthTypes, get_auth_header_value  # pragma: unit
 11from bedrock.log import log_config  # pragma: unit
 12from bedrock.websockets.authentication import authenticate_token, authorise_token  # pragma: unit
 13
 14logger = log_config("websocket_connections")  # pragma: unit
 15
 16
 17def save_connection(connection_data: WebsocketConnectionData) -> dict:  # pragma: no cover
 18    return _save_connection(connection_data, get_cache())
 19
 20
 21def _save_connection(connection_data: WebsocketConnectionData, cache: Valkey) -> dict:  # pragma: unit
 22    """
 23    Store a new websocket connection in the cache.
 24    """
 25    data_to_store = connection_data.as_json()
 26    cache.set(connection_data.connection_id, json.dumps(data_to_store), ex=connection_data.token_expiry)
 27    return data_to_store
 28
 29
 30def remove_connection(connection_id: str) -> str | None:  # pragma: no cover
 31    """
 32    Remove a websocket connection from the cache.
 33    """
 34    return _remove_connection(connection_id, get_cache())
 35
 36
 37def _remove_connection(connection_id: str, cache: Valkey) -> str | None:  # pragma: unit
 38    logger.debug(f"Removing connection with id: {connection_id}")
 39    if cache.exists(connection_id):
 40        cache.delete(connection_id)
 41        return connection_id
 42    return None
 43
 44
 45def subscribe_connection_to_topic(connection_id: str, topic: str, filters: dict = None) -> bool:  # pragma: no cover
 46    """
 47    Add a connection to a topic in the cache.
 48    """
 49    config = get_config_params()
 50    return _add_connection_to_topic(connection_id, topic, get_cache(), config["cache"]["topic_prefix"], filters)
 51
 52
 53def _add_connection_to_topic(connection_id: str, topic: str, cache: Valkey, topic_prefix: str,
 54                             filters: dict = None) -> bool:  # pragma: unit
 55    _topic = f"{topic_prefix}-{topic}"
 56    logger.info(f"Subscribing connection id: {connection_id} to topic: {_topic}")
 57    try:
 58        cache.sadd(_topic, connection_id)
 59        if filters:
 60            _add_connection_topic_filters(connection_id, _topic, cache, filters)
 61        else:
 62            _remove_connection_topic_filters(connection_id, _topic, cache)
 63        return True
 64    except Exception as e:
 65        logger.error(f"Failed to add connection {connection_id} to topic {topic}: {e}")
 66        return False
 67
 68
 69def unsubscribe_connection(connection_id: str, topics: list[str] = None) -> None:  # pragma: no cover
 70    """
 71    Remove a connection from a topic in the cache.
 72    """
 73    config = get_config_params()
 74    _topics = topics if topics else get_cache_topics()
 75    prefix = config["cache"]["topic_prefix"] if topics else None
 76    for topic in _topics:
 77        _remove_connection_from_topic(connection_id, topic, get_cache(), prefix)
 78
 79
 80def _remove_connection_from_topic(connection_id: str, topic: str, cache: Valkey,
 81                                  topic_prefix: str) -> bool:  # pragma: unit
 82    _topic = f"{topic_prefix}-{topic}" if topic_prefix else topic
 83    logger.debug(f"Unsubscribing connection id: {connection_id} from topic: {_topic}")
 84    try:
 85        cache.srem(_topic, connection_id)
 86        _remove_connection_topic_filters(connection_id, _topic, cache)
 87        return True
 88    except Exception as e:
 89        logger.error(f"Failed to remove connection {connection_id} from topic {topic}: {e}")
 90        return False
 91
 92
 93def is_connection_token_valid(connection_id: str) -> bool:  # pragma: no cover
 94    """
 95    Check if the token for a connection is valid and recheck it if necessary.
 96    :param connection_id: The ID of the websocket connection.
 97    """
 98    try:
 99        cache = get_cache()
100        ping_interval = int(get_config_params()["websockets"]["ping_interval"])
101        if not _token_needs_recheck(connection_id, ping_interval, cache):
102            return True
103        else:
104            # Load the connection data from the cache
105            connection_data = json.loads(cache.get(connection_id))
106            connection_data = WebsocketConnectionData.from_dict(connection_data)
107
108            # Recheck the token
109            _, last_token_check, ttl = authenticate_token(connection_data.token)
110
111            # Update the connection data with the new token expiry and last check time
112            connection_data.token_expiry = ttl
113            connection_data.last_token_check = last_token_check
114            cache.set(connection_id, json.dumps(connection_data.as_json()), ex=ttl)
115            return True
116    except ExpiredSignatureError:
117        logger.warning(f"Token for connection {connection_id} has expired.")
118        return False
119    except Exception as e:
120        logger.error(f"Failed to ping connection {connection_id}: {e}")
121        return False
122
123
124def _token_needs_recheck(connection_id: str, ping_interval: int, cache: Valkey) -> bool:  # pragma: unit
125    """
126    Check if the token for a connection needs to be rechecked based on the last check time.
127    :param connection_id: The ID of the websocket connection.
128    :param ping_interval: The interval in seconds to check if the token needs rechecking.
129    """
130    if cache.exists(connection_id):
131        data = json.loads(cache.get(connection_id))
132        last_check = data.get("lastTokenCheck", 0)
133        current_time = int(time.time())
134        if current_time - last_check < ping_interval:
135            return False
136    return True
137
138
139def create_websocket_connection_data_from_connect_event(event: dict) -> WebsocketConnectionData:  # pragma: unit
140    """
141    Create unauthorised websocket connection data from a connect event.
142    :param event: The event data containing connection information.
143    :return: A dictionary representing the WebsocketConnectionData.
144    """
145    connection_id = event["requestContext"]["connectionId"]
146    source_ip = event["requestContext"]["identity"]["sourceIp"]
147    return WebsocketConnectionData(connection_id=connection_id, source_ip=source_ip)
148
149
150def create_websocket_connection_data_from_authenticate_event(event: dict) -> WebsocketConnectionData:  # pragma: unit
151    """
152    Create a WebsocketConnectionData object from an authorise event.
153    :param event: The event data authorisation information.
154    :return: A WebsocketConnectionData instance.
155    """
156    token = get_auth_header_value(event, AuthTypes.BEARER)
157    _, last_token_check, ttl = authenticate_token(token)
158    connection_id = event["requestContext"]["connectionId"]
159    source_ip = event["requestContext"]["identity"]["sourceIp"]
160    broadcast_filters = authorise_token(token)
161
162    return WebsocketConnectionData(connection_id=connection_id,
163                                   source_ip=source_ip,
164                                   token=token,
165                                   token_expiry=ttl,
166                                   broadcast_filters=broadcast_filters,
167                                   last_token_check=last_token_check,
168                                   is_authorised=True)
169
170
171def _get_connection_data(connection_id: str, cache: Valkey):  # pragma: unit
172    if cache.exists(connection_id):
173        data = json.loads(cache.get(connection_id))
174        return WebsocketConnectionData.from_dict(data)
175    return None
176
177
178def get_connection_data(
179        connection_id: str) -> WebsocketConnectionData | None:  # pragma: no cover - The logic is already covered by _get_connection_data
180    return _get_connection_data(connection_id, get_cache())
181
182
183def _add_connection_topic_filters(connection_id: str, topic: str, cache: Valkey,
184                                  filters: dict) -> dict:  # pragma: unit
185    key = f"{connection_id}-{topic}"
186    cache.set(key, json.dumps(filters))
187    return filters
188
189
190def _remove_connection_topic_filters(connection_id: str, topic: str, cache: Valkey):  # pragma: unit
191    key = f"{connection_id}-{topic}"
192    cache.delete(key)
193
194
195def get_connection_topic_filters(connection_id: str, topic: str) -> dict:  # pragma: unit
196    return _get_connection_topic_filters(connection_id, topic, get_cache())
197
198
199def _get_connection_topic_filters(connection_id: str, topic: str, cache: Valkey) -> dict | None:  # pragma: unit
200    key = f"{connection_id}-{topic}"
201    if cache.exists(key):
202        filters = json.loads(cache.get(key))
203        return filters
204    return None
logger = <MyLogger BEDROCK-websocket_connections (INFO)>
def save_connection( connection_data: bedrock.cache.websockets.websocket_connection_data.WebsocketConnectionData) -> dict:
18def save_connection(connection_data: WebsocketConnectionData) -> dict:  # pragma: no cover
19    return _save_connection(connection_data, get_cache())
def remove_connection(connection_id: str) -> str | None:
31def remove_connection(connection_id: str) -> str | None:  # pragma: no cover
32    """
33    Remove a websocket connection from the cache.
34    """
35    return _remove_connection(connection_id, get_cache())

Remove a websocket connection from the cache.

def subscribe_connection_to_topic(connection_id: str, topic: str, filters: dict = None) -> bool:
46def subscribe_connection_to_topic(connection_id: str, topic: str, filters: dict = None) -> bool:  # pragma: no cover
47    """
48    Add a connection to a topic in the cache.
49    """
50    config = get_config_params()
51    return _add_connection_to_topic(connection_id, topic, get_cache(), config["cache"]["topic_prefix"], filters)

Add a connection to a topic in the cache.

def unsubscribe_connection(connection_id: str, topics: list[str] = None) -> None:
70def unsubscribe_connection(connection_id: str, topics: list[str] = None) -> None:  # pragma: no cover
71    """
72    Remove a connection from a topic in the cache.
73    """
74    config = get_config_params()
75    _topics = topics if topics else get_cache_topics()
76    prefix = config["cache"]["topic_prefix"] if topics else None
77    for topic in _topics:
78        _remove_connection_from_topic(connection_id, topic, get_cache(), prefix)

Remove a connection from a topic in the cache.

def is_connection_token_valid(connection_id: str) -> bool:
 94def is_connection_token_valid(connection_id: str) -> bool:  # pragma: no cover
 95    """
 96    Check if the token for a connection is valid and recheck it if necessary.
 97    :param connection_id: The ID of the websocket connection.
 98    """
 99    try:
100        cache = get_cache()
101        ping_interval = int(get_config_params()["websockets"]["ping_interval"])
102        if not _token_needs_recheck(connection_id, ping_interval, cache):
103            return True
104        else:
105            # Load the connection data from the cache
106            connection_data = json.loads(cache.get(connection_id))
107            connection_data = WebsocketConnectionData.from_dict(connection_data)
108
109            # Recheck the token
110            _, last_token_check, ttl = authenticate_token(connection_data.token)
111
112            # Update the connection data with the new token expiry and last check time
113            connection_data.token_expiry = ttl
114            connection_data.last_token_check = last_token_check
115            cache.set(connection_id, json.dumps(connection_data.as_json()), ex=ttl)
116            return True
117    except ExpiredSignatureError:
118        logger.warning(f"Token for connection {connection_id} has expired.")
119        return False
120    except Exception as e:
121        logger.error(f"Failed to ping connection {connection_id}: {e}")
122        return False

Check if the token for a connection is valid and recheck it if necessary.

Parameters
  • connection_id: The ID of the websocket connection.
def create_websocket_connection_data_from_connect_event( event: dict) -> bedrock.cache.websockets.websocket_connection_data.WebsocketConnectionData:
140def create_websocket_connection_data_from_connect_event(event: dict) -> WebsocketConnectionData:  # pragma: unit
141    """
142    Create unauthorised websocket connection data from a connect event.
143    :param event: The event data containing connection information.
144    :return: A dictionary representing the WebsocketConnectionData.
145    """
146    connection_id = event["requestContext"]["connectionId"]
147    source_ip = event["requestContext"]["identity"]["sourceIp"]
148    return WebsocketConnectionData(connection_id=connection_id, source_ip=source_ip)

Create unauthorised websocket connection data from a connect event.

Parameters
  • event: The event data containing connection information.
Returns

A dictionary representing the WebsocketConnectionData.

def create_websocket_connection_data_from_authenticate_event( event: dict) -> bedrock.cache.websockets.websocket_connection_data.WebsocketConnectionData:
151def create_websocket_connection_data_from_authenticate_event(event: dict) -> WebsocketConnectionData:  # pragma: unit
152    """
153    Create a WebsocketConnectionData object from an authorise event.
154    :param event: The event data authorisation information.
155    :return: A WebsocketConnectionData instance.
156    """
157    token = get_auth_header_value(event, AuthTypes.BEARER)
158    _, last_token_check, ttl = authenticate_token(token)
159    connection_id = event["requestContext"]["connectionId"]
160    source_ip = event["requestContext"]["identity"]["sourceIp"]
161    broadcast_filters = authorise_token(token)
162
163    return WebsocketConnectionData(connection_id=connection_id,
164                                   source_ip=source_ip,
165                                   token=token,
166                                   token_expiry=ttl,
167                                   broadcast_filters=broadcast_filters,
168                                   last_token_check=last_token_check,
169                                   is_authorised=True)

Create a WebsocketConnectionData object from an authorise event.

Parameters
  • event: The event data authorisation information.
Returns

A WebsocketConnectionData instance.

def get_connection_data( connection_id: str) -> bedrock.cache.websockets.websocket_connection_data.WebsocketConnectionData | None:
179def get_connection_data(
180        connection_id: str) -> WebsocketConnectionData | None:  # pragma: no cover - The logic is already covered by _get_connection_data
181    return _get_connection_data(connection_id, get_cache())
def get_connection_topic_filters(connection_id: str, topic: str) -> dict:
196def get_connection_topic_filters(connection_id: str, topic: str) -> dict:  # pragma: unit
197    return _get_connection_topic_filters(connection_id, topic, get_cache())