507 lines
15 KiB
Python
507 lines
15 KiB
Python
import asyncio
|
|
import contextlib
|
|
import dataclasses
|
|
import dominate.tags as tags
|
|
import functools
|
|
import http.server
|
|
import io
|
|
import pathlib
|
|
import time
|
|
import traceback
|
|
import threading
|
|
import urllib.parse
|
|
|
|
from typing import Callable, Concatenate, ParamSpec
|
|
|
|
from . import database
|
|
from . import feed
|
|
from . import views
|
|
|
|
class DeadlineCondition:
|
|
"""A condition variable that allows you to wait with a timeout."""
|
|
|
|
lock: threading.Lock
|
|
waiters: int
|
|
event: threading.Semaphore
|
|
|
|
def __init__(self):
|
|
self.lock = threading.Lock()
|
|
self.event = threading.Semaphore()
|
|
self.waiters = 0
|
|
|
|
@contextlib.contextmanager
|
|
def acquire(self, deadline: float):
|
|
"""Acquire the lock with the given deadline.
|
|
|
|
Yields an `AcquiredDeadlineCondition` that can be used to wait or signal
|
|
the condition. If the deadline passes before the lock can be acquired,
|
|
raises TimeoutError.
|
|
"""
|
|
now = time.time()
|
|
if now >= deadline:
|
|
raise TimeoutError()
|
|
|
|
timeout = deadline - now
|
|
if timeout <= 0:
|
|
raise TimeoutError()
|
|
|
|
if not self.lock.acquire(timeout=timeout):
|
|
raise TimeoutError()
|
|
|
|
state = AcquiredDeadlineCondition(self, deadline)
|
|
try:
|
|
yield state
|
|
finally:
|
|
if state.locked:
|
|
self.lock.release()
|
|
|
|
|
|
class AcquiredDeadlineCondition:
|
|
"""An aquired condition lock, which can be used to wait for conditions or
|
|
signal other waiters.
|
|
"""
|
|
|
|
locked: bool
|
|
condition: DeadlineCondition
|
|
deadline: float
|
|
|
|
def __init__(self, lock: DeadlineCondition, deadline: float):
|
|
self.condition = lock
|
|
self.deadline = deadline
|
|
self.locked = True
|
|
|
|
def signal(self):
|
|
"""Wake up every locked consumer that is blocked in `wait`."""
|
|
assert self.locked
|
|
|
|
condition = self.condition
|
|
count = condition.waiters
|
|
condition.waiters = 0
|
|
if count > 0:
|
|
condition.event.release(count)
|
|
|
|
def wait(self, deadline: float | None = None):
|
|
"""Release the lock, wait for the lock to be signaled, then re-acquire
|
|
the lock.
|
|
|
|
If the specified deadline passes before the lock is acquired, raises
|
|
`TimeoutError`. If no deadline is provided then we use the deadline
|
|
that was provided when the lock was originally acquired.
|
|
"""
|
|
assert self.locked
|
|
|
|
if deadline is None:
|
|
deadline = self.deadline
|
|
|
|
timeout = time.time() - deadline
|
|
if timeout <= 0:
|
|
raise TimeoutError()
|
|
|
|
self.condition.waiters += 1
|
|
self.locked = False
|
|
self.condition.lock.release()
|
|
|
|
# NOTE: Leak here: we might wake the semaphore more than we intend to.
|
|
# If we timeout it's not safe to decrement the semaphore so we
|
|
# just leave it; the side effect is that a signaller will
|
|
# increment the semaphore more than it should, and so other
|
|
# waiters will have spurious wakes, but that's a known problem
|
|
# with this kind of synchronization.
|
|
if not self.condition.event.acquire(timeout=timeout):
|
|
raise TimeoutError()
|
|
|
|
timeout = time.time() - deadline
|
|
if timeout <= 0:
|
|
raise TimeoutError()
|
|
|
|
if not self.condition.lock.acquire(timeout=timeout):
|
|
raise TimeoutError()
|
|
|
|
self.locked = True
|
|
return True
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class Event:
|
|
event: str | None
|
|
data: str | None
|
|
id: int | None
|
|
|
|
|
|
class EventChannel:
|
|
condition: DeadlineCondition
|
|
events: list[Event]
|
|
closed: bool
|
|
|
|
def __init__(self):
|
|
self.condition = DeadlineCondition()
|
|
self.events = []
|
|
self.closed = False
|
|
|
|
def close(self):
|
|
deadline = time.time() + 30
|
|
with self.condition.acquire(deadline) as lock:
|
|
if not self.closed:
|
|
self.closed = True
|
|
lock.signal()
|
|
|
|
def log(self, line: str):
|
|
self.event("log", line)
|
|
|
|
def status(self, status: str):
|
|
self.event("status", status)
|
|
|
|
def redirect(self, url: str):
|
|
self.event("redirect", url)
|
|
|
|
def event(self, event: str, data: str | None = None):
|
|
deadline = time.time() + 30
|
|
with self.condition.acquire(deadline) as lock:
|
|
assert not self.closed
|
|
self.events.append(Event(event=event, data=data, id=len(self.events)))
|
|
lock.signal()
|
|
|
|
def get_consumer(self, index: int | None = None) -> "EventConsumer":
|
|
return EventConsumer(self, index)
|
|
|
|
|
|
class EventConsumer:
|
|
channel: EventChannel
|
|
index: int
|
|
|
|
def __init__(self, channel: EventChannel, index: int | None):
|
|
self.channel = channel
|
|
self.index = index or 0
|
|
|
|
@property
|
|
def closed(self) -> bool:
|
|
return self.channel.closed
|
|
|
|
def get_event(self, deadline: float) -> Event | None:
|
|
try:
|
|
with self.channel.condition.acquire(deadline) as cond:
|
|
while self.index == len(self.channel.events) and not self.closed:
|
|
cond.wait()
|
|
|
|
if self.index < len(self.channel.events):
|
|
result = self.channel.events[self.index]
|
|
self.index = self.index + 1
|
|
return result
|
|
|
|
if self.channel.closed:
|
|
return None
|
|
|
|
except TimeoutError:
|
|
return Event(event=None, data=None, id=None)
|
|
|
|
|
|
###############################################################################
|
|
# Background Tasks
|
|
###############################################################################
|
|
|
|
|
|
class BackgroundTask:
|
|
"""Some task running in the background. The specified `func` receives an
|
|
`EventChannel` object which it can use to report progress. That same
|
|
`EventChannel` object is present as the `sink` property, and you can use
|
|
it to subscribe to the events.
|
|
"""
|
|
|
|
func: Callable[[EventChannel], None]
|
|
|
|
sink: EventChannel
|
|
thread: threading.Thread
|
|
|
|
def __init__(self, func: Callable[[EventChannel], None]):
|
|
self.func = func
|
|
|
|
self.sink = EventChannel()
|
|
self.thread = threading.Thread(target=self._background_thread)
|
|
self.thread.daemon = True
|
|
self.thread.start()
|
|
|
|
@property
|
|
def closed(self):
|
|
return self.sink.closed
|
|
|
|
def _background_thread(self):
|
|
sink = self.sink
|
|
try:
|
|
self.func(sink)
|
|
finally:
|
|
sink.close()
|
|
|
|
|
|
TaskParams = ParamSpec("TaskParams")
|
|
|
|
|
|
def background_task(
|
|
f: Callable[Concatenate[EventChannel, TaskParams], None]
|
|
) -> Callable[TaskParams, BackgroundTask]:
|
|
"""A decorator that wraps a function to produce a BackgroundTask. The
|
|
function will receive an `EventChannel` as its first argument, along with
|
|
whatever other arguments it might need.
|
|
"""
|
|
|
|
@functools.wraps(f)
|
|
def impl(*args: TaskParams.args, **kwargs: TaskParams.kwargs) -> BackgroundTask:
|
|
return BackgroundTask(lambda sink: f(sink, *args, **kwargs))
|
|
|
|
return impl
|
|
|
|
|
|
@background_task
|
|
def refresh_feeds(sink: EventChannel):
|
|
"""Refresh all the subscribed feeds."""
|
|
|
|
async def _refresh_meta(db: database.Database, meta: feed.FeedMeta):
|
|
sink.log(f"[{meta.url}] Fetching...")
|
|
d = None
|
|
try:
|
|
d, meta = await feed.fetch_feed(meta)
|
|
if d is None:
|
|
sink.log(f"[{meta.url}] No updates")
|
|
db.update_meta(meta)
|
|
elif isinstance(d, str):
|
|
sink.log(
|
|
f"[{meta.url}] WARNING: returned a non-feed result!",
|
|
)
|
|
else:
|
|
new_count = db.store_feed(d)
|
|
sink.log(
|
|
f"[{meta.url}] {new_count} new items\n",
|
|
)
|
|
|
|
except Exception as e:
|
|
sink.log(f"[{meta.url}] Error refressing feed: {e}")
|
|
|
|
async def _refresh_all(db: database.Database, metas: list[feed.FeedMeta]):
|
|
async with asyncio.TaskGroup() as group:
|
|
for meta in metas:
|
|
group.create_task(_refresh_meta(db, meta))
|
|
|
|
db = database.Database.local()
|
|
sink.status("Synchronizing state...")
|
|
database.sync(db)
|
|
|
|
sink.status("Loading subscriptions...")
|
|
metas = db.load_all_meta()
|
|
|
|
sink.status("Refreshing subscriptions...")
|
|
asyncio.run(_refresh_all(db, metas))
|
|
|
|
sink.status("Done")
|
|
sink.redirect("/")
|
|
|
|
|
|
REFRESH_TASK: BackgroundTask | None = None
|
|
|
|
@background_task
|
|
def subscribe(sink: EventChannel, url: str):
|
|
"""Subscribe to a feed."""
|
|
|
|
db = database.Database.local()
|
|
sink.status("Synchronizing state...")
|
|
database.sync(db)
|
|
|
|
sink.status("Searching for feeds...")
|
|
feeds = asyncio.run(feed.feed_search(url))
|
|
if len(feeds) == 0:
|
|
sink.status(f"Unable to find a suitable feed for {url}")
|
|
return
|
|
|
|
if len(feeds) > 1:
|
|
candidates = [(("t", f.title), ("u", f.meta.url)) for f in feeds]
|
|
qs = urllib.parse.urlencode([e for c in candidates for e in c])
|
|
sink.redirect(f"/subscribe-choose?{qs}")
|
|
return
|
|
|
|
result = feeds[0]
|
|
sink.log(f"Identified {result.meta.url} as a feed for {url}")
|
|
|
|
existing = db.load_meta(result.meta.url)
|
|
if existing is not None:
|
|
sink.log(f"This feed already exists (as {result.meta.url})")
|
|
sink.status("Already subscribed")
|
|
return
|
|
|
|
db.store_feed(result)
|
|
sink.status("Done")
|
|
sink.redirect("/")
|
|
|
|
|
|
SUBSCRIBE_TASK: BackgroundTask | None = None
|
|
|
|
|
|
class Handler(http.server.BaseHTTPRequestHandler):
|
|
def do_GET(self):
|
|
if self.path == "/":
|
|
return self.serve_feeds()
|
|
elif self.path == "/style.css":
|
|
return self.serve_style()
|
|
elif self.path == "/event.js":
|
|
return self.serve_event_js()
|
|
elif self.path == "/refresh-status":
|
|
return self.serve_status()
|
|
elif self.path == "/subscribe-status":
|
|
return self.serve_status()
|
|
elif self.path == "/refresh-status/events":
|
|
return self.serve_events(REFRESH_TASK)
|
|
elif self.path == "/subscribe-status/events":
|
|
return self.serve_events(SUBSCRIBE_TASK)
|
|
elif self.path.startswith("/subscribe-choose?"):
|
|
return self.serve_subscribe_choose()
|
|
else:
|
|
self.send_error(404)
|
|
|
|
def do_POST(self):
|
|
if self.path == "/refresh":
|
|
self.do_refresh()
|
|
elif self.path == "/subscribe":
|
|
self.do_subscribe()
|
|
else:
|
|
self.log_error(f"Bad POST: {repr(self.path)}")
|
|
self.send_error(400)
|
|
|
|
def do_refresh(self):
|
|
global REFRESH_TASK
|
|
if REFRESH_TASK is None or REFRESH_TASK.closed:
|
|
self.log_message("Starting new refresh task...")
|
|
REFRESH_TASK = refresh_feeds()
|
|
|
|
self.send_response(303)
|
|
self.send_header("Location", "/refresh-status")
|
|
self.end_headers()
|
|
|
|
def do_subscribe(self):
|
|
# pull url from form body
|
|
try:
|
|
content_length = int(self.headers.get("content-length", "0"))
|
|
content_str = self.rfile.read(content_length).decode("utf-8")
|
|
params = urllib.parse.parse_qs(content_str)
|
|
url = params["url"][0]
|
|
except Exception as e:
|
|
tb = "\n".join(traceback.format_exception(e))
|
|
self.log_error(f"Bad subscribe request: {tb}")
|
|
self.send_error(400, explain=tb)
|
|
return
|
|
|
|
global SUBSCRIBE_TASK
|
|
if SUBSCRIBE_TASK is None or SUBSCRIBE_TASK.closed:
|
|
SUBSCRIBE_TASK = subscribe(url)
|
|
|
|
self.send_response(303)
|
|
self.send_header("Location", "/subscribe-status")
|
|
self.end_headers()
|
|
|
|
def serve_events(self, task: BackgroundTask | None):
|
|
if task is None:
|
|
self.send_response(204)
|
|
self.end_headers()
|
|
return
|
|
|
|
# Handle reconnect
|
|
last_id = self.headers.get("Last-Event-ID", None)
|
|
if last_id is not None:
|
|
try:
|
|
last_index = int(last_id)
|
|
except ValueError:
|
|
last_index = 0
|
|
else:
|
|
last_index = 0
|
|
|
|
consumer = task.sink.get_consumer(last_index)
|
|
|
|
self.send_response(200)
|
|
self.send_header("content-type", "text/event-stream")
|
|
self.send_header("x-accel-buffering", "no")
|
|
self.send_header("cache-control", "no-cache")
|
|
self.end_headers()
|
|
while True:
|
|
deadline = time.time() + 5 # 5 seconds from now
|
|
event = consumer.get_event(deadline)
|
|
if event is None:
|
|
# Event stream closed
|
|
break
|
|
|
|
if event.id is None and event.data is None and event.event is None:
|
|
# Empty line for connection keepalive
|
|
self.wfile.write(b":\n\n")
|
|
else:
|
|
if event.id is not None:
|
|
self.wfile.write(f"id: {event.id}\n".encode("utf-8"))
|
|
if event.data is not None:
|
|
self.wfile.write(f"data: {event.data}\n".encode("utf-8"))
|
|
if event.event is not None:
|
|
self.wfile.write(f"event: {event.event}\n".encode("utf-8"))
|
|
self.wfile.write(b"\n")
|
|
self.wfile.flush()
|
|
|
|
self.wfile.write(b"event: closed\ndata\n\n")
|
|
self.wfile.flush()
|
|
|
|
def serve_status(self):
|
|
document = views.status_view()
|
|
|
|
self.write_html("<!DOCTYPE html>\n" + document.render())
|
|
|
|
def serve_feeds(self):
|
|
db = database.Database.local()
|
|
feeds = db.load_all(feed_limit=10)
|
|
del db
|
|
|
|
feeds.sort(key=feed.sort_key, reverse=True)
|
|
|
|
document = views.feed_view(feeds)
|
|
self.write_html("<!DOCTYPE html>\n" + document.render())
|
|
|
|
def serve_subscribe_choose(self):
|
|
try:
|
|
req_url = urllib.parse.urlsplit(self.path)
|
|
parsed = urllib.parse.parse_qs(req_url.query)
|
|
|
|
candidates = zip(parsed["t"], parsed["u"])
|
|
except Exception as e:
|
|
tb = "\n".join(traceback.format_exception(e))
|
|
self.log_error(f"Error parsing query string for subscription: {tb}")
|
|
self.send_error(400, explain=tb)
|
|
return
|
|
|
|
document = views.subscribe_choose_view(candidates)
|
|
self.write_html("<!DOCTYPE html>\n" + document.render())
|
|
|
|
def write_html(self, html: str):
|
|
response = html.encode("utf-8")
|
|
self.send_response(200)
|
|
self.send_header("content-type", "text/html")
|
|
self.send_header("content-length", str(len(response)))
|
|
self.end_headers()
|
|
self.wfile.write(response)
|
|
|
|
def serve_event_js(self):
|
|
self.write_file(
|
|
pathlib.Path(__file__).parent / "static" / "event.js",
|
|
content_type="text/javascript",
|
|
)
|
|
|
|
def serve_style(self):
|
|
self.write_file(
|
|
pathlib.Path(__file__).parent / "static" / "style.css",
|
|
content_type="text/css",
|
|
)
|
|
|
|
def write_file(self, path: pathlib.Path, content_type: str):
|
|
with open(path, "rb") as file:
|
|
content = file.read()
|
|
|
|
self.send_response(200)
|
|
self.send_header("content-type", content_type)
|
|
self.send_header("content-length", str(len(content)))
|
|
self.end_headers()
|
|
self.wfile.write(content)
|
|
|
|
|
|
def serve():
|
|
with http.server.ThreadingHTTPServer(("", 8000), Handler) as server:
|
|
print("Serving at http://127.0.0.1:8000/")
|
|
server.serve_forever()
|