From c855ff46d2e308c95722705dd08e2077a724949a Mon Sep 17 00:00:00 2001 From: John Doty Date: Thu, 25 Jul 2024 16:59:40 -0700 Subject: [PATCH] Refresh from web with server-sent events Not as bad as it was --- cry/web.py | 336 ++++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 294 insertions(+), 42 deletions(-) diff --git a/cry/web.py b/cry/web.py index 1f149ba..3dd8dba 100644 --- a/cry/web.py +++ b/cry/web.py @@ -1,63 +1,251 @@ import asyncio +import contextlib +import dataclasses import html import http.server import io +import time import threading from . import database from . import feed -class Refresher: - status: io.StringIO +class DeadlineCondition: + lock: threading.Lock + waiting: int + event: threading.Semaphore - def start(self): - self.status = io.StringIO() - self.thread = threading.Thread(target=self.refresh_thread) + def __init__(self): + self.lock = threading.Lock() + self.event = threading.Semaphore() + self.waiting = 0 + + @contextlib.contextmanager + def acquire(self, deadline: float): + """Acquire the lock with the given deadline. + + Yields an `AcquiredDeadlineCondition` that can be used to wait or signal + the condition. If the deadline passes before the lock can be acquired, + raises TimeoutError. + """ + now = time.time() + if now >= deadline: + raise TimeoutError() + + timeout = deadline - now + if timeout <= 0: + raise TimeoutError() + + if not self.lock.acquire(timeout=timeout): + raise TimeoutError() + + state = AcquiredDeadlineCondition(self, deadline) + try: + yield state + finally: + if state.locked: + self.lock.release() + + +class AcquiredDeadlineCondition: + locked: bool + condition: DeadlineCondition + deadline: float + + def __init__(self, lock: DeadlineCondition, deadline: float): + self.condition = lock + self.deadline = deadline + self.locked = True + + def signal(self): + """Wake up every locked consumer that is blocked in `wait`.""" + assert self.locked + + condition = self.condition + count = condition.waiting + condition.waiting = 0 + if count > 0: + condition.event.release(count) + + def wait(self, deadline: float | None = None): + """Release the lock, wait for the lock to be signaled, then re-acquire + the lock. + + If the specified deadline passes before the lock is acquired, raises + `TimeoutError`. If no deadline is provided then we use the deadline + that was provided when the lock was originally acquired. + """ + assert self.locked + + if deadline is None: + deadline = self.deadline + + timeout = time.time() - deadline + if timeout <= 0: + raise TimeoutError() + + self.condition.waiting += 1 + self.locked = False + self.condition.lock.release() + + # NOTE: Leak here: we might wake the semaphore more than we intend to. + # If we timeout it's not safe to decrement the semaphore so we + # just leave it; the side effect is that a signaller will + # increment the semaphore more than it should, and so other + # waiters will have spurious waits, but that's a known problem + # with this kind of synchronization. + if not self.condition.event.acquire(timeout=timeout): + raise TimeoutError() + + timeout = time.time() - deadline + if timeout <= 0: + raise TimeoutError() + + if not self.condition.lock.acquire(timeout=timeout): + raise TimeoutError() + + self.locked = True + return True + + +class Closed: + pass + + +CLOSED = Closed() + + +@dataclasses.dataclass +class Event: + event: str | None + data: str | None + id: int | None + + +class EventChannel: + condition: DeadlineCondition + events: list[Event] + closed: bool + + def __init__(self): + self.condition = DeadlineCondition() + self.events = [] + self.closed = False + + def close(self): + deadline = time.time() + 30 + with self.condition.acquire(deadline) as lock: + if not self.closed: + self.closed = True + lock.signal() + + def event(self, event: str, data: str | None = None): + deadline = time.time() + 30 + with self.condition.acquire(deadline) as lock: + assert not self.closed + self.events.append(Event(event=event, data=data, id=len(self.events))) + lock.signal() + + def get_consumer(self, index: int | None = None) -> "EventConsumer": + return EventConsumer(self, index) + + +class EventConsumer: + channel: EventChannel + index: int + + def __init__(self, channel: EventChannel, index: int | None): + self.channel = channel + self.index = index or 0 + + @property + def closed(self) -> bool: + return self.channel.closed + + def get_event(self, deadline: float) -> Event | None: + try: + with self.channel.condition.acquire(deadline) as cond: + while self.index == len(self.channel.events) and not self.closed: + cond.wait() + + if self.index < len(self.channel.events): + result = self.channel.events[self.index] + self.index = self.index + 1 + return result + + if self.channel.closed: + return None + + except TimeoutError: + return Event(event=None, data=None, id=None) + + +# THE EVENT STUFF + + +class RefreshTask: + sink: EventChannel + thread: threading.Thread + + def __init__(self): + self.sink = EventChannel() + self.thread = threading.Thread(target=self._refresh_thread) self.thread.daemon = True self.thread.start() - def getvalue(self) -> str: - return self.status.getvalue() + @property + def closed(self): + return self.sink.closed - def refresh_thread(self): - global REFRESH_STATUS + def _refresh_thread(self): + sink = self.sink + try: + db = database.Database.local() + sink.event("status", "Synchronizing state...") + database.sync(db) - db = database.Database.local() - database.sync(db) + sink.event("status", "Loading subscriptions...") + metas = db.load_all_meta() - metas = db.load_all_meta() - asyncio.run(self.do_refresh(db, metas)) + sink.event("status", "Refreshing subscriptions...") + asyncio.run(self._refresh_all(db, metas)) - # Mark done for redirect... - REFRESH_STATUS = None + sink.event("status", "Done") + finally: + sink.close() - async def do_refresh(self, db: database.Database, metas: list[feed.FeedMeta]): + async def _refresh_all(self, db: database.Database, metas: list[feed.FeedMeta]): async with asyncio.TaskGroup() as group: for meta in metas: - group.create_task(self.refresh_meta(db, meta)) + group.create_task(self._refresh_meta(db, meta)) - async def refresh_meta(self, db: database.Database, meta: feed.FeedMeta): - self.status.write(f"[{meta.url}] Refreshing...\n") + async def _refresh_meta(self, db: database.Database, meta: feed.FeedMeta): + sink = self.sink + sink.event("log", f"[{meta.url}] Fetching...") d = None try: d, meta = await feed.fetch_feed(meta) if d is None: - self.status.write(f"[{meta.url}] No updates\n") + sink.event("log", f"[{meta.url}] No updates") db.update_meta(meta) elif isinstance(d, str): - self.status.write( - f"[{meta.url}] WARNING: returned a non-feed result!\n" + sink.event( + "log", + f"[{meta.url}] WARNING: returned a non-feed result!", ) else: new_count = db.store_feed(d) - self.status.write(f"[{meta.url}] {new_count} new items\n") + sink.event( + "log", + f"[{meta.url}] {new_count} new items\n", + ) except Exception as e: - self.status.write(f"[{meta.url}] Error while fetching: {e}\n") + sink.event("log", f"[{meta.url}] Error refressing feed: {e}") -REFRESH_STATUS: Refresher | None = None +REFRESH_TASK: RefreshTask | None = None class Handler(http.server.BaseHTTPRequestHandler): @@ -66,6 +254,8 @@ class Handler(http.server.BaseHTTPRequestHandler): return self.serve_feeds() elif self.path == "/refresh-status": return self.serve_refresh_status() + elif self.path == "/events/refresh": + return self.serve_events(REFRESH_TASK) else: self.send_error(404) @@ -77,39 +267,101 @@ class Handler(http.server.BaseHTTPRequestHandler): self.send_error(400) def do_refresh(self): - global REFRESH_STATUS - if REFRESH_STATUS is None: - REFRESH_STATUS = Refresher() - REFRESH_STATUS.start() + global REFRESH_TASK + if REFRESH_TASK is None or REFRESH_TASK.closed: + REFRESH_TASK = RefreshTask() self.send_response(303) self.send_header("Location", "/refresh-status") self.end_headers() - def serve_refresh_status(self): - global REFRESH_STATUS - status = REFRESH_STATUS - if status is None: - self.send_response(302) - self.send_header("Location", "/") + def serve_events(self, task: RefreshTask | None): + if task is None or task.closed: + self.send_response(204) self.end_headers() return + # Handle reconnect + last_id = self.headers.get("Last-Event-ID", None) + if last_id is not None: + try: + last_index = int(last_id) + except ValueError: + last_index = 0 + else: + last_index = 0 + + consumer = task.sink.get_consumer(last_index) + + self.send_response(200) + self.send_header("content-type", "text/event-stream") + self.send_header("x-accel-buffering", "no") + self.send_header("cache-control", "no-cache") + self.end_headers() + while True: + deadline = time.time() + 5 # 5 seconds from now + event = consumer.get_event(deadline) + if event is None: + # Event stream closed + break + + if event.id is None and event.data is None and event.event is None: + # Empty line for connection keepalive + self.wfile.write(b":\n\n") + else: + if event.id is not None: + self.wfile.write(f"id: {event.id}\n".encode("utf-8")) + if event.data is not None: + self.wfile.write(f"data: {event.data}\n".encode("utf-8")) + if event.event is not None: + self.wfile.write(f"event: {event.event}\n".encode("utf-8")) + self.wfile.write(b"\n") + self.wfile.flush() + + self.wfile.write(b"event: closed\ndata\n\n") + self.wfile.flush() + + def serve_refresh_status(self): + global REFRESH_TASK + buffer = io.StringIO() buffer.write( """ Refresh Status + -
"""
-        )
-        buffer.write(status.getvalue())
-        buffer.write(
-            """
+
+

Refresh Status

+

Status: Starting...

+

+            
""" )