Skip to content
Snippets Groups Projects
Commit cf00c0a1 authored by Jürg Billeter's avatar Jürg Billeter
Browse files

Merge branch 'juerg/cas-batch' into 'master'

CAS: Implement BatchUpdateBlobs support

Closes #677 and #676

See merge request !839
parents fafa8136 f47895c0
No related branches found
No related tags found
1 merge request!839CAS: Implement BatchUpdateBlobs support
Pipeline #31693615 passed
......@@ -1048,10 +1048,29 @@ class CASCache(ArtifactCache):
missing_blobs[d.hash] = d
# Upload any blobs missing on the server
for blob_digest in missing_blobs.values():
with open(self.objpath(blob_digest), 'rb') as f:
assert os.fstat(f.fileno()).st_size == blob_digest.size_bytes
self._send_blob(remote, blob_digest, f, u_uid=u_uid)
self._send_blobs(remote, missing_blobs.values(), u_uid)
def _send_blobs(self, remote, digests, u_uid=uuid.uuid4()):
batch = _CASBatchUpdate(remote)
for digest in digests:
with open(self.objpath(digest), 'rb') as f:
assert os.fstat(f.fileno()).st_size == digest.size_bytes
if (digest.size_bytes >= remote.max_batch_total_size_bytes or
not remote.batch_update_supported):
# Too large for batch request, upload in independent request.
self._send_blob(remote, digest, f, u_uid=u_uid)
else:
if not batch.add(digest, f):
# Not enough space left in batch request.
# Complete pending batch first.
batch.send()
batch = _CASBatchUpdate(remote)
batch.add(digest, f)
# Send final batch
batch.send()
# Represents a single remote CAS cache.
......@@ -1126,6 +1145,17 @@ class _CASRemote():
if e.code() != grpc.StatusCode.UNIMPLEMENTED:
raise
# Check whether the server supports BatchUpdateBlobs()
self.batch_update_supported = False
try:
request = remote_execution_pb2.BatchUpdateBlobsRequest()
response = self.cas.BatchUpdateBlobs(request)
self.batch_update_supported = True
except grpc.RpcError as e:
if (e.code() != grpc.StatusCode.UNIMPLEMENTED and
e.code() != grpc.StatusCode.PERMISSION_DENIED):
raise
self._initialized = True
......@@ -1173,6 +1203,46 @@ class _CASBatchRead():
yield (response.digest, response.data)
# Represents a batch of blobs queued for upload.
#
class _CASBatchUpdate():
def __init__(self, remote):
self._remote = remote
self._max_total_size_bytes = remote.max_batch_total_size_bytes
self._request = remote_execution_pb2.BatchUpdateBlobsRequest()
self._size = 0
self._sent = False
def add(self, digest, stream):
assert not self._sent
new_batch_size = self._size + digest.size_bytes
if new_batch_size > self._max_total_size_bytes:
# Not enough space left in current batch
return False
blob_request = self._request.requests.add()
blob_request.digest.hash = digest.hash
blob_request.digest.size_bytes = digest.size_bytes
blob_request.data = stream.read(digest.size_bytes)
self._size = new_batch_size
return True
def send(self):
assert not self._sent
self._sent = True
if len(self._request.requests) == 0:
return
batch_response = self._remote.cas.BatchUpdateBlobs(self._request)
for response in batch_response.responses:
if response.status.code != grpc.StatusCode.OK.value[0]:
raise ArtifactError("Failed to upload blob {}: {}".format(
response.digest.hash, response.status.code))
def _grouper(iterable, n):
while True:
try:
......
......@@ -68,7 +68,7 @@ def create_server(repo, *, enable_push):
_ByteStreamServicer(artifactcache, enable_push=enable_push), server)
remote_execution_pb2_grpc.add_ContentAddressableStorageServicer_to_server(
_ContentAddressableStorageServicer(artifactcache), server)
_ContentAddressableStorageServicer(artifactcache, enable_push=enable_push), server)
remote_execution_pb2_grpc.add_CapabilitiesServicer_to_server(
_CapabilitiesServicer(), server)
......@@ -222,9 +222,10 @@ class _ByteStreamServicer(bytestream_pb2_grpc.ByteStreamServicer):
class _ContentAddressableStorageServicer(remote_execution_pb2_grpc.ContentAddressableStorageServicer):
def __init__(self, cas):
def __init__(self, cas, *, enable_push):
super().__init__()
self.cas = cas
self.enable_push = enable_push
def FindMissingBlobs(self, request, context):
response = remote_execution_pb2.FindMissingBlobsResponse()
......@@ -260,6 +261,46 @@ class _ContentAddressableStorageServicer(remote_execution_pb2_grpc.ContentAddres
return response
def BatchUpdateBlobs(self, request, context):
response = remote_execution_pb2.BatchUpdateBlobsResponse()
if not self.enable_push:
context.set_code(grpc.StatusCode.PERMISSION_DENIED)
return response
batch_size = 0
for blob_request in request.requests:
digest = blob_request.digest
batch_size += digest.size_bytes
if batch_size > _MAX_PAYLOAD_BYTES:
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
return response
blob_response = response.responses.add()
blob_response.digest.hash = digest.hash
blob_response.digest.size_bytes = digest.size_bytes
if len(blob_request.data) != digest.size_bytes:
blob_response.status.code = grpc.StatusCode.FAILED_PRECONDITION
continue
try:
_clean_up_cache(self.cas, digest.size_bytes)
with tempfile.NamedTemporaryFile(dir=self.cas.tmpdir) as out:
out.write(blob_request.data)
out.flush()
server_digest = self.cas.add_object(path=out.name)
if server_digest.hash != digest.hash:
blob_response.status.code = grpc.StatusCode.FAILED_PRECONDITION
except ArtifactTooLargeException:
blob_response.status.code = grpc.StatusCode.RESOURCE_EXHAUSTED
return response
class _CapabilitiesServicer(remote_execution_pb2_grpc.CapabilitiesServicer):
def GetCapabilities(self, request, context):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment