686 lines
22 KiB
Python
686 lines
22 KiB
Python
import logging
|
|
import pathlib
|
|
import random
|
|
import socket
|
|
import sqlite3
|
|
import string
|
|
import time
|
|
import typing
|
|
|
|
import platformdirs
|
|
|
|
from . import feed
|
|
|
|
LOG = logging.getLogger(__name__)
|
|
|
|
SCHEMA_STATEMENTS = [
|
|
"""
|
|
CREATE TABLE feeds (
|
|
url VARCHAR NOT NULL PRIMARY KEY,
|
|
last_fetched_ts INTEGER NOT NULL,
|
|
retry_after_ts INTEGER NOT NULL,
|
|
status INTEGER NOT NULL,
|
|
etag VARCHAR,
|
|
modified VARCHAR,
|
|
title VARCHAR,
|
|
link VARCHAR
|
|
);
|
|
|
|
CREATE TABLE entries(
|
|
id VARCHAR NOT NULL,
|
|
inserted_at INTEGER NOT NULL,
|
|
feed_url VARCHAR NOT NULL,
|
|
title VARCHAR,
|
|
link VARCHAR,
|
|
PRIMARY KEY (id, feed_url),
|
|
FOREIGN KEY (feed_url) REFERENCES feeds(url)
|
|
ON UPDATE CASCADE
|
|
ON DELETE CASCADE
|
|
);
|
|
""",
|
|
# I went and changed the status enum to make ALIVE == 0 when I added the
|
|
# "unsubscribed" status. I should probably make these strings huh.
|
|
"""
|
|
UPDATE feeds
|
|
SET status=CASE
|
|
WHEN status = 0 THEN 1
|
|
WHEN status = 1 THEN 0
|
|
ELSE status
|
|
END
|
|
""",
|
|
# The "clock" is a number that increments as we make changes. We use this
|
|
# to do reconciliation, and track which versions of other databases we
|
|
# have reconciled already.
|
|
"""
|
|
INSERT INTO properties (name, value) VALUES ('clock', 1);
|
|
|
|
CREATE TRIGGER update_clock_on_feed_insert
|
|
AFTER INSERT ON feeds
|
|
BEGIN
|
|
UPDATE properties SET value=value + 1 WHERE name='clock';
|
|
END;
|
|
|
|
CREATE TRIGGER update_clock_on_feed_delete
|
|
AFTER DELETE ON feeds
|
|
BEGIN
|
|
UPDATE properties SET value=value + 1 WHERE name='clock';
|
|
END;
|
|
|
|
CREATE TRIGGER update_clock_on_feed_update
|
|
AFTER UPDATE ON feeds
|
|
WHEN (NEW.last_fetched_ts IS NOT OLD.last_fetched_ts)
|
|
OR (NEW.retry_after_ts IS NOT OLD.retry_after_ts)
|
|
OR (NEW.status IS NOT OLD.status)
|
|
OR (NEW.etag IS NOT OLD.etag)
|
|
OR (NEW.modified IS NOT OLD.modified)
|
|
OR (NEW.title IS NOT OLD.title)
|
|
OR (NEW.link IS NOT OLD.link)
|
|
BEGIN
|
|
UPDATE properties SET value=value + 1 WHERE name='clock';
|
|
END;
|
|
|
|
CREATE TRIGGER update_clock_on_entries_insert
|
|
AFTER INSERT ON entries
|
|
BEGIN
|
|
UPDATE properties SET value=value + 1 WHERE name='clock';
|
|
END;
|
|
|
|
CREATE TRIGGER update_clock_on_entries_delete
|
|
AFTER DELETE ON entries
|
|
BEGIN
|
|
UPDATE properties SET value=value + 1 WHERE name='clock';
|
|
END;
|
|
|
|
CREATE TRIGGER update_clock_on_entries_update
|
|
AFTER UPDATE ON entries
|
|
WHEN (NEW.id IS NOT OLD.id)
|
|
OR (NEW.inserted_at IS NOT OLD.inserted_at)
|
|
OR (NEW.feed_url IS NOT OLD.feed_url)
|
|
OR (NEW.title IS NOT OLD.title)
|
|
OR (NEW.link IS NOT OLD.link)
|
|
BEGIN
|
|
UPDATE properties SET value=value + 1 WHERE name='clock';
|
|
END;
|
|
""",
|
|
"""
|
|
CREATE TABLE sync_status (
|
|
origin VARCHAR NOT NULL PRIMARY KEY,
|
|
clock INT NOT NULL
|
|
);
|
|
""",
|
|
]
|
|
|
|
|
|
def origin_path() -> pathlib.Path:
|
|
return platformdirs.user_data_path("cry", "cry") / "origin"
|
|
|
|
|
|
def local_origin(path: pathlib.Path | None = None) -> str:
|
|
if path is None:
|
|
path = origin_path()
|
|
|
|
if path.exists():
|
|
with open(path, "r", encoding="utf-8") as f:
|
|
return f.read().strip()
|
|
|
|
host = socket.gethostname()
|
|
slug = "".join(
|
|
random.choices(
|
|
string.ascii_uppercase + string.ascii_lowercase + string.digits, k=8
|
|
)
|
|
)
|
|
origin = f"{host}-{slug}"
|
|
|
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
with open(path, "w", encoding="utf-8") as f:
|
|
f.write(origin)
|
|
|
|
return origin
|
|
|
|
|
|
def databases_directory() -> pathlib.Path:
|
|
return pathlib.Path.home() / "Dropbox" / "cry"
|
|
|
|
|
|
def database_path(origin: str) -> pathlib.Path:
|
|
return databases_directory() / f"{origin}.db"
|
|
|
|
|
|
# TODO: Refactor into:
|
|
# -top level: transactions
|
|
# -bottom level: queries
|
|
# to enable reuse
|
|
class Database:
|
|
db: sqlite3.Connection
|
|
origin: str
|
|
|
|
def __init__(self, path: pathlib.Path | str, origin: str, readonly: bool = False):
|
|
uri = False
|
|
if not isinstance(path, str):
|
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
path = f"file:{str(path)}"
|
|
uri = True
|
|
if readonly:
|
|
path = f"{path}?mode=ro"
|
|
|
|
# Enable autocommit as a separate step so that I can enable foreign
|
|
# keys cleanly. (Can't enable foreign keys in a transaction.)
|
|
db = sqlite3.connect(str(path), uri=uri)
|
|
db.execute("PRAGMA foreign_keys = ON")
|
|
db.autocommit = False
|
|
|
|
cursor = db.execute("PRAGMA foreign_keys")
|
|
rows = cursor.fetchall()
|
|
assert str(rows[0][0]) == "1", f"Foreign keys not enabled! {rows[0][0]}"
|
|
|
|
self.db = db
|
|
self.origin = origin
|
|
|
|
@classmethod
|
|
def local(cls, origin: str | None = None) -> "Database":
|
|
if origin is None:
|
|
origin = local_origin()
|
|
|
|
db = Database(database_path(origin), origin)
|
|
db.ensure_database_schema()
|
|
db.set_property("origin", origin)
|
|
return db
|
|
|
|
@classmethod
|
|
def from_file(cls, path: pathlib.Path) -> "Database":
|
|
db = Database(path, "", readonly=True)
|
|
origin = db.get_property("origin")
|
|
if origin is None:
|
|
raise Exception("No origin!")
|
|
db.origin = str(origin)
|
|
return db
|
|
|
|
def get_property(self, prop: str, default=None) -> typing.Any:
|
|
with self.db:
|
|
return self._get_property(prop, default)
|
|
|
|
def set_property(self, prop: str, value):
|
|
with self.db:
|
|
return self._set_property(prop, value)
|
|
|
|
def get_clock(self) -> int:
|
|
return int(self.get_property("clock", 0))
|
|
|
|
def ensure_database_schema(self):
|
|
with self.db:
|
|
self.db.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS properties (
|
|
name VARCHAR NOT NULL PRIMARY KEY,
|
|
value VARCHAR NOT NULL
|
|
)
|
|
"""
|
|
)
|
|
version = int(self._get_property("version", 0))
|
|
for script in SCHEMA_STATEMENTS[version:]:
|
|
try:
|
|
self.db.executescript(script)
|
|
except Exception as e:
|
|
raise Exception(f"Error executing:\n{script}") from e
|
|
self._set_property("version", len(SCHEMA_STATEMENTS))
|
|
self._set_property("origin", self.origin)
|
|
|
|
def load_all_meta(self) -> list[feed.FeedMeta]:
|
|
with self.db:
|
|
cursor = self.db.execute(
|
|
"""
|
|
SELECT
|
|
url,
|
|
last_fetched_ts,
|
|
retry_after_ts,
|
|
status,
|
|
etag,
|
|
modified
|
|
FROM feeds
|
|
"""
|
|
)
|
|
rows = cursor.fetchall()
|
|
return [
|
|
feed.FeedMeta(
|
|
url=url,
|
|
last_fetched_ts=int(last_fetched_ts),
|
|
retry_after_ts=int(retry_after_ts),
|
|
status=int(status),
|
|
etag=etag,
|
|
modified=modified,
|
|
)
|
|
for url, last_fetched_ts, retry_after_ts, status, etag, modified in rows
|
|
]
|
|
|
|
def load_all(self, feed_limit: int = 20, pattern: str = "") -> list[feed.Feed]:
|
|
with self.db:
|
|
pattern = (
|
|
pattern.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
|
|
)
|
|
sql_pattern = f"%{pattern}%"
|
|
cursor = self.db.execute(
|
|
"""
|
|
SELECT
|
|
url,
|
|
last_fetched_ts,
|
|
retry_after_ts,
|
|
status,
|
|
etag,
|
|
modified,
|
|
title,
|
|
link
|
|
FROM feeds
|
|
WHERE (title LIKE :sql_pattern ESCAPE '\\'
|
|
OR link LIKE :sql_pattern ESCAPE '\\')
|
|
AND status != 2 -- UNSUBSCRIBED
|
|
""",
|
|
{"sql_pattern": sql_pattern},
|
|
)
|
|
rows = cursor.fetchall()
|
|
|
|
almost_feeds = []
|
|
for row in rows:
|
|
(
|
|
url,
|
|
last_fetched_ts,
|
|
retry_after_ts,
|
|
status,
|
|
etag,
|
|
modified,
|
|
title,
|
|
link,
|
|
) = row
|
|
meta = feed.FeedMeta(
|
|
url=url,
|
|
last_fetched_ts=int(last_fetched_ts),
|
|
retry_after_ts=int(retry_after_ts),
|
|
status=int(status),
|
|
etag=etag,
|
|
modified=modified,
|
|
)
|
|
almost_feeds.append((meta, title, link))
|
|
|
|
feeds = []
|
|
for meta, title, link in almost_feeds:
|
|
if feed_limit > 0:
|
|
cursor = self.db.execute(
|
|
"""
|
|
SELECT
|
|
id,
|
|
inserted_at,
|
|
title,
|
|
link
|
|
FROM entries
|
|
WHERE feed_url=?
|
|
ORDER BY inserted_at DESC
|
|
LIMIT ?
|
|
""",
|
|
[meta.url, feed_limit],
|
|
)
|
|
rows = cursor.fetchall()
|
|
else:
|
|
rows = []
|
|
|
|
entries = [
|
|
feed.Entry(id=id, inserted_at=inserted_at, title=title, link=link)
|
|
for id, inserted_at, title, link in rows
|
|
]
|
|
f = feed.Feed(meta=meta, title=title, link=link, entries=entries)
|
|
feeds.append(f)
|
|
|
|
return feeds
|
|
|
|
def load_meta(self, url: str) -> feed.FeedMeta | None:
|
|
with self.db:
|
|
return self._load_meta(url)
|
|
|
|
def update_meta(self, f: feed.FeedMeta):
|
|
with self.db:
|
|
self.db.execute(
|
|
"""
|
|
UPDATE feeds SET
|
|
last_fetched_ts=?,
|
|
retry_after_ts=?,
|
|
status=?,
|
|
etag=?,
|
|
modified=?
|
|
WHERE url=?
|
|
""",
|
|
[
|
|
f.last_fetched_ts,
|
|
f.retry_after_ts,
|
|
f.status,
|
|
f.etag,
|
|
f.modified,
|
|
f.url,
|
|
],
|
|
)
|
|
|
|
def store_feed(self, f: feed.Feed) -> int:
|
|
"""Store the given feed in the database.
|
|
|
|
Returns the number of new entries inserted.
|
|
"""
|
|
with self.db:
|
|
self._insert_feed(f.meta, f.title, f.link)
|
|
return self._insert_entries(f.meta.url, f.entries)
|
|
|
|
def update_feed_status(self, meta: feed.FeedMeta, status: int) -> int:
|
|
with self.db:
|
|
return self._update_feed_status(meta, status)
|
|
|
|
def redirect_feed(self, old_url: str, new_url: str):
|
|
with self.db:
|
|
cursor = self.db.execute(
|
|
"SELECT COUNT(1) FROM feeds WHERE url=?", [new_url]
|
|
)
|
|
row = cursor.fetchone()
|
|
if row[0] == 0:
|
|
self.db.execute(
|
|
"UPDATE feeds SET url = ? WHERE url = ?", [new_url, old_url]
|
|
)
|
|
else:
|
|
# First update all the entries that you can with the old url.
|
|
self.db.execute(
|
|
"""
|
|
UPDATE OR IGNORE entries
|
|
SET feed_url = ?
|
|
WHERE feed_url = ?
|
|
""",
|
|
[new_url, old_url],
|
|
)
|
|
|
|
# TODO: It is expensive and not worth it to try to load and
|
|
# re-insert all the old stuff so I'm not going to
|
|
# bother.
|
|
|
|
# Mark the old feed unsubscribed.
|
|
# TODO: Rebuild with helpers
|
|
self.db.execute(
|
|
"""
|
|
UPDATE feeds
|
|
SET status = ?,
|
|
last_fetched_ts = ?
|
|
WHERE url = ?
|
|
""",
|
|
[feed.FEED_STATUS_UNSUBSCRIBED, int(time.time()), old_url],
|
|
)
|
|
|
|
def get_sync_clock(self, origin: str) -> int | None:
|
|
with self.db:
|
|
cursor = self.db.execute(
|
|
"SELECT clock FROM sync_status WHERE origin = ?",
|
|
[origin],
|
|
)
|
|
row = cursor.fetchone()
|
|
if row is None:
|
|
return None
|
|
return int(row[0])
|
|
|
|
def set_sync_clock(self, origin: str, clock: int):
|
|
with self.db:
|
|
self.db.execute(
|
|
"""
|
|
INSERT INTO sync_status (origin, clock)
|
|
VALUES (?, ?)
|
|
ON CONFLICT DO UPDATE SET clock=excluded.clock
|
|
""",
|
|
[origin, clock],
|
|
)
|
|
|
|
def sync_from(self, other: "Database"):
|
|
with self.db:
|
|
with other.db:
|
|
feed_cursor = other.db.execute(
|
|
"""
|
|
SELECT
|
|
url,
|
|
last_fetched_ts,
|
|
retry_after_ts,
|
|
status,
|
|
etag,
|
|
modified,
|
|
title,
|
|
link
|
|
FROM feeds
|
|
"""
|
|
)
|
|
for row in feed_cursor:
|
|
(
|
|
url,
|
|
last_fetched_ts,
|
|
retry_after_ts,
|
|
status,
|
|
etag,
|
|
modified,
|
|
title,
|
|
link,
|
|
) = row
|
|
meta = feed.FeedMeta(
|
|
url=url,
|
|
last_fetched_ts=int(last_fetched_ts),
|
|
retry_after_ts=int(retry_after_ts),
|
|
status=int(status),
|
|
etag=etag,
|
|
modified=modified,
|
|
)
|
|
self._insert_feed(meta, title, link)
|
|
|
|
entries_cursor = other.db.execute(
|
|
"""
|
|
SELECT
|
|
id,
|
|
inserted_at,
|
|
title,
|
|
link
|
|
FROM entries
|
|
WHERE feed_url=?
|
|
""",
|
|
[url],
|
|
)
|
|
entries_results = entries_cursor.fetchmany()
|
|
while len(entries_results) > 0:
|
|
self._insert_entries(
|
|
url,
|
|
[
|
|
feed.Entry(
|
|
id=id,
|
|
inserted_at=int(inserted_at),
|
|
title=title,
|
|
link=link,
|
|
)
|
|
for id, inserted_at, title, link in entries_results
|
|
],
|
|
)
|
|
entries_results = entries_cursor.fetchmany()
|
|
|
|
def _get_property(self, prop: str, default=None) -> typing.Any:
|
|
cursor = self.db.execute("SELECT value FROM properties WHERE name=?", (prop,))
|
|
result = cursor.fetchone()
|
|
if result is None:
|
|
return default
|
|
return result[0]
|
|
|
|
def _set_property(self, prop: str, value):
|
|
self.db.execute(
|
|
"""
|
|
INSERT INTO properties (name, value) VALUES (?, ?)
|
|
ON CONFLICT DO UPDATE SET value=excluded.value
|
|
""",
|
|
(prop, value),
|
|
)
|
|
|
|
def _insert_feed(self, meta: feed.FeedMeta, title: str, link: str):
|
|
"""Insert into the feeds table, handling collisions with UPSERT."""
|
|
self.db.execute(
|
|
"""
|
|
INSERT INTO feeds (
|
|
url,
|
|
last_fetched_ts,
|
|
retry_after_ts,
|
|
status,
|
|
etag,
|
|
modified,
|
|
title,
|
|
link
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
|
ON CONFLICT DO UPDATE
|
|
SET
|
|
last_fetched_ts=MAX(last_fetched_ts, excluded.last_fetched_ts),
|
|
retry_after_ts=MAX(retry_after_ts, excluded.retry_after_ts),
|
|
-- For all other fields, take the value that was computed by the
|
|
-- most recent fetch.
|
|
status=CASE
|
|
WHEN last_fetched_ts > excluded.last_fetched_ts THEN status
|
|
ELSE excluded.status
|
|
END,
|
|
etag=CASE
|
|
WHEN last_fetched_ts > excluded.last_fetched_ts THEN etag
|
|
ELSE excluded.etag
|
|
END,
|
|
modified=CASE
|
|
WHEN last_fetched_ts > excluded.last_fetched_ts THEN modified
|
|
ELSE excluded.modified
|
|
END,
|
|
title=CASE
|
|
WHEN last_fetched_ts > excluded.last_fetched_ts THEN title
|
|
ELSE excluded.title
|
|
END,
|
|
link=CASE
|
|
WHEN last_fetched_ts > excluded.last_fetched_ts THEN link
|
|
ELSE excluded.link
|
|
END
|
|
""",
|
|
[
|
|
meta.url,
|
|
meta.last_fetched_ts,
|
|
meta.retry_after_ts,
|
|
meta.status,
|
|
meta.etag,
|
|
meta.modified,
|
|
title,
|
|
link,
|
|
],
|
|
)
|
|
|
|
def _insert_entries(self, feed_url: str, entries: list[feed.Entry]) -> int:
|
|
cursor = self.db.execute(
|
|
"SELECT COUNT (*) FROM entries WHERE feed_url=?", [feed_url]
|
|
)
|
|
start_count = cursor.fetchone()[0]
|
|
|
|
self.db.executemany(
|
|
"""
|
|
INSERT INTO entries (
|
|
id,
|
|
inserted_at,
|
|
feed_url,
|
|
title,
|
|
link
|
|
) VALUES (?, ?, ?, ?, ?)
|
|
ON CONFLICT DO UPDATE
|
|
SET
|
|
-- NOTE: This is also part of the feed merge algorithm, BUT
|
|
-- we implement it here because feeds tend to be rolling
|
|
-- windows over some external content and we don't want
|
|
-- to read and write the entire feed just to update the
|
|
-- few new items. But we can't just do ON CONFLICT DO
|
|
-- NOTHING because we *might* be storing a feed where we
|
|
-- resolved conflicts with another instance. So we want
|
|
-- to handle all the cases. (In theory we could make two
|
|
-- different INSERTs to handle the two cases but that is
|
|
-- more complexity than it is worth.)
|
|
inserted_at=MIN(inserted_at, excluded.inserted_at),
|
|
title=CASE
|
|
WHEN inserted_at < excluded.inserted_at THEN title
|
|
ELSE excluded.title
|
|
END,
|
|
link=CASE
|
|
WHEN inserted_at < excluded.inserted_at THEN link
|
|
ELSE excluded.link
|
|
END
|
|
""",
|
|
[(e.id, e.inserted_at, feed_url, e.title, e.link) for e in entries],
|
|
)
|
|
|
|
cursor = self.db.execute(
|
|
"SELECT COUNT (*) FROM entries WHERE feed_url=?", [feed_url]
|
|
)
|
|
end_count = cursor.fetchone()[0]
|
|
return end_count - start_count
|
|
|
|
def _update_feed_status(self, meta: feed.FeedMeta, status: int) -> int:
|
|
new_ts = max(int(time.time()), meta.last_fetched_ts + 1)
|
|
cursor = self.db.execute(
|
|
"""
|
|
UPDATE feeds
|
|
SET status = ?,
|
|
last_fetched_ts = ?
|
|
WHERE url = ?
|
|
""",
|
|
[status, new_ts, meta.url],
|
|
)
|
|
return cursor.rowcount
|
|
|
|
def _load_meta(self, url: str) -> feed.FeedMeta | None:
|
|
cursor = self.db.execute(
|
|
"""
|
|
SELECT
|
|
last_fetched_ts,
|
|
retry_after_ts,
|
|
status,
|
|
etag,
|
|
modified
|
|
FROM feeds
|
|
WHERE url=?
|
|
""",
|
|
[url],
|
|
)
|
|
|
|
row = cursor.fetchone()
|
|
if row is None:
|
|
return None
|
|
|
|
last_fetched_ts, retry_after_ts, status, etag, modified = row
|
|
return feed.FeedMeta(
|
|
url=url,
|
|
last_fetched_ts=int(last_fetched_ts),
|
|
retry_after_ts=int(retry_after_ts),
|
|
status=int(status),
|
|
etag=etag,
|
|
modified=modified,
|
|
)
|
|
|
|
|
|
def sync(local_db: Database):
|
|
local_version = local_db.get_property("version", 0)
|
|
for p in databases_directory().glob("*.db"):
|
|
if not p.is_file():
|
|
continue
|
|
|
|
try:
|
|
other_db = Database.from_file(p)
|
|
if local_db.origin == other_db.origin:
|
|
continue
|
|
|
|
# Ensure the schema version is compatible so that we don't run
|
|
# into trouble trying to query the other database.
|
|
other_version = other_db.get_property("version", 0)
|
|
if other_version != local_version:
|
|
LOG.warn(
|
|
f"{other_db.origin}: Not reconciling version {other_version} against {local_version}"
|
|
)
|
|
continue
|
|
|
|
# Check to see if we've already reconciled this other database.
|
|
other_clock = other_db.get_clock()
|
|
reconciled_clock = local_db.get_sync_clock(other_db.origin)
|
|
if other_clock == reconciled_clock:
|
|
continue
|
|
|
|
local_db.sync_from(other_db)
|
|
|
|
local_db.set_sync_clock(other_db.origin, other_clock)
|
|
|
|
except Exception as e:
|
|
LOG.error(f"Error loading {p}: {e}")
|