Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Private Network Access support in CORSMiddleware #2621

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/middleware.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ The following arguments are supported:
* `allow_methods` - A list of HTTP methods that should be allowed for cross-origin requests. Defaults to `['GET']`. You can use `['*']` to allow all standard methods.
* `allow_headers` - A list of HTTP request headers that should be supported for cross-origin requests. Defaults to `[]`. You can use `['*']` to allow all headers. The `Accept`, `Accept-Language`, `Content-Language` and `Content-Type` headers are always allowed for CORS requests.
* `allow_credentials` - Indicate that cookies should be supported for cross-origin requests. Defaults to `False`. Also, `allow_origins`, `allow_methods` and `allow_headers` cannot be set to `['*']` for credentials to be allowed, all of them must be explicitly specified.
* `allow_private_network` - Indicate that direct access to localhost or private network endpoints from public websites allowed. It also allows access to localhost endpoints from private network. Defaults to `False`.
* `expose_headers` - Indicate any response headers that should be made accessible to the browser. Defaults to `[]`.
* `max_age` - Sets a maximum time in seconds for browsers to cache CORS responses. Defaults to `600`.

Expand All @@ -87,6 +88,12 @@ These are any `OPTIONS` request with `Origin` and `Access-Control-Request-Method
In this case the middleware will intercept the incoming request and respond with
appropriate CORS headers, and either a 200 or 400 response for informational purposes.

#### PNA preflight requests

These are any `OPTIONS` request with `Origin` and `Access-Control-Request-Private-Network` headers.
In this case the middleware will intercept the incoming request and respond with
appropriate CORS headers, and either a 200 or 400 response for informational purposes.

#### Simple requests

Any request with an `Origin` header. In this case the middleware will pass the
Expand Down
49 changes: 46 additions & 3 deletions starlette/middleware/cors.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(
allow_methods: typing.Sequence[str] = ("GET",),
allow_headers: typing.Sequence[str] = (),
allow_credentials: bool = False,
allow_private_network: bool = False,
allow_origin_regex: str | None = None,
expose_headers: typing.Sequence[str] = (),
max_age: int = 600,
Expand All @@ -40,12 +41,15 @@ def __init__(
simple_headers["Access-Control-Allow-Origin"] = "*"
if allow_credentials:
simple_headers["Access-Control-Allow-Credentials"] = "true"
if allow_private_network:
simple_headers["Access-Control-Allow-Private-Network"] = "true"
if expose_headers:
simple_headers["Access-Control-Expose-Headers"] = ", ".join(expose_headers)

preflight_headers = {}
if preflight_explicit_allow_origin:
# The origin value will be set in preflight_response() if it is allowed.
# The origin value will be set in cors_preflight_response()
# if it is allowed.
preflight_headers["Vary"] = "Origin"
else:
preflight_headers["Access-Control-Allow-Origin"] = "*"
Expand All @@ -60,10 +64,13 @@ def __init__(
preflight_headers["Access-Control-Allow-Headers"] = ", ".join(allow_headers)
if allow_credentials:
preflight_headers["Access-Control-Allow-Credentials"] = "true"
if allow_private_network:
preflight_headers["Access-Control-Allow-Private-Network"] = "true"

self.app = app
self.allow_origins = allow_origins
self.allow_methods = allow_methods
self.allow_private_network = allow_private_network
self.allow_headers = [h.lower() for h in allow_headers]
self.allow_all_origins = allow_all_origins
self.allow_all_headers = allow_all_headers
Expand All @@ -86,7 +93,14 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
return

if method == "OPTIONS" and "access-control-request-method" in headers:
response = self.preflight_response(request_headers=headers)
response = self.cors_preflight_response(request_headers=headers)
await response(scope, receive, send)
return

# Must respond to preflight request in no-cors mode
# (https://developer.chrome.com/blog/private-network-access-preflight#no-cors_mode)
if method == "OPTIONS" and "access-control-request-private-network" in headers:
response = self.pna_preflight_response(request_headers=headers)
await response(scope, receive, send)
return

Expand All @@ -103,7 +117,7 @@ def is_allowed_origin(self, origin: str) -> bool:

return origin in self.allow_origins

def preflight_response(self, request_headers: Headers) -> Response:
def cors_preflight_response(self, request_headers: Headers) -> Response:
requested_origin = request_headers["origin"]
requested_method = request_headers["access-control-request-method"]
requested_headers = request_headers.get("access-control-request-headers")
Expand Down Expand Up @@ -141,6 +155,35 @@ def preflight_response(self, request_headers: Headers) -> Response:

return PlainTextResponse("OK", status_code=200, headers=headers)

def pna_preflight_response(self, request_headers: Headers) -> Response:
requested_origin = request_headers["origin"]
requested_private_network = request_headers[
"access-control-request-private-network"
]

headers = dict(self.preflight_headers)
failures = []

if self.is_allowed_origin(origin=requested_origin):
if self.preflight_explicit_allow_origin:
# The "else" case is already accounted for in self.preflight_headers
# and the value would be "*".
headers["Access-Control-Allow-Origin"] = requested_origin
else:
failures.append("origin")

if requested_private_network == "true" and not self.allow_private_network:
failures.append("private-network")

# We don't strictly need to use 400 responses here, since its up to
# the browser to enforce the CORS policy, but its more informative
# if we do.
if failures:
failure_text = "Disallowed PNA " + ", ".join(failures)
return PlainTextResponse(failure_text, status_code=400, headers=headers)

return PlainTextResponse("OK", status_code=200, headers=headers)

async def simple_response(
self, scope: Scope, receive: Receive, send: Send, request_headers: Headers
) -> None:
Expand Down
Loading