Commit c9dd670e authored by pgjones's avatar pgjones

Bugfix the root_path handling

If a root_path is specified the server should return 404 to any
requests that have a path that do not start with the root_path. The
server should also adjust the path to remove the root_path, whilst
keeping the raw_path intact.

This allows the root_path to be used as a global prefix to all routes
served.
parent a9357dea
Pipeline #244852758 passed with stages
in 6 minutes and 26 seconds
......@@ -46,6 +46,7 @@ class Config:
_quic_bind: List[str] = []
_quic_addresses: List[Tuple] = []
_log: Optional[Logger] = None
_root_path: str = ""
access_log_format = '%(h)s %(l)s %(l)s %(t)s "%(r)s" %(s)s %(b)s "%(f)s" "%(a)s"'
accesslog: Union[logging.Logger, str, None] = None
......@@ -74,7 +75,6 @@ class Config:
loglevel: str = "INFO"
max_app_queue_size: int = 10
pid_path: Optional[str] = None
root_path = ""
server_names: List[str] = []
shutdown_timeout = 60 * SECONDS
ssl_handshake_timeout = 60 * SECONDS
......@@ -136,6 +136,14 @@ class Config:
else:
self._quic_bind = value
@property
def root_path(self) -> str:
return self._root_path
@root_path.setter
def root_path(self, value: str) -> None:
self._root_path = value.rstrip("/")
def create_ssl_context(self) -> Optional[SSLContext]:
if not self.ssl_enabled:
return None
......
from enum import auto, Enum
from time import time
from typing import Awaitable, Callable, Optional, Tuple
from urllib.parse import unquote
from .events import Body, EndBody, Event, Request, Response, StreamClosed
from ..config import Config
from ..typing import ASGIFramework, ASGISendEvent, Context, HTTPResponseStartEvent, HTTPScope
from ..utils import build_and_validate_headers, suppress_body, UnexpectedMessage, valid_server_name
from ..utils import (
build_and_validate_headers,
extract_path,
suppress_body,
UnexpectedMessage,
valid_server_name,
)
PUSH_VERSIONS = {"2", "3"}
......@@ -55,15 +60,16 @@ class HTTPStream:
return
elif isinstance(event, Request):
self.start_time = time()
path, _, query_string = event.raw_path.partition(b"?")
raw_path, _, query_string = event.raw_path.partition(b"?")
path = extract_path(raw_path, self.config.root_path)
self.scope = {
"type": "http",
"http_version": event.http_version,
"asgi": {"spec_version": "2.1"},
"method": event.method,
"scheme": self.scheme,
"path": unquote(path.decode("ascii")),
"raw_path": path,
"path": path,
"raw_path": raw_path,
"query_string": query_string,
"root_path": self.config.root_path,
"headers": event.headers,
......@@ -74,13 +80,13 @@ class HTTPStream:
if event.http_version in PUSH_VERSIONS:
self.scope["extensions"]["http.response.push"] = {}
if valid_server_name(self.config, event):
if path is None or not valid_server_name(self.config, event):
await self._send_error_response(404)
self.closed = True
else:
self.app_put = await self.context.spawn_app(
self.app, self.config, self.scope, self.app_send
)
else:
await self._send_error_response(404)
self.closed = True
elif isinstance(event, Body):
await self.app_put(
......
from enum import auto, Enum
from time import time
from typing import Awaitable, Callable, List, Optional, Tuple, Union
from urllib.parse import unquote
from wsproto.connection import Connection, ConnectionState, ConnectionType
from wsproto.events import (
......@@ -28,7 +27,13 @@ from ..typing import (
WebsocketResponseStartEvent,
WebsocketScope,
)
from ..utils import build_and_validate_headers, suppress_body, UnexpectedMessage, valid_server_name
from ..utils import (
build_and_validate_headers,
extract_path,
suppress_body,
UnexpectedMessage,
valid_server_name,
)
class ASGIWebsocketState(Enum):
......@@ -183,14 +188,15 @@ class WSStream:
elif isinstance(event, Request):
self.start_time = time()
self.handshake = Handshake(event.headers, event.http_version)
path, _, query_string = event.raw_path.partition(b"?")
raw_path, _, query_string = event.raw_path.partition(b"?")
path = extract_path(raw_path, self.config.root_path)
self.scope = {
"type": "websocket",
"asgi": {"spec_version": "2.1"},
"scheme": self.scheme,
"http_version": event.http_version,
"path": unquote(path.decode("ascii")),
"raw_path": path,
"path": path,
"raw_path": raw_path,
"query_string": query_string,
"root_path": self.config.root_path,
"headers": event.headers,
......@@ -200,7 +206,7 @@ class WSStream:
"extensions": {"websocket.http.response": {}},
}
if not valid_server_name(self.config, event):
if path is None or not valid_server_name(self.config, event):
await self._send_error_response(404)
self.closed = True
elif not self.handshake.is_valid():
......
......@@ -19,6 +19,7 @@ from typing import (
Tuple,
TYPE_CHECKING,
)
from urllib.parse import unquote
from .config import Config
from .typing import (
......@@ -259,3 +260,13 @@ def valid_server_name(config: Config, request: "Request") -> bool:
host = value.decode()
break
return host in config.server_names
def extract_path(raw_path: bytes, root_path: str) -> Optional[str]:
"""Extract the path portion from the raw_path accounting for any root_path."""
path = unquote(raw_path.decode("ascii"))
if path.startswith(root_path):
path = path[len(root_path) :]
return path if path != "" else "/"
else:
return None
......@@ -146,6 +146,30 @@ async def test_invalid_server_name(stream: HTTPStream) -> None:
await stream.handle(Body(stream_id=1, data=b"Body"))
@pytest.mark.asyncio
async def test_root_path(stream: HTTPStream) -> None:
stream.config.root_path = "/bob"
await stream.handle(
Request(
stream_id=1,
http_version="2",
headers=[(b"host", b"example.com")],
raw_path=b"/",
method="GET",
)
)
assert stream.send.call_args_list == [
call(
Response(
stream_id=1,
headers=[(b"content-length", b"0"), (b"connection", b"close")],
status_code=404,
)
),
call(EndBody(stream_id=1)),
]
@pytest.mark.asyncio
async def test_send_push(stream: HTTPStream) -> None:
stream.scope = {"scheme": "https", "headers": [(b"host", b"hypercorn")], "http_version": "2"}
......
......@@ -267,6 +267,30 @@ async def test_invalid_server_name(stream: WSStream) -> None:
await stream.handle(Body(stream_id=1, data=b"Body"))
@pytest.mark.asyncio
async def test_root_path(stream: WSStream) -> None:
stream.config.root_path = "/bob"
await stream.handle(
Request(
stream_id=1,
http_version="2",
headers=[(b"host", b"example.com"), (b"sec-websocket-version", b"13")],
raw_path=b"/",
method="GET",
)
)
assert stream.send.call_args_list == [
call(
Response(
stream_id=1,
headers=[(b"content-length", b"0"), (b"connection", b"close")],
status_code=404,
)
),
call(EndBody(stream_id=1)),
]
@pytest.mark.asyncio
async def test_send_app_error_handshake(stream: WSStream) -> None:
await stream.handle(
......
from typing import Callable
from typing import Callable, Optional
import pytest
......@@ -97,3 +97,17 @@ def test_filter_pseudo_headers() -> None:
[(b":authority", b"quart"), (b":path", b"/"), (b"user-agent", b"something")]
)
assert result == [(b"host", b"quart"), (b"user-agent", b"something")]
@pytest.mark.parametrize(
"raw_path, root_path, expected",
[
(b"/bob/", "/bob", "/"),
(b"/bob", "/bob", "/"),
(b"/foo", "/bob", None),
(b"/bob", "", "/bob"),
(b"/bob", "bob", None),
],
)
def test_extract_path(raw_path: bytes, root_path: str, expected: Optional[str]) -> None:
assert hypercorn.utils.extract_path(raw_path, root_path) == expected
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment