diff --git a/src/bsv/simple_cas/cas.py b/src/bsv/simple_cas/cas.py index 504fe43..bd203ca 100644 --- a/src/bsv/simple_cas/cas.py +++ b/src/bsv/simple_cas/cas.py @@ -108,6 +108,20 @@ class SimpleCas: return digest + def get_ref(self, key: str) -> bytes | None: + ref_path = self._ref_path(key) + if not ref_path.is_file(): + return None + hex = ref_path.read_text().strip() + if len(hex) != 2 * self._digest_size: + return None + return bytes.fromhex(hex) + + def set_ref(self, key: str, digest: bytes): + ref_path = self._ref_path(key) + ref_path.parent.mkdir(parents=True, exist_ok=True) + ref_path.write_text(digest.hex()) + def _open_writer(self, digest: bytes, object_type: bytes, size: int) -> BinaryIO: dat_file = (self._root_dir / "cas.dat").open("ab") offset = dat_file.tell() @@ -121,6 +135,14 @@ class SimpleCas: return dat_file + def _ref_path(self, key: str) -> Path: + assert key + assert key[0] not in ("/", "\\") + assert key[-1] not in ("/", "\\") + key_path = Path(key) + assert not any(part in (".", "..") for part in key_path.parts) + return self._root_dir / "refs" / key_path + @dataclass class Object: diff --git a/tests/test_simple_cas.py b/tests/test_simple_cas.py index c219fe7..28fc5ea 100644 --- a/tests/test_simple_cas.py +++ b/tests/test_simple_cas.py @@ -74,3 +74,11 @@ def test_simple_cas(tmp_dir: Path): digest2 = cas.write(b"blob", data) assert digest2 == digest assert len(cas) == 1 + + +def test_refs(cas: SimpleCas): + digest = bytes([42] * cas._digest_size) + assert cas.get_ref("foo/bar") is None + cas.set_ref("foo/bar", digest) + assert cas.get_ref("foo/bar") == digest + assert cas.get_ref("foo") is None