SimpleCas refs.

This commit is contained in:
2023-11-10 18:41:48 +01:00
parent 8937b51a45
commit 8d6f8ad7db
2 changed files with 30 additions and 0 deletions

View File

@@ -108,6 +108,20 @@ class SimpleCas:
return digest
def get_ref(self, key: str) -> bytes | None:
ref_path = self._ref_path(key)
if not ref_path.is_file():
return None
hex = ref_path.read_text().strip()
if len(hex) != 2 * self._digest_size:
return None
return bytes.fromhex(hex)
def set_ref(self, key: str, digest: bytes):
ref_path = self._ref_path(key)
ref_path.parent.mkdir(parents=True, exist_ok=True)
ref_path.write_text(digest.hex())
def _open_writer(self, digest: bytes, object_type: bytes, size: int) -> BinaryIO:
dat_file = (self._root_dir / "cas.dat").open("ab")
offset = dat_file.tell()
@@ -121,6 +135,14 @@ class SimpleCas:
return dat_file
def _ref_path(self, key: str) -> Path:
assert key
assert key[0] not in ("/", "\\")
assert key[-1] not in ("/", "\\")
key_path = Path(key)
assert not any(part in (".", "..") for part in key_path.parts)
return self._root_dir / "refs" / key_path
@dataclass
class Object:

View File

@@ -74,3 +74,11 @@ def test_simple_cas(tmp_dir: Path):
digest2 = cas.write(b"blob", data)
assert digest2 == digest
assert len(cas) == 1
def test_refs(cas: SimpleCas):
digest = bytes([42] * cas._digest_size)
assert cas.get_ref("foo/bar") is None
cas.set_ref("foo/bar", digest)
assert cas.get_ref("foo/bar") == digest
assert cas.get_ref("foo") is None