diff --git a/cry/cli.py b/cry/cli.py index 5164e24..e47a8b1 100644 --- a/cry/cli.py +++ b/cry/cli.py @@ -56,7 +56,7 @@ def subscribe(url, literal): if not literal: click.echo(f"Searching for feeds for {url} ...") - feeds = asyncio.run(feed.feed_search(url, db.origin)) + feeds = asyncio.run(feed.feed_search(url)) if len(feeds) == 0: click.echo(f"Unable to find a suitable feed for {url}") return 1 @@ -83,7 +83,7 @@ def subscribe(url, literal): click.echo(f"Identified {result.meta.url} as a feed for {url}") else: click.echo(f"Fetching {url} ...") - meta = feed.FeedMeta.from_url(url, db.origin) + meta = feed.FeedMeta.from_url(url) d, meta = asyncio.run(feed.fetch_feed(meta)) if d is None: click.echo(f"Unable to fetch {url}") @@ -112,7 +112,7 @@ def import_opml(opml_file): db = database.Database.local() urls = opml.parse_opml(opml_file.read()) - metas = [feed.FeedMeta.from_url(url, db.origin) for url in urls] + metas = [feed.FeedMeta.from_url(url) for url in urls] click.echo(f"Fetching {len(urls)} feeds ...") results = asyncio.run(feed.fetch_many(metas)) diff --git a/cry/database.py b/cry/database.py index 334fe7e..b2d812d 100644 --- a/cry/database.py +++ b/cry/database.py @@ -76,7 +76,6 @@ def local_origin(path: pathlib.Path | None = None) -> str: def database_path(origin: str) -> pathlib.Path: - # TODO: Determine the name/slug from local state if necessary return pathlib.Path.home() / "Dropbox" / "cry" / f"{origin}.db" @@ -168,7 +167,6 @@ class Database: status=int(status), etag=etag, modified=modified, - origin=self.origin, ) for url, last_fetched_ts, retry_after_ts, status, etag, modified in rows ] @@ -218,7 +216,6 @@ class Database: status=status, etag=etag, modified=modified, - origin=self.origin, ) almost_feeds.append((meta, title, link)) @@ -282,7 +279,6 @@ class Database: status=status, etag=etag, modified=modified, - origin=self.origin, ) cursor = self.db.execute( diff --git a/cry/feed.py b/cry/feed.py index 2af594d..6e7b4fa 100644 --- a/cry/feed.py +++ b/cry/feed.py @@ -38,10 +38,9 @@ class FeedMeta: status: int etag: str | None modified: str | None - origin: str @classmethod - def from_url(cls, url: str, origin: str) -> "FeedMeta": + def from_url(cls, url: str) -> "FeedMeta": return FeedMeta( url=url, last_fetched_ts=0, @@ -49,7 +48,6 @@ class FeedMeta: status=FEED_STATUS_ALIVE, etag=None, modified=None, - origin=origin, ) def should_fetch(self, now) -> bool: @@ -147,6 +145,7 @@ class Feed: link = None if d.feed is not None: + assert not isinstance(d.feed, list) title = d.feed.get("title") link = d.feed.get("link") @@ -428,7 +427,7 @@ async def fetch_many( return [t.result() for t in tasks] -def merge_feeds(a: Feed, b: Feed) -> Feed: +def merge_feeds(a: Feed, a_origin: str, b: Feed, b_origin: str) -> Feed: """Merge two known feeds. There are two conflict resolution policies: 1. The newer fetch of feed metadata wins. @@ -449,7 +448,7 @@ def merge_feeds(a: Feed, b: Feed) -> Feed: if a.meta.last_fetched_ts > b.meta.last_fetched_ts: source_feed = a elif a.meta.last_fetched_ts == b.meta.last_fetched_ts: - source_feed = a if a.meta.origin < b.meta.origin else b + source_feed = a if a_origin < b_origin else b else: source_feed = b @@ -569,11 +568,11 @@ def is_XML_related_link(link: str) -> bool: return "rss" in link or "rdf" in link or "xml" in link or "atom" in link -async def check_feed(url: str, origin: str) -> Feed | None: +async def check_feed(url: str) -> Feed | None: """Check to see if the given URL is a feed. If it is, return the feed, otherwise return None. """ - meta = FeedMeta.from_url(url, origin) + meta = FeedMeta.from_url(url) result, meta = await fetch_feed(meta) if isinstance(result, Feed): return result @@ -581,13 +580,13 @@ async def check_feed(url: str, origin: str) -> Feed | None: return None -async def check_links(links: typing.Iterable[str], origin: str) -> list[Feed]: +async def check_links(links: typing.Iterable[str]) -> list[Feed]: """Fetch all the links and return the ones that appear to have feeds in them. If none of them are fetchable or none of them have feeds then this will return nothing. """ async with asyncio.TaskGroup() as group: - tasks = [group.create_task(check_feed(link, origin)) for link in links] + tasks = [group.create_task(check_feed(link)) for link in links] outfeeds: list[Feed] = [] for task in tasks: @@ -598,8 +597,8 @@ async def check_links(links: typing.Iterable[str], origin: str) -> list[Feed]: return outfeeds -async def feed_search(uri: str, origin: str) -> list[Feed]: - meta = FeedMeta.from_url(massage_url(uri), origin) +async def feed_search(uri: str) -> list[Feed]: + meta = FeedMeta.from_url(massage_url(uri)) result, meta = await fetch_feed(meta) if result is None: return [] @@ -611,22 +610,22 @@ async def feed_search(uri: str, origin: str) -> list[Feed]: parser.feed(result) LOG.debug("Checking links...") - outfeeds = await check_links(parser.link_links, origin) + outfeeds = await check_links(parser.link_links) if len(outfeeds) > 0: return outfeeds LOG.debug("No links, checking A tags...") local_links, remote_links = classify_links(parser.a_links, meta.url) - outfeeds = await check_links(filter(is_feed_link, local_links), origin) + outfeeds = await check_links(filter(is_feed_link, local_links)) if len(outfeeds) > 0: return outfeeds - outfeeds = await check_links(filter(is_XML_related_link, local_links), origin) + outfeeds = await check_links(filter(is_XML_related_link, local_links)) if len(outfeeds) > 0: return outfeeds - outfeeds = await check_links(filter(is_feed_link, remote_links), origin) + outfeeds = await check_links(filter(is_feed_link, remote_links)) if len(outfeeds) > 0: return outfeeds - outfeeds = await check_links(filter(is_XML_related_link, remote_links), origin) + outfeeds = await check_links(filter(is_XML_related_link, remote_links)) if len(outfeeds) > 0: return outfeeds @@ -639,7 +638,5 @@ async def feed_search(uri: str, origin: str) -> list[Feed]: "index.xml", # MT "index.rss", # Slash ] - outfeeds = await check_links( - [urllib.parse.urljoin(meta.url, x) for x in suffixes], origin - ) + outfeeds = await check_links([urllib.parse.urljoin(meta.url, x) for x in suffixes]) return outfeeds diff --git a/tests/test_database.py b/tests/test_database.py new file mode 100644 index 0000000..b8d4010 --- /dev/null +++ b/tests/test_database.py @@ -0,0 +1,66 @@ +import pathlib +import random +import string +import tempfile + +from cry import database + + +def random_slug() -> str: + return "".join( + random.choices( + string.ascii_uppercase + string.ascii_lowercase + string.digits, k=8 + ) + ) + + +def test_database_origin_path(): + op = database.origin_path() + assert op is not None + + +def test_database_local_origin(): + with tempfile.TemporaryDirectory() as op: + origin_file = pathlib.Path(op) / "origin" + assert not origin_file.exists() + + origin = database.local_origin(origin_file) + + assert origin_file.exists() + assert len(origin) > 0 + + +def test_database_local_origin_repeatable(): + with tempfile.TemporaryDirectory() as op: + origin_file = pathlib.Path(op) / "origin" + + a = database.local_origin(origin_file) + b = database.local_origin(origin_file) + + assert len(a) > 0 + assert a == b + + +def test_database_origin_in_path(): + slug = random_slug() + p = database.database_path(slug) + assert slug in str(p) + + +def test_database_schema(): + db = database.Database(":memory:", random_slug()) + db.ensure_database_schema() + + c = db.db.execute("SELECT value FROM properties WHERE name = 'version'") + row = c.fetchone() + assert int(row[0]) == len(database.SCHEMA_STATEMENTS) + + +def test_database_prop_get_set(): + db = database.Database(":memory:", random_slug()) + db.ensure_database_schema() + + assert db.get_property("foo") is None + val = random_slug() + db.set_property("foo", val) + assert db.get_property("foo") == val diff --git a/tests/test_feed.py b/tests/test_feed.py index 235eaf1..84a1bdc 100644 --- a/tests/test_feed.py +++ b/tests/test_feed.py @@ -4,7 +4,6 @@ import http.server import threading import typing -import requests from cry import feed @@ -118,7 +117,7 @@ def test_basic_successful_fetch(): with TestWebServer() as server: server.handle("/", TEST_FEED, content_type="text/xml") - meta = feed.FeedMeta.from_url(server.make_url("/"), "asdf") + meta = feed.FeedMeta.from_url(server.make_url("/")) result, new_meta = asyncio.run(feed.fetch_feed(meta)) assert new_meta.url == meta.url @@ -132,7 +131,7 @@ def test_fetch_after_temp_redirect(): server.handle("/old", code=307, headers=[("location", "/temp")]) server.handle("/temp", TEST_FEED, content_type="text/xml") - meta = feed.FeedMeta.from_url(server.make_url("/old"), "asdf") + meta = feed.FeedMeta.from_url(server.make_url("/old")) result, new_meta = asyncio.run(feed.fetch_feed(meta)) assert new_meta.url == meta.url assert isinstance(result, feed.Feed) @@ -143,7 +142,7 @@ def test_fetch_after_permanent_redirect(): server.handle("/old", code=308, headers=[("location", "/perm")]) server.handle("/perm", TEST_FEED, content_type="text/xml") - meta = feed.FeedMeta.from_url(server.make_url("/old"), "asdf") + meta = feed.FeedMeta.from_url(server.make_url("/old")) result, new_meta = asyncio.run(feed.fetch_feed(meta)) assert new_meta.url == server.make_url("/perm") assert isinstance(result, feed.Feed) @@ -155,7 +154,7 @@ def test_fetch_after_permanent_to_temporary_redirect(): server.handle("/perm", code=307, headers=[("location", "/temp")]) server.handle("/temp", TEST_FEED, content_type="text/xml") - meta = feed.FeedMeta.from_url(server.make_url("/old"), "asdf") + meta = feed.FeedMeta.from_url(server.make_url("/old")) result, new_meta = asyncio.run(feed.fetch_feed(meta)) # NOTE: we should record the PERMANENT redirect, not the temporary one. @@ -169,7 +168,7 @@ def test_fetch_after_permanent_to_permanent_redirect(): server.handle("/one", code=308, headers=[("location", "/two")]) server.handle("/two", TEST_FEED, content_type="text/xml") - meta = feed.FeedMeta.from_url(server.make_url("/old"), "asdf") + meta = feed.FeedMeta.from_url(server.make_url("/old")) result, new_meta = asyncio.run(feed.fetch_feed(meta)) # NOTE: we should record the latest redirect.