Skip to content
Snippets Groups Projects
Commit 336a2229 authored by Martin Blanchard's avatar Martin Blanchard
Browse files

tests/cas/test_storage.py: Fix existing storage tests

#77
parent dd4adf29
No related branches found
No related tags found
Loading
Pipeline #31065815 passed
......@@ -19,220 +19,286 @@
import tempfile
from unittest import mock
import boto3
import grpc
from grpc._server import _Context
import pytest
from moto import mock_s3
from buildgrid._protos.build.bazel.remote.execution.v2.remote_execution_pb2 import Digest
from buildgrid.server.cas import service
from buildgrid.server.cas.instance import ByteStreamInstance, ContentAddressableStorageInstance
from buildgrid.server.cas.storage import remote
from buildgrid._protos.build.bazel.remote.execution.v2 import remote_execution_pb2
from buildgrid.server.cas.storage.remote import RemoteStorage
from buildgrid.server.cas.storage.lru_memory_cache import LRUMemoryCache
from buildgrid.server.cas.storage.disk import DiskStorage
from buildgrid.server.cas.storage.s3 import S3Storage
from buildgrid.server.cas.storage.with_cache import WithCacheStorage
from buildgrid.settings import HASH
from ..utils.cas import serve_cas, run_in_subprocess
context = mock.create_autospec(_Context)
server = mock.create_autospec(grpc.server)
abc = b"abc"
abc_digest = Digest(hash=HASH(abc).hexdigest(), size_bytes=3)
defg = b"defg"
defg_digest = Digest(hash=HASH(defg).hexdigest(), size_bytes=4)
hijk = b"hijk"
hijk_digest = Digest(hash=HASH(hijk).hexdigest(), size_bytes=4)
def write(storage, digest, blob):
session = storage.begin_write(digest)
session.write(blob)
storage.commit_write(digest, session)
class MockCASStorage(ByteStreamInstance, ContentAddressableStorageInstance):
def __init__(self):
storage = LRUMemoryCache(256)
super().__init__(storage)
# Mock a CAS server with LRUStorage to return "calls" made to it
class MockStubServer:
def __init__(self):
instances = {"": MockCASStorage(), "dna": MockCASStorage()}
self._requests = []
with mock.patch.object(service, 'bytestream_pb2_grpc'):
self._bs_service = service.ByteStreamService(server)
for k, v in instances.items():
self._bs_service.add_instance(k, v)
with mock.patch.object(service, 'remote_execution_pb2_grpc'):
self._cas_service = service.ContentAddressableStorageService(server)
for k, v in instances.items():
self._cas_service.add_instance(k, v)
def Read(self, request):
yield from self._bs_service.Read(request, context)
def Write(self, request):
self._requests.append(request)
if request.finish_write:
response = self._bs_service.Write(self._requests, context)
self._requests = []
return response
return None
def FindMissingBlobs(self, request):
return self._cas_service.FindMissingBlobs(request, context)
def BatchUpdateBlobs(self, request):
return self._cas_service.BatchUpdateBlobs(request, context)
BLOBS = [(b'abc', b'defg', b'hijk', b'')]
BLOBS_DIGESTS = [tuple([remote_execution_pb2.Digest(hash=HASH(blob).hexdigest(),
size_bytes=len(blob)) for blob in blobs])
for blobs in BLOBS]
# Instances of MockCASStorage
@pytest.fixture(params=["", "dna"])
def instance(params):
return {params, MockCASStorage()}
# General tests for all storage providers
@pytest.fixture(params=["lru", "disk", "s3", "lru_disk", "disk_s3", "remote"])
@pytest.fixture(params=['lru', 'disk', 's3', 'lru_disk', 'disk_s3', 'remote'])
def any_storage(request):
if request.param == "lru":
if request.param == 'lru':
yield LRUMemoryCache(256)
elif request.param == "disk":
elif request.param == 'disk':
with tempfile.TemporaryDirectory() as path:
yield DiskStorage(path)
elif request.param == "s3":
elif request.param == 's3':
with mock_s3():
boto3.resource('s3').create_bucket(Bucket="testing")
yield S3Storage("testing")
elif request.param == "lru_disk":
boto3.resource('s3').create_bucket(Bucket='testing')
yield S3Storage('testing')
elif request.param == 'lru_disk':
# LRU cache with a uselessly small limit, so requests always fall back
with tempfile.TemporaryDirectory() as path:
yield WithCacheStorage(LRUMemoryCache(1), DiskStorage(path))
elif request.param == "disk_s3":
elif request.param == 'disk_s3':
# Disk-based cache of S3, but we don't delete files, so requests
# are always handled by the cache
with tempfile.TemporaryDirectory() as path:
with mock_s3():
boto3.resource('s3').create_bucket(Bucket="testing")
yield WithCacheStorage(DiskStorage(path), S3Storage("testing"))
elif request.param == "remote":
with mock.patch.object(remote, 'bytestream_pb2_grpc'):
with mock.patch.object(remote, 'remote_execution_pb2_grpc'):
mock_server = MockStubServer()
storage = remote.RemoteStorage(None, "")
storage._stub_bs = mock_server
storage._stub_cas = mock_server
yield storage
def test_initially_empty(any_storage):
assert not any_storage.has_blob(abc_digest)
assert not any_storage.has_blob(defg_digest)
assert not any_storage.has_blob(hijk_digest)
def test_basic_write_read(any_storage):
assert not any_storage.has_blob(abc_digest)
write(any_storage, abc_digest, abc)
assert any_storage.has_blob(abc_digest)
assert any_storage.get_blob(abc_digest).read() == abc
# Try writing the same digest again (since it's valid to do that)
write(any_storage, abc_digest, abc)
assert any_storage.has_blob(abc_digest)
assert any_storage.get_blob(abc_digest).read() == abc
def test_bulk_write_read(any_storage):
missing_digests = any_storage.missing_blobs([abc_digest, defg_digest, hijk_digest])
assert len(missing_digests) == 3
assert abc_digest in missing_digests
assert defg_digest in missing_digests
assert hijk_digest in missing_digests
boto3.resource('s3').create_bucket(Bucket='testing')
yield WithCacheStorage(DiskStorage(path), S3Storage('testing'))
elif request.param == 'remote':
with serve_cas(['testing']) as server:
yield server.remote
server = None
bulk_update_results = any_storage.bulk_update_blobs([(abc_digest, abc), (defg_digest, defg),
(hijk_digest, b'????')])
assert len(bulk_update_results) == 3
assert bulk_update_results[0].code == 0
assert bulk_update_results[1].code == 0
assert bulk_update_results[2].code != 0
missing_digests = any_storage.missing_blobs([abc_digest, defg_digest, hijk_digest])
assert missing_digests == [hijk_digest]
assert any_storage.get_blob(abc_digest).read() == abc
assert any_storage.get_blob(defg_digest).read() == defg
def test_nonexistent_read(any_storage):
assert any_storage.get_blob(abc_digest) is None
def write(storage, digest, blob):
session = storage.begin_write(digest)
session.write(blob)
storage.commit_write(digest, session)
# Tests for special behavior of individual storage providers
@pytest.mark.parametrize('blobs_digests', zip(BLOBS, BLOBS_DIGESTS))
def test_initially_empty(any_storage, blobs_digests):
_, digests = blobs_digests
# Actual test function, failing on assertions:
def __test_initially_empty(any_storage, digests):
for digest in digests:
assert not any_storage.has_blob(digest)
# Helper test function for remote storage, to be run in a subprocess:
def __test_remote_initially_empty(queue, remote, serialized_digests):
channel = grpc.insecure_channel(remote)
remote_storage = RemoteStorage(channel, 'testing')
digests = []
for data in serialized_digests:
digest = remote_execution_pb2.Digest()
digest.ParseFromString(data)
digests.append(digest)
try:
__test_initially_empty(remote_storage, digests)
except AssertionError:
queue.put(False)
else:
queue.put(True)
if isinstance(any_storage, str):
serialized_digests = [digest.SerializeToString() for digest in digests]
assert run_in_subprocess(__test_remote_initially_empty,
any_storage, serialized_digests)
else:
__test_initially_empty(any_storage, digests)
@pytest.mark.parametrize('blobs_digests', zip(BLOBS, BLOBS_DIGESTS))
def test_basic_write_read(any_storage, blobs_digests):
blobs, digests = blobs_digests
# Actual test function, failing on assertions:
def __test_basic_write_read(any_storage, blobs, digests):
for blob, digest in zip(blobs, digests):
assert not any_storage.has_blob(digest)
write(any_storage, digest, blob)
assert any_storage.has_blob(digest)
assert any_storage.get_blob(digest).read() == blob
# Try writing the same digest again (since it's valid to do that)
write(any_storage, digest, blob)
assert any_storage.has_blob(digest)
assert any_storage.get_blob(digest).read() == blob
# Helper test function for remote storage, to be run in a subprocess:
def __test_remote_basic_write_read(queue, remote, blobs, serialized_digests):
channel = grpc.insecure_channel(remote)
remote_storage = RemoteStorage(channel, 'testing')
digests = []
for data in serialized_digests:
digest = remote_execution_pb2.Digest()
digest.ParseFromString(data)
digests.append(digest)
try:
__test_basic_write_read(remote_storage, blobs, digests)
except AssertionError:
queue.put(False)
else:
queue.put(True)
if isinstance(any_storage, str):
serialized_digests = [digest.SerializeToString() for digest in digests]
assert run_in_subprocess(__test_remote_basic_write_read,
any_storage, blobs, serialized_digests)
else:
__test_basic_write_read(any_storage, blobs, digests)
@pytest.mark.parametrize('blobs_digests', zip(BLOBS, BLOBS_DIGESTS))
def test_bulk_write_read(any_storage, blobs_digests):
blobs, digests = blobs_digests
# Actual test function, failing on assertions:
def __test_bulk_write_read(any_storage, blobs, digests):
missing_digests = any_storage.missing_blobs(digests)
assert len(missing_digests) == len(digests)
for digest in digests:
assert digest in missing_digests
faulty_blobs = list(blobs)
faulty_blobs[-1] = b'this-is-not-matching'
results = any_storage.bulk_update_blobs(list(zip(digests, faulty_blobs)))
assert len(results) == len(digests)
for result, blob, digest in zip(results[:-1], faulty_blobs[:-1], digests[:-1]):
assert result.code == 0
assert any_storage.get_blob(digest).read() == blob
assert results[-1].code != 0
missing_digests = any_storage.missing_blobs(digests)
assert len(missing_digests) == 1
assert missing_digests[0] == digests[-1]
# Helper test function for remote storage, to be run in a subprocess:
def __test_remote_bulk_write_read(queue, remote, blobs, serialized_digests):
channel = grpc.insecure_channel(remote)
remote_storage = RemoteStorage(channel, 'testing')
digests = []
for data in serialized_digests:
digest = remote_execution_pb2.Digest()
digest.ParseFromString(data)
digests.append(digest)
try:
__test_bulk_write_read(remote_storage, blobs, digests)
except AssertionError:
queue.put(False)
else:
queue.put(True)
if isinstance(any_storage, str):
serialized_digests = [digest.SerializeToString() for digest in digests]
assert run_in_subprocess(__test_remote_bulk_write_read,
any_storage, blobs, serialized_digests)
else:
__test_bulk_write_read(any_storage, blobs, digests)
@pytest.mark.parametrize('blobs_digests', zip(BLOBS, BLOBS_DIGESTS))
def test_nonexistent_read(any_storage, blobs_digests):
_, digests = blobs_digests
# Actual test function, failing on assertions:
def __test_nonexistent_read(any_storage, digests):
for digest in digests:
assert any_storage.get_blob(digest) is None
# Helper test function for remote storage, to be run in a subprocess:
def __test_remote_nonexistent_read(queue, remote, serialized_digests):
channel = grpc.insecure_channel(remote)
remote_storage = RemoteStorage(channel, 'testing')
digests = []
for data in serialized_digests:
digest = remote_execution_pb2.Digest()
digest.ParseFromString(data)
digests.append(digest)
try:
__test_nonexistent_read(remote_storage, digests)
except AssertionError:
queue.put(False)
else:
queue.put(True)
if isinstance(any_storage, str):
serialized_digests = [digest.SerializeToString() for digest in digests]
assert run_in_subprocess(__test_remote_nonexistent_read,
any_storage, serialized_digests)
else:
__test_nonexistent_read(any_storage, digests)
@pytest.mark.parametrize('blobs_digests', [(BLOBS[0], BLOBS_DIGESTS[0])])
def test_lru_eviction(blobs_digests):
blobs, digests = blobs_digests
blob1, blob2, blob3, *_ = blobs
digest1, digest2, digest3, *_ = digests
def test_lru_eviction():
lru = LRUMemoryCache(8)
write(lru, abc_digest, abc)
write(lru, defg_digest, defg)
assert lru.has_blob(abc_digest)
assert lru.has_blob(defg_digest)
write(lru, hijk_digest, hijk)
# Check that the LRU evicted abc (it was written first)
assert not lru.has_blob(abc_digest)
assert lru.has_blob(defg_digest)
assert lru.has_blob(hijk_digest)
assert lru.get_blob(defg_digest).read() == defg
write(lru, abc_digest, abc)
# Check that the LRU evicted hijk (since we just read defg)
assert lru.has_blob(abc_digest)
assert lru.has_blob(defg_digest)
assert not lru.has_blob(hijk_digest)
assert lru.has_blob(defg_digest)
write(lru, hijk_digest, abc)
# Check that the LRU evicted abc (since we just checked hijk)
assert not lru.has_blob(abc_digest)
assert lru.has_blob(defg_digest)
assert lru.has_blob(hijk_digest)
def test_with_cache():
write(lru, digest1, blob1)
write(lru, digest2, blob2)
assert lru.has_blob(digest1)
assert lru.has_blob(digest2)
write(lru, digest3, blob3)
# Check that the LRU evicted blob1 (it was written first)
assert not lru.has_blob(digest1)
assert lru.has_blob(digest2)
assert lru.has_blob(digest3)
assert lru.get_blob(digest2).read() == blob2
write(lru, digest1, blob1)
# Check that the LRU evicted blob3 (since we just read blob2)
assert lru.has_blob(digest1)
assert lru.has_blob(digest2)
assert not lru.has_blob(digest3)
assert lru.has_blob(digest2)
write(lru, digest3, blob1)
# Check that the LRU evicted blob1 (since we just checked blob3)
assert not lru.has_blob(digest1)
assert lru.has_blob(digest2)
assert lru.has_blob(digest3)
@pytest.mark.parametrize('blobs_digests', [(BLOBS[0], BLOBS_DIGESTS[0])])
def test_with_cache(blobs_digests):
blobs, digests = blobs_digests
blob1, blob2, blob3, *_ = blobs
digest1, digest2, digest3, *_ = digests
cache = LRUMemoryCache(256)
fallback = LRUMemoryCache(256)
with_cache_storage = WithCacheStorage(cache, fallback)
assert not with_cache_storage.has_blob(abc_digest)
write(with_cache_storage, abc_digest, abc)
assert cache.has_blob(abc_digest)
assert fallback.has_blob(abc_digest)
assert with_cache_storage.get_blob(abc_digest).read() == abc
assert not with_cache_storage.has_blob(digest1)
write(with_cache_storage, digest1, blob1)
assert cache.has_blob(digest1)
assert fallback.has_blob(digest1)
assert with_cache_storage.get_blob(digest1).read() == blob1
# Even if a blob is in cache, we still need to check if the fallback
# has it.
write(cache, defg_digest, defg)
assert not with_cache_storage.has_blob(defg_digest)
write(fallback, defg_digest, defg)
assert with_cache_storage.has_blob(defg_digest)
write(cache, digest2, blob2)
assert not with_cache_storage.has_blob(digest2)
write(fallback, digest2, blob2)
assert with_cache_storage.has_blob(digest2)
# When a blob is in the fallback but not the cache, reading it should
# put it into the cache.
write(fallback, hijk_digest, hijk)
assert with_cache_storage.get_blob(hijk_digest).read() == hijk
assert cache.has_blob(hijk_digest)
assert cache.get_blob(hijk_digest).read() == hijk
assert cache.has_blob(hijk_digest)
write(fallback, digest3, blob3)
assert with_cache_storage.get_blob(digest3).read() == blob3
assert cache.has_blob(digest3)
assert cache.get_blob(digest3).read() == blob3
assert cache.has_blob(digest3)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment