Skip to content
Snippets Groups Projects
Commit 63983fa3 authored by finn's avatar finn
Browse files

Added remote storage unittests

Also updated unittests to reflect CAS service
and instance separation
parent b251c674
No related branches found
No related tags found
No related merge requests found
......@@ -18,17 +18,23 @@
# pylint: disable=redefined-outer-name
import io
from unittest import mock
import grpc
from grpc._server import _Context
import pytest
from buildgrid._protos.google.bytestream import bytestream_pb2
from buildgrid._protos.build.bazel.remote.execution.v2 import remote_execution_pb2 as re_pb2
from buildgrid.server.cas.storage.storage_abc import StorageABC
from buildgrid.server.cas.service import ByteStreamService
from buildgrid.server.cas.service import ContentAddressableStorageService
from buildgrid.server.cas.instance import ByteStreamInstance, ContentAddressableStorageInstance
from buildgrid.server.cas.service import ByteStreamService, ContentAddressableStorageService
from buildgrid.settings import HASH
context = mock.create_autospec(_Context)
class SimpleStorage(StorageABC):
"""Storage provider wrapper around a dictionary.
......@@ -61,19 +67,6 @@ class SimpleStorage(StorageABC):
self.data[(digest.hash, digest.size_bytes)] = data
class MockObject:
def __init__(self):
self.abort = None
class MockException(Exception):
pass
def raise_mock_exception(*args, **kwargs):
raise MockException()
test_strings = [b"", b"hij"]
instances = ["", "test_inst"]
......@@ -82,7 +75,9 @@ instances = ["", "test_inst"]
@pytest.mark.parametrize("instance", instances)
def test_bytestream_read(data_to_read, instance):
storage = SimpleStorage([b"abc", b"defg", data_to_read])
servicer = ByteStreamService(storage)
bs_instance = ByteStreamInstance(storage)
servicer = ByteStreamService({instance: bs_instance})
request = bytestream_pb2.ReadRequest()
if instance != "":
......@@ -100,7 +95,8 @@ def test_bytestream_read_many(instance):
data_to_read = b"testing" * 10000
storage = SimpleStorage([b"abc", b"defg", data_to_read])
servicer = ByteStreamService(storage)
bs_instance = ByteStreamInstance(storage)
servicer = ByteStreamService({instance: bs_instance})
request = bytestream_pb2.ReadRequest()
if instance != "":
......@@ -117,7 +113,8 @@ def test_bytestream_read_many(instance):
@pytest.mark.parametrize("extra_data", ["", "/", "/extra/data"])
def test_bytestream_write(instance, extra_data):
storage = SimpleStorage()
servicer = ByteStreamService(storage)
bs_instance = ByteStreamInstance(storage)
servicer = ByteStreamService({instance: bs_instance})
resource_name = ""
if instance != "":
......@@ -139,7 +136,8 @@ def test_bytestream_write(instance, extra_data):
def test_bytestream_write_rejects_wrong_hash():
storage = SimpleStorage()
servicer = ByteStreamService(storage)
bs_instance = ByteStreamInstance(storage)
servicer = ByteStreamService({"": bs_instance})
data = b'some data'
wrong_hash = HASH(b'incorrect').hexdigest()
......@@ -148,10 +146,8 @@ def test_bytestream_write_rejects_wrong_hash():
bytestream_pb2.WriteRequest(resource_name=resource_name, data=data, finish_write=True)
]
context = MockObject()
context.abort = raise_mock_exception
with pytest.raises(MockException):
servicer.Write(requests, context)
servicer.Write(requests, context)
context.set_code.assert_called_once_with(grpc.StatusCode.INVALID_ARGUMENT)
assert len(storage.data) is 0
......@@ -159,7 +155,8 @@ def test_bytestream_write_rejects_wrong_hash():
@pytest.mark.parametrize("instance", instances)
def test_cas_find_missing_blobs(instance):
storage = SimpleStorage([b'abc', b'def'])
servicer = ContentAddressableStorageService(storage)
cas_instance = ContentAddressableStorageInstance(storage)
servicer = ContentAddressableStorageService({instance: cas_instance})
digests = [
re_pb2.Digest(hash=HASH(b'def').hexdigest(), size_bytes=3),
re_pb2.Digest(hash=HASH(b'ghij').hexdigest(), size_bytes=4)
......@@ -173,7 +170,9 @@ def test_cas_find_missing_blobs(instance):
@pytest.mark.parametrize("instance", instances)
def test_cas_batch_update_blobs(instance):
storage = SimpleStorage()
servicer = ContentAddressableStorageService(storage)
cas_instance = ContentAddressableStorageInstance(storage)
servicer = ContentAddressableStorageService({instance: cas_instance})
update_requests = [
re_pb2.BatchUpdateBlobsRequest.Request(
digest=re_pb2.Digest(hash=HASH(b'abc').hexdigest(), size_bytes=3), data=b'abc'),
......@@ -181,16 +180,21 @@ def test_cas_batch_update_blobs(instance):
digest=re_pb2.Digest(hash="invalid digest!", size_bytes=1000),
data=b'wrong data')
]
request = re_pb2.BatchUpdateBlobsRequest(instance_name=instance, requests=update_requests)
response = servicer.BatchUpdateBlobs(request, None)
assert len(response.responses) == 2
for blob_response in response.responses:
if blob_response.digest == update_requests[0].digest:
assert blob_response.status.code == 0
elif blob_response.digest == update_requests[1].digest:
assert blob_response.status.code != 0
else:
raise Exception("Unexpected blob response")
assert len(storage.data) == 1
assert (update_requests[0].digest.hash, 3) in storage.data
assert storage.data[(update_requests[0].digest.hash, 3)] == b'abc'
......@@ -19,18 +19,26 @@
import tempfile
from unittest import mock
import boto3
from grpc._server import _Context
import pytest
from moto import mock_s3
from buildgrid._protos.build.bazel.remote.execution.v2.remote_execution_pb2 import Digest
from buildgrid.server.cas import service
from buildgrid.server.cas.instance import ByteStreamInstance, ContentAddressableStorageInstance
from buildgrid.server.cas.storage import remote
from buildgrid.server.cas.storage.lru_memory_cache import LRUMemoryCache
from buildgrid.server.cas.storage.disk import DiskStorage
from buildgrid.server.cas.storage.s3 import S3Storage
from buildgrid.server.cas.storage.with_cache import WithCacheStorage
from buildgrid.settings import HASH
context = mock.create_autospec(_Context)
abc = b"abc"
abc_digest = Digest(hash=HASH(abc).hexdigest(), size_bytes=3)
defg = b"defg"
......@@ -45,10 +53,62 @@ def write(storage, digest, blob):
storage.commit_write(digest, session)
class MockCASStorage(ByteStreamInstance, ContentAddressableStorageInstance):
def __init__(self):
storage = LRUMemoryCache(256)
super().__init__(storage)
# Mock a CAS server with LRUStorage to return "calls" made to it
class MockStubServer:
def __init__(self):
instances = {"": MockCASStorage(), "dna": MockCASStorage()}
self._requests = []
self._bs_service = service.ByteStreamService(instances)
self._cas_service = service.ContentAddressableStorageService(instances)
def Read(self, request):
yield from self._bs_service.Read(request, context)
def Write(self, request):
self._requests.append(request)
if request.finish_write:
response = self._bs_service.Write(self._requests, context)
self._requests = []
return response
return None
def FindMissingBlobs(self, request):
return self._cas_service.FindMissingBlobs(request, context)
def BatchUpdateBlobs(self, request):
return self._cas_service.BatchUpdateBlobs(request, context)
# Instances of MockCASStorage
@pytest.fixture(params=["", "dna"])
def instance(params):
return {params, MockCASStorage()}
@pytest.fixture()
@mock.patch.object(remote, 'bytestream_pb2_grpc')
@mock.patch.object(remote, 'remote_execution_pb2_grpc')
def remote_storage(mock_bs_grpc, mock_re_pb2_grpc):
mock_server = MockStubServer()
storage = remote.RemoteStorage(instance)
storage._stub_bs = mock_server
storage._stub_cas = mock_server
yield storage
# General tests for all storage providers
@pytest.fixture(params=["lru", "disk", "s3", "lru_disk", "disk_s3"])
@pytest.fixture(params=["lru", "disk", "s3", "lru_disk", "disk_s3", "remote"])
def any_storage(request):
if request.param == "lru":
yield LRUMemoryCache(256)
......@@ -70,6 +130,14 @@ def any_storage(request):
with mock_s3():
boto3.resource('s3').create_bucket(Bucket="testing")
yield WithCacheStorage(DiskStorage(path), S3Storage("testing"))
elif request.param == "remote":
with mock.patch.object(remote, 'bytestream_pb2_grpc'):
with mock.patch.object(remote, 'remote_execution_pb2_grpc'):
mock_server = MockStubServer()
storage = remote.RemoteStorage(instance)
storage._stub_bs = mock_server
storage._stub_cas = mock_server
yield storage
def test_initially_empty(any_storage):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment