diff --git a/cry/cli.py b/cry/cli.py index 9a93c30..86d173f 100644 --- a/cry/cli.py +++ b/cry/cli.py @@ -250,8 +250,8 @@ def serve(): web.serve() -@cli.command("reconcile") -def reconcile(): +@cli.command("sync") +def sync(): local_db = database.Database.local() local_version = local_db.get_property("version", 0) for p in database.databases_directory().glob("*.db"): @@ -263,6 +263,8 @@ def reconcile(): if local_db.origin == other_db.origin: continue + # Ensure the schema version is compatible so that we don't run + # into trouble trying to query the other database. other_version = other_db.get_property("version", 0) if other_version != local_version: click.echo( @@ -270,13 +272,16 @@ def reconcile(): ) continue - # TODO: GET CLOCK OF BOTH. + # Check to see if we've already reconciled this other database. other_clock = other_db.get_clock() - reconciled_clock = local_db.get_reconcile_clock(other_db.origin) + reconciled_clock = local_db.get_sync_clock(other_db.origin) if other_clock == reconciled_clock: continue # TODO: RECONCILE FOR REALS + local_db.sync_from(other_db) + + local_db.set_sync_clock(other_db.origin, other_clock) except Exception as e: click.echo(f"Error loading {p}: {e}") diff --git a/cry/database.py b/cry/database.py index 138a831..72f5848 100644 --- a/cry/database.py +++ b/cry/database.py @@ -100,7 +100,7 @@ SCHEMA_STATEMENTS = [ END; """, """ - CREATE TABLE reconcile_status ( + CREATE TABLE sync_status ( origin VARCHAR NOT NULL PRIMARY KEY, clock INT NOT NULL ); @@ -289,9 +289,9 @@ class Database: ) = row meta = feed.FeedMeta( url=url, - last_fetched_ts=last_fetched_ts, - retry_after_ts=retry_after_ts, - status=status, + last_fetched_ts=int(last_fetched_ts), + retry_after_ts=int(retry_after_ts), + status=int(status), etag=etag, modified=modified, ) @@ -403,6 +403,94 @@ class Database: [feed.FEED_STATUS_UNSUBSCRIBED, int(time.time()), old_url], ) + def get_sync_clock(self, origin: str) -> int | None: + with self.db: + cursor = self.db.execute( + "SELECT clock FROM sync_status WHERE origin = ?", + [origin], + ) + row = cursor.fetchone() + if row is None: + return None + return int(row[0]) + + def set_sync_clock(self, origin: str, clock: int): + with self.db: + self.db.execute( + """ + INSERT INTO sync_status (origin, clock) + VALUES (?, ?) + ON CONFLICT DO UPDATE SET clock=excluded.clock + """, + [origin, clock], + ) + + def sync_from(self, other: "Database"): + with self.db: + with other.db: + feed_cursor = other.db.execute( + """ + SELECT + url, + last_fetched_ts, + retry_after_ts, + status, + etag, + modified, + title, + link + FROM feeds + """ + ) + for row in feed_cursor: + ( + url, + last_fetched_ts, + retry_after_ts, + status, + etag, + modified, + title, + link, + ) = row + meta = feed.FeedMeta( + url=url, + last_fetched_ts=int(last_fetched_ts), + retry_after_ts=int(retry_after_ts), + status=int(status), + etag=etag, + modified=modified, + ) + self._insert_feed(meta, title, link) + + entries_cursor = other.db.execute( + """ + SELECT + id, + inserted_at, + title, + link + FROM entries + WHERE feed_url=? + """, + [url], + ) + entries_results = entries_cursor.fetchmany() + while len(entries_results) > 0: + self._insert_entries( + url, + [ + feed.Entry( + id=id, + inserted_at=int(inserted_at), + title=title, + link=link, + ) + for id, inserted_at, title, link in entries_results + ], + ) + entries_results = entries_cursor.fetchmany() + def _get_property(self, prop: str, default=None) -> typing.Any: cursor = self.db.execute("SELECT value FROM properties WHERE name=?", (prop,)) result = cursor.fetchone() @@ -553,9 +641,9 @@ class Database: last_fetched_ts, retry_after_ts, status, etag, modified = row return feed.FeedMeta( url=url, - last_fetched_ts=last_fetched_ts, - retry_after_ts=retry_after_ts, - status=status, + last_fetched_ts=int(last_fetched_ts), + retry_after_ts=int(retry_after_ts), + status=int(status), etag=etag, modified=modified, ) diff --git a/tests/test_database.py b/tests/test_database.py index 50a245d..79eea7d 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -104,6 +104,28 @@ FEED = feed.Feed( ], ) +OTHER_FEED = feed.Feed( + meta=feed.FeedMeta( + url="http://example.com/test/other", + last_fetched_ts=REF_TIME, + retry_after_ts=REF_TIME, + status=feed.FEED_STATUS_ALIVE, + etag=random_slug(), + modified=random_slug(), + ), + title="Other Test Feed", + link="http://example.com/other", + entries=[ + feed.Entry( + id=random_slug(), + inserted_at=(REF_TIME * 1000) + index, + title=f"Entry {index}", + link=f"http://example.com/other/a{index}", + ) + for index in range(100, 0, -1) + ], +) + def test_database_load_store_meta(): db = database.Database(":memory:", random_slug()) @@ -274,6 +296,7 @@ def test_database_store_with_older_entries(): db.store_feed(FEED) old_entry = FEED.entries[0] + assert old_entry.link is not None older_entry = dataclasses.replace( old_entry, @@ -314,7 +337,7 @@ def test_database_store_update_meta(): assert db.load_all_meta()[0] == new_meta -def test_database_store_update_meta(): +def test_database_store_update_meta_clock(): db = database.Database(":memory:", random_slug()) db.ensure_database_schema() @@ -413,3 +436,51 @@ def test_database_redirect_clock(): db.redirect_feed(FEED.meta.url, new_url) assert db.get_clock() != old_clock + + +def test_database_sync_clocks(): + db = database.Database(":memory:", random_slug()) + db.ensure_database_schema() + + other_origin = f"other_{random_slug()}" + + other_clock = db.get_sync_clock(other_origin) + assert other_clock is None + + db.set_sync_clock(other_origin, 1234) + + other_clock = db.get_sync_clock(other_origin) + assert other_clock == 1234 + + +def test_database_do_sync(): + db = database.Database(":memory:", random_slug()) + db.ensure_database_schema() + + other = database.Database(":memory:", random_slug()) + other.ensure_database_schema() + + other.store_feed(FEED) + other.store_feed(OTHER_FEED) + + db.sync_from(other) + + others = db.load_all(feed_limit=99999) + assert others == [FEED, OTHER_FEED] + + +def test_database_do_sync_conflict(): + db = database.Database(":memory:", random_slug()) + db.ensure_database_schema() + + other = database.Database(":memory:", random_slug()) + other.ensure_database_schema() + + db.store_feed(FEED) + other.store_feed(FEED) + other.store_feed(OTHER_FEED) + + db.sync_from(other) + + feeds = db.load_all(feed_limit=99999) + assert feeds == [FEED, OTHER_FEED]