cry/cry/web.py

507 lines
15 KiB
Python

import asyncio
import contextlib
import dataclasses
import dominate.tags as tags
import functools
import http.server
import io
import pathlib
import time
import traceback
import threading
import urllib.parse
from typing import Callable, Concatenate, ParamSpec
from . import database
from . import feed
from . import views
class DeadlineCondition:
"""A condition variable that allows you to wait with a timeout."""
lock: threading.Lock
waiters: int
event: threading.Semaphore
def __init__(self):
self.lock = threading.Lock()
self.event = threading.Semaphore()
self.waiters = 0
@contextlib.contextmanager
def acquire(self, deadline: float):
"""Acquire the lock with the given deadline.
Yields an `AcquiredDeadlineCondition` that can be used to wait or signal
the condition. If the deadline passes before the lock can be acquired,
raises TimeoutError.
"""
now = time.time()
if now >= deadline:
raise TimeoutError()
timeout = deadline - now
if timeout <= 0:
raise TimeoutError()
if not self.lock.acquire(timeout=timeout):
raise TimeoutError()
state = AcquiredDeadlineCondition(self, deadline)
try:
yield state
finally:
if state.locked:
self.lock.release()
class AcquiredDeadlineCondition:
"""An aquired condition lock, which can be used to wait for conditions or
signal other waiters.
"""
locked: bool
condition: DeadlineCondition
deadline: float
def __init__(self, lock: DeadlineCondition, deadline: float):
self.condition = lock
self.deadline = deadline
self.locked = True
def signal(self):
"""Wake up every locked consumer that is blocked in `wait`."""
assert self.locked
condition = self.condition
count = condition.waiters
condition.waiters = 0
if count > 0:
condition.event.release(count)
def wait(self, deadline: float | None = None):
"""Release the lock, wait for the lock to be signaled, then re-acquire
the lock.
If the specified deadline passes before the lock is acquired, raises
`TimeoutError`. If no deadline is provided then we use the deadline
that was provided when the lock was originally acquired.
"""
assert self.locked
if deadline is None:
deadline = self.deadline
timeout = time.time() - deadline
if timeout <= 0:
raise TimeoutError()
self.condition.waiters += 1
self.locked = False
self.condition.lock.release()
# NOTE: Leak here: we might wake the semaphore more than we intend to.
# If we timeout it's not safe to decrement the semaphore so we
# just leave it; the side effect is that a signaller will
# increment the semaphore more than it should, and so other
# waiters will have spurious wakes, but that's a known problem
# with this kind of synchronization.
if not self.condition.event.acquire(timeout=timeout):
raise TimeoutError()
timeout = time.time() - deadline
if timeout <= 0:
raise TimeoutError()
if not self.condition.lock.acquire(timeout=timeout):
raise TimeoutError()
self.locked = True
return True
@dataclasses.dataclass
class Event:
event: str | None
data: str | None
id: int | None
class EventChannel:
condition: DeadlineCondition
events: list[Event]
closed: bool
def __init__(self):
self.condition = DeadlineCondition()
self.events = []
self.closed = False
def close(self):
deadline = time.time() + 30
with self.condition.acquire(deadline) as lock:
if not self.closed:
self.closed = True
lock.signal()
def log(self, line: str):
self.event("log", line)
def status(self, status: str):
self.event("status", status)
def redirect(self, url: str):
self.event("redirect", url)
def event(self, event: str, data: str | None = None):
deadline = time.time() + 30
with self.condition.acquire(deadline) as lock:
assert not self.closed
self.events.append(Event(event=event, data=data, id=len(self.events)))
lock.signal()
def get_consumer(self, index: int | None = None) -> "EventConsumer":
return EventConsumer(self, index)
class EventConsumer:
channel: EventChannel
index: int
def __init__(self, channel: EventChannel, index: int | None):
self.channel = channel
self.index = index or 0
@property
def closed(self) -> bool:
return self.channel.closed
def get_event(self, deadline: float) -> Event | None:
try:
with self.channel.condition.acquire(deadline) as cond:
while self.index == len(self.channel.events) and not self.closed:
cond.wait()
if self.index < len(self.channel.events):
result = self.channel.events[self.index]
self.index = self.index + 1
return result
if self.channel.closed:
return None
except TimeoutError:
return Event(event=None, data=None, id=None)
###############################################################################
# Background Tasks
###############################################################################
class BackgroundTask:
"""Some task running in the background. The specified `func` receives an
`EventChannel` object which it can use to report progress. That same
`EventChannel` object is present as the `sink` property, and you can use
it to subscribe to the events.
"""
func: Callable[[EventChannel], None]
sink: EventChannel
thread: threading.Thread
def __init__(self, func: Callable[[EventChannel], None]):
self.func = func
self.sink = EventChannel()
self.thread = threading.Thread(target=self._background_thread)
self.thread.daemon = True
self.thread.start()
@property
def closed(self):
return self.sink.closed
def _background_thread(self):
sink = self.sink
try:
self.func(sink)
finally:
sink.close()
TaskParams = ParamSpec("TaskParams")
def background_task(
f: Callable[Concatenate[EventChannel, TaskParams], None]
) -> Callable[TaskParams, BackgroundTask]:
"""A decorator that wraps a function to produce a BackgroundTask. The
function will receive an `EventChannel` as its first argument, along with
whatever other arguments it might need.
"""
@functools.wraps(f)
def impl(*args: TaskParams.args, **kwargs: TaskParams.kwargs) -> BackgroundTask:
return BackgroundTask(lambda sink: f(sink, *args, **kwargs))
return impl
@background_task
def refresh_feeds(sink: EventChannel):
"""Refresh all the subscribed feeds."""
async def _refresh_meta(db: database.Database, meta: feed.FeedMeta):
sink.log(f"[{meta.url}] Fetching...")
d = None
try:
d, meta = await feed.fetch_feed(meta)
if d is None:
sink.log(f"[{meta.url}] No updates")
db.update_meta(meta)
elif isinstance(d, str):
sink.log(
f"[{meta.url}] WARNING: returned a non-feed result!",
)
else:
new_count = db.store_feed(d)
sink.log(
f"[{meta.url}] {new_count} new items\n",
)
except Exception as e:
sink.log(f"[{meta.url}] Error refressing feed: {e}")
async def _refresh_all(db: database.Database, metas: list[feed.FeedMeta]):
async with asyncio.TaskGroup() as group:
for meta in metas:
group.create_task(_refresh_meta(db, meta))
db = database.Database.local()
sink.status("Synchronizing state...")
database.sync(db)
sink.status("Loading subscriptions...")
metas = db.load_all_meta()
sink.status("Refreshing subscriptions...")
asyncio.run(_refresh_all(db, metas))
sink.status("Done")
sink.redirect("/")
REFRESH_TASK: BackgroundTask | None = None
@background_task
def subscribe(sink: EventChannel, url: str):
"""Subscribe to a feed."""
db = database.Database.local()
sink.status("Synchronizing state...")
database.sync(db)
sink.status("Searching for feeds...")
feeds = asyncio.run(feed.feed_search(url))
if len(feeds) == 0:
sink.status(f"Unable to find a suitable feed for {url}")
return
if len(feeds) > 1:
candidates = [(("t", f.title), ("u", f.meta.url)) for f in feeds]
qs = urllib.parse.urlencode([e for c in candidates for e in c])
sink.redirect(f"/subscribe-choose?{qs}")
return
result = feeds[0]
sink.log(f"Identified {result.meta.url} as a feed for {url}")
existing = db.load_meta(result.meta.url)
if existing is not None:
sink.log(f"This feed already exists (as {result.meta.url})")
sink.status("Already subscribed")
return
db.store_feed(result)
sink.status("Done")
sink.redirect("/")
SUBSCRIBE_TASK: BackgroundTask | None = None
class Handler(http.server.BaseHTTPRequestHandler):
def do_GET(self):
if self.path == "/":
return self.serve_feeds()
elif self.path == "/style.css":
return self.serve_style()
elif self.path == "/event.js":
return self.serve_event_js()
elif self.path == "/refresh-status":
return self.serve_status()
elif self.path == "/subscribe-status":
return self.serve_status()
elif self.path == "/refresh-status/events":
return self.serve_events(REFRESH_TASK)
elif self.path == "/subscribe-status/events":
return self.serve_events(SUBSCRIBE_TASK)
elif self.path.startswith("/subscribe-choose?"):
return self.serve_subscribe_choose()
else:
self.send_error(404)
def do_POST(self):
if self.path == "/refresh":
self.do_refresh()
elif self.path == "/subscribe":
self.do_subscribe()
else:
self.log_error(f"Bad POST: {repr(self.path)}")
self.send_error(400)
def do_refresh(self):
global REFRESH_TASK
if REFRESH_TASK is None or REFRESH_TASK.closed:
self.log_message("Starting new refresh task...")
REFRESH_TASK = refresh_feeds()
self.send_response(303)
self.send_header("Location", "/refresh-status")
self.end_headers()
def do_subscribe(self):
# pull url from form body
try:
content_length = int(self.headers.get("content-length", "0"))
content_str = self.rfile.read(content_length).decode("utf-8")
params = urllib.parse.parse_qs(content_str)
url = params["url"][0]
except Exception as e:
tb = "\n".join(traceback.format_exception(e))
self.log_error(f"Bad subscribe request: {tb}")
self.send_error(400, explain=tb)
return
global SUBSCRIBE_TASK
if SUBSCRIBE_TASK is None or SUBSCRIBE_TASK.closed:
SUBSCRIBE_TASK = subscribe(url)
self.send_response(303)
self.send_header("Location", "/subscribe-status")
self.end_headers()
def serve_events(self, task: BackgroundTask | None):
if task is None:
self.send_response(204)
self.end_headers()
return
# Handle reconnect
last_id = self.headers.get("Last-Event-ID", None)
if last_id is not None:
try:
last_index = int(last_id)
except ValueError:
last_index = 0
else:
last_index = 0
consumer = task.sink.get_consumer(last_index)
self.send_response(200)
self.send_header("content-type", "text/event-stream")
self.send_header("x-accel-buffering", "no")
self.send_header("cache-control", "no-cache")
self.end_headers()
while True:
deadline = time.time() + 5 # 5 seconds from now
event = consumer.get_event(deadline)
if event is None:
# Event stream closed
break
if event.id is None and event.data is None and event.event is None:
# Empty line for connection keepalive
self.wfile.write(b":\n\n")
else:
if event.id is not None:
self.wfile.write(f"id: {event.id}\n".encode("utf-8"))
if event.data is not None:
self.wfile.write(f"data: {event.data}\n".encode("utf-8"))
if event.event is not None:
self.wfile.write(f"event: {event.event}\n".encode("utf-8"))
self.wfile.write(b"\n")
self.wfile.flush()
self.wfile.write(b"event: closed\ndata\n\n")
self.wfile.flush()
def serve_status(self):
document = views.status_view()
self.write_html("<!DOCTYPE html>\n" + document.render())
def serve_feeds(self):
db = database.Database.local()
feeds = db.load_all(feed_limit=10)
del db
feeds.sort(key=feed.sort_key, reverse=True)
document = views.feed_view(feeds)
self.write_html("<!DOCTYPE html>\n" + document.render())
def serve_subscribe_choose(self):
try:
req_url = urllib.parse.urlsplit(self.path)
parsed = urllib.parse.parse_qs(req_url.query)
candidates = zip(parsed["t"], parsed["u"])
except Exception as e:
tb = "\n".join(traceback.format_exception(e))
self.log_error(f"Error parsing query string for subscription: {tb}")
self.send_error(400, explain=tb)
return
document = views.subscribe_choose_view(candidates)
self.write_html("<!DOCTYPE html>\n" + document.render())
def write_html(self, html: str):
response = html.encode("utf-8")
self.send_response(200)
self.send_header("content-type", "text/html")
self.send_header("content-length", str(len(response)))
self.end_headers()
self.wfile.write(response)
def serve_event_js(self):
self.write_file(
pathlib.Path(__file__).parent / "static" / "event.js",
content_type="text/javascript",
)
def serve_style(self):
self.write_file(
pathlib.Path(__file__).parent / "static" / "style.css",
content_type="text/css",
)
def write_file(self, path: pathlib.Path, content_type: str):
with open(path, "rb") as file:
content = file.read()
self.send_response(200)
self.send_header("content-type", content_type)
self.send_header("content-length", str(len(content)))
self.end_headers()
self.wfile.write(content)
def serve():
with http.server.ThreadingHTTPServer(("", 8000), Handler) as server:
print("Serving at http://127.0.0.1:8000/")
server.serve_forever()