diff --git a/cry/web.py b/cry/web.py index 7a5fea5..1717cab 100644 --- a/cry/web.py +++ b/cry/web.py @@ -1,6 +1,7 @@ import asyncio import contextlib import dataclasses +import functools import html import http.server import io @@ -9,19 +10,23 @@ import traceback import threading import urllib.parse +from typing import Any, Callable, Concatenate, ParamSpec, Protocol + from . import database from . import feed class DeadlineCondition: + """A condition variable that allows you to wait with a timeout.""" + lock: threading.Lock - waiting: int + waiters: int event: threading.Semaphore def __init__(self): self.lock = threading.Lock() self.event = threading.Semaphore() - self.waiting = 0 + self.waiters = 0 @contextlib.contextmanager def acquire(self, deadline: float): @@ -51,6 +56,10 @@ class DeadlineCondition: class AcquiredDeadlineCondition: + """An aquired condition lock, which can be used to wait for conditions or + signal other waiters. + """ + locked: bool condition: DeadlineCondition deadline: float @@ -65,8 +74,8 @@ class AcquiredDeadlineCondition: assert self.locked condition = self.condition - count = condition.waiting - condition.waiting = 0 + count = condition.waiters + condition.waiters = 0 if count > 0: condition.event.release(count) @@ -87,7 +96,7 @@ class AcquiredDeadlineCondition: if timeout <= 0: raise TimeoutError() - self.condition.waiting += 1 + self.condition.waiters += 1 self.locked = False self.condition.lock.release() @@ -95,7 +104,7 @@ class AcquiredDeadlineCondition: # If we timeout it's not safe to decrement the semaphore so we # just leave it; the side effect is that a signaller will # increment the semaphore more than it should, and so other - # waiters will have spurious waits, but that's a known problem + # waiters will have spurious wakes, but that's a known problem # with this kind of synchronization. if not self.condition.event.acquire(timeout=timeout): raise TimeoutError() @@ -185,16 +194,26 @@ class EventConsumer: return Event(event=None, data=None, id=None) -# THE EVENT STUFF +# Background Tasks -class RefreshTask: +class BackgroundTask: + """Some task running in the background. The specified `func` receives an + `EventChannel` object which it can use to report progress. That same + `EventChannel` object is present as the `sink` property, and you can use + it to subscribe to the events. + """ + + func: Callable[[EventChannel], None] + sink: EventChannel thread: threading.Thread - def __init__(self): + def __init__(self, func: Callable[[EventChannel], None]): + self.func = func + self.sink = EventChannel() - self.thread = threading.Thread(target=self._refresh_thread) + self.thread = threading.Thread(target=self._background_thread) self.thread.daemon = True self.thread.start() @@ -202,31 +221,37 @@ class RefreshTask: def closed(self): return self.sink.closed - def _refresh_thread(self): + def _background_thread(self): sink = self.sink try: - db = database.Database.local() - sink.status("Synchronizing state...") - database.sync(db) - - sink.status("Loading subscriptions...") - metas = db.load_all_meta() - - sink.status("Refreshing subscriptions...") - asyncio.run(self._refresh_all(db, metas)) - - sink.status("Done") - sink.redirect("/") + self.func(sink) finally: sink.close() - async def _refresh_all(self, db: database.Database, metas: list[feed.FeedMeta]): - async with asyncio.TaskGroup() as group: - for meta in metas: - group.create_task(self._refresh_meta(db, meta)) - async def _refresh_meta(self, db: database.Database, meta: feed.FeedMeta): - sink = self.sink +TaskParams = ParamSpec("TaskParams") + + +def background_task( + f: Callable[Concatenate[EventChannel, TaskParams], None] +) -> Callable[TaskParams, BackgroundTask]: + """A decorator that wraps a function to produce a BackgroundTask. The + function will receive an `EventChannel` as its first argument, along with + whatever other arguments it might need. + """ + + @functools.wraps(f) + def impl(*args: TaskParams.args, **kwargs: TaskParams.kwargs) -> BackgroundTask: + return BackgroundTask(lambda sink: f(sink, *args, **kwargs)) + + return impl + + +@background_task +def refresh_feeds(sink: EventChannel): + """Refresh all the subscribed feeds.""" + + async def _refresh_meta(db: database.Database, meta: feed.FeedMeta): sink.log(f"[{meta.url}] Fetching...") d = None try: @@ -247,66 +272,63 @@ class RefreshTask: except Exception as e: sink.log(f"[{meta.url}] Error refressing feed: {e}") + async def _refresh_all(db: database.Database, metas: list[feed.FeedMeta]): + async with asyncio.TaskGroup() as group: + for meta in metas: + group.create_task(_refresh_meta(db, meta)) -REFRESH_TASK: RefreshTask | None = None + db = database.Database.local() + sink.status("Synchronizing state...") + database.sync(db) + + sink.status("Loading subscriptions...") + metas = db.load_all_meta() + + sink.status("Refreshing subscriptions...") + asyncio.run(_refresh_all(db, metas)) + + sink.status("Done") + sink.redirect("/") -class SubscribeTask: - url: str - sink: EventChannel - thread: threading.Thread - - def __init__(self, url: str): - self.url = url - - self.sink = EventChannel() - self.thread = threading.Thread(target=self._task_thread) - self.thread.daemon = True - self.thread.start() - - @property - def closed(self): - return self.sink.closed - - def _task_thread(self): - sink = self.sink - url = self.url - - try: - db = database.Database.local() - sink.status("Synchronizing state...") - database.sync(db) - - sink.status("Searching for feeds...") - feeds = asyncio.run(feed.feed_search(url)) - if len(feeds) == 0: - sink.status(f"Unable to find a suitable feed for {url}") - return - - if len(feeds) > 1: - candidates = [(("t", f.title), ("u", f.meta.url)) for f in feeds] - qs = urllib.parse.urlencode([e for c in candidates for e in c]) - sink.redirect(f"/subscribe-choose?{qs}") - return - - result = feeds[0] - sink.log(f"Identified {result.meta.url} as a feed for {url}") - - existing = db.load_meta(result.meta.url) - if existing is not None: - sink.log(f"This feed already exists (as {result.meta.url})") - sink.status("Already subscribed") - return - - db.store_feed(result) - sink.status("Done") - sink.redirect("/") - - finally: - sink.close() +REFRESH_TASK: BackgroundTask | None = None -SUBSCRIBE_TASK: SubscribeTask | None = None +@background_task +def subscribe(sink: EventChannel, url: str): + """Subscribe to a feed.""" + + db = database.Database.local() + sink.status("Synchronizing state...") + database.sync(db) + + sink.status("Searching for feeds...") + feeds = asyncio.run(feed.feed_search(url)) + if len(feeds) == 0: + sink.status(f"Unable to find a suitable feed for {url}") + return + + if len(feeds) > 1: + candidates = [(("t", f.title), ("u", f.meta.url)) for f in feeds] + qs = urllib.parse.urlencode([e for c in candidates for e in c]) + sink.redirect(f"/subscribe-choose?{qs}") + return + + result = feeds[0] + sink.log(f"Identified {result.meta.url} as a feed for {url}") + + existing = db.load_meta(result.meta.url) + if existing is not None: + sink.log(f"This feed already exists (as {result.meta.url})") + sink.status("Already subscribed") + return + + db.store_feed(result) + sink.status("Done") + sink.redirect("/") + + +SUBSCRIBE_TASK: BackgroundTask | None = None class Handler(http.server.BaseHTTPRequestHandler): @@ -338,7 +360,7 @@ class Handler(http.server.BaseHTTPRequestHandler): def do_refresh(self): global REFRESH_TASK if REFRESH_TASK is None or REFRESH_TASK.closed: - REFRESH_TASK = RefreshTask() + REFRESH_TASK = refresh_feeds() self.send_response(303) self.send_header("Location", "/refresh-status") @@ -358,13 +380,13 @@ class Handler(http.server.BaseHTTPRequestHandler): global SUBSCRIBE_TASK if SUBSCRIBE_TASK is None or SUBSCRIBE_TASK.closed: - SUBSCRIBE_TASK = SubscribeTask(url) + SUBSCRIBE_TASK = subscribe(url) self.send_response(303) self.send_header("Location", "/subscribe-status") self.end_headers() - def serve_events(self, task: RefreshTask | SubscribeTask | None): + def serve_events(self, task: BackgroundTask | None): if task is None or task.closed: self.send_response(204) self.end_headers()