The most self-indulgent re-factor of background tasks

The ParamSpec stuff is special-purpose but gee, the type checker has
gotten pretty good, huh.
This commit is contained in:
John Doty 2024-08-18 08:54:28 -07:00
parent 785f71223b
commit 74f7146937

View file

@ -1,6 +1,7 @@
import asyncio
import contextlib
import dataclasses
import functools
import html
import http.server
import io
@ -9,19 +10,23 @@ import traceback
import threading
import urllib.parse
from typing import Any, Callable, Concatenate, ParamSpec, Protocol
from . import database
from . import feed
class DeadlineCondition:
"""A condition variable that allows you to wait with a timeout."""
lock: threading.Lock
waiting: int
waiters: int
event: threading.Semaphore
def __init__(self):
self.lock = threading.Lock()
self.event = threading.Semaphore()
self.waiting = 0
self.waiters = 0
@contextlib.contextmanager
def acquire(self, deadline: float):
@ -51,6 +56,10 @@ class DeadlineCondition:
class AcquiredDeadlineCondition:
"""An aquired condition lock, which can be used to wait for conditions or
signal other waiters.
"""
locked: bool
condition: DeadlineCondition
deadline: float
@ -65,8 +74,8 @@ class AcquiredDeadlineCondition:
assert self.locked
condition = self.condition
count = condition.waiting
condition.waiting = 0
count = condition.waiters
condition.waiters = 0
if count > 0:
condition.event.release(count)
@ -87,7 +96,7 @@ class AcquiredDeadlineCondition:
if timeout <= 0:
raise TimeoutError()
self.condition.waiting += 1
self.condition.waiters += 1
self.locked = False
self.condition.lock.release()
@ -95,7 +104,7 @@ class AcquiredDeadlineCondition:
# If we timeout it's not safe to decrement the semaphore so we
# just leave it; the side effect is that a signaller will
# increment the semaphore more than it should, and so other
# waiters will have spurious waits, but that's a known problem
# waiters will have spurious wakes, but that's a known problem
# with this kind of synchronization.
if not self.condition.event.acquire(timeout=timeout):
raise TimeoutError()
@ -185,16 +194,26 @@ class EventConsumer:
return Event(event=None, data=None, id=None)
# THE EVENT STUFF
# Background Tasks
class RefreshTask:
class BackgroundTask:
"""Some task running in the background. The specified `func` receives an
`EventChannel` object which it can use to report progress. That same
`EventChannel` object is present as the `sink` property, and you can use
it to subscribe to the events.
"""
func: Callable[[EventChannel], None]
sink: EventChannel
thread: threading.Thread
def __init__(self):
def __init__(self, func: Callable[[EventChannel], None]):
self.func = func
self.sink = EventChannel()
self.thread = threading.Thread(target=self._refresh_thread)
self.thread = threading.Thread(target=self._background_thread)
self.thread.daemon = True
self.thread.start()
@ -202,31 +221,37 @@ class RefreshTask:
def closed(self):
return self.sink.closed
def _refresh_thread(self):
def _background_thread(self):
sink = self.sink
try:
db = database.Database.local()
sink.status("Synchronizing state...")
database.sync(db)
sink.status("Loading subscriptions...")
metas = db.load_all_meta()
sink.status("Refreshing subscriptions...")
asyncio.run(self._refresh_all(db, metas))
sink.status("Done")
sink.redirect("/")
self.func(sink)
finally:
sink.close()
async def _refresh_all(self, db: database.Database, metas: list[feed.FeedMeta]):
async with asyncio.TaskGroup() as group:
for meta in metas:
group.create_task(self._refresh_meta(db, meta))
async def _refresh_meta(self, db: database.Database, meta: feed.FeedMeta):
sink = self.sink
TaskParams = ParamSpec("TaskParams")
def background_task(
f: Callable[Concatenate[EventChannel, TaskParams], None]
) -> Callable[TaskParams, BackgroundTask]:
"""A decorator that wraps a function to produce a BackgroundTask. The
function will receive an `EventChannel` as its first argument, along with
whatever other arguments it might need.
"""
@functools.wraps(f)
def impl(*args: TaskParams.args, **kwargs: TaskParams.kwargs) -> BackgroundTask:
return BackgroundTask(lambda sink: f(sink, *args, **kwargs))
return impl
@background_task
def refresh_feeds(sink: EventChannel):
"""Refresh all the subscribed feeds."""
async def _refresh_meta(db: database.Database, meta: feed.FeedMeta):
sink.log(f"[{meta.url}] Fetching...")
d = None
try:
@ -247,66 +272,63 @@ class RefreshTask:
except Exception as e:
sink.log(f"[{meta.url}] Error refressing feed: {e}")
async def _refresh_all(db: database.Database, metas: list[feed.FeedMeta]):
async with asyncio.TaskGroup() as group:
for meta in metas:
group.create_task(_refresh_meta(db, meta))
REFRESH_TASK: RefreshTask | None = None
db = database.Database.local()
sink.status("Synchronizing state...")
database.sync(db)
sink.status("Loading subscriptions...")
metas = db.load_all_meta()
sink.status("Refreshing subscriptions...")
asyncio.run(_refresh_all(db, metas))
sink.status("Done")
sink.redirect("/")
class SubscribeTask:
url: str
sink: EventChannel
thread: threading.Thread
def __init__(self, url: str):
self.url = url
self.sink = EventChannel()
self.thread = threading.Thread(target=self._task_thread)
self.thread.daemon = True
self.thread.start()
@property
def closed(self):
return self.sink.closed
def _task_thread(self):
sink = self.sink
url = self.url
try:
db = database.Database.local()
sink.status("Synchronizing state...")
database.sync(db)
sink.status("Searching for feeds...")
feeds = asyncio.run(feed.feed_search(url))
if len(feeds) == 0:
sink.status(f"Unable to find a suitable feed for {url}")
return
if len(feeds) > 1:
candidates = [(("t", f.title), ("u", f.meta.url)) for f in feeds]
qs = urllib.parse.urlencode([e for c in candidates for e in c])
sink.redirect(f"/subscribe-choose?{qs}")
return
result = feeds[0]
sink.log(f"Identified {result.meta.url} as a feed for {url}")
existing = db.load_meta(result.meta.url)
if existing is not None:
sink.log(f"This feed already exists (as {result.meta.url})")
sink.status("Already subscribed")
return
db.store_feed(result)
sink.status("Done")
sink.redirect("/")
finally:
sink.close()
REFRESH_TASK: BackgroundTask | None = None
SUBSCRIBE_TASK: SubscribeTask | None = None
@background_task
def subscribe(sink: EventChannel, url: str):
"""Subscribe to a feed."""
db = database.Database.local()
sink.status("Synchronizing state...")
database.sync(db)
sink.status("Searching for feeds...")
feeds = asyncio.run(feed.feed_search(url))
if len(feeds) == 0:
sink.status(f"Unable to find a suitable feed for {url}")
return
if len(feeds) > 1:
candidates = [(("t", f.title), ("u", f.meta.url)) for f in feeds]
qs = urllib.parse.urlencode([e for c in candidates for e in c])
sink.redirect(f"/subscribe-choose?{qs}")
return
result = feeds[0]
sink.log(f"Identified {result.meta.url} as a feed for {url}")
existing = db.load_meta(result.meta.url)
if existing is not None:
sink.log(f"This feed already exists (as {result.meta.url})")
sink.status("Already subscribed")
return
db.store_feed(result)
sink.status("Done")
sink.redirect("/")
SUBSCRIBE_TASK: BackgroundTask | None = None
class Handler(http.server.BaseHTTPRequestHandler):
@ -338,7 +360,7 @@ class Handler(http.server.BaseHTTPRequestHandler):
def do_refresh(self):
global REFRESH_TASK
if REFRESH_TASK is None or REFRESH_TASK.closed:
REFRESH_TASK = RefreshTask()
REFRESH_TASK = refresh_feeds()
self.send_response(303)
self.send_header("Location", "/refresh-status")
@ -358,13 +380,13 @@ class Handler(http.server.BaseHTTPRequestHandler):
global SUBSCRIBE_TASK
if SUBSCRIBE_TASK is None or SUBSCRIBE_TASK.closed:
SUBSCRIBE_TASK = SubscribeTask(url)
SUBSCRIBE_TASK = subscribe(url)
self.send_response(303)
self.send_header("Location", "/subscribe-status")
self.end_headers()
def serve_events(self, task: RefreshTask | SubscribeTask | None):
def serve_events(self, task: BackgroundTask | None):
if task is None or task.closed:
self.send_response(204)
self.end_headers()