3 Star 0 Fork 0

mirrors_LF-Engineering / serverless-wsgi

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
serverless_wsgi.py 10.63 KB
一键复制 编辑 原始数据 按行查看 历史
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
This module converts an AWS API Gateway proxied request to a WSGI request.
Inspired by: https://github.com/miserlou/zappa
Author: Logan Raarup <logan@logan.dk>
"""
import base64
import io
import json
import os
import sys
from werkzeug.datastructures import Headers, iter_multi_items, MultiDict
from werkzeug.wrappers import Response
from werkzeug.urls import url_encode, url_unquote, url_unquote_plus
from werkzeug.http import HTTP_STATUS_CODES
# List of MIME types that should not be base64 encoded. MIME types within `text/*`
# are included by default.
TEXT_MIME_TYPES = [
"application/json",
"application/javascript",
"application/xml",
"application/vnd.api+json",
"image/svg+xml",
]
def all_casings(input_string):
"""
Permute all casings of a given string.
A pretty algoritm, via @Amber
http://stackoverflow.com/questions/6792803/finding-all-possible-case-permutations-in-python
"""
if not input_string:
yield ""
else:
first = input_string[:1]
if first.lower() == first.upper():
for sub_casing in all_casings(input_string[1:]):
yield first + sub_casing
else:
for sub_casing in all_casings(input_string[1:]):
yield first.lower() + sub_casing
yield first.upper() + sub_casing
def split_headers(headers):
"""
If there are multiple occurrences of headers, create case-mutated variations
in order to pass them through APIGW. This is a hack that's currently
needed. See: https://github.com/logandk/serverless-wsgi/issues/11
Source: https://github.com/Miserlou/Zappa/blob/master/zappa/middleware.py
"""
new_headers = {}
for key in headers.keys():
values = headers.get_all(key)
if len(values) > 1:
for value, casing in zip(values, all_casings(key)):
new_headers[casing] = value
elif len(values) == 1:
new_headers[key] = values[0]
return new_headers
def group_headers(headers):
new_headers = {}
for key in headers.keys():
new_headers[key] = headers.get_all(key)
return new_headers
def is_alb_event(event):
return event.get("requestContext", {}).get("elb")
def encode_query_string(event):
params = event.get("multiValueQueryStringParameters")
if not params:
params = event.get("queryStringParameters")
if not params:
params = event.get("query")
if not params:
params = ""
if is_alb_event(event):
params = MultiDict(
(url_unquote_plus(k), url_unquote_plus(v))
for k, v in iter_multi_items(params)
)
return url_encode(params)
def get_script_name(headers, request_context):
strip_stage_path = os.environ.get("STRIP_STAGE_PATH", "").lower().strip() in [
"yes",
"y",
"true",
"t",
"1",
]
if "amazonaws.com" in headers.get("Host", "") and not strip_stage_path:
script_name = "/{}".format(request_context.get("stage", ""))
else:
script_name = ""
return script_name
def get_body_bytes(event, body):
if event.get("isBase64Encoded", False):
body = base64.b64decode(body)
if isinstance(body, str):
body = body.encode("utf-8")
return body
def setup_environ_items(environ, headers):
for key, value in environ.items():
if isinstance(value, str):
environ[key] = value.encode("utf-8").decode("latin1", "replace")
for key, value in headers.items():
key = "HTTP_" + key.upper().replace("-", "_")
if key not in ("HTTP_CONTENT_TYPE", "HTTP_CONTENT_LENGTH"):
environ[key] = value
return environ
def generate_response(response, event):
returndict = {"statusCode": response.status_code}
if "multiValueHeaders" in event:
returndict["multiValueHeaders"] = group_headers(response.headers)
else:
returndict["headers"] = split_headers(response.headers)
if is_alb_event(event):
# If the request comes from ALB we need to add a status description
returndict["statusDescription"] = "%d %s" % (
response.status_code,
HTTP_STATUS_CODES[response.status_code],
)
if response.data:
mimetype = response.mimetype or "text/plain"
if (
mimetype.startswith("text/") or mimetype in TEXT_MIME_TYPES
) and not response.headers.get("Content-Encoding", ""):
returndict["body"] = response.get_data(as_text=True)
returndict["isBase64Encoded"] = False
else:
returndict["body"] = base64.b64encode(response.data).decode("utf-8")
returndict["isBase64Encoded"] = True
return returndict
def handle_request(app, event, context):
if event.get("source") in ["aws.events", "serverless-plugin-warmup"]:
print("Lambda warming event received, skipping handler")
return {}
if (
event.get("version") is None
and event.get("isBase64Encoded") is None
and not is_alb_event(event)
):
return handle_lambda_integration(app, event, context)
if event.get("version") == "2.0":
return handle_payload_v2(app, event, context)
return handle_payload_v1(app, event, context)
def handle_payload_v1(app, event, context):
if "multiValueHeaders" in event:
headers = Headers(event["multiValueHeaders"])
else:
headers = Headers(event["headers"])
script_name = get_script_name(headers, event.get("requestContext", {}))
# If a user is using a custom domain on API Gateway, they may have a base
# path in their URL. This allows us to strip it out via an optional
# environment variable.
path_info = event["path"]
base_path = os.environ.get("API_GATEWAY_BASE_PATH")
if base_path:
script_name = "/" + base_path
if path_info.startswith(script_name):
path_info = path_info[len(script_name) :]
body = event["body"] or ""
body = get_body_bytes(event, body)
environ = {
"CONTENT_LENGTH": str(len(body)),
"CONTENT_TYPE": headers.get("Content-Type", ""),
"PATH_INFO": url_unquote(path_info),
"QUERY_STRING": encode_query_string(event),
"REMOTE_ADDR": event.get("requestContext", {})
.get("identity", {})
.get("sourceIp", ""),
"REMOTE_USER": event.get("requestContext", {})
.get("authorizer", {})
.get("principalId", ""),
"REQUEST_METHOD": event.get("httpMethod", {}),
"SCRIPT_NAME": script_name,
"SERVER_NAME": headers.get("Host", "lambda"),
"SERVER_PORT": headers.get("X-Forwarded-Port", "80"),
"SERVER_PROTOCOL": "HTTP/1.1",
"wsgi.errors": sys.stderr,
"wsgi.input": io.BytesIO(body),
"wsgi.multiprocess": False,
"wsgi.multithread": False,
"wsgi.run_once": False,
"wsgi.url_scheme": headers.get("X-Forwarded-Proto", "http"),
"wsgi.version": (1, 0),
"serverless.authorizer": event.get("requestContext", {}).get("authorizer"),
"serverless.event": event,
"serverless.context": context,
}
environ = setup_environ_items(environ, headers)
response = Response.from_app(app, environ)
returndict = generate_response(response, event)
return returndict
def handle_payload_v2(app, event, context):
headers = Headers(event["headers"])
script_name = get_script_name(headers, event.get("requestContext", {}))
path_info = event["rawPath"]
body = event.get("body", "")
body = get_body_bytes(event, body)
headers["Cookie"] = "; ".join(event.get("cookies", []))
environ = {
"CONTENT_LENGTH": str(len(body)),
"CONTENT_TYPE": headers.get("Content-Type", ""),
"PATH_INFO": url_unquote(path_info),
"QUERY_STRING": event.get("rawQueryString", ""),
"REMOTE_ADDR": event.get("requestContext", {})
.get("http", {})
.get("sourceIp", ""),
"REMOTE_USER": event.get("requestContext", {})
.get("authorizer", {})
.get("principalId", ""),
"REQUEST_METHOD": event.get("requestContext", {})
.get("http", {})
.get("method", ""),
"SCRIPT_NAME": script_name,
"SERVER_NAME": headers.get("Host", "lambda"),
"SERVER_PORT": headers.get("X-Forwarded-Port", "80"),
"SERVER_PROTOCOL": "HTTP/1.1",
"wsgi.errors": sys.stderr,
"wsgi.input": io.BytesIO(body),
"wsgi.multiprocess": False,
"wsgi.multithread": False,
"wsgi.run_once": False,
"wsgi.url_scheme": headers.get("X-Forwarded-Proto", "http"),
"wsgi.version": (1, 0),
"serverless.authorizer": event.get("requestContext", {}).get("authorizer"),
"serverless.event": event,
"serverless.context": context,
}
environ = setup_environ_items(environ, headers)
response = Response.from_app(app, environ)
returndict = generate_response(response, event)
return returndict
def handle_lambda_integration(app, event, context):
headers = Headers(event["headers"])
script_name = get_script_name(headers, event)
path_info = event["requestPath"]
for key, value in event.get("path", {}).items():
path_info = path_info.replace("{%s}" % key, value)
path_info = path_info.replace("{%s+}" % key, value)
body = event.get("body", {})
body = json.dumps(body) if body else ""
body = get_body_bytes(event, body)
environ = {
"CONTENT_LENGTH": str(len(body)),
"CONTENT_TYPE": headers.get("Content-Type", ""),
"PATH_INFO": url_unquote(path_info),
"QUERY_STRING": url_encode(event.get("query", {})),
"REMOTE_ADDR": event.get("identity", {}).get("sourceIp", ""),
"REMOTE_USER": event.get("principalId", ""),
"REQUEST_METHOD": event.get("method", ""),
"SCRIPT_NAME": script_name,
"SERVER_NAME": headers.get("Host", "lambda"),
"SERVER_PORT": headers.get("X-Forwarded-Port", "80"),
"SERVER_PROTOCOL": "HTTP/1.1",
"wsgi.errors": sys.stderr,
"wsgi.input": io.BytesIO(body),
"wsgi.multiprocess": False,
"wsgi.multithread": False,
"wsgi.run_once": False,
"wsgi.url_scheme": headers.get("X-Forwarded-Proto", "http"),
"wsgi.version": (1, 0),
"serverless.authorizer": event.get("enhancedAuthContext"),
"serverless.event": event,
"serverless.context": context,
}
environ = setup_environ_items(environ, headers)
response = Response.from_app(app, environ)
returndict = generate_response(response, event)
if response.status_code >= 300:
raise RuntimeError(json.dumps(returndict))
return returndict
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/mirrors_LF-Engineering/serverless-wsgi.git
git@gitee.com:mirrors_LF-Engineering/serverless-wsgi.git
mirrors_LF-Engineering
serverless-wsgi
serverless-wsgi
master

搜索帮助

344bd9b3 5694891 D2dac590 5694891