bedrock.endpoints.local_runner
1import asyncio # pragma: unit 2import json # pragma: unit 3import re # pragma: unit 4import uuid # pragma: unit 5from base64 import b64encode # pragma: unit 6from http.server import BaseHTTPRequestHandler # pragma: unit 7from json import dumps, loads # pragma: unit 8from urllib.parse import unquote # pragma: unit 9 10import validators # pragma: unit 11import websockets # pragma: unit 12 13from bedrock.config.headers import ValidHeaders # pragma: unit 14from bedrock.log import log_config # pragma: unit 15from bedrock.websockets.routes import ( 16 route_connect_handler, 17 route_default_handler, 18 route_ping_handler, 19 route_disconnect_handler, 20 route_subscribe_handler, 21 route_unsubscribe_handler, 22 route_send_message_handler, 23 route_authenticate_handler, 24) # pragma: unit 25from bedrock.websockets.websocket_gateway_factory import get_websocket_api_gateway_client # pragma: unit 26 27log = log_config("local_runner") # pragma: unit 28LAMBDA_HANDLER = None # pragma: unit 29MAPPINGS = None # pragma: unit 30 31 32def split_query_params_in_path(path: str) -> map: # pragma: unit 33 if not path: 34 raise ValueError('path cannot be None') 35 return map(lambda x: (x.split("=")[0], unquote(x.split("=")[1])), path.split("?")[1].split("&")) 36 37 38def to_multi_value_params(path: str) -> dict: # pragma: unit 39 try: 40 d = {} 41 tuples = split_query_params_in_path(path) 42 for pair in tuples: 43 if pair[0] not in d: 44 d[pair[0]] = [] 45 d[pair[0]] = d[pair[0]] + [unquote(pair[1])] 46 return d 47 except: # pragma: no cover 48 return None 49 50 51def to_single_value_params(path: str) -> dict: # pragma: unit 52 try: 53 return dict(split_query_params_in_path(path)) 54 except: # pragma: no cover 55 return None 56 57 58def path_matches_mapping(path: str, mapping: str): # pragma: unit 59 _mapping = f'/{mapping.strip("/")}' 60 matcher_expression = "^" + re.sub(r"\{\w*\}", "([^/]+)", _mapping) + "/?$" 61 just_path = path.split("?")[0] 62 return re.match(matcher_expression, just_path) is not None 63 64 65def get_path_params(path: str, mapping: str): # pragma: unit 66 _mapping = f'/{mapping.strip("/")}' 67 matcher_expression = "^" + re.sub(r"\{\w*\}", "([^/]+)", _mapping) + "/?$" 68 just_path = path.split("?")[0] 69 value_match = re.match(matcher_expression, just_path) 70 key_match = re.match(matcher_expression, _mapping) 71 return dict(((re.sub(r"[\{\}]", "", key), value_match.groups()[i])) for i, key in enumerate(key_match.groups())) 72 73 74def path_to_path_info(path: str, mappings: dict): # pragma: unit 75 mapping_paths = mappings.items() if mappings else [] 76 for mapping, endpoint in mapping_paths: 77 if path_matches_mapping(path, mapping): 78 return { 79 "resource": mapping, 80 "params": get_path_params(path, mapping), 81 "path": path 82 } 83 return { 84 "resource": "/??????", 85 "params": {}, 86 "path": path 87 } 88 89 90def to_api_gateway_request(path: str, 91 method: str, 92 headers: dict, 93 body: dict = None, 94 mappings: dict = None): # pragma: unit 95 path_info = path_to_path_info(path, mappings) 96 return { 97 "resource": path_info["resource"], 98 "path": path, 99 "httpMethod": method, 100 "queryStringParameters": to_single_value_params(path), 101 "multiValueQueryStringParameters": to_multi_value_params(path), 102 "pathParameters": path_info["params"], 103 "stageVariables": None, 104 "body": body, 105 "isBase64Encoded": False, 106 "headers": headers 107 } 108 109 110def to_msk_request(path, body: str or dict = {}, mappings: dict = None): # pragma: unit 111 path_info = path_to_path_info(path, mappings) 112 _body = body if isinstance(body, dict) else loads(body) 113 if _body != {}: 114 for partition in _body: 115 for event in _body[partition]: 116 event['value'] = b64encode(dumps(event['value']).encode('ascii')).decode('ascii') 117 118 return { 119 "resource": path_info["resource"], 120 "eventSource": "aws:kafka", 121 "eventSourceArn": "arn:aws:kafka:eu-west-1:888888888888:cluster/kafka-cluster-test/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee-1", 122 "bootstrapServers": "b-2.kafkaclustertest.aaaaaa.bb.kafka.eu-west-1.amazonaws.com:9094,b-3.kafkaclustertest.cccccc.dd.kafka.eu-west-1.amazonaws.com:9094,b-1.kafkaclustertest.eeeeee.ff.kafka.eu-west-1.amazonaws.com:9094", 123 "records": _body 124 } 125 126 127def set_lambda_handler(handler): # pragma: no cover 128 global LAMBDA_HANDLER 129 LAMBDA_HANDLER = handler 130 131 132def set_mappings(mappings): # pragma: no cover 133 global MAPPINGS 134 MAPPINGS = mappings 135 136 137class HandlerClass(BaseHTTPRequestHandler): # pragma: unit 138 def _set_response(self, response): # pragma: no cover 139 self.send_response(response["statusCode"]) 140 self.send_header('Content-type', 'application/json') 141 for header in response["headers"]: 142 self.send_header(header, response["headers"][header]) 143 self.end_headers() 144 145 def aws(self, method, body=None): # pragma: no cover 146 response = self._aws(self.path, self.headers, LAMBDA_HANDLER, MAPPINGS, method, body) 147 self._set_response(response) 148 self.wfile.write(response["body"].encode(encoding='utf_8')) 149 150 @classmethod 151 def _aws(cls, 152 path: str, 153 headers: dict, 154 lambda_handler: callable, 155 mappings: dict, 156 method: str, 157 body: dict = None): # pragma: unit 158 path = unquote(path) 159 is_kafka_event = False 160 try: 161 is_kafka_event = headers['eventSource'] == 'aws:kafka' 162 except KeyError: 163 pass 164 if is_kafka_event: 165 log.debug("Request to be handled as Kafka event") 166 else: 167 log.debug("Request to be handled as APIGW event") 168 169 url_to_validate = f"http://localhost:5050{path.split('?')[0]}" 170 if not validators.url(url_to_validate, simple_host=True): 171 log.error(f"Invalid URL: {url_to_validate}") 172 resp = { 173 'statusCode': 400, 174 'body': json.dumps({ 175 "error": f"URL validation failed: {url_to_validate}" 176 }), 177 'headers': { 178 "Access-Control-Allow-Origin": "*", 179 "Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS", 180 "Access-Control-Allow-Headers": "*" 181 }, 182 "isBase64Encoded": False 183 } 184 elif is_kafka_event: 185 log.debug("Request to be handled as Kafka event") 186 event = to_msk_request(str(path), body, mappings) 187 resp = lambda_handler(event) 188 else: 189 log.debug("Request to be handled as APIGW event") 190 event = to_api_gateway_request(str(path), method, headers, body, mappings) 191 resp = lambda_handler(event) 192 return resp 193 194 def process_with_body(self, method): # pragma: no cover 195 content_length = int(self.headers['Content-Length']) # <--- Gets the size of data 196 data = self.rfile.read(content_length) # <--- Gets the data itself 197 log.info(method + " request,\nPath: %s\nHeaders:\n%s\n\nBody:\n%s\n", 198 str(self.path), str(self.headers), data.decode('utf-8')) 199 self.aws(method, data.decode('utf-8')) 200 201 def do_GET(self): # pragma: no cover 202 log.info("\n GET request\n Path: %s \n", str(self.path)) 203 self.aws("GET") 204 205 def do_POST(self): # pragma: no cover 206 log.info("\n POST request\n Path: %s \n", str(self.path)) 207 self.process_with_body("POST") 208 209 def do_PUT(self): # pragma: no cover 210 log.info("\n PUT request\n Path: %s \n", str(self.path)) 211 self.process_with_body("PUT") 212 213 def do_DELETE(self): # pragma: no cover 214 log.info("\n DELETE request\n Path: %s \n", str(self.path)) 215 self.aws("DELETE") 216 217 def do_OPTIONS(self): # pragma: no cover 218 log.info("\n OPTIONS request\n Path: %s \n", str(self.path)) 219 self.send_response(200, "ok") 220 self.send_header('Access-Control-Allow-Origin', '*') 221 self.send_header('Access-Control-Allow-Methods', 'GET, POST, PUT, DELETE, OPTIONS') 222 self.send_header("Access-Control-Allow-Headers", "*") 223 self.end_headers() 224 225 226# Websocket methods 227def handle_routes(event, context): # pragma: unit 228 routes = { 229 "$connect": route_connect_handler, 230 "$disconnect": route_disconnect_handler, 231 "ping": route_ping_handler, 232 "send-message": route_send_message_handler, 233 "subscribe": route_subscribe_handler, 234 "unsubscribe": route_unsubscribe_handler, 235 "authenticate": route_authenticate_handler, 236 } 237 238 route = event.get("body").get("action", "") 239 event["body"] = json.dumps(event.get("body")) 240 241 if route in routes: 242 return routes[route](event, context) 243 return route_default_handler(event, context) 244 245 246async def handler(websocket): # pragma: unit 247 connection_id = str(uuid.uuid4()) 248 api_gateway_client = get_websocket_api_gateway_client() 249 api_gateway_client.set_loop(loop=asyncio.get_running_loop()) 250 api_gateway_client.add_websocket_connection(connection_id, websocket) 251 252 headers = dict(websocket.request.headers) 253 if "authorization" in headers: 254 headers[ValidHeaders.AUTHORIZATION.value] = websocket.request.headers["authorization"] 255 headers.pop("authorization") 256 257 connect_event = { 258 "requestContext": { 259 "connectionId": connection_id, 260 "eventType": "CONNECT", 261 "domainName": "localhost", 262 "stage": "local", 263 "identity": {"sourceIp": "127.0.0.1"}, 264 }, 265 "headers": headers, 266 "body": {"action": "$connect"}, 267 } 268 269 result = handle_routes(connect_event, None) 270 if result.get("statusCode") != 200: 271 return 272 273 print(f"client connected to {connection_id}") 274 275 try: 276 async for message in websocket: # pragma: no cover 277 try: 278 data = json.loads(message) 279 except json.JSONDecodeError: 280 data = {} 281 282 event = { 283 "requestContext": { 284 "connectionId": connection_id, 285 "identity": {"sourceIp": "127.0.0.1"}, 286 }, 287 "body": data, 288 } 289 handle_routes(event, None) 290 finally: 291 print(f"Closing connection: {connection_id}") 292 api_gateway_client.remove_websocket_connection(connection_id) 293 disconnect_event = { 294 "requestContext": { 295 "connectionId": connection_id, 296 "eventType": "CONNECT", 297 "domainName": "localhost", 298 "stage": "local", 299 "identity": {"sourceIp": "127.0.0.1"}, 300 }, 301 "body": {"action": "$disconnect"}, 302 } 303 handle_routes(disconnect_event, None) 304 305 306async def run_websocket_server(): # pragma: no cover 307 async with websockets.serve(handler, "0.0.0.0", 5051): 308 print("WebSocket server listening on ws://localhost:5051") 309 await asyncio.Future()
66def get_path_params(path: str, mapping: str): # pragma: unit 67 _mapping = f'/{mapping.strip("/")}' 68 matcher_expression = "^" + re.sub(r"\{\w*\}", "([^/]+)", _mapping) + "/?$" 69 just_path = path.split("?")[0] 70 value_match = re.match(matcher_expression, just_path) 71 key_match = re.match(matcher_expression, _mapping) 72 return dict(((re.sub(r"[\{\}]", "", key), value_match.groups()[i])) for i, key in enumerate(key_match.groups()))
75def path_to_path_info(path: str, mappings: dict): # pragma: unit 76 mapping_paths = mappings.items() if mappings else [] 77 for mapping, endpoint in mapping_paths: 78 if path_matches_mapping(path, mapping): 79 return { 80 "resource": mapping, 81 "params": get_path_params(path, mapping), 82 "path": path 83 } 84 return { 85 "resource": "/??????", 86 "params": {}, 87 "path": path 88 }
91def to_api_gateway_request(path: str, 92 method: str, 93 headers: dict, 94 body: dict = None, 95 mappings: dict = None): # pragma: unit 96 path_info = path_to_path_info(path, mappings) 97 return { 98 "resource": path_info["resource"], 99 "path": path, 100 "httpMethod": method, 101 "queryStringParameters": to_single_value_params(path), 102 "multiValueQueryStringParameters": to_multi_value_params(path), 103 "pathParameters": path_info["params"], 104 "stageVariables": None, 105 "body": body, 106 "isBase64Encoded": False, 107 "headers": headers 108 }
111def to_msk_request(path, body: str or dict = {}, mappings: dict = None): # pragma: unit 112 path_info = path_to_path_info(path, mappings) 113 _body = body if isinstance(body, dict) else loads(body) 114 if _body != {}: 115 for partition in _body: 116 for event in _body[partition]: 117 event['value'] = b64encode(dumps(event['value']).encode('ascii')).decode('ascii') 118 119 return { 120 "resource": path_info["resource"], 121 "eventSource": "aws:kafka", 122 "eventSourceArn": "arn:aws:kafka:eu-west-1:888888888888:cluster/kafka-cluster-test/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee-1", 123 "bootstrapServers": "b-2.kafkaclustertest.aaaaaa.bb.kafka.eu-west-1.amazonaws.com:9094,b-3.kafkaclustertest.cccccc.dd.kafka.eu-west-1.amazonaws.com:9094,b-1.kafkaclustertest.eeeeee.ff.kafka.eu-west-1.amazonaws.com:9094", 124 "records": _body 125 }
138class HandlerClass(BaseHTTPRequestHandler): # pragma: unit 139 def _set_response(self, response): # pragma: no cover 140 self.send_response(response["statusCode"]) 141 self.send_header('Content-type', 'application/json') 142 for header in response["headers"]: 143 self.send_header(header, response["headers"][header]) 144 self.end_headers() 145 146 def aws(self, method, body=None): # pragma: no cover 147 response = self._aws(self.path, self.headers, LAMBDA_HANDLER, MAPPINGS, method, body) 148 self._set_response(response) 149 self.wfile.write(response["body"].encode(encoding='utf_8')) 150 151 @classmethod 152 def _aws(cls, 153 path: str, 154 headers: dict, 155 lambda_handler: callable, 156 mappings: dict, 157 method: str, 158 body: dict = None): # pragma: unit 159 path = unquote(path) 160 is_kafka_event = False 161 try: 162 is_kafka_event = headers['eventSource'] == 'aws:kafka' 163 except KeyError: 164 pass 165 if is_kafka_event: 166 log.debug("Request to be handled as Kafka event") 167 else: 168 log.debug("Request to be handled as APIGW event") 169 170 url_to_validate = f"http://localhost:5050{path.split('?')[0]}" 171 if not validators.url(url_to_validate, simple_host=True): 172 log.error(f"Invalid URL: {url_to_validate}") 173 resp = { 174 'statusCode': 400, 175 'body': json.dumps({ 176 "error": f"URL validation failed: {url_to_validate}" 177 }), 178 'headers': { 179 "Access-Control-Allow-Origin": "*", 180 "Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS", 181 "Access-Control-Allow-Headers": "*" 182 }, 183 "isBase64Encoded": False 184 } 185 elif is_kafka_event: 186 log.debug("Request to be handled as Kafka event") 187 event = to_msk_request(str(path), body, mappings) 188 resp = lambda_handler(event) 189 else: 190 log.debug("Request to be handled as APIGW event") 191 event = to_api_gateway_request(str(path), method, headers, body, mappings) 192 resp = lambda_handler(event) 193 return resp 194 195 def process_with_body(self, method): # pragma: no cover 196 content_length = int(self.headers['Content-Length']) # <--- Gets the size of data 197 data = self.rfile.read(content_length) # <--- Gets the data itself 198 log.info(method + " request,\nPath: %s\nHeaders:\n%s\n\nBody:\n%s\n", 199 str(self.path), str(self.headers), data.decode('utf-8')) 200 self.aws(method, data.decode('utf-8')) 201 202 def do_GET(self): # pragma: no cover 203 log.info("\n GET request\n Path: %s \n", str(self.path)) 204 self.aws("GET") 205 206 def do_POST(self): # pragma: no cover 207 log.info("\n POST request\n Path: %s \n", str(self.path)) 208 self.process_with_body("POST") 209 210 def do_PUT(self): # pragma: no cover 211 log.info("\n PUT request\n Path: %s \n", str(self.path)) 212 self.process_with_body("PUT") 213 214 def do_DELETE(self): # pragma: no cover 215 log.info("\n DELETE request\n Path: %s \n", str(self.path)) 216 self.aws("DELETE") 217 218 def do_OPTIONS(self): # pragma: no cover 219 log.info("\n OPTIONS request\n Path: %s \n", str(self.path)) 220 self.send_response(200, "ok") 221 self.send_header('Access-Control-Allow-Origin', '*') 222 self.send_header('Access-Control-Allow-Methods', 'GET, POST, PUT, DELETE, OPTIONS') 223 self.send_header("Access-Control-Allow-Headers", "*") 224 self.end_headers()
HTTP request handler base class.
The following explanation of HTTP serves to guide you through the code as well as to expose any misunderstandings I may have about HTTP (so you don't need to read the code to figure out I'm wrong :-).
HTTP (HyperText Transfer Protocol) is an extensible protocol on top of a reliable stream transport (e.g. TCP/IP). The protocol recognizes three parts to a request:
- One line identifying the request type and path
- An optional set of RFC-822-style headers
- An optional data part
The headers and data are separated by a blank line.
The first line of the request has the form
where
The specification specifies that lines are separated by CRLF but for compatibility with the widest range of clients recommends servers also handle LF. Similarly, whitespace in the request line is treated sensibly (allowing multiple spaces between components and allowing trailing whitespace).
Similarly, for output, lines ought to be separated by CRLF pairs but most clients grok LF characters just fine.
If the first line of the request has the form
(i.e.
The reply form of the HTTP 1.x protocol again has three parts:
- One line giving the response code
- An optional set of RFC-822-style headers
- The data
Again, the headers and data are separated by a blank line.
The response code line has the form
where
This server parses the request and the headers, and then calls a
function specific to the request type (
do_SPAM()
Note that the request name is case sensitive (i.e. SPAM and spam are different requests).
The various request details are stored in instance variables:
client_address is the client IP address in the form (host, port);
command, path and version are the broken-down request line;
headers is an instance of email.message.Message (or a derived class) containing the header information;
rfile is a file object open for reading positioned at the start of the optional input data part;
wfile is a file object open for writing.
IT IS IMPORTANT TO ADHERE TO THE PROTOCOL FOR WRITING!
The first thing to be written must be the response line. Then follow 0 or more header lines, then a blank line, and then the actual data (if any). The meaning of the header lines depends on the command executed by the server; in most cases, when data is returned, there should be at least one header line of the form
Content-type:
where
195 def process_with_body(self, method): # pragma: no cover 196 content_length = int(self.headers['Content-Length']) # <--- Gets the size of data 197 data = self.rfile.read(content_length) # <--- Gets the data itself 198 log.info(method + " request,\nPath: %s\nHeaders:\n%s\n\nBody:\n%s\n", 199 str(self.path), str(self.headers), data.decode('utf-8')) 200 self.aws(method, data.decode('utf-8'))
218 def do_OPTIONS(self): # pragma: no cover 219 log.info("\n OPTIONS request\n Path: %s \n", str(self.path)) 220 self.send_response(200, "ok") 221 self.send_header('Access-Control-Allow-Origin', '*') 222 self.send_header('Access-Control-Allow-Methods', 'GET, POST, PUT, DELETE, OPTIONS') 223 self.send_header("Access-Control-Allow-Headers", "*") 224 self.end_headers()
228def handle_routes(event, context): # pragma: unit 229 routes = { 230 "$connect": route_connect_handler, 231 "$disconnect": route_disconnect_handler, 232 "ping": route_ping_handler, 233 "send-message": route_send_message_handler, 234 "subscribe": route_subscribe_handler, 235 "unsubscribe": route_unsubscribe_handler, 236 "authenticate": route_authenticate_handler, 237 } 238 239 route = event.get("body").get("action", "") 240 event["body"] = json.dumps(event.get("body")) 241 242 if route in routes: 243 return routes[route](event, context) 244 return route_default_handler(event, context)
247async def handler(websocket): # pragma: unit 248 connection_id = str(uuid.uuid4()) 249 api_gateway_client = get_websocket_api_gateway_client() 250 api_gateway_client.set_loop(loop=asyncio.get_running_loop()) 251 api_gateway_client.add_websocket_connection(connection_id, websocket) 252 253 headers = dict(websocket.request.headers) 254 if "authorization" in headers: 255 headers[ValidHeaders.AUTHORIZATION.value] = websocket.request.headers["authorization"] 256 headers.pop("authorization") 257 258 connect_event = { 259 "requestContext": { 260 "connectionId": connection_id, 261 "eventType": "CONNECT", 262 "domainName": "localhost", 263 "stage": "local", 264 "identity": {"sourceIp": "127.0.0.1"}, 265 }, 266 "headers": headers, 267 "body": {"action": "$connect"}, 268 } 269 270 result = handle_routes(connect_event, None) 271 if result.get("statusCode") != 200: 272 return 273 274 print(f"client connected to {connection_id}") 275 276 try: 277 async for message in websocket: # pragma: no cover 278 try: 279 data = json.loads(message) 280 except json.JSONDecodeError: 281 data = {} 282 283 event = { 284 "requestContext": { 285 "connectionId": connection_id, 286 "identity": {"sourceIp": "127.0.0.1"}, 287 }, 288 "body": data, 289 } 290 handle_routes(event, None) 291 finally: 292 print(f"Closing connection: {connection_id}") 293 api_gateway_client.remove_websocket_connection(connection_id) 294 disconnect_event = { 295 "requestContext": { 296 "connectionId": connection_id, 297 "eventType": "CONNECT", 298 "domainName": "localhost", 299 "stage": "local", 300 "identity": {"sourceIp": "127.0.0.1"}, 301 }, 302 "body": {"action": "$disconnect"}, 303 } 304 handle_routes(disconnect_event, None)