The most self-indulgent re-factor of background tasks
The ParamSpec stuff is special-purpose but gee, the type checker has gotten pretty good, huh.
This commit is contained in:
parent
785f71223b
commit
74f7146937
1 changed files with 109 additions and 87 deletions
196
cry/web.py
196
cry/web.py
|
|
@ -1,6 +1,7 @@
|
|||
import asyncio
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import functools
|
||||
import html
|
||||
import http.server
|
||||
import io
|
||||
|
|
@ -9,19 +10,23 @@ import traceback
|
|||
import threading
|
||||
import urllib.parse
|
||||
|
||||
from typing import Any, Callable, Concatenate, ParamSpec, Protocol
|
||||
|
||||
from . import database
|
||||
from . import feed
|
||||
|
||||
|
||||
class DeadlineCondition:
|
||||
"""A condition variable that allows you to wait with a timeout."""
|
||||
|
||||
lock: threading.Lock
|
||||
waiting: int
|
||||
waiters: int
|
||||
event: threading.Semaphore
|
||||
|
||||
def __init__(self):
|
||||
self.lock = threading.Lock()
|
||||
self.event = threading.Semaphore()
|
||||
self.waiting = 0
|
||||
self.waiters = 0
|
||||
|
||||
@contextlib.contextmanager
|
||||
def acquire(self, deadline: float):
|
||||
|
|
@ -51,6 +56,10 @@ class DeadlineCondition:
|
|||
|
||||
|
||||
class AcquiredDeadlineCondition:
|
||||
"""An aquired condition lock, which can be used to wait for conditions or
|
||||
signal other waiters.
|
||||
"""
|
||||
|
||||
locked: bool
|
||||
condition: DeadlineCondition
|
||||
deadline: float
|
||||
|
|
@ -65,8 +74,8 @@ class AcquiredDeadlineCondition:
|
|||
assert self.locked
|
||||
|
||||
condition = self.condition
|
||||
count = condition.waiting
|
||||
condition.waiting = 0
|
||||
count = condition.waiters
|
||||
condition.waiters = 0
|
||||
if count > 0:
|
||||
condition.event.release(count)
|
||||
|
||||
|
|
@ -87,7 +96,7 @@ class AcquiredDeadlineCondition:
|
|||
if timeout <= 0:
|
||||
raise TimeoutError()
|
||||
|
||||
self.condition.waiting += 1
|
||||
self.condition.waiters += 1
|
||||
self.locked = False
|
||||
self.condition.lock.release()
|
||||
|
||||
|
|
@ -95,7 +104,7 @@ class AcquiredDeadlineCondition:
|
|||
# If we timeout it's not safe to decrement the semaphore so we
|
||||
# just leave it; the side effect is that a signaller will
|
||||
# increment the semaphore more than it should, and so other
|
||||
# waiters will have spurious waits, but that's a known problem
|
||||
# waiters will have spurious wakes, but that's a known problem
|
||||
# with this kind of synchronization.
|
||||
if not self.condition.event.acquire(timeout=timeout):
|
||||
raise TimeoutError()
|
||||
|
|
@ -185,16 +194,26 @@ class EventConsumer:
|
|||
return Event(event=None, data=None, id=None)
|
||||
|
||||
|
||||
# THE EVENT STUFF
|
||||
# Background Tasks
|
||||
|
||||
|
||||
class RefreshTask:
|
||||
class BackgroundTask:
|
||||
"""Some task running in the background. The specified `func` receives an
|
||||
`EventChannel` object which it can use to report progress. That same
|
||||
`EventChannel` object is present as the `sink` property, and you can use
|
||||
it to subscribe to the events.
|
||||
"""
|
||||
|
||||
func: Callable[[EventChannel], None]
|
||||
|
||||
sink: EventChannel
|
||||
thread: threading.Thread
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, func: Callable[[EventChannel], None]):
|
||||
self.func = func
|
||||
|
||||
self.sink = EventChannel()
|
||||
self.thread = threading.Thread(target=self._refresh_thread)
|
||||
self.thread = threading.Thread(target=self._background_thread)
|
||||
self.thread.daemon = True
|
||||
self.thread.start()
|
||||
|
||||
|
|
@ -202,31 +221,37 @@ class RefreshTask:
|
|||
def closed(self):
|
||||
return self.sink.closed
|
||||
|
||||
def _refresh_thread(self):
|
||||
def _background_thread(self):
|
||||
sink = self.sink
|
||||
try:
|
||||
db = database.Database.local()
|
||||
sink.status("Synchronizing state...")
|
||||
database.sync(db)
|
||||
|
||||
sink.status("Loading subscriptions...")
|
||||
metas = db.load_all_meta()
|
||||
|
||||
sink.status("Refreshing subscriptions...")
|
||||
asyncio.run(self._refresh_all(db, metas))
|
||||
|
||||
sink.status("Done")
|
||||
sink.redirect("/")
|
||||
self.func(sink)
|
||||
finally:
|
||||
sink.close()
|
||||
|
||||
async def _refresh_all(self, db: database.Database, metas: list[feed.FeedMeta]):
|
||||
async with asyncio.TaskGroup() as group:
|
||||
for meta in metas:
|
||||
group.create_task(self._refresh_meta(db, meta))
|
||||
|
||||
async def _refresh_meta(self, db: database.Database, meta: feed.FeedMeta):
|
||||
sink = self.sink
|
||||
TaskParams = ParamSpec("TaskParams")
|
||||
|
||||
|
||||
def background_task(
|
||||
f: Callable[Concatenate[EventChannel, TaskParams], None]
|
||||
) -> Callable[TaskParams, BackgroundTask]:
|
||||
"""A decorator that wraps a function to produce a BackgroundTask. The
|
||||
function will receive an `EventChannel` as its first argument, along with
|
||||
whatever other arguments it might need.
|
||||
"""
|
||||
|
||||
@functools.wraps(f)
|
||||
def impl(*args: TaskParams.args, **kwargs: TaskParams.kwargs) -> BackgroundTask:
|
||||
return BackgroundTask(lambda sink: f(sink, *args, **kwargs))
|
||||
|
||||
return impl
|
||||
|
||||
|
||||
@background_task
|
||||
def refresh_feeds(sink: EventChannel):
|
||||
"""Refresh all the subscribed feeds."""
|
||||
|
||||
async def _refresh_meta(db: database.Database, meta: feed.FeedMeta):
|
||||
sink.log(f"[{meta.url}] Fetching...")
|
||||
d = None
|
||||
try:
|
||||
|
|
@ -247,66 +272,63 @@ class RefreshTask:
|
|||
except Exception as e:
|
||||
sink.log(f"[{meta.url}] Error refressing feed: {e}")
|
||||
|
||||
async def _refresh_all(db: database.Database, metas: list[feed.FeedMeta]):
|
||||
async with asyncio.TaskGroup() as group:
|
||||
for meta in metas:
|
||||
group.create_task(_refresh_meta(db, meta))
|
||||
|
||||
REFRESH_TASK: RefreshTask | None = None
|
||||
db = database.Database.local()
|
||||
sink.status("Synchronizing state...")
|
||||
database.sync(db)
|
||||
|
||||
sink.status("Loading subscriptions...")
|
||||
metas = db.load_all_meta()
|
||||
|
||||
sink.status("Refreshing subscriptions...")
|
||||
asyncio.run(_refresh_all(db, metas))
|
||||
|
||||
sink.status("Done")
|
||||
sink.redirect("/")
|
||||
|
||||
|
||||
class SubscribeTask:
|
||||
url: str
|
||||
sink: EventChannel
|
||||
thread: threading.Thread
|
||||
|
||||
def __init__(self, url: str):
|
||||
self.url = url
|
||||
|
||||
self.sink = EventChannel()
|
||||
self.thread = threading.Thread(target=self._task_thread)
|
||||
self.thread.daemon = True
|
||||
self.thread.start()
|
||||
|
||||
@property
|
||||
def closed(self):
|
||||
return self.sink.closed
|
||||
|
||||
def _task_thread(self):
|
||||
sink = self.sink
|
||||
url = self.url
|
||||
|
||||
try:
|
||||
db = database.Database.local()
|
||||
sink.status("Synchronizing state...")
|
||||
database.sync(db)
|
||||
|
||||
sink.status("Searching for feeds...")
|
||||
feeds = asyncio.run(feed.feed_search(url))
|
||||
if len(feeds) == 0:
|
||||
sink.status(f"Unable to find a suitable feed for {url}")
|
||||
return
|
||||
|
||||
if len(feeds) > 1:
|
||||
candidates = [(("t", f.title), ("u", f.meta.url)) for f in feeds]
|
||||
qs = urllib.parse.urlencode([e for c in candidates for e in c])
|
||||
sink.redirect(f"/subscribe-choose?{qs}")
|
||||
return
|
||||
|
||||
result = feeds[0]
|
||||
sink.log(f"Identified {result.meta.url} as a feed for {url}")
|
||||
|
||||
existing = db.load_meta(result.meta.url)
|
||||
if existing is not None:
|
||||
sink.log(f"This feed already exists (as {result.meta.url})")
|
||||
sink.status("Already subscribed")
|
||||
return
|
||||
|
||||
db.store_feed(result)
|
||||
sink.status("Done")
|
||||
sink.redirect("/")
|
||||
|
||||
finally:
|
||||
sink.close()
|
||||
REFRESH_TASK: BackgroundTask | None = None
|
||||
|
||||
|
||||
SUBSCRIBE_TASK: SubscribeTask | None = None
|
||||
@background_task
|
||||
def subscribe(sink: EventChannel, url: str):
|
||||
"""Subscribe to a feed."""
|
||||
|
||||
db = database.Database.local()
|
||||
sink.status("Synchronizing state...")
|
||||
database.sync(db)
|
||||
|
||||
sink.status("Searching for feeds...")
|
||||
feeds = asyncio.run(feed.feed_search(url))
|
||||
if len(feeds) == 0:
|
||||
sink.status(f"Unable to find a suitable feed for {url}")
|
||||
return
|
||||
|
||||
if len(feeds) > 1:
|
||||
candidates = [(("t", f.title), ("u", f.meta.url)) for f in feeds]
|
||||
qs = urllib.parse.urlencode([e for c in candidates for e in c])
|
||||
sink.redirect(f"/subscribe-choose?{qs}")
|
||||
return
|
||||
|
||||
result = feeds[0]
|
||||
sink.log(f"Identified {result.meta.url} as a feed for {url}")
|
||||
|
||||
existing = db.load_meta(result.meta.url)
|
||||
if existing is not None:
|
||||
sink.log(f"This feed already exists (as {result.meta.url})")
|
||||
sink.status("Already subscribed")
|
||||
return
|
||||
|
||||
db.store_feed(result)
|
||||
sink.status("Done")
|
||||
sink.redirect("/")
|
||||
|
||||
|
||||
SUBSCRIBE_TASK: BackgroundTask | None = None
|
||||
|
||||
|
||||
class Handler(http.server.BaseHTTPRequestHandler):
|
||||
|
|
@ -338,7 +360,7 @@ class Handler(http.server.BaseHTTPRequestHandler):
|
|||
def do_refresh(self):
|
||||
global REFRESH_TASK
|
||||
if REFRESH_TASK is None or REFRESH_TASK.closed:
|
||||
REFRESH_TASK = RefreshTask()
|
||||
REFRESH_TASK = refresh_feeds()
|
||||
|
||||
self.send_response(303)
|
||||
self.send_header("Location", "/refresh-status")
|
||||
|
|
@ -358,13 +380,13 @@ class Handler(http.server.BaseHTTPRequestHandler):
|
|||
|
||||
global SUBSCRIBE_TASK
|
||||
if SUBSCRIBE_TASK is None or SUBSCRIBE_TASK.closed:
|
||||
SUBSCRIBE_TASK = SubscribeTask(url)
|
||||
SUBSCRIBE_TASK = subscribe(url)
|
||||
|
||||
self.send_response(303)
|
||||
self.send_header("Location", "/subscribe-status")
|
||||
self.end_headers()
|
||||
|
||||
def serve_events(self, task: RefreshTask | SubscribeTask | None):
|
||||
def serve_events(self, task: BackgroundTask | None):
|
||||
if task is None or task.closed:
|
||||
self.send_response(204)
|
||||
self.end_headers()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue