diff --git a/cry/cli.py b/cry/cli.py index 86d173f..2158738 100644 --- a/cry/cli.py +++ b/cry/cli.py @@ -253,35 +253,4 @@ def serve(): @cli.command("sync") def sync(): local_db = database.Database.local() - local_version = local_db.get_property("version", 0) - for p in database.databases_directory().glob("*.db"): - if not p.is_file(): - continue - - try: - other_db = database.Database.from_file(p) - if local_db.origin == other_db.origin: - continue - - # Ensure the schema version is compatible so that we don't run - # into trouble trying to query the other database. - other_version = other_db.get_property("version", 0) - if other_version != local_version: - click.echo( - f"{other_db.origin}: Not reconciling version {other_version} against {local_version}" - ) - continue - - # Check to see if we've already reconciled this other database. - other_clock = other_db.get_clock() - reconciled_clock = local_db.get_sync_clock(other_db.origin) - if other_clock == reconciled_clock: - continue - - # TODO: RECONCILE FOR REALS - local_db.sync_from(other_db) - - local_db.set_sync_clock(other_db.origin, other_clock) - - except Exception as e: - click.echo(f"Error loading {p}: {e}") + database.sync(local_db) diff --git a/cry/database.py b/cry/database.py index 72f5848..b9026d2 100644 --- a/cry/database.py +++ b/cry/database.py @@ -1,3 +1,4 @@ +import logging import pathlib import random import socket @@ -10,6 +11,8 @@ import platformdirs from . import feed +LOG = logging.getLogger(__name__) + SCHEMA_STATEMENTS = [ """ CREATE TABLE feeds ( @@ -647,3 +650,37 @@ class Database: etag=etag, modified=modified, ) + + +def sync(local_db: Database): + local_version = local_db.get_property("version", 0) + for p in databases_directory().glob("*.db"): + if not p.is_file(): + continue + + try: + other_db = Database.from_file(p) + if local_db.origin == other_db.origin: + continue + + # Ensure the schema version is compatible so that we don't run + # into trouble trying to query the other database. + other_version = other_db.get_property("version", 0) + if other_version != local_version: + LOG.warn( + f"{other_db.origin}: Not reconciling version {other_version} against {local_version}" + ) + continue + + # Check to see if we've already reconciled this other database. + other_clock = other_db.get_clock() + reconciled_clock = local_db.get_sync_clock(other_db.origin) + if other_clock == reconciled_clock: + continue + + local_db.sync_from(other_db) + + local_db.set_sync_clock(other_db.origin, other_clock) + + except Exception as e: + LOG.error(f"Error loading {p}: {e}") diff --git a/cry/web.py b/cry/web.py index c6e5801..1f149ba 100644 --- a/cry/web.py +++ b/cry/web.py @@ -1,17 +1,120 @@ +import asyncio import html import http.server import io +import threading from . import database from . import feed +class Refresher: + status: io.StringIO + + def start(self): + self.status = io.StringIO() + self.thread = threading.Thread(target=self.refresh_thread) + self.thread.daemon = True + self.thread.start() + + def getvalue(self) -> str: + return self.status.getvalue() + + def refresh_thread(self): + global REFRESH_STATUS + + db = database.Database.local() + database.sync(db) + + metas = db.load_all_meta() + asyncio.run(self.do_refresh(db, metas)) + + # Mark done for redirect... + REFRESH_STATUS = None + + async def do_refresh(self, db: database.Database, metas: list[feed.FeedMeta]): + async with asyncio.TaskGroup() as group: + for meta in metas: + group.create_task(self.refresh_meta(db, meta)) + + async def refresh_meta(self, db: database.Database, meta: feed.FeedMeta): + self.status.write(f"[{meta.url}] Refreshing...\n") + d = None + try: + d, meta = await feed.fetch_feed(meta) + if d is None: + self.status.write(f"[{meta.url}] No updates\n") + db.update_meta(meta) + elif isinstance(d, str): + self.status.write( + f"[{meta.url}] WARNING: returned a non-feed result!\n" + ) + else: + new_count = db.store_feed(d) + self.status.write(f"[{meta.url}] {new_count} new items\n") + + except Exception as e: + self.status.write(f"[{meta.url}] Error while fetching: {e}\n") + + +REFRESH_STATUS: Refresher | None = None + + class Handler(http.server.BaseHTTPRequestHandler): def do_GET(self): if self.path == "/": return self.serve_feeds() + elif self.path == "/refresh-status": + return self.serve_refresh_status() else: - self.send_response_only(404) + self.send_error(404) + + def do_POST(self): + print(f"{self.path}") + if self.path == "/refresh": + self.do_refresh() + else: + self.send_error(400) + + def do_refresh(self): + global REFRESH_STATUS + if REFRESH_STATUS is None: + REFRESH_STATUS = Refresher() + REFRESH_STATUS.start() + + self.send_response(303) + self.send_header("Location", "/refresh-status") + self.end_headers() + + def serve_refresh_status(self): + global REFRESH_STATUS + status = REFRESH_STATUS + if status is None: + self.send_response(302) + self.send_header("Location", "/") + self.end_headers() + return + + buffer = io.StringIO() + buffer.write( + """ + + + Refresh Status + +
"""
+        )
+        buffer.write(status.getvalue())
+        buffer.write(
+            """
+ + """ + ) + + self.write_html(buffer.getvalue()) def serve_feeds(self): db = database.Database.local() @@ -31,9 +134,15 @@ class Handler(http.server.BaseHTTPRequestHandler): body { margin-left: 4rem; margin-right: 4rem; } li.entry { display: inline; padding-right: 1rem; } li.entry:before { content: '\\2022'; padding-right: 0.5rem; } + h1 { margin-bottom: 0.25rem; }

Feeds

+
+
+ +
+
""" ) for f in feeds: @@ -56,9 +165,11 @@ class Handler(http.server.BaseHTTPRequestHandler): buffer.write("No entries...") buffer.write(f"") # feed buffer.flush() - text = buffer.getvalue() - response = text.encode("utf-8") + self.write_html(buffer.getvalue()) + + def write_html(self, html: str): + response = html.encode("utf-8") self.send_response(200) self.send_header("content-type", "text/html") self.send_header("content-length", str(len(response))) @@ -67,6 +178,6 @@ class Handler(http.server.BaseHTTPRequestHandler): def serve(): - with http.server.HTTPServer(("", 8000), Handler) as server: + with http.server.ThreadingHTTPServer(("", 8000), Handler) as server: print("Serving at http://127.0.0.1:8000/") server.serve_forever() diff --git a/tests/test_feed.py b/tests/test_feed.py index c3c877a..4dd0ae5 100644 --- a/tests/test_feed.py +++ b/tests/test_feed.py @@ -57,6 +57,8 @@ class TestWebServer: server._handle_GET(self) self.http_server = http.server.ThreadingHTTPServer(("", 0), TestWebHandler) + self.http_server.block_on_close = False + self.http_server.daemon_threads = True self.server_thread = threading.Thread(target=self._do_serve) self.server_thread.daemon = True self.handlers = {}