200 lines
6.1 KiB
Python
200 lines
6.1 KiB
Python
import asyncio
|
|
import dataclasses
|
|
import http.server
|
|
import threading
|
|
import typing
|
|
|
|
import contextlib
|
|
import time
|
|
|
|
from cry import feed
|
|
|
|
TRACE_ELAPSED = False
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def dumb_trace(name):
|
|
start = time.time()
|
|
if TRACE_ELAPSED:
|
|
print(f"{start:.3f} ENTER {name}")
|
|
yield
|
|
end = time.time()
|
|
elapsed = end - start
|
|
if TRACE_ELAPSED:
|
|
print(f"{end:.3f} EXIT {name} (elapsed: {elapsed:.3f}s")
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class MockResponse:
|
|
code: int = 200
|
|
headers: list[typing.Tuple[str, str]] | None = None
|
|
content: bytes = b""
|
|
|
|
def write_to(self, handler: http.server.BaseHTTPRequestHandler):
|
|
handler.send_response(self.code)
|
|
if self.headers is not None:
|
|
for name, value in self.headers:
|
|
handler.send_header(name, value)
|
|
handler.send_header("content-length", str(len(self.content)))
|
|
handler.end_headers()
|
|
handler.wfile.write(self.content)
|
|
|
|
|
|
class TestWebServer:
|
|
__test__ = False
|
|
server_port: int
|
|
http_server: http.server.ThreadingHTTPServer
|
|
server_thread: threading.Thread
|
|
|
|
handlers: dict[str, list[MockResponse]]
|
|
|
|
def __init__(self):
|
|
with dumb_trace("init"):
|
|
server = self
|
|
|
|
class TestWebHandler(http.server.BaseHTTPRequestHandler):
|
|
def do_GET(self):
|
|
server._handle_GET(self)
|
|
|
|
self.http_server = http.server.ThreadingHTTPServer(("", 0), TestWebHandler)
|
|
self.http_server.block_on_close = False
|
|
self.http_server.daemon_threads = True
|
|
self.server_thread = threading.Thread(target=self._do_serve)
|
|
self.server_thread.daemon = True
|
|
self.handlers = {}
|
|
|
|
self.server_port = self.http_server.server_port
|
|
|
|
def _do_serve(self):
|
|
self.http_server.serve_forever(poll_interval=0.01)
|
|
|
|
def make_url(self, path: str) -> str:
|
|
return f"http://localhost:{self.server_port}{path}"
|
|
|
|
def handle(
|
|
self,
|
|
path: str,
|
|
content: bytes = b"",
|
|
code: int = 200,
|
|
content_type: str | None = None,
|
|
headers: list[typing.Tuple[str, str]] | None = None,
|
|
):
|
|
if headers is None:
|
|
headers = []
|
|
else:
|
|
headers = list(headers)
|
|
|
|
if content_type is not None:
|
|
headers.append(("content-type", content_type))
|
|
|
|
self.respond(path, MockResponse(code, headers, content))
|
|
|
|
def respond(self, path: str, response: MockResponse):
|
|
response_list = self.handlers.get(path)
|
|
if response_list is None:
|
|
response_list = []
|
|
self.handlers[path] = response_list
|
|
|
|
response_list.append(response)
|
|
|
|
def _handle_GET(self, handler: http.server.BaseHTTPRequestHandler):
|
|
responses = self.handlers.get(handler.path, [])
|
|
|
|
if len(responses) > 0:
|
|
response = responses[0]
|
|
if len(responses) > 1:
|
|
responses.pop(0)
|
|
else:
|
|
response = MockResponse(code=404)
|
|
|
|
response.write_to(handler)
|
|
|
|
def __enter__(self):
|
|
with dumb_trace("__enter__"):
|
|
self.server_thread.start()
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
with dumb_trace("__exit__"):
|
|
del exc_type
|
|
del exc_value
|
|
del traceback
|
|
self.http_server.shutdown()
|
|
|
|
|
|
TEST_FEED = b"""
|
|
<?xml version="1.0" encoding="utf-8" standalone="yes" ?>
|
|
<rss version="2.0">
|
|
<channel>
|
|
<title>Smallest</title>
|
|
<item>
|
|
<title>An Item</title>
|
|
<link>http://example.com</link>
|
|
<guid>tag:d0ty.me,2024-01-01:smallest:one</guid>
|
|
</item>
|
|
</channel>
|
|
</rss>
|
|
"""
|
|
|
|
|
|
def test_basic_successful_fetch():
|
|
with TestWebServer() as server:
|
|
server.handle("/", TEST_FEED, content_type="text/xml")
|
|
|
|
meta = feed.FeedMeta.from_url(server.make_url("/"))
|
|
result, new_meta = asyncio.run(feed.fetch_feed(meta))
|
|
|
|
assert new_meta.url == meta.url
|
|
assert isinstance(result, feed.Feed)
|
|
assert len(result.entries) == 1
|
|
assert result.title == "Smallest"
|
|
|
|
|
|
def test_fetch_after_temp_redirect():
|
|
with TestWebServer() as server:
|
|
server.handle("/old", code=307, headers=[("location", "/temp")])
|
|
server.handle("/temp", TEST_FEED, content_type="text/xml")
|
|
|
|
meta = feed.FeedMeta.from_url(server.make_url("/old"))
|
|
result, new_meta = asyncio.run(feed.fetch_feed(meta))
|
|
assert new_meta.url == meta.url
|
|
assert isinstance(result, feed.Feed)
|
|
|
|
|
|
def test_fetch_after_permanent_redirect():
|
|
with TestWebServer() as server:
|
|
server.handle("/old", code=308, headers=[("location", "/perm")])
|
|
server.handle("/perm", TEST_FEED, content_type="text/xml")
|
|
|
|
meta = feed.FeedMeta.from_url(server.make_url("/old"))
|
|
result, new_meta = asyncio.run(feed.fetch_feed(meta))
|
|
assert new_meta.url == server.make_url("/perm")
|
|
assert isinstance(result, feed.Feed)
|
|
|
|
|
|
def test_fetch_after_permanent_to_temporary_redirect():
|
|
with TestWebServer() as server:
|
|
server.handle("/old", code=308, headers=[("location", "/perm")])
|
|
server.handle("/perm", code=307, headers=[("location", "/temp")])
|
|
server.handle("/temp", TEST_FEED, content_type="text/xml")
|
|
|
|
meta = feed.FeedMeta.from_url(server.make_url("/old"))
|
|
result, new_meta = asyncio.run(feed.fetch_feed(meta))
|
|
|
|
# NOTE: we should record the PERMANENT redirect, not the temporary one.
|
|
assert new_meta.url == server.make_url("/perm")
|
|
assert isinstance(result, feed.Feed)
|
|
|
|
|
|
def test_fetch_after_permanent_to_permanent_redirect():
|
|
with TestWebServer() as server:
|
|
server.handle("/old", code=308, headers=[("location", "/one")])
|
|
server.handle("/one", code=308, headers=[("location", "/two")])
|
|
server.handle("/two", TEST_FEED, content_type="text/xml")
|
|
|
|
meta = feed.FeedMeta.from_url(server.make_url("/old"))
|
|
result, new_meta = asyncio.run(feed.fetch_feed(meta))
|
|
|
|
# NOTE: we should record the latest redirect.
|
|
assert new_meta.url == server.make_url("/two")
|
|
assert isinstance(result, feed.Feed)
|