Commit cf00c0a1 authored by Jürg Billeter's avatar Jürg Billeter
Browse files

Merge branch 'juerg/cas-batch' into 'master'

CAS: Implement BatchUpdateBlobs support

Closes #677 and #676

See merge request !839
parents fafa8136 f47895c0
Loading
Loading
Loading
Loading
Loading
+74 −4
Original line number Original line Diff line number Diff line
@@ -1048,10 +1048,29 @@ class CASCache(ArtifactCache):
                missing_blobs[d.hash] = d
                missing_blobs[d.hash] = d


        # Upload any blobs missing on the server
        # Upload any blobs missing on the server
        for blob_digest in missing_blobs.values():
        self._send_blobs(remote, missing_blobs.values(), u_uid)
            with open(self.objpath(blob_digest), 'rb') as f:

                assert os.fstat(f.fileno()).st_size == blob_digest.size_bytes
    def _send_blobs(self, remote, digests, u_uid=uuid.uuid4()):
                self._send_blob(remote, blob_digest, f, u_uid=u_uid)
        batch = _CASBatchUpdate(remote)

        for digest in digests:
            with open(self.objpath(digest), 'rb') as f:
                assert os.fstat(f.fileno()).st_size == digest.size_bytes

                if (digest.size_bytes >= remote.max_batch_total_size_bytes or
                        not remote.batch_update_supported):
                    # Too large for batch request, upload in independent request.
                    self._send_blob(remote, digest, f, u_uid=u_uid)
                else:
                    if not batch.add(digest, f):
                        # Not enough space left in batch request.
                        # Complete pending batch first.
                        batch.send()
                        batch = _CASBatchUpdate(remote)
                        batch.add(digest, f)

        # Send final batch
        batch.send()




# Represents a single remote CAS cache.
# Represents a single remote CAS cache.
@@ -1126,6 +1145,17 @@ class _CASRemote():
                if e.code() != grpc.StatusCode.UNIMPLEMENTED:
                if e.code() != grpc.StatusCode.UNIMPLEMENTED:
                    raise
                    raise


            # Check whether the server supports BatchUpdateBlobs()
            self.batch_update_supported = False
            try:
                request = remote_execution_pb2.BatchUpdateBlobsRequest()
                response = self.cas.BatchUpdateBlobs(request)
                self.batch_update_supported = True
            except grpc.RpcError as e:
                if (e.code() != grpc.StatusCode.UNIMPLEMENTED and
                        e.code() != grpc.StatusCode.PERMISSION_DENIED):
                    raise

            self._initialized = True
            self._initialized = True




@@ -1173,6 +1203,46 @@ class _CASBatchRead():
            yield (response.digest, response.data)
            yield (response.digest, response.data)




# Represents a batch of blobs queued for upload.
#
class _CASBatchUpdate():
    def __init__(self, remote):
        self._remote = remote
        self._max_total_size_bytes = remote.max_batch_total_size_bytes
        self._request = remote_execution_pb2.BatchUpdateBlobsRequest()
        self._size = 0
        self._sent = False

    def add(self, digest, stream):
        assert not self._sent

        new_batch_size = self._size + digest.size_bytes
        if new_batch_size > self._max_total_size_bytes:
            # Not enough space left in current batch
            return False

        blob_request = self._request.requests.add()
        blob_request.digest.hash = digest.hash
        blob_request.digest.size_bytes = digest.size_bytes
        blob_request.data = stream.read(digest.size_bytes)
        self._size = new_batch_size
        return True

    def send(self):
        assert not self._sent
        self._sent = True

        if len(self._request.requests) == 0:
            return

        batch_response = self._remote.cas.BatchUpdateBlobs(self._request)

        for response in batch_response.responses:
            if response.status.code != grpc.StatusCode.OK.value[0]:
                raise ArtifactError("Failed to upload blob {}: {}".format(
                    response.digest.hash, response.status.code))


def _grouper(iterable, n):
def _grouper(iterable, n):
    while True:
    while True:
        try:
        try:
+43 −2
Original line number Original line Diff line number Diff line
@@ -68,7 +68,7 @@ def create_server(repo, *, enable_push):
        _ByteStreamServicer(artifactcache, enable_push=enable_push), server)
        _ByteStreamServicer(artifactcache, enable_push=enable_push), server)


    remote_execution_pb2_grpc.add_ContentAddressableStorageServicer_to_server(
    remote_execution_pb2_grpc.add_ContentAddressableStorageServicer_to_server(
        _ContentAddressableStorageServicer(artifactcache), server)
        _ContentAddressableStorageServicer(artifactcache, enable_push=enable_push), server)


    remote_execution_pb2_grpc.add_CapabilitiesServicer_to_server(
    remote_execution_pb2_grpc.add_CapabilitiesServicer_to_server(
        _CapabilitiesServicer(), server)
        _CapabilitiesServicer(), server)
@@ -222,9 +222,10 @@ class _ByteStreamServicer(bytestream_pb2_grpc.ByteStreamServicer):




class _ContentAddressableStorageServicer(remote_execution_pb2_grpc.ContentAddressableStorageServicer):
class _ContentAddressableStorageServicer(remote_execution_pb2_grpc.ContentAddressableStorageServicer):
    def __init__(self, cas):
    def __init__(self, cas, *, enable_push):
        super().__init__()
        super().__init__()
        self.cas = cas
        self.cas = cas
        self.enable_push = enable_push


    def FindMissingBlobs(self, request, context):
    def FindMissingBlobs(self, request, context):
        response = remote_execution_pb2.FindMissingBlobsResponse()
        response = remote_execution_pb2.FindMissingBlobsResponse()
@@ -260,6 +261,46 @@ class _ContentAddressableStorageServicer(remote_execution_pb2_grpc.ContentAddres


        return response
        return response


    def BatchUpdateBlobs(self, request, context):
        response = remote_execution_pb2.BatchUpdateBlobsResponse()

        if not self.enable_push:
            context.set_code(grpc.StatusCode.PERMISSION_DENIED)
            return response

        batch_size = 0

        for blob_request in request.requests:
            digest = blob_request.digest

            batch_size += digest.size_bytes
            if batch_size > _MAX_PAYLOAD_BYTES:
                context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
                return response

            blob_response = response.responses.add()
            blob_response.digest.hash = digest.hash
            blob_response.digest.size_bytes = digest.size_bytes

            if len(blob_request.data) != digest.size_bytes:
                blob_response.status.code = grpc.StatusCode.FAILED_PRECONDITION
                continue

            try:
                _clean_up_cache(self.cas, digest.size_bytes)

                with tempfile.NamedTemporaryFile(dir=self.cas.tmpdir) as out:
                    out.write(blob_request.data)
                    out.flush()
                    server_digest = self.cas.add_object(path=out.name)
                    if server_digest.hash != digest.hash:
                        blob_response.status.code = grpc.StatusCode.FAILED_PRECONDITION

            except ArtifactTooLargeException:
                blob_response.status.code = grpc.StatusCode.RESOURCE_EXHAUSTED

        return response



class _CapabilitiesServicer(remote_execution_pb2_grpc.CapabilitiesServicer):
class _CapabilitiesServicer(remote_execution_pb2_grpc.CapabilitiesServicer):
    def GetCapabilities(self, request, context):
    def GetCapabilities(self, request, context):