Commit d2d0e907 authored by Raoul Hidalgo Charman's avatar Raoul Hidalgo Charman
Browse files

casremote.py: Move remote CAS classes into its own file

Part of #802
parent ead41b31
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
@@ -17,4 +17,5 @@
#  Authors:
#        Tristan Van Berkom <tristan.vanberkom@codethink.co.uk>

from .cascache import CASCache, CASRemote, CASRemoteSpec
from .cascache import CASCache
from .casremote import CASRemote, CASRemoteSpec
+19 −256
Original line number Diff line number Diff line
@@ -17,7 +17,6 @@
#  Authors:
#        Jürg Billeter <juerg.billeter@codethink.co.uk>

from collections import namedtuple
import hashlib
import itertools
import io
@@ -26,76 +25,17 @@ import stat
import tempfile
import uuid
import contextlib
from urllib.parse import urlparse

import grpc

from .._protos.google.rpc import code_pb2
from .._protos.google.bytestream import bytestream_pb2, bytestream_pb2_grpc
from .._protos.build.bazel.remote.execution.v2 import remote_execution_pb2, remote_execution_pb2_grpc
from .._protos.buildstream.v2 import buildstream_pb2, buildstream_pb2_grpc
from .._protos.google.bytestream import bytestream_pb2
from .._protos.build.bazel.remote.execution.v2 import remote_execution_pb2
from .._protos.buildstream.v2 import buildstream_pb2

from .. import utils
from .._exceptions import CASError, LoadError, LoadErrorReason
from .. import _yaml
from .._exceptions import CASCacheError


# The default limit for gRPC messages is 4 MiB.
# Limit payload to 1 MiB to leave sufficient headroom for metadata.
_MAX_PAYLOAD_BYTES = 1024 * 1024


class CASRemoteSpec(namedtuple('CASRemoteSpec', 'url push server_cert client_key client_cert instance_name')):

    # _new_from_config_node
    #
    # Creates an CASRemoteSpec() from a YAML loaded node
    #
    @staticmethod
    def _new_from_config_node(spec_node, basedir=None):
        _yaml.node_validate(spec_node, ['url', 'push', 'server-cert', 'client-key', 'client-cert', 'instance-name'])
        url = _yaml.node_get(spec_node, str, 'url')
        push = _yaml.node_get(spec_node, bool, 'push', default_value=False)
        if not url:
            provenance = _yaml.node_get_provenance(spec_node, 'url')
            raise LoadError(LoadErrorReason.INVALID_DATA,
                            "{}: empty artifact cache URL".format(provenance))

        instance_name = _yaml.node_get(spec_node, str, 'instance-name', default_value=None)

        server_cert = _yaml.node_get(spec_node, str, 'server-cert', default_value=None)
        if server_cert and basedir:
            server_cert = os.path.join(basedir, server_cert)

        client_key = _yaml.node_get(spec_node, str, 'client-key', default_value=None)
        if client_key and basedir:
            client_key = os.path.join(basedir, client_key)

        client_cert = _yaml.node_get(spec_node, str, 'client-cert', default_value=None)
        if client_cert and basedir:
            client_cert = os.path.join(basedir, client_cert)

        if client_key and not client_cert:
            provenance = _yaml.node_get_provenance(spec_node, 'client-key')
            raise LoadError(LoadErrorReason.INVALID_DATA,
                            "{}: 'client-key' was specified without 'client-cert'".format(provenance))

        if client_cert and not client_key:
            provenance = _yaml.node_get_provenance(spec_node, 'client-cert')
            raise LoadError(LoadErrorReason.INVALID_DATA,
                            "{}: 'client-cert' was specified without 'client-key'".format(provenance))

        return CASRemoteSpec(url, push, server_cert, client_key, client_cert, instance_name)


CASRemoteSpec.__new__.__defaults__ = (None, None, None, None)


class BlobNotFound(CASError):

    def __init__(self, blob, msg):
        self.blob = blob
        super().__init__(msg)
from .casremote import CASRemote, BlobNotFound, _CASBatchRead, _CASBatchUpdate, _MAX_PAYLOAD_BYTES


# A CASCache manages a CAS repository as specified in the Remote Execution API.
@@ -120,7 +60,7 @@ class CASCache():
        headdir = os.path.join(self.casdir, 'refs', 'heads')
        objdir = os.path.join(self.casdir, 'objects')
        if not (os.path.isdir(headdir) and os.path.isdir(objdir)):
            raise CASError("CAS repository check failed for '{}'".format(self.casdir))
            raise CASCacheError("CAS repository check failed for '{}'".format(self.casdir))

    # contains():
    #
@@ -169,7 +109,7 @@ class CASCache():
    #     subdir (str): Optional specific dir to extract
    #
    # Raises:
    #     CASError: In cases there was an OSError, or if the ref did not exist.
    #     CASCacheError: In cases there was an OSError, or if the ref did not exist.
    #
    # Returns: path to extracted directory
    #
@@ -201,7 +141,7 @@ class CASCache():
                # Another process beat us to rename
                pass
            except OSError as e:
                raise CASError("Failed to extract directory for ref '{}': {}".format(ref, e)) from e
                raise CASCacheError("Failed to extract directory for ref '{}': {}".format(ref, e)) from e

        return originaldest

@@ -306,7 +246,7 @@ class CASCache():
            return True
        except grpc.RpcError as e:
            if e.code() != grpc.StatusCode.NOT_FOUND:
                raise CASError("Failed to pull ref {}: {}".format(ref, e)) from e
                raise CASCacheError("Failed to pull ref {}: {}".format(ref, e)) from e
            else:
                return False
        except BlobNotFound as e:
@@ -360,7 +300,7 @@ class CASCache():
    #   (bool): True if any remote was updated, False if no pushes were required
    #
    # Raises:
    #   (CASError): if there was an error
    #   (CASCacheError): if there was an error
    #
    def push(self, refs, remote):
        skipped_remote = True
@@ -395,7 +335,7 @@ class CASCache():
                skipped_remote = False
        except grpc.RpcError as e:
            if e.code() != grpc.StatusCode.RESOURCE_EXHAUSTED:
                raise CASError("Failed to push ref {}: {}".format(refs, e), temporary=True) from e
                raise CASCacheError("Failed to push ref {}: {}".format(refs, e), temporary=True) from e

        return not skipped_remote

@@ -408,7 +348,7 @@ class CASCache():
    #     directory (Directory): A virtual directory object to push.
    #
    # Raises:
    #     (CASError): if there was an error
    #     (CASCacheError): if there was an error
    #
    def push_directory(self, remote, directory):
        remote.init()
@@ -424,7 +364,7 @@ class CASCache():
    #     message (Message): A protobuf message to push.
    #
    # Raises:
    #     (CASError): if there was an error
    #     (CASCacheError): if there was an error
    #
    def push_message(self, remote, message):

@@ -531,7 +471,7 @@ class CASCache():
            pass

        except OSError as e:
            raise CASError("Failed to hash object: {}".format(e)) from e
            raise CASCacheError("Failed to hash object: {}".format(e)) from e

        return digest

@@ -572,7 +512,7 @@ class CASCache():
                return digest

        except FileNotFoundError as e:
            raise CASError("Attempt to access unavailable ref: {}".format(e)) from e
            raise CASCacheError("Attempt to access unavailable ref: {}".format(e)) from e

    # update_mtime()
    #
@@ -585,7 +525,7 @@ class CASCache():
        try:
            os.utime(self._refpath(ref))
        except FileNotFoundError as e:
            raise CASError("Attempt to access unavailable ref: {}".format(e)) from e
            raise CASCacheError("Attempt to access unavailable ref: {}".format(e)) from e

    # calculate_cache_size()
    #
@@ -676,7 +616,7 @@ class CASCache():
        # Remove cache ref
        refpath = self._refpath(ref)
        if not os.path.exists(refpath):
            raise CASError("Could not find ref '{}'".format(ref))
            raise CASCacheError("Could not find ref '{}'".format(ref))

        os.unlink(refpath)

@@ -792,7 +732,7 @@ class CASCache():
                # The process serving the socket can't be cached anyway
                pass
            else:
                raise CASError("Unsupported file type for {}".format(full_path))
                raise CASCacheError("Unsupported file type for {}".format(full_path))

        return self.add_object(digest=dir_digest,
                               buffer=directory.SerializeToString())
@@ -811,7 +751,7 @@ class CASCache():
            if dirnode.name == name:
                return dirnode.digest

        raise CASError("Subdirectory {} not found".format(name))
        raise CASCacheError("Subdirectory {} not found".format(name))

    def _diff_trees(self, tree_a, tree_b, *, added, removed, modified, path=""):
        dir_a = remote_execution_pb2.Directory()
@@ -1150,183 +1090,6 @@ class CASCache():
        batch.send()


# Represents a single remote CAS cache.
#
class CASRemote():
    def __init__(self, spec):
        self.spec = spec
        self._initialized = False
        self.channel = None
        self.bytestream = None
        self.cas = None
        self.ref_storage = None
        self.batch_update_supported = None
        self.batch_read_supported = None
        self.capabilities = None
        self.max_batch_total_size_bytes = None

    def init(self):
        if not self._initialized:
            url = urlparse(self.spec.url)
            if url.scheme == 'http':
                port = url.port or 80
                self.channel = grpc.insecure_channel('{}:{}'.format(url.hostname, port))
            elif url.scheme == 'https':
                port = url.port or 443

                if self.spec.server_cert:
                    with open(self.spec.server_cert, 'rb') as f:
                        server_cert_bytes = f.read()
                else:
                    server_cert_bytes = None

                if self.spec.client_key:
                    with open(self.spec.client_key, 'rb') as f:
                        client_key_bytes = f.read()
                else:
                    client_key_bytes = None

                if self.spec.client_cert:
                    with open(self.spec.client_cert, 'rb') as f:
                        client_cert_bytes = f.read()
                else:
                    client_cert_bytes = None

                credentials = grpc.ssl_channel_credentials(root_certificates=server_cert_bytes,
                                                           private_key=client_key_bytes,
                                                           certificate_chain=client_cert_bytes)
                self.channel = grpc.secure_channel('{}:{}'.format(url.hostname, port), credentials)
            else:
                raise CASError("Unsupported URL: {}".format(self.spec.url))

            self.bytestream = bytestream_pb2_grpc.ByteStreamStub(self.channel)
            self.cas = remote_execution_pb2_grpc.ContentAddressableStorageStub(self.channel)
            self.capabilities = remote_execution_pb2_grpc.CapabilitiesStub(self.channel)
            self.ref_storage = buildstream_pb2_grpc.ReferenceStorageStub(self.channel)

            self.max_batch_total_size_bytes = _MAX_PAYLOAD_BYTES
            try:
                request = remote_execution_pb2.GetCapabilitiesRequest(instance_name=self.spec.instance_name)
                response = self.capabilities.GetCapabilities(request)
                server_max_batch_total_size_bytes = response.cache_capabilities.max_batch_total_size_bytes
                if 0 < server_max_batch_total_size_bytes < self.max_batch_total_size_bytes:
                    self.max_batch_total_size_bytes = server_max_batch_total_size_bytes
            except grpc.RpcError as e:
                # Simply use the defaults for servers that don't implement GetCapabilities()
                if e.code() != grpc.StatusCode.UNIMPLEMENTED:
                    raise

            # Check whether the server supports BatchReadBlobs()
            self.batch_read_supported = False
            try:
                request = remote_execution_pb2.BatchReadBlobsRequest(instance_name=self.spec.instance_name)
                response = self.cas.BatchReadBlobs(request)
                self.batch_read_supported = True
            except grpc.RpcError as e:
                if e.code() != grpc.StatusCode.UNIMPLEMENTED:
                    raise

            # Check whether the server supports BatchUpdateBlobs()
            self.batch_update_supported = False
            try:
                request = remote_execution_pb2.BatchUpdateBlobsRequest(instance_name=self.spec.instance_name)
                response = self.cas.BatchUpdateBlobs(request)
                self.batch_update_supported = True
            except grpc.RpcError as e:
                if (e.code() != grpc.StatusCode.UNIMPLEMENTED and
                        e.code() != grpc.StatusCode.PERMISSION_DENIED):
                    raise

            self._initialized = True


# Represents a batch of blobs queued for fetching.
#
class _CASBatchRead():
    def __init__(self, remote):
        self._remote = remote
        self._max_total_size_bytes = remote.max_batch_total_size_bytes
        self._request = remote_execution_pb2.BatchReadBlobsRequest(instance_name=remote.spec.instance_name)
        self._size = 0
        self._sent = False

    def add(self, digest):
        assert not self._sent

        new_batch_size = self._size + digest.size_bytes
        if new_batch_size > self._max_total_size_bytes:
            # Not enough space left in current batch
            return False

        request_digest = self._request.digests.add()
        request_digest.hash = digest.hash
        request_digest.size_bytes = digest.size_bytes
        self._size = new_batch_size
        return True

    def send(self):
        assert not self._sent
        self._sent = True

        if not self._request.digests:
            return

        batch_response = self._remote.cas.BatchReadBlobs(self._request)

        for response in batch_response.responses:
            if response.status.code == code_pb2.NOT_FOUND:
                raise BlobNotFound(response.digest.hash, "Failed to download blob {}: {}".format(
                    response.digest.hash, response.status.code))
            if response.status.code != code_pb2.OK:
                raise CASError("Failed to download blob {}: {}".format(
                    response.digest.hash, response.status.code))
            if response.digest.size_bytes != len(response.data):
                raise CASError("Failed to download blob {}: expected {} bytes, received {} bytes".format(
                    response.digest.hash, response.digest.size_bytes, len(response.data)))

            yield (response.digest, response.data)


# Represents a batch of blobs queued for upload.
#
class _CASBatchUpdate():
    def __init__(self, remote):
        self._remote = remote
        self._max_total_size_bytes = remote.max_batch_total_size_bytes
        self._request = remote_execution_pb2.BatchUpdateBlobsRequest(instance_name=remote.spec.instance_name)
        self._size = 0
        self._sent = False

    def add(self, digest, stream):
        assert not self._sent

        new_batch_size = self._size + digest.size_bytes
        if new_batch_size > self._max_total_size_bytes:
            # Not enough space left in current batch
            return False

        blob_request = self._request.requests.add()
        blob_request.digest.hash = digest.hash
        blob_request.digest.size_bytes = digest.size_bytes
        blob_request.data = stream.read(digest.size_bytes)
        self._size = new_batch_size
        return True

    def send(self):
        assert not self._sent
        self._sent = True

        if not self._request.requests:
            return

        batch_response = self._remote.cas.BatchUpdateBlobs(self._request)

        for response in batch_response.responses:
            if response.status.code != code_pb2.OK:
                raise CASError("Failed to upload blob {}: {}".format(
                    response.digest.hash, response.status.code))


def _grouper(iterable, n):
    while True:
        try:
+247 −0

File added.

Preview size limit exceeded, changes collapsed.

+15 −0
Original line number Diff line number Diff line
@@ -284,6 +284,21 @@ class CASError(BstError):
        super().__init__(message, detail=detail, domain=ErrorDomain.CAS, reason=reason, temporary=True)


# CASRemoteError
#
# Raised when errors are encountered in the remote CAS
class CASRemoteError(CASError):
    pass


# CASCacheError
#
# Raised when errors are encountered in the local CASCacheError
#
class CASCacheError(CASError):
    pass


# PipelineError
#
# Raised from pipeline operations