From 52c12785c82204fb8c0758e8168449f9019091c7 Mon Sep 17 00:00:00 2001 From: John Doty Date: Wed, 17 Jul 2024 06:39:59 -0700 Subject: [PATCH] Some database tests --- cry/database.py | 1 - tests/test_database.py | 66 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 1 deletion(-) create mode 100644 tests/test_database.py diff --git a/cry/database.py b/cry/database.py index 3fde287..b2d812d 100644 --- a/cry/database.py +++ b/cry/database.py @@ -76,7 +76,6 @@ def local_origin(path: pathlib.Path | None = None) -> str: def database_path(origin: str) -> pathlib.Path: - # TODO: Determine the name/slug from local state if necessary return pathlib.Path.home() / "Dropbox" / "cry" / f"{origin}.db" diff --git a/tests/test_database.py b/tests/test_database.py new file mode 100644 index 0000000..b8d4010 --- /dev/null +++ b/tests/test_database.py @@ -0,0 +1,66 @@ +import pathlib +import random +import string +import tempfile + +from cry import database + + +def random_slug() -> str: + return "".join( + random.choices( + string.ascii_uppercase + string.ascii_lowercase + string.digits, k=8 + ) + ) + + +def test_database_origin_path(): + op = database.origin_path() + assert op is not None + + +def test_database_local_origin(): + with tempfile.TemporaryDirectory() as op: + origin_file = pathlib.Path(op) / "origin" + assert not origin_file.exists() + + origin = database.local_origin(origin_file) + + assert origin_file.exists() + assert len(origin) > 0 + + +def test_database_local_origin_repeatable(): + with tempfile.TemporaryDirectory() as op: + origin_file = pathlib.Path(op) / "origin" + + a = database.local_origin(origin_file) + b = database.local_origin(origin_file) + + assert len(a) > 0 + assert a == b + + +def test_database_origin_in_path(): + slug = random_slug() + p = database.database_path(slug) + assert slug in str(p) + + +def test_database_schema(): + db = database.Database(":memory:", random_slug()) + db.ensure_database_schema() + + c = db.db.execute("SELECT value FROM properties WHERE name = 'version'") + row = c.fetchone() + assert int(row[0]) == len(database.SCHEMA_STATEMENTS) + + +def test_database_prop_get_set(): + db = database.Database(":memory:", random_slug()) + db.ensure_database_schema() + + assert db.get_property("foo") is None + val = random_slug() + db.set_property("foo", val) + assert db.get_property("foo") == val