The most self-indulgent re-factor of background tasks
The ParamSpec stuff is special-purpose but gee, the type checker has gotten pretty good, huh.
This commit is contained in:
parent
785f71223b
commit
74f7146937
1 changed files with 109 additions and 87 deletions
134
cry/web.py
134
cry/web.py
|
|
@ -1,6 +1,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import contextlib
|
import contextlib
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
import functools
|
||||||
import html
|
import html
|
||||||
import http.server
|
import http.server
|
||||||
import io
|
import io
|
||||||
|
|
@ -9,19 +10,23 @@ import traceback
|
||||||
import threading
|
import threading
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
|
|
||||||
|
from typing import Any, Callable, Concatenate, ParamSpec, Protocol
|
||||||
|
|
||||||
from . import database
|
from . import database
|
||||||
from . import feed
|
from . import feed
|
||||||
|
|
||||||
|
|
||||||
class DeadlineCondition:
|
class DeadlineCondition:
|
||||||
|
"""A condition variable that allows you to wait with a timeout."""
|
||||||
|
|
||||||
lock: threading.Lock
|
lock: threading.Lock
|
||||||
waiting: int
|
waiters: int
|
||||||
event: threading.Semaphore
|
event: threading.Semaphore
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.lock = threading.Lock()
|
self.lock = threading.Lock()
|
||||||
self.event = threading.Semaphore()
|
self.event = threading.Semaphore()
|
||||||
self.waiting = 0
|
self.waiters = 0
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def acquire(self, deadline: float):
|
def acquire(self, deadline: float):
|
||||||
|
|
@ -51,6 +56,10 @@ class DeadlineCondition:
|
||||||
|
|
||||||
|
|
||||||
class AcquiredDeadlineCondition:
|
class AcquiredDeadlineCondition:
|
||||||
|
"""An aquired condition lock, which can be used to wait for conditions or
|
||||||
|
signal other waiters.
|
||||||
|
"""
|
||||||
|
|
||||||
locked: bool
|
locked: bool
|
||||||
condition: DeadlineCondition
|
condition: DeadlineCondition
|
||||||
deadline: float
|
deadline: float
|
||||||
|
|
@ -65,8 +74,8 @@ class AcquiredDeadlineCondition:
|
||||||
assert self.locked
|
assert self.locked
|
||||||
|
|
||||||
condition = self.condition
|
condition = self.condition
|
||||||
count = condition.waiting
|
count = condition.waiters
|
||||||
condition.waiting = 0
|
condition.waiters = 0
|
||||||
if count > 0:
|
if count > 0:
|
||||||
condition.event.release(count)
|
condition.event.release(count)
|
||||||
|
|
||||||
|
|
@ -87,7 +96,7 @@ class AcquiredDeadlineCondition:
|
||||||
if timeout <= 0:
|
if timeout <= 0:
|
||||||
raise TimeoutError()
|
raise TimeoutError()
|
||||||
|
|
||||||
self.condition.waiting += 1
|
self.condition.waiters += 1
|
||||||
self.locked = False
|
self.locked = False
|
||||||
self.condition.lock.release()
|
self.condition.lock.release()
|
||||||
|
|
||||||
|
|
@ -95,7 +104,7 @@ class AcquiredDeadlineCondition:
|
||||||
# If we timeout it's not safe to decrement the semaphore so we
|
# If we timeout it's not safe to decrement the semaphore so we
|
||||||
# just leave it; the side effect is that a signaller will
|
# just leave it; the side effect is that a signaller will
|
||||||
# increment the semaphore more than it should, and so other
|
# increment the semaphore more than it should, and so other
|
||||||
# waiters will have spurious waits, but that's a known problem
|
# waiters will have spurious wakes, but that's a known problem
|
||||||
# with this kind of synchronization.
|
# with this kind of synchronization.
|
||||||
if not self.condition.event.acquire(timeout=timeout):
|
if not self.condition.event.acquire(timeout=timeout):
|
||||||
raise TimeoutError()
|
raise TimeoutError()
|
||||||
|
|
@ -185,16 +194,26 @@ class EventConsumer:
|
||||||
return Event(event=None, data=None, id=None)
|
return Event(event=None, data=None, id=None)
|
||||||
|
|
||||||
|
|
||||||
# THE EVENT STUFF
|
# Background Tasks
|
||||||
|
|
||||||
|
|
||||||
class RefreshTask:
|
class BackgroundTask:
|
||||||
|
"""Some task running in the background. The specified `func` receives an
|
||||||
|
`EventChannel` object which it can use to report progress. That same
|
||||||
|
`EventChannel` object is present as the `sink` property, and you can use
|
||||||
|
it to subscribe to the events.
|
||||||
|
"""
|
||||||
|
|
||||||
|
func: Callable[[EventChannel], None]
|
||||||
|
|
||||||
sink: EventChannel
|
sink: EventChannel
|
||||||
thread: threading.Thread
|
thread: threading.Thread
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, func: Callable[[EventChannel], None]):
|
||||||
|
self.func = func
|
||||||
|
|
||||||
self.sink = EventChannel()
|
self.sink = EventChannel()
|
||||||
self.thread = threading.Thread(target=self._refresh_thread)
|
self.thread = threading.Thread(target=self._background_thread)
|
||||||
self.thread.daemon = True
|
self.thread.daemon = True
|
||||||
self.thread.start()
|
self.thread.start()
|
||||||
|
|
||||||
|
|
@ -202,31 +221,37 @@ class RefreshTask:
|
||||||
def closed(self):
|
def closed(self):
|
||||||
return self.sink.closed
|
return self.sink.closed
|
||||||
|
|
||||||
def _refresh_thread(self):
|
def _background_thread(self):
|
||||||
sink = self.sink
|
sink = self.sink
|
||||||
try:
|
try:
|
||||||
db = database.Database.local()
|
self.func(sink)
|
||||||
sink.status("Synchronizing state...")
|
|
||||||
database.sync(db)
|
|
||||||
|
|
||||||
sink.status("Loading subscriptions...")
|
|
||||||
metas = db.load_all_meta()
|
|
||||||
|
|
||||||
sink.status("Refreshing subscriptions...")
|
|
||||||
asyncio.run(self._refresh_all(db, metas))
|
|
||||||
|
|
||||||
sink.status("Done")
|
|
||||||
sink.redirect("/")
|
|
||||||
finally:
|
finally:
|
||||||
sink.close()
|
sink.close()
|
||||||
|
|
||||||
async def _refresh_all(self, db: database.Database, metas: list[feed.FeedMeta]):
|
|
||||||
async with asyncio.TaskGroup() as group:
|
|
||||||
for meta in metas:
|
|
||||||
group.create_task(self._refresh_meta(db, meta))
|
|
||||||
|
|
||||||
async def _refresh_meta(self, db: database.Database, meta: feed.FeedMeta):
|
TaskParams = ParamSpec("TaskParams")
|
||||||
sink = self.sink
|
|
||||||
|
|
||||||
|
def background_task(
|
||||||
|
f: Callable[Concatenate[EventChannel, TaskParams], None]
|
||||||
|
) -> Callable[TaskParams, BackgroundTask]:
|
||||||
|
"""A decorator that wraps a function to produce a BackgroundTask. The
|
||||||
|
function will receive an `EventChannel` as its first argument, along with
|
||||||
|
whatever other arguments it might need.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@functools.wraps(f)
|
||||||
|
def impl(*args: TaskParams.args, **kwargs: TaskParams.kwargs) -> BackgroundTask:
|
||||||
|
return BackgroundTask(lambda sink: f(sink, *args, **kwargs))
|
||||||
|
|
||||||
|
return impl
|
||||||
|
|
||||||
|
|
||||||
|
@background_task
|
||||||
|
def refresh_feeds(sink: EventChannel):
|
||||||
|
"""Refresh all the subscribed feeds."""
|
||||||
|
|
||||||
|
async def _refresh_meta(db: database.Database, meta: feed.FeedMeta):
|
||||||
sink.log(f"[{meta.url}] Fetching...")
|
sink.log(f"[{meta.url}] Fetching...")
|
||||||
d = None
|
d = None
|
||||||
try:
|
try:
|
||||||
|
|
@ -247,32 +272,32 @@ class RefreshTask:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
sink.log(f"[{meta.url}] Error refressing feed: {e}")
|
sink.log(f"[{meta.url}] Error refressing feed: {e}")
|
||||||
|
|
||||||
|
async def _refresh_all(db: database.Database, metas: list[feed.FeedMeta]):
|
||||||
|
async with asyncio.TaskGroup() as group:
|
||||||
|
for meta in metas:
|
||||||
|
group.create_task(_refresh_meta(db, meta))
|
||||||
|
|
||||||
REFRESH_TASK: RefreshTask | None = None
|
db = database.Database.local()
|
||||||
|
sink.status("Synchronizing state...")
|
||||||
|
database.sync(db)
|
||||||
|
|
||||||
|
sink.status("Loading subscriptions...")
|
||||||
|
metas = db.load_all_meta()
|
||||||
|
|
||||||
|
sink.status("Refreshing subscriptions...")
|
||||||
|
asyncio.run(_refresh_all(db, metas))
|
||||||
|
|
||||||
|
sink.status("Done")
|
||||||
|
sink.redirect("/")
|
||||||
|
|
||||||
|
|
||||||
class SubscribeTask:
|
REFRESH_TASK: BackgroundTask | None = None
|
||||||
url: str
|
|
||||||
sink: EventChannel
|
|
||||||
thread: threading.Thread
|
|
||||||
|
|
||||||
def __init__(self, url: str):
|
|
||||||
self.url = url
|
|
||||||
|
|
||||||
self.sink = EventChannel()
|
@background_task
|
||||||
self.thread = threading.Thread(target=self._task_thread)
|
def subscribe(sink: EventChannel, url: str):
|
||||||
self.thread.daemon = True
|
"""Subscribe to a feed."""
|
||||||
self.thread.start()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def closed(self):
|
|
||||||
return self.sink.closed
|
|
||||||
|
|
||||||
def _task_thread(self):
|
|
||||||
sink = self.sink
|
|
||||||
url = self.url
|
|
||||||
|
|
||||||
try:
|
|
||||||
db = database.Database.local()
|
db = database.Database.local()
|
||||||
sink.status("Synchronizing state...")
|
sink.status("Synchronizing state...")
|
||||||
database.sync(db)
|
database.sync(db)
|
||||||
|
|
@ -302,11 +327,8 @@ class SubscribeTask:
|
||||||
sink.status("Done")
|
sink.status("Done")
|
||||||
sink.redirect("/")
|
sink.redirect("/")
|
||||||
|
|
||||||
finally:
|
|
||||||
sink.close()
|
|
||||||
|
|
||||||
|
SUBSCRIBE_TASK: BackgroundTask | None = None
|
||||||
SUBSCRIBE_TASK: SubscribeTask | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class Handler(http.server.BaseHTTPRequestHandler):
|
class Handler(http.server.BaseHTTPRequestHandler):
|
||||||
|
|
@ -338,7 +360,7 @@ class Handler(http.server.BaseHTTPRequestHandler):
|
||||||
def do_refresh(self):
|
def do_refresh(self):
|
||||||
global REFRESH_TASK
|
global REFRESH_TASK
|
||||||
if REFRESH_TASK is None or REFRESH_TASK.closed:
|
if REFRESH_TASK is None or REFRESH_TASK.closed:
|
||||||
REFRESH_TASK = RefreshTask()
|
REFRESH_TASK = refresh_feeds()
|
||||||
|
|
||||||
self.send_response(303)
|
self.send_response(303)
|
||||||
self.send_header("Location", "/refresh-status")
|
self.send_header("Location", "/refresh-status")
|
||||||
|
|
@ -358,13 +380,13 @@ class Handler(http.server.BaseHTTPRequestHandler):
|
||||||
|
|
||||||
global SUBSCRIBE_TASK
|
global SUBSCRIBE_TASK
|
||||||
if SUBSCRIBE_TASK is None or SUBSCRIBE_TASK.closed:
|
if SUBSCRIBE_TASK is None or SUBSCRIBE_TASK.closed:
|
||||||
SUBSCRIBE_TASK = SubscribeTask(url)
|
SUBSCRIBE_TASK = subscribe(url)
|
||||||
|
|
||||||
self.send_response(303)
|
self.send_response(303)
|
||||||
self.send_header("Location", "/subscribe-status")
|
self.send_header("Location", "/subscribe-status")
|
||||||
self.end_headers()
|
self.end_headers()
|
||||||
|
|
||||||
def serve_events(self, task: RefreshTask | SubscribeTask | None):
|
def serve_events(self, task: BackgroundTask | None):
|
||||||
if task is None or task.closed:
|
if task is None or task.closed:
|
||||||
self.send_response(204)
|
self.send_response(204)
|
||||||
self.end_headers()
|
self.end_headers()
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue