bedrock.endpoints.local_runner

  1import asyncio  # pragma: unit
  2import json  # pragma: unit
  3import re  # pragma: unit
  4import uuid  # pragma: unit
  5from base64 import b64encode  # pragma: unit
  6from http.server import BaseHTTPRequestHandler  # pragma: unit
  7from json import dumps, loads  # pragma: unit
  8from urllib.parse import unquote  # pragma: unit
  9
 10import validators  # pragma: unit
 11import websockets  # pragma: unit
 12
 13from bedrock.config.headers import ValidHeaders  # pragma: unit
 14from bedrock.log import log_config  # pragma: unit
 15from bedrock.websockets.routes import (
 16    route_connect_handler,
 17    route_default_handler,
 18    route_ping_handler,
 19    route_disconnect_handler,
 20    route_subscribe_handler,
 21    route_unsubscribe_handler,
 22    route_send_message_handler,
 23    route_authenticate_handler,
 24)  # pragma: unit
 25from bedrock.websockets.websocket_gateway_factory import get_websocket_api_gateway_client  # pragma: unit
 26
 27log = log_config("local_runner")  # pragma: unit
 28LAMBDA_HANDLER = None  # pragma: unit
 29MAPPINGS = None  # pragma: unit
 30
 31
 32def split_query_params_in_path(path: str) -> map:  # pragma: unit
 33    if not path:
 34        raise ValueError('path cannot be None')
 35    return map(lambda x: (x.split("=")[0], unquote(x.split("=")[1])), path.split("?")[1].split("&"))
 36
 37
 38def to_multi_value_params(path: str) -> dict:  # pragma: unit
 39    try:
 40        d = {}
 41        tuples = split_query_params_in_path(path)
 42        for pair in tuples:
 43            if pair[0] not in d:
 44                d[pair[0]] = []
 45            d[pair[0]] = d[pair[0]] + [unquote(pair[1])]
 46        return d
 47    except:  # pragma: no cover
 48        return None
 49
 50
 51def to_single_value_params(path: str) -> dict:  # pragma: unit
 52    try:
 53        return dict(split_query_params_in_path(path))
 54    except:  # pragma: no cover
 55        return None
 56
 57
 58def path_matches_mapping(path: str, mapping: str):  # pragma: unit
 59    _mapping = f'/{mapping.strip("/")}'
 60    matcher_expression = "^" + re.sub(r"\{\w*\}", "([^/]+)", _mapping) + "/?$"
 61    just_path = path.split("?")[0]
 62    return re.match(matcher_expression, just_path) is not None
 63
 64
 65def get_path_params(path: str, mapping: str):  # pragma: unit
 66    _mapping = f'/{mapping.strip("/")}'
 67    matcher_expression = "^" + re.sub(r"\{\w*\}", "([^/]+)", _mapping) + "/?$"
 68    just_path = path.split("?")[0]
 69    value_match = re.match(matcher_expression, just_path)
 70    key_match = re.match(matcher_expression, _mapping)
 71    return dict(((re.sub(r"[\{\}]", "", key), value_match.groups()[i])) for i, key in enumerate(key_match.groups()))
 72
 73
 74def path_to_path_info(path: str, mappings: dict):  # pragma: unit
 75    mapping_paths = mappings.items() if mappings else []
 76    for mapping, endpoint in mapping_paths:
 77        if path_matches_mapping(path, mapping):
 78            return {
 79                "resource": mapping,
 80                "params": get_path_params(path, mapping),
 81                "path": path
 82            }
 83    return {
 84        "resource": "/??????",
 85        "params": {},
 86        "path": path
 87    }
 88
 89
 90def to_api_gateway_request(path: str,
 91                           method: str,
 92                           headers: dict,
 93                           body: dict = None,
 94                           mappings: dict = None):  # pragma: unit
 95    path_info = path_to_path_info(path, mappings)
 96    return {
 97        "resource": path_info["resource"],
 98        "path": path,
 99        "httpMethod": method,
100        "queryStringParameters": to_single_value_params(path),
101        "multiValueQueryStringParameters": to_multi_value_params(path),
102        "pathParameters": path_info["params"],
103        "stageVariables": None,
104        "body": body,
105        "isBase64Encoded": False,
106        "headers": headers
107    }
108
109
110def to_msk_request(path, body: str or dict = {}, mappings: dict = None):  # pragma: unit
111    path_info = path_to_path_info(path, mappings)
112    _body = body if isinstance(body, dict) else loads(body)
113    if _body != {}:
114        for partition in _body:
115            for event in _body[partition]:
116                event['value'] = b64encode(dumps(event['value']).encode('ascii')).decode('ascii')
117
118    return {
119        "resource": path_info["resource"],
120        "eventSource": "aws:kafka",
121        "eventSourceArn": "arn:aws:kafka:eu-west-1:888888888888:cluster/kafka-cluster-test/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee-1",
122        "bootstrapServers": "b-2.kafkaclustertest.aaaaaa.bb.kafka.eu-west-1.amazonaws.com:9094,b-3.kafkaclustertest.cccccc.dd.kafka.eu-west-1.amazonaws.com:9094,b-1.kafkaclustertest.eeeeee.ff.kafka.eu-west-1.amazonaws.com:9094",
123        "records": _body
124    }
125
126
127def set_lambda_handler(handler):  # pragma: no cover
128    global LAMBDA_HANDLER
129    LAMBDA_HANDLER = handler
130
131
132def set_mappings(mappings):  # pragma: no cover
133    global MAPPINGS
134    MAPPINGS = mappings
135
136
137class HandlerClass(BaseHTTPRequestHandler):  # pragma: unit
138    def _set_response(self, response):  # pragma: no cover
139        self.send_response(response["statusCode"])
140        self.send_header('Content-type', 'application/json')
141        for header in response["headers"]:
142            self.send_header(header, response["headers"][header])
143        self.end_headers()
144
145    def aws(self, method, body=None):  # pragma: no cover
146        response = self._aws(self.path, self.headers, LAMBDA_HANDLER, MAPPINGS, method, body)
147        self._set_response(response)
148        self.wfile.write(response["body"].encode(encoding='utf_8'))
149
150    @classmethod
151    def _aws(cls,
152             path: str,
153             headers: dict,
154             lambda_handler: callable,
155             mappings: dict,
156             method: str,
157             body: dict = None):  # pragma: unit
158        path = unquote(path)
159        is_kafka_event = False
160        try:
161            is_kafka_event = headers['eventSource'] == 'aws:kafka'
162        except KeyError:
163            pass
164        if is_kafka_event:
165            log.debug("Request to be handled as Kafka event")
166        else:
167            log.debug("Request to be handled as APIGW event")
168
169        url_to_validate = f"http://localhost:5050{path.split('?')[0]}"
170        if not validators.url(url_to_validate, simple_host=True):
171            log.error(f"Invalid URL: {url_to_validate}")
172            resp = {
173                'statusCode': 400,
174                'body': json.dumps({
175                    "error": f"URL validation failed: {url_to_validate}"
176                }),
177                'headers': {
178                    "Access-Control-Allow-Origin": "*",
179                    "Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
180                    "Access-Control-Allow-Headers": "*"
181                },
182                "isBase64Encoded": False
183            }
184        elif is_kafka_event:
185            log.debug("Request to be handled as Kafka event")
186            event = to_msk_request(str(path), body, mappings)
187            resp = lambda_handler(event)
188        else:
189            log.debug("Request to be handled as APIGW event")
190            event = to_api_gateway_request(str(path), method, headers, body, mappings)
191            resp = lambda_handler(event)
192        return resp
193
194    def process_with_body(self, method):  # pragma: no cover
195        content_length = int(self.headers['Content-Length'])  # <--- Gets the size of data
196        data = self.rfile.read(content_length)  # <--- Gets the data itself
197        log.info(method + " request,\nPath: %s\nHeaders:\n%s\n\nBody:\n%s\n",
198                 str(self.path), str(self.headers), data.decode('utf-8'))
199        self.aws(method, data.decode('utf-8'))
200
201    def do_GET(self):  # pragma: no cover
202        log.info("\n  GET request\n  Path: %s \n", str(self.path))
203        self.aws("GET")
204
205    def do_POST(self):  # pragma: no cover
206        log.info("\n  POST request\n  Path: %s \n", str(self.path))
207        self.process_with_body("POST")
208
209    def do_PUT(self):  # pragma: no cover
210        log.info("\n  PUT request\n  Path: %s \n", str(self.path))
211        self.process_with_body("PUT")
212
213    def do_DELETE(self):  # pragma: no cover
214        log.info("\n  DELETE request\n  Path: %s \n", str(self.path))
215        self.aws("DELETE")
216
217    def do_OPTIONS(self):  # pragma: no cover
218        log.info("\n  OPTIONS request\n  Path: %s \n", str(self.path))
219        self.send_response(200, "ok")
220        self.send_header('Access-Control-Allow-Origin', '*')
221        self.send_header('Access-Control-Allow-Methods', 'GET, POST, PUT, DELETE, OPTIONS')
222        self.send_header("Access-Control-Allow-Headers", "*")
223        self.end_headers()
224
225
226# Websocket methods
227def handle_routes(event, context):  # pragma: unit
228    routes = {
229        "$connect": route_connect_handler,
230        "$disconnect": route_disconnect_handler,
231        "ping": route_ping_handler,
232        "send-message": route_send_message_handler,
233        "subscribe": route_subscribe_handler,
234        "unsubscribe": route_unsubscribe_handler,
235        "authenticate": route_authenticate_handler,
236    }
237
238    route = event.get("body").get("action", "")
239    event["body"] = json.dumps(event.get("body"))
240
241    if route in routes:
242        return routes[route](event, context)
243    return route_default_handler(event, context)
244
245
246async def handler(websocket):  # pragma: unit
247    connection_id = str(uuid.uuid4())
248    api_gateway_client = get_websocket_api_gateway_client()
249    api_gateway_client.set_loop(loop=asyncio.get_running_loop())
250    api_gateway_client.add_websocket_connection(connection_id, websocket)
251
252    headers = dict(websocket.request.headers)
253    if "authorization" in headers:
254        headers[ValidHeaders.AUTHORIZATION.value] = websocket.request.headers["authorization"]
255        headers.pop("authorization")
256
257    connect_event = {
258        "requestContext": {
259            "connectionId": connection_id,
260            "eventType": "CONNECT",
261            "domainName": "localhost",
262            "stage": "local",
263            "identity": {"sourceIp": "127.0.0.1"},
264        },
265        "headers": headers,
266        "body": {"action": "$connect"},
267    }
268
269    result = handle_routes(connect_event, None)
270    if result.get("statusCode") != 200:
271        return
272
273    print(f"client connected to {connection_id}")
274
275    try:
276        async for message in websocket:  # pragma: no cover
277            try:
278                data = json.loads(message)
279            except json.JSONDecodeError:
280                data = {}
281
282            event = {
283                "requestContext": {
284                    "connectionId": connection_id,
285                    "identity": {"sourceIp": "127.0.0.1"},
286                },
287                "body": data,
288            }
289            handle_routes(event, None)
290    finally:
291        print(f"Closing connection: {connection_id}")
292        api_gateway_client.remove_websocket_connection(connection_id)
293        disconnect_event = {
294            "requestContext": {
295                "connectionId": connection_id,
296                "eventType": "CONNECT",
297                "domainName": "localhost",
298                "stage": "local",
299                "identity": {"sourceIp": "127.0.0.1"},
300            },
301            "body": {"action": "$disconnect"},
302        }
303        handle_routes(disconnect_event, None)
304
305
306async def run_websocket_server():  # pragma: no cover
307    async with websockets.serve(handler, "0.0.0.0", 5051):
308        print("WebSocket server listening on ws://localhost:5051")
309        await asyncio.Future()
log = <MyLogger BEDROCK-local_runner (INFO)>
LAMBDA_HANDLER = None
MAPPINGS = None
def split_query_params_in_path(path: str) -> map:
33def split_query_params_in_path(path: str) -> map:  # pragma: unit
34    if not path:
35        raise ValueError('path cannot be None')
36    return map(lambda x: (x.split("=")[0], unquote(x.split("=")[1])), path.split("?")[1].split("&"))
def to_multi_value_params(path: str) -> dict:
39def to_multi_value_params(path: str) -> dict:  # pragma: unit
40    try:
41        d = {}
42        tuples = split_query_params_in_path(path)
43        for pair in tuples:
44            if pair[0] not in d:
45                d[pair[0]] = []
46            d[pair[0]] = d[pair[0]] + [unquote(pair[1])]
47        return d
48    except:  # pragma: no cover
49        return None
def to_single_value_params(path: str) -> dict:
52def to_single_value_params(path: str) -> dict:  # pragma: unit
53    try:
54        return dict(split_query_params_in_path(path))
55    except:  # pragma: no cover
56        return None
def path_matches_mapping(path: str, mapping: str):
59def path_matches_mapping(path: str, mapping: str):  # pragma: unit
60    _mapping = f'/{mapping.strip("/")}'
61    matcher_expression = "^" + re.sub(r"\{\w*\}", "([^/]+)", _mapping) + "/?$"
62    just_path = path.split("?")[0]
63    return re.match(matcher_expression, just_path) is not None
def get_path_params(path: str, mapping: str):
66def get_path_params(path: str, mapping: str):  # pragma: unit
67    _mapping = f'/{mapping.strip("/")}'
68    matcher_expression = "^" + re.sub(r"\{\w*\}", "([^/]+)", _mapping) + "/?$"
69    just_path = path.split("?")[0]
70    value_match = re.match(matcher_expression, just_path)
71    key_match = re.match(matcher_expression, _mapping)
72    return dict(((re.sub(r"[\{\}]", "", key), value_match.groups()[i])) for i, key in enumerate(key_match.groups()))
def path_to_path_info(path: str, mappings: dict):
75def path_to_path_info(path: str, mappings: dict):  # pragma: unit
76    mapping_paths = mappings.items() if mappings else []
77    for mapping, endpoint in mapping_paths:
78        if path_matches_mapping(path, mapping):
79            return {
80                "resource": mapping,
81                "params": get_path_params(path, mapping),
82                "path": path
83            }
84    return {
85        "resource": "/??????",
86        "params": {},
87        "path": path
88    }
def to_api_gateway_request( path: str, method: str, headers: dict, body: dict = None, mappings: dict = None):
 91def to_api_gateway_request(path: str,
 92                           method: str,
 93                           headers: dict,
 94                           body: dict = None,
 95                           mappings: dict = None):  # pragma: unit
 96    path_info = path_to_path_info(path, mappings)
 97    return {
 98        "resource": path_info["resource"],
 99        "path": path,
100        "httpMethod": method,
101        "queryStringParameters": to_single_value_params(path),
102        "multiValueQueryStringParameters": to_multi_value_params(path),
103        "pathParameters": path_info["params"],
104        "stageVariables": None,
105        "body": body,
106        "isBase64Encoded": False,
107        "headers": headers
108    }
def to_msk_request(path, body: str = {}, mappings: dict = None):
111def to_msk_request(path, body: str or dict = {}, mappings: dict = None):  # pragma: unit
112    path_info = path_to_path_info(path, mappings)
113    _body = body if isinstance(body, dict) else loads(body)
114    if _body != {}:
115        for partition in _body:
116            for event in _body[partition]:
117                event['value'] = b64encode(dumps(event['value']).encode('ascii')).decode('ascii')
118
119    return {
120        "resource": path_info["resource"],
121        "eventSource": "aws:kafka",
122        "eventSourceArn": "arn:aws:kafka:eu-west-1:888888888888:cluster/kafka-cluster-test/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee-1",
123        "bootstrapServers": "b-2.kafkaclustertest.aaaaaa.bb.kafka.eu-west-1.amazonaws.com:9094,b-3.kafkaclustertest.cccccc.dd.kafka.eu-west-1.amazonaws.com:9094,b-1.kafkaclustertest.eeeeee.ff.kafka.eu-west-1.amazonaws.com:9094",
124        "records": _body
125    }
def set_lambda_handler(handler):
128def set_lambda_handler(handler):  # pragma: no cover
129    global LAMBDA_HANDLER
130    LAMBDA_HANDLER = handler
def set_mappings(mappings):
133def set_mappings(mappings):  # pragma: no cover
134    global MAPPINGS
135    MAPPINGS = mappings
class HandlerClass(http.server.BaseHTTPRequestHandler):
138class HandlerClass(BaseHTTPRequestHandler):  # pragma: unit
139    def _set_response(self, response):  # pragma: no cover
140        self.send_response(response["statusCode"])
141        self.send_header('Content-type', 'application/json')
142        for header in response["headers"]:
143            self.send_header(header, response["headers"][header])
144        self.end_headers()
145
146    def aws(self, method, body=None):  # pragma: no cover
147        response = self._aws(self.path, self.headers, LAMBDA_HANDLER, MAPPINGS, method, body)
148        self._set_response(response)
149        self.wfile.write(response["body"].encode(encoding='utf_8'))
150
151    @classmethod
152    def _aws(cls,
153             path: str,
154             headers: dict,
155             lambda_handler: callable,
156             mappings: dict,
157             method: str,
158             body: dict = None):  # pragma: unit
159        path = unquote(path)
160        is_kafka_event = False
161        try:
162            is_kafka_event = headers['eventSource'] == 'aws:kafka'
163        except KeyError:
164            pass
165        if is_kafka_event:
166            log.debug("Request to be handled as Kafka event")
167        else:
168            log.debug("Request to be handled as APIGW event")
169
170        url_to_validate = f"http://localhost:5050{path.split('?')[0]}"
171        if not validators.url(url_to_validate, simple_host=True):
172            log.error(f"Invalid URL: {url_to_validate}")
173            resp = {
174                'statusCode': 400,
175                'body': json.dumps({
176                    "error": f"URL validation failed: {url_to_validate}"
177                }),
178                'headers': {
179                    "Access-Control-Allow-Origin": "*",
180                    "Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
181                    "Access-Control-Allow-Headers": "*"
182                },
183                "isBase64Encoded": False
184            }
185        elif is_kafka_event:
186            log.debug("Request to be handled as Kafka event")
187            event = to_msk_request(str(path), body, mappings)
188            resp = lambda_handler(event)
189        else:
190            log.debug("Request to be handled as APIGW event")
191            event = to_api_gateway_request(str(path), method, headers, body, mappings)
192            resp = lambda_handler(event)
193        return resp
194
195    def process_with_body(self, method):  # pragma: no cover
196        content_length = int(self.headers['Content-Length'])  # <--- Gets the size of data
197        data = self.rfile.read(content_length)  # <--- Gets the data itself
198        log.info(method + " request,\nPath: %s\nHeaders:\n%s\n\nBody:\n%s\n",
199                 str(self.path), str(self.headers), data.decode('utf-8'))
200        self.aws(method, data.decode('utf-8'))
201
202    def do_GET(self):  # pragma: no cover
203        log.info("\n  GET request\n  Path: %s \n", str(self.path))
204        self.aws("GET")
205
206    def do_POST(self):  # pragma: no cover
207        log.info("\n  POST request\n  Path: %s \n", str(self.path))
208        self.process_with_body("POST")
209
210    def do_PUT(self):  # pragma: no cover
211        log.info("\n  PUT request\n  Path: %s \n", str(self.path))
212        self.process_with_body("PUT")
213
214    def do_DELETE(self):  # pragma: no cover
215        log.info("\n  DELETE request\n  Path: %s \n", str(self.path))
216        self.aws("DELETE")
217
218    def do_OPTIONS(self):  # pragma: no cover
219        log.info("\n  OPTIONS request\n  Path: %s \n", str(self.path))
220        self.send_response(200, "ok")
221        self.send_header('Access-Control-Allow-Origin', '*')
222        self.send_header('Access-Control-Allow-Methods', 'GET, POST, PUT, DELETE, OPTIONS')
223        self.send_header("Access-Control-Allow-Headers", "*")
224        self.end_headers()

HTTP request handler base class.

The following explanation of HTTP serves to guide you through the code as well as to expose any misunderstandings I may have about HTTP (so you don't need to read the code to figure out I'm wrong :-).

HTTP (HyperText Transfer Protocol) is an extensible protocol on top of a reliable stream transport (e.g. TCP/IP). The protocol recognizes three parts to a request:

  1. One line identifying the request type and path
  2. An optional set of RFC-822-style headers
  3. An optional data part

The headers and data are separated by a blank line.

The first line of the request has the form

where is a (case-sensitive) keyword such as GET or POST, is a string containing path information for the request, and should be the string "HTTP/1.0" or "HTTP/1.1". is encoded using the URL encoding scheme (using %xx to signify the ASCII character with hex code xx).

The specification specifies that lines are separated by CRLF but for compatibility with the widest range of clients recommends servers also handle LF. Similarly, whitespace in the request line is treated sensibly (allowing multiple spaces between components and allowing trailing whitespace).

Similarly, for output, lines ought to be separated by CRLF pairs but most clients grok LF characters just fine.

If the first line of the request has the form

(i.e. is left out) then this is assumed to be an HTTP 0.9 request; this form has no optional headers and data part and the reply consists of just the data.

The reply form of the HTTP 1.x protocol again has three parts:

  1. One line giving the response code
  2. An optional set of RFC-822-style headers
  3. The data

Again, the headers and data are separated by a blank line.

The response code line has the form

where is the protocol version ("HTTP/1.0" or "HTTP/1.1"), is a 3-digit response code indicating success or failure of the request, and is an optional human-readable string explaining what the response code means.

This server parses the request and the headers, and then calls a function specific to the request type (). Specifically, a request SPAM will be handled by a method do_SPAM(). If no such method exists the server sends an error response to the client. If it exists, it is called with no arguments:

do_SPAM()

Note that the request name is case sensitive (i.e. SPAM and spam are different requests).

The various request details are stored in instance variables:

  • client_address is the client IP address in the form (host, port);

  • command, path and version are the broken-down request line;

  • headers is an instance of email.message.Message (or a derived class) containing the header information;

  • rfile is a file object open for reading positioned at the start of the optional input data part;

  • wfile is a file object open for writing.

IT IS IMPORTANT TO ADHERE TO THE PROTOCOL FOR WRITING!

The first thing to be written must be the response line. Then follow 0 or more header lines, then a blank line, and then the actual data (if any). The meaning of the header lines depends on the command executed by the server; in most cases, when data is returned, there should be at least one header line of the form

Content-type: /

where and should be registered MIME types, e.g. "text/html" or "text/plain".

def aws(self, method, body=None):
146    def aws(self, method, body=None):  # pragma: no cover
147        response = self._aws(self.path, self.headers, LAMBDA_HANDLER, MAPPINGS, method, body)
148        self._set_response(response)
149        self.wfile.write(response["body"].encode(encoding='utf_8'))
def process_with_body(self, method):
195    def process_with_body(self, method):  # pragma: no cover
196        content_length = int(self.headers['Content-Length'])  # <--- Gets the size of data
197        data = self.rfile.read(content_length)  # <--- Gets the data itself
198        log.info(method + " request,\nPath: %s\nHeaders:\n%s\n\nBody:\n%s\n",
199                 str(self.path), str(self.headers), data.decode('utf-8'))
200        self.aws(method, data.decode('utf-8'))
def do_GET(self):
202    def do_GET(self):  # pragma: no cover
203        log.info("\n  GET request\n  Path: %s \n", str(self.path))
204        self.aws("GET")
def do_POST(self):
206    def do_POST(self):  # pragma: no cover
207        log.info("\n  POST request\n  Path: %s \n", str(self.path))
208        self.process_with_body("POST")
def do_PUT(self):
210    def do_PUT(self):  # pragma: no cover
211        log.info("\n  PUT request\n  Path: %s \n", str(self.path))
212        self.process_with_body("PUT")
def do_DELETE(self):
214    def do_DELETE(self):  # pragma: no cover
215        log.info("\n  DELETE request\n  Path: %s \n", str(self.path))
216        self.aws("DELETE")
def do_OPTIONS(self):
218    def do_OPTIONS(self):  # pragma: no cover
219        log.info("\n  OPTIONS request\n  Path: %s \n", str(self.path))
220        self.send_response(200, "ok")
221        self.send_header('Access-Control-Allow-Origin', '*')
222        self.send_header('Access-Control-Allow-Methods', 'GET, POST, PUT, DELETE, OPTIONS')
223        self.send_header("Access-Control-Allow-Headers", "*")
224        self.end_headers()
def handle_routes(event, context):
228def handle_routes(event, context):  # pragma: unit
229    routes = {
230        "$connect": route_connect_handler,
231        "$disconnect": route_disconnect_handler,
232        "ping": route_ping_handler,
233        "send-message": route_send_message_handler,
234        "subscribe": route_subscribe_handler,
235        "unsubscribe": route_unsubscribe_handler,
236        "authenticate": route_authenticate_handler,
237    }
238
239    route = event.get("body").get("action", "")
240    event["body"] = json.dumps(event.get("body"))
241
242    if route in routes:
243        return routes[route](event, context)
244    return route_default_handler(event, context)
async def handler(websocket):
247async def handler(websocket):  # pragma: unit
248    connection_id = str(uuid.uuid4())
249    api_gateway_client = get_websocket_api_gateway_client()
250    api_gateway_client.set_loop(loop=asyncio.get_running_loop())
251    api_gateway_client.add_websocket_connection(connection_id, websocket)
252
253    headers = dict(websocket.request.headers)
254    if "authorization" in headers:
255        headers[ValidHeaders.AUTHORIZATION.value] = websocket.request.headers["authorization"]
256        headers.pop("authorization")
257
258    connect_event = {
259        "requestContext": {
260            "connectionId": connection_id,
261            "eventType": "CONNECT",
262            "domainName": "localhost",
263            "stage": "local",
264            "identity": {"sourceIp": "127.0.0.1"},
265        },
266        "headers": headers,
267        "body": {"action": "$connect"},
268    }
269
270    result = handle_routes(connect_event, None)
271    if result.get("statusCode") != 200:
272        return
273
274    print(f"client connected to {connection_id}")
275
276    try:
277        async for message in websocket:  # pragma: no cover
278            try:
279                data = json.loads(message)
280            except json.JSONDecodeError:
281                data = {}
282
283            event = {
284                "requestContext": {
285                    "connectionId": connection_id,
286                    "identity": {"sourceIp": "127.0.0.1"},
287                },
288                "body": data,
289            }
290            handle_routes(event, None)
291    finally:
292        print(f"Closing connection: {connection_id}")
293        api_gateway_client.remove_websocket_connection(connection_id)
294        disconnect_event = {
295            "requestContext": {
296                "connectionId": connection_id,
297                "eventType": "CONNECT",
298                "domainName": "localhost",
299                "stage": "local",
300                "identity": {"sourceIp": "127.0.0.1"},
301            },
302            "body": {"action": "$disconnect"},
303        }
304        handle_routes(disconnect_event, None)
async def run_websocket_server():
307async def run_websocket_server():  # pragma: no cover
308    async with websockets.serve(handler, "0.0.0.0", 5051):
309        print("WebSocket server listening on ws://localhost:5051")
310        await asyncio.Future()