diff --git a/buildgrid/client/cas.py b/buildgrid/client/cas.py index a38c3de1c93b7543c66c66bdb03190e90eaaa02a..fc35b6bf29bcafe08ec84e30e86e890a43fac72d 100644 --- a/buildgrid/client/cas.py +++ b/buildgrid/client/cas.py @@ -19,6 +19,7 @@ import os import grpc +from buildgrid._exceptions import NotFoundError from buildgrid._protos.build.bazel.remote.execution.v2 import remote_execution_pb2, remote_execution_pb2_grpc from buildgrid._protos.google.bytestream import bytestream_pb2, bytestream_pb2_grpc from buildgrid._protos.google.rpc import code_pb2 @@ -26,6 +27,16 @@ from buildgrid.settings import HASH from buildgrid.utils import merkle_tree_maker +# Maximum size for a queueable file: +FILE_SIZE_THRESHOLD = 1 * 1024 * 1024 + +# Maximum size for a single gRPC request: +MAX_REQUEST_SIZE = 2 * 1024 * 1024 + +# Maximum number of elements per gRPC request: +MAX_REQUEST_COUNT = 500 + + class _CallCache: """Per remote grpc.StatusCode.UNIMPLEMENTED call cache.""" __calls = {} @@ -43,6 +54,397 @@ class _CallCache: return name in cls.__calls[channel] +@contextmanager +def download(channel, instance=None, u_uid=None): + """Context manager generator for the :class:`Downloader` class.""" + downloader = Downloader(channel, instance=instance) + try: + yield downloader + finally: + downloader.close() + + +class Downloader: + """Remote CAS files, directories and messages download helper. + + The :class:`Downloader` class comes with a generator factory function that + can be used together with the `with` statement for context management:: + + from buildgrid.client.cas import download + + with download(channel, instance='build') as downloader: + downloader.get_message(message_digest) + """ + + def __init__(self, channel, instance=None): + """Initializes a new :class:`Downloader` instance. + + Args: + channel (grpc.Channel): A gRPC channel to the CAS endpoint. + instance (str, optional): the targeted instance's name. + """ + self.channel = channel + + self.instance_name = instance + + self.__bytestream_stub = bytestream_pb2_grpc.ByteStreamStub(self.channel) + self.__cas_stub = remote_execution_pb2_grpc.ContentAddressableStorageStub(self.channel) + + self.__file_requests = {} + self.__file_request_count = 0 + self.__file_request_size = 0 + self.__file_response_size = 0 + + # --- Public API --- + + def get_blob(self, digest): + """Retrieves a blob from the remote CAS server. + + Args: + digest (:obj:`Digest`): the blob's digest to fetch. + + Returns: + bytearray: the fetched blob data or None if not found. + """ + try: + blob = self._fetch_blob(digest) + except NotFoundError: + return None + + return blob + + def get_blobs(self, digests): + """Retrieves a list of blobs from the remote CAS server. + + Args: + digests (list): list of :obj:`Digest`s for the blobs to fetch. + + Returns: + list: the fetched blob data list. + """ + return self._fetch_blob_batch(digests) + + def get_message(self, digest, message): + """Retrieves a :obj:`Message` from the remote CAS server. + + Args: + digest (:obj:`Digest`): the message's digest to fetch. + message (:obj:`Message`): an empty message to fill. + + Returns: + :obj:`Message`: `message` filled or emptied if not found. + """ + try: + message_blob = self._fetch_blob(digest) + except NotFoundError: + message_blob = None + + if message_blob is not None: + message.ParseFromString(message_blob) + else: + message.Clear() + + return message + + def get_messages(self, digests, messages): + """Retrieves a list of :obj:`Message`s from the remote CAS server. + + Note: + The `digests` and `messages` list **must** contain the same number + of elements. + + Args: + digests (list): list of :obj:`Digest`s for the messages to fetch. + messages (list): list of empty :obj:`Message`s to fill. + + Returns: + list: the fetched and filled message list. + """ + assert len(digests) == len(messages) + + message_blobs = self._fetch_blob_batch(digests) + + assert len(message_blobs) == len(messages) + + for message, message_blob in zip(messages, message_blobs): + message.ParseFromString(message_blob) + + return messages + + def download_file(self, digest, file_path, queue=True): + """Retrieves a file from the remote CAS server. + + If queuing is allowed (`queue=True`), the download request **may** be + defer. An explicit call to :func:`~flush` can force the request to be + send immediately (along with the rest of the queued batch). + + Args: + digest (:obj:`Digest`): the file's digest to fetch. + file_path (str): absolute or relative path to the local file to write. + queue (bool, optional): whether or not the download request may be + queued and submitted as part of a batch upload request. Defaults + to True. + + Raises: + NotFoundError: if `digest` is not present in the remote CAS server. + OSError: if `file_path` does not exist or is not readable. + """ + if not os.path.isabs(file_path): + file_path = os.path.abspath(file_path) + + if not queue or digest.size_bytes > FILE_SIZE_THRESHOLD: + self._fetch_file(digest, file_path) + else: + self._queue_file(digest, file_path) + + def download_directory(self, digest, directory_path): + """Retrieves a :obj:`Directory` from the remote CAS server. + + Args: + digest (:obj:`Digest`): the directory's digest to fetch. + + Raises: + NotFoundError: if `digest` is not present in the remote CAS server. + FileExistsError: if `directory_path` already contains parts of their + fetched directory's content. + """ + if not os.path.isabs(directory_path): + directory_path = os.path.abspath(directory_path) + + # We want to start fresh here, the rest is very synchronous... + self.flush() + + self._fetch_directory(digest, directory_path) + + def flush(self): + """Ensures any queued request gets sent.""" + if self.__file_requests: + self._fetch_file_batch(self.__file_requests) + + self.__file_requests.clear() + self.__file_request_count = 0 + self.__file_request_size = 0 + self.__file_response_size = 0 + + def close(self): + """Closes the underlying connection stubs. + + Note: + This will always send pending requests before closing connections, + if any. + """ + self.flush() + + self.__bytestream_stub = None + self.__cas_stub = None + + # --- Private API --- + + def _fetch_blob(self, digest): + """Fetches a blob using ByteStream.Read()""" + read_blob = bytearray() + + if self.instance_name is not None: + resource_name = '/'.join([self.instance_name, 'blobs', + digest.hash, str(digest.size_bytes)]) + else: + resource_name = '/'.join(['blobs', digest.hash, str(digest.size_bytes)]) + + read_request = bytestream_pb2.ReadRequest() + read_request.resource_name = resource_name + read_request.read_offset = 0 + + try: + # TODO: Handle connection loss/recovery + for read_response in self.__bytestream_stub.Read(read_request): + read_blob += read_response.data + + assert len(read_blob) == digest.size_bytes + + except grpc.RpcError as e: + status_code = e.code() + if status_code == grpc.StatusCode.NOT_FOUND: + raise NotFoundError("Requested data does not exist on the remote.") + + else: + assert False + + return read_blob + + def _fetch_blob_batch(self, digests): + """Fetches blobs using ContentAddressableStorage.BatchReadBlobs()""" + batch_fetched = False + read_blobs = [] + + # First, try BatchReadBlobs(), if not already known not being implemented: + if not _CallCache.unimplemented(self.channel, 'BatchReadBlobs'): + batch_request = remote_execution_pb2.BatchReadBlobsRequest() + batch_request.digests.extend(digests) + if self.instance_name is not None: + batch_request.instance_name = self.instance_name + + try: + batch_response = self.__cas_stub.BatchReadBlobs(batch_request) + for response in batch_response.responses: + assert response.digest.hash in digests + + read_blobs.append(response.data) + + if response.status.code != code_pb2.OK: + assert False + + batch_fetched = True + + except grpc.RpcError as e: + status_code = e.code() + if status_code == grpc.StatusCode.UNIMPLEMENTED: + _CallCache.mark_unimplemented(self.channel, 'BatchReadBlobs') + + else: + assert False + + # Fallback to Read() if no BatchReadBlobs(): + if not batch_fetched: + for digest in digests: + read_blobs.append(self._fetch_blob(digest)) + + return read_blobs + + def _fetch_file(self, digest, file_path): + """Fetches a file using ByteStream.Read()""" + if self.instance_name is not None: + resource_name = '/'.join([self.instance_name, 'blobs', + digest.hash, str(digest.size_bytes)]) + else: + resource_name = '/'.join(['blobs', digest.hash, str(digest.size_bytes)]) + + read_request = bytestream_pb2.ReadRequest() + read_request.resource_name = resource_name + read_request.read_offset = 0 + + os.makedirs(os.path.dirname(file_path), exist_ok=True) + + with open(file_path, 'wb') as byte_file: + # TODO: Handle connection loss/recovery + for read_response in self.__bytestream_stub.Read(read_request): + byte_file.write(read_response.data) + + assert byte_file.tell() == digest.size_bytes + + def _queue_file(self, digest, file_path): + """Queues a file for later batch download""" + if self.__file_request_size + digest.ByteSize() > MAX_REQUEST_SIZE: + self.flush() + elif self.__file_response_size + digest.size_bytes > MAX_REQUEST_SIZE: + self.flush() + elif self.__file_request_count >= MAX_REQUEST_COUNT: + self.flush() + + self.__file_requests[digest.hash] = (digest, file_path) + self.__file_request_count += 1 + self.__file_request_size += digest.ByteSize() + self.__file_response_size += digest.size_bytes + + def _fetch_file_batch(self, batch): + """Sends queued data using ContentAddressableStorage.BatchReadBlobs()""" + batch_digests = [digest for digest, _ in batch.values()] + batch_blobs = self._fetch_blob_batch(batch_digests) + + for (_, file_path), file_blob in zip(batch.values(), batch_blobs): + os.makedirs(os.path.dirname(file_path), exist_ok=True) + + with open(file_path, 'wb') as byte_file: + byte_file.write(file_blob) + + def _fetch_directory(self, digest, directory_path): + """Fetches a file using ByteStream.GetTree()""" + # Better fail early if the local root path cannot be created: + os.makedirs(directory_path, exist_ok=True) + + directories = {} + directory_fetched = False + # First, try GetTree() if not known to be unimplemented yet: + if not _CallCache.unimplemented(self.channel, 'GetTree'): + tree_request = remote_execution_pb2.GetTreeRequest() + tree_request.root_digest.CopyFrom(digest) + tree_request.page_size = MAX_REQUEST_COUNT + if self.instance_name is not None: + tree_request.instance_name = self.instance_name + + try: + for tree_response in self.__cas_stub.GetTree(tree_request): + for directory in tree_response.directories: + directory_blob = directory.SerializeToString() + directory_hash = HASH(directory_blob).hexdigest() + + directories[directory_hash] = directory + + assert digest.hash in directories + + directory = directories[digest.hash] + self._write_directory(digest.hash, directory_path, + directories=directories, root_barrier=directory_path) + + directory_fetched = True + + except grpc.RpcError as e: + status_code = e.code() + if status_code == grpc.StatusCode.UNIMPLEMENTED: + _CallCache.mark_unimplemented(self.channel, 'BatchUpdateBlobs') + + elif status_code == grpc.StatusCode.NOT_FOUND: + raise NotFoundError("Requested directory does not exist on the remote.") + + else: + assert False + + # TODO: Try with BatchReadBlobs(). + + # Fallback to Read() if no GetTree(): + if not directory_fetched: + directory = remote_execution_pb2.Directory() + directory.ParseFromString(self._fetch_blob(digest)) + + self._write_directory(directory, directory_path, + root_barrier=directory_path) + + def _write_directory(self, root_directory, root_path, directories=None, root_barrier=None): + """Generates a local directory structure""" + for file_node in root_directory.files: + file_path = os.path.join(root_path, file_node.name) + + self._queue_file(file_node.digest, file_path) + + for directory_node in root_directory.directories: + directory_path = os.path.join(root_path, directory_node.name) + if directories and directory_node.digest.hash in directories: + directory = directories[directory_node.digest.hash] + else: + directory = remote_execution_pb2.Directory() + directory.ParseFromString(self._fetch_blob(directory_node.digest)) + + os.makedirs(directory_path, exist_ok=True) + + self._write_directory(directory, directory_path, + directories=directories, root_barrier=root_barrier) + + for symlink_node in root_directory.symlinks: + symlink_path = os.path.join(root_path, symlink_node.name) + if not os.path.isabs(symlink_node.target): + target_path = os.path.join(root_path, symlink_node.target) + else: + target_path = symlink_node.target + target_path = os.path.normpath(target_path) + + # Do not create links pointing outside the barrier: + if root_barrier is not None: + common_path = os.path.commonprefix([root_barrier, target_path]) + if not common_path.startswith(root_barrier): + continue + + os.symlink(symlink_path, target_path) + + @contextmanager def upload(channel, instance=None, u_uid=None): """Context manager generator for the :class:`Uploader` class.""" @@ -63,16 +465,8 @@ class Uploader: with upload(channel, instance='build') as uploader: uploader.upload_file('/path/to/local/file') - - Attributes: - FILE_SIZE_THRESHOLD (int): maximum size for a queueable file. - MAX_REQUEST_SIZE (int): maximum size for a single gRPC request. """ - FILE_SIZE_THRESHOLD = 1 * 1024 * 1024 - MAX_REQUEST_SIZE = 2 * 1024 * 1024 - MAX_REQUEST_COUNT = 500 - def __init__(self, channel, instance=None, u_uid=None): """Initializes a new :class:`Uploader` instance. @@ -115,7 +509,7 @@ class Uploader: Returns: :obj:`Digest`: the sent blob's digest. """ - if not queue or len(blob) > Uploader.FILE_SIZE_THRESHOLD: + if not queue or len(blob) > FILE_SIZE_THRESHOLD: blob_digest = self._send_blob(blob, digest=digest) else: blob_digest = self._queue_blob(blob, digest=digest) @@ -141,7 +535,7 @@ class Uploader: """ message_blob = message.SerializeToString() - if not queue or len(message_blob) > Uploader.FILE_SIZE_THRESHOLD: + if not queue or len(message_blob) > FILE_SIZE_THRESHOLD: message_digest = self._send_blob(message_blob, digest=digest) else: message_digest = self._queue_blob(message_blob, digest=digest) @@ -174,7 +568,7 @@ class Uploader: with open(file_path, 'rb') as bytes_steam: file_bytes = bytes_steam.read() - if not queue or len(file_bytes) > Uploader.FILE_SIZE_THRESHOLD: + if not queue or len(file_bytes) > FILE_SIZE_THRESHOLD: file_digest = self._send_blob(file_bytes) else: file_digest = self._queue_blob(file_bytes) @@ -347,9 +741,9 @@ class Uploader: blob_digest.hash = HASH(blob).hexdigest() blob_digest.size_bytes = len(blob) - if self.__request_size + blob_digest.size_bytes > Uploader.MAX_REQUEST_SIZE: + if self.__request_size + blob_digest.size_bytes > MAX_REQUEST_SIZE: self.flush() - elif self.__request_count >= Uploader.MAX_REQUEST_COUNT: + elif self.__request_count >= MAX_REQUEST_COUNT: self.flush() self.__requests[blob_digest.hash] = (blob, blob_digest)