cry/cry/database.py

472 lines
15 KiB
Python

import pathlib
import random
import socket
import sqlite3
import string
import time
import typing
import platformdirs
from . import feed
SCHEMA_STATEMENTS = [
"""
CREATE TABLE feeds (
url VARCHAR NOT NULL PRIMARY KEY,
last_fetched_ts INTEGER NOT NULL,
retry_after_ts INTEGER NOT NULL,
status INTEGER NOT NULL,
etag VARCHAR,
modified VARCHAR,
title VARCHAR,
link VARCHAR
);
CREATE TABLE entries(
id VARCHAR NOT NULL,
inserted_at INTEGER NOT NULL,
feed_url VARCHAR NOT NULL,
title VARCHAR,
link VARCHAR,
PRIMARY KEY (id, feed_url),
FOREIGN KEY (feed_url) REFERENCES feeds(url)
ON UPDATE CASCADE
ON DELETE CASCADE
);
""",
# I went and changed the status enum to make ALIVE == 0 when I added the
# "unsubscribed" status. I should probably make these strings huh.
"""
UPDATE feeds
SET status=CASE
WHEN status = 0 THEN 1
WHEN status = 1 THEN 0
ELSE status
END
""",
]
def origin_path() -> pathlib.Path:
return platformdirs.user_data_path("cry", "cry") / "origin"
def local_origin(path: pathlib.Path | None = None) -> str:
if path is None:
path = origin_path()
if path.exists():
with open(path, "r", encoding="utf-8") as f:
return f.read().strip()
host = socket.gethostname()
slug = "".join(
random.choices(
string.ascii_uppercase + string.ascii_lowercase + string.digits, k=8
)
)
origin = f"{host}-{slug}"
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "w", encoding="utf-8") as f:
f.write(origin)
return origin
def database_path(origin: str) -> pathlib.Path:
# TODO: Determine the name/slug from local state if necessary
return pathlib.Path.home() / "Dropbox" / "cry" / f"{origin}.db"
# TODO: Refactor into:
# -top level: transactions
# -bottom level: queries
# to enable reuse
class Database:
db: sqlite3.Connection
origin: str
def __init__(self, path: pathlib.Path | str, origin: str):
if not isinstance(path, str):
path.parent.mkdir(parents=True, exist_ok=True)
db = sqlite3.Connection(str(path), autocommit=False)
db.execute("PRAGMA foreign_keys = ON")
self.db = db
self.origin = origin
@classmethod
def local(cls, origin: str | None = None) -> "Database":
if origin is None:
origin = local_origin()
db = Database(database_path(origin), origin)
db.ensure_database_schema()
return db
def get_property(self, prop: str, default=None) -> typing.Any:
with self.db:
cursor = self.db.execute(
"SELECT value FROM properties WHERE name=?", (prop,)
)
result = cursor.fetchone()
if result is None:
return default
return result[0]
def set_property(self, prop: str, value):
with self.db:
self.db.execute(
"""
INSERT INTO properties (name, value) VALUES (?, ?)
ON CONFLICT DO UPDATE SET value=excluded.value
""",
(prop, value),
)
def ensure_database_schema(self):
with self.db:
self.db.execute(
"""
CREATE TABLE IF NOT EXISTS properties (
name VARCHAR NOT NULL PRIMARY KEY,
value VARCHAR NOT NULL
)
"""
)
version = int(self.get_property("version", 0))
for script in SCHEMA_STATEMENTS[version:]:
for statement in script.split(";"):
try:
self.db.execute(statement)
except Exception as e:
raise Exception(f"Error executing:\n{statement}") from e
self.set_property("version", len(SCHEMA_STATEMENTS))
self.set_property("origin", self.origin)
def load_all_meta(self) -> list[feed.FeedMeta]:
with self.db:
cursor = self.db.execute(
"""
SELECT
url,
last_fetched_ts,
retry_after_ts,
status,
etag,
modified
FROM feeds
"""
)
rows = cursor.fetchall()
return [
feed.FeedMeta(
url=url,
last_fetched_ts=int(last_fetched_ts),
retry_after_ts=int(retry_after_ts),
status=int(status),
etag=etag,
modified=modified,
origin=self.origin,
)
for url, last_fetched_ts, retry_after_ts, status, etag, modified in rows
]
def load_all(self, feed_limit: int = 20, pattern: str = "") -> list[feed.Feed]:
with self.db:
pattern = (
pattern.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
)
sql_pattern = f"%{pattern}%"
cursor = self.db.execute(
"""
SELECT
url,
last_fetched_ts,
retry_after_ts,
status,
etag,
modified,
title,
link
FROM feeds
WHERE (title LIKE :sql_pattern ESCAPE '\\'
OR link LIKE :sql_pattern ESCAPE '\\')
AND status != 2 -- UNSUBSCRIBED
""",
{"sql_pattern": sql_pattern},
)
rows = cursor.fetchall()
almost_feeds = []
for row in rows:
(
url,
last_fetched_ts,
retry_after_ts,
status,
etag,
modified,
title,
link,
) = row
meta = feed.FeedMeta(
url=url,
last_fetched_ts=last_fetched_ts,
retry_after_ts=retry_after_ts,
status=status,
etag=etag,
modified=modified,
origin=self.origin,
)
almost_feeds.append((meta, title, link))
feeds = []
for meta, title, link in almost_feeds:
if feed_limit > 0:
cursor = self.db.execute(
"""
SELECT
id,
inserted_at,
title,
link
FROM entries
WHERE feed_url=?
ORDER BY inserted_at DESC
LIMIT ?
""",
[meta.url, feed_limit],
)
rows = cursor.fetchall()
else:
rows = []
entries = [
feed.Entry(id=id, inserted_at=inserted_at, title=title, link=link)
for id, inserted_at, title, link in rows
]
f = feed.Feed(meta=meta, title=title, link=link, entries=entries)
feeds.append(f)
return feeds
def load_feed(self, url: str) -> feed.Feed | None:
with self.db:
cursor = self.db.execute(
"""
SELECT
last_fetched_ts,
retry_after_ts,
status,
etag,
modified,
title,
link
FROM feeds
WHERE url=?
""",
[url],
)
row = cursor.fetchone()
if row is None:
return None
last_fetched_ts, retry_after_ts, status, etag, modified, title, link = row
meta = feed.FeedMeta(
url=url,
last_fetched_ts=last_fetched_ts,
retry_after_ts=retry_after_ts,
status=status,
etag=etag,
modified=modified,
origin=self.origin,
)
cursor = self.db.execute(
"""
SELECT
id,
inserted_at,
title,
link
FROM entries
WHERE feed_url=?
""",
[url],
)
rows = cursor.fetchall()
entries = [
feed.Entry(id=id, inserted_at=inserted_at, title=title, link=link)
for id, inserted_at, title, link in rows
]
return feed.Feed(meta=meta, title=title, link=link, entries=entries)
def update_meta(self, f: feed.FeedMeta):
with self.db:
self.db.execute(
"""
UPDATE feeds SET
last_fetched_ts=?,
retry_after_ts=?,
status=?,
etag=?,
modified=?
WHERE url=?
""",
[
f.last_fetched_ts,
f.retry_after_ts,
f.status,
f.etag,
f.modified,
f.url,
],
)
def store_feed(self, f: feed.Feed) -> int:
"""Store the given feed in the database.
Returns the number of new entries inserted.
"""
with self.db:
self.db.execute(
"""
INSERT INTO feeds (
url,
last_fetched_ts,
retry_after_ts,
status,
etag,
modified,
title,
link
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT DO UPDATE
SET
last_fetched_ts=excluded.last_fetched_ts,
retry_after_ts=excluded.retry_after_ts,
status=excluded.status,
etag=excluded.etag,
modified=excluded.modified,
title=excluded.title,
link=excluded.link
""",
[
f.meta.url,
f.meta.last_fetched_ts,
f.meta.retry_after_ts,
f.meta.status,
f.meta.etag,
f.meta.modified,
f.title,
f.link,
],
)
cursor = self.db.execute(
"SELECT COUNT (*) FROM entries WHERE feed_url=?", [f.meta.url]
)
start_count = cursor.fetchone()[0]
self.db.executemany(
"""
INSERT INTO entries (
id,
inserted_at,
feed_url,
title,
link
) VALUES (?, ?, ?, ?, ?)
ON CONFLICT DO UPDATE
SET
-- NOTE: This is also part of the feed merge algorithm, BUT
-- we implement it here because feeds tend to be rolling
-- windows over some external content and we don't want
-- to read and write the entire feed just to update the
-- few new items. But we can't just do ON CONFLICT DO
-- NOTHING because we *might* be storing a feed where we
-- resolved conflicts with another instance. So we want
-- to handle all the cases. (In theory we could make two
-- different INSERTs to handle the two cases but that is
-- more complexity than it is worth.)
inserted_at=MIN(inserted_at, excluded.inserted_at),
title=CASE
WHEN inserted_at < excluded.inserted_at THEN title
ELSE excluded.title
END,
link=CASE
WHEN inserted_at < excluded.inserted_at THEN link
ELSE excluded.link
END
""",
[(e.id, e.inserted_at, f.meta.url, e.title, e.link) for e in f.entries],
)
cursor = self.db.execute(
"SELECT COUNT (*) FROM entries WHERE feed_url=?", [f.meta.url]
)
end_count = cursor.fetchone()[0]
return end_count - start_count
def set_feed_status(self, url: str, status: int) -> int:
with self.db:
cursor = self.db.execute(
"""
UPDATE feeds
SET status = ?,
last_fetched_ts = ?
WHERE url = ?
""",
[status, int(time.time()), url],
)
return cursor.rowcount
def redirect_feed(self, old_url: str, new_url: str):
with self.db:
cursor = self.db.execute(
"SELECT COUNT(1) FROM feeds WHERE url=?", [new_url]
)
row = cursor.fetchone()
if row[0] == 0:
self.db.execute(
"UPDATE feeds SET url = ? WHERE url = ?", [new_url, old_url]
)
else:
# Preserve the entries that were under the old url.
self.db.execute(
"""
UPDATE entries
SET feed_url = ?
WHERE feed_url = ?
ON CONFLICT DO UPDATE
SET
-- NOTE: This is also part of the feed merge algorithm, BUT
-- we implement it here. See the comment in store_feed
-- for the rationale.
inserted_at=MIN(inserted_at, excluded.inserted_at),
title=CASE
WHEN inserted_at < excluded.inserted_at THEN title
ELSE excluded.title
END,
link=CASE
WHEN inserted_at < excluded.inserted_at THEN link
ELSE excluded.link
END
"""
)
# Mark the old feed dead.
self.db.execute(
"""
UPDATE feeds
SET status = ?,
last_fetched_ts = ?
WHERE url = ?
""",
[feed.FEED_STATUS_DEAD, int(time.time()), old_url],
)