Skip to content
Snippets Groups Projects
Commit 0e32837c authored by Martin Blanchard's avatar Martin Blanchard
Browse files

client/cas.py: Introduce CAS downloader helper class

#79
parent 82b9806d
No related branches found
No related tags found
No related merge requests found
......@@ -14,12 +14,452 @@
from contextlib import contextmanager
import io
import uuid
import os
import stat
from buildgrid.settings import HASH
import grpc
from buildgrid._exceptions import NotFoundError
from buildgrid._protos.build.bazel.remote.execution.v2 import remote_execution_pb2, remote_execution_pb2_grpc
from buildgrid._protos.google.bytestream import bytestream_pb2, bytestream_pb2_grpc
from buildgrid._protos.google.rpc import code_pb2
from buildgrid.settings import HASH
from buildgrid.utils import write_file
# Maximum size for a queueable file:
__FILE_SIZE_THRESHOLD = 1 * 1024 * 1024
# Maximum size for a single gRPC request:
__MAX_REQUEST_SIZE = 2 * 1024 * 1024
# Maximum number of elements per gRPC request:
__MAX_REQUEST_COUNT = 500
class CallCache:
"""Per remote grpc.StatusCode.UNIMPLEMENTED call cache."""
__calls = dict()
@classmethod
def mark_unimplemented(cls, channel, name):
if channel not in cls.__calls:
cls.__calls[channel] = set()
cls.__calls[channel].add(name)
@classmethod
def unimplemented(cls, channel, name):
if channel not in cls.__calls:
return True
return name in cls.__calls[channel]
@contextmanager
def download(channel, instance=None, u_uid=None):
downloader = Downloader(channel, instance=instance)
try:
yield downloader
finally:
downloader.close()
class Downloader:
"""Remote CAS files, directories and messages download helper.
The :class:`Downloader` class comes with a generator factory function that
can be used together with the `with` statement for context management::
with download(channel, instance='build') as cas:
cas.get_message(message_digest)
"""
def __init__(self, channel, instance=None):
"""Initializes a new :class:`Downloader` instance.
Args:
channel (grpc.Channel): A gRPC channel to the CAS endpoint.
instance (str, optional): the targeted instance's name.
"""
self.channel = channel
self.instance_name = instance
self.__bytestream_stub = bytestream_pb2_grpc.ByteStreamStub(self.channel)
self.__cas_stub = remote_execution_pb2_grpc.ContentAddressableStorageStub(self.channel)
self.__file_requests = dict()
self.__file_request_count = 0
self.__file_request_size = 0
self.__file_response_size = 0
## Public API:
def get_blob(self, digest):
"""Retrieves a blob from the remote CAS server.
Args:
digest (:obj:`Digest`): the blob's digest to fetch.
Returns:
bytearray: the fetched blob data or None if not found.
"""
try:
blob = self._fetch_blob(digest)
except NotFoundError:
return None
return blob
def get_blobs(self, digests):
"""Retrieves a list of blobs from the remote CAS server.
Args:
digests (list): list of :obj:`Digest`s for the blobs to fetch.
Returns:
list: the fetched blob data list.
"""
return self._fetch_blob_batch(digests)
def get_message(self, digest, message):
"""Retrieves a :obj:`Message` from the remote CAS server.
Args:
digest (:obj:`Digest`): the message's digest to fetch.
message (:obj:`Message`): an empty message to fill.
Returns:
:obj:`Message`: `message` filled or emptied if not found.
"""
try:
message_blob = self._fetch_blob(digest)
except NotFoundError:
message_blob = None
if message_blob is not None:
message.ParseFromString(message_blob)
else:
message.Clear()
return message
def get_messages(self, digests, messages):
"""Retrieves a list of :obj:`Message`s from the remote CAS server.
Note:
The `digests` and `messages` list **must** contain the same number
of elements.
Args:
digests (list): list of :obj:`Digest`s for the messages to fetch.
messages (list): list of empty :obj:`Message`s to fill.
Returns:
list: the fetched and filled message list.
"""
assert len(digests) == len(messages)
message_blobs = self._fetch_blob_batch(digests)
assert len(message_blobs) == len(messages)
for message, message_blob in zip(messages, message_blobs):
message.ParseFromString(message_blob)
return messages
def download_file(self, digest, file_path, queue=True):
"""Retrieves a file from the remote CAS server.
If queuing is allowed (`queue=True`), the download request **may** be
defer. An explicit call to :method:`flush` can force the request to be
send immediately (along with the rest of the queued batch).
Args:
digest (:obj:`Digest`): the file's digest to fetch.
file_path (str): absolute or relative path to the local file to write.
queue (bool, optional): whether or not the download request may be
queued and submitted as part of a batch upload request. Defaults
to True.
Raises:
NotFoundError: if `digest` is not present in the remote CAS server.
OSError: if `file_path` does not exist or is not readable.
"""
if not os.path.isabs(file_path):
file_path = os.path.abspath(file_path)
if not queue or digest.size_bytes > __FILE_SIZE_THRESHOLD:
self._fetch_file(digest, file_path)
else:
self._queue_file(digest, file_path)
def download_directory(self, digest, directory_path):
"""Retrieves a :obj:`Directory` from the remote CAS server.
Args:
digest (:obj:`Digest`): the directory's digest to fetch.
Returns:
:obj:`Digest`: The digest of the :obj:`Directory`.
directory_path (str): absolute or relative path to the local
directory to write.
Raises:
NotFoundError: if `digest` is not present in the remote CAS server.
FileExistsError: if `directory_path` already contains parts of their
fetched directory's content.
"""
if not os.path.isabs(directory_path):
directory_path = os.path.abspath(directory_path)
# We want to start fresh here, the rest is very synchronous...
self.flush()
self._fetch_directory(digest, directory_path)
def flush(self):
"""Ensures any queued request gets sent."""
if self.__file_requests:
self._fetch_batch()
def close(self):
"""Closes the underlying connection stubs.
Note:
This will always send pending requests before closing connections,
if any.
"""
self.flush()
self.__bytestream_stub = None
self.__cas_stub = None
## Private API:
def _fetch_blob(self, digest):
"""Fetches a blob using ByteStream.Read()"""
read_blob = bytearray()
if self.instance_name is not None:
resource_name = '/'.join([self.instance_name, 'blobs',
digest.hash, str(digest.size_bytes)])
else:
resource_name = '/'.join(['blobs', digest.hash, str(digest.size_bytes)])
read_request = bytestream_pb2.ReadRequest()
read_request.resource_name = resource_name
read_request.read_offset = 0
try:
# TODO: Handle connection loss/recovery
for read_response in self.__bytestream_stub.Read(read_request):
read_blob += read_response.data
assert len(read_blob) == digest.size_bytes
except grpc.RpcError as e:
status_code = e.code()
if status_code == grpc.StatusCode.NOT_FOUND:
raise NotFoundError("Requested data does not exist on the remote.")
else:
assert False
return read_blob
def _fetch_blob_batch(self, digests):
"""Fetches blobs using ContentAddressableStorage.BatchReadBlobs()"""
batch_fetched = False
read_blobs = list()
# First, try BatchReadBlobs(), if not already known not being implemented:
if not CallCache.unimplemented(self.channel, 'BatchReadBlobs'):
batch_request = remote_execution_pb2.BatchReadBlobsRequest()
batch_request.digests.extend(digests)
if self.instance_name is not None:
batch_request.instance_name = self.instance_name
try:
batch_response = self.__cas_stub.BatchReadBlobs(batch_request)
for response in batch_response.responses:
assert response.digest.hash in digests
read_blobs.append(response.data)
if response.status.code != code_pb2.OK:
assert False
batch_fetched = True
except grpc.RpcError as e:
status_code = e.code()
if status_code == grpc.StatusCode.UNIMPLEMENTED:
CallCache.mark_unimplemented(self.channel, 'BatchReadBlobs')
else:
assert False
# Fallback to Read() if no BatchReadBlobs():
if not batch_fetched:
for digest in digests:
read_blobs.append(self._fetch_blob(digest))
return read_blobs
def _fetch_file(self, digest, file_path):
"""Fetches a file using ByteStream.Read()"""
if self.instance_name is not None:
resource_name = '/'.join([self.instance_name, 'blobs',
digest.hash, str(digest.size_bytes)])
else:
resource_name = '/'.join(['blobs', digest.hash, str(digest.size_bytes)])
read_request = bytestream_pb2.ReadRequest()
read_request.resource_name = resource_name
read_request.read_offset = 0
with open(file_path, 'wb') as byte_file:
# TODO: Handle connection loss/recovery
for read_response in self.__bytestream_stub.Read(read_request):
byte_file.write(read_response.data)
assert byte_file.tell() == digest.size_bytes
def _queue_file(self, digest, file_path):
"""Queues a file for later batch download"""
if self.__file_request_size + digest.ByteSize() > __MAX_REQUEST_SIZE:
self._fetch_file_batch(self.__file_requests.values())
elif self.__file_response_size + digest.size_bytes > __MAX_REQUEST_SIZE:
self._fetch_file_batch(self.__file_requests.values())
elif self.__file_request_count >= __MAX_REQUEST_COUNT:
self._fetch_file_batch(self.__file_requests.values())
self.__file_requests[digest.hash] = (digest, file_path)
self.__file_request_count += 1
self.__file_request_size += digest.ByteSize()
self.__file_response_size += digest.size_bytes
def _fetch_file_batch(self, digests_paths):
"""Sends queued data using ContentAddressableStorage.BatchReadBlobs()"""
batch_digests = [digest for digest, _ in digests_paths]
batch_blobs = self._fetch_blob_batch(batch_digests)
for (_, file_path), file_blob in zip(digests_paths, batch_blobs):
self._write_file(file_blob, file_path)
self.__file_requests.clear()
self.__file_request_count = 0
self.__file_request_size = 0
self.__file_response_size = 0
def _write_file(self, blob, file_path, create_parent=False):
"""Dumps a memory blob to a local file"""
if create_parent:
os.makedirs(os.path.dirname(file_path), exist_ok=True)
write_file(file_path, blob)
def _fetch_directory(self, digest, directory_path):
"""Fetches a file using ByteStream.GetTree()"""
# Better fail early if the local root path cannot be created:
os.makedirs(directory_path, exist_ok=True)
directories = dict()
directory_fetched = False
# First, try GetTree() if not known to be unimplemented yet:
if not CallCache.unimplemented(self.channel, 'GetTree'):
tree_request = remote_execution_pb2.GetTreeRequest()
tree_request.root_digest.CopyFrom(digest)
tree_request.page_size = __MAX_REQUEST_COUNT
if self.instance_name is not None:
tree_request.instance_name = self.instance_name
try:
tree_fetched = False
while not tree_fetched:
tree_response = self.__cas_stub.GetTree(tree_request)
for directory in tree_response.directories:
directory_blob = directory.SerializeToString()
directory_hash = HASH(directory_blob).hexdigest()
directories[directory_hash] = directory
if tree_response.next_page_token:
tree_request = remote_execution_pb2.BatchReadBlobsRequest()
tree_request.root_digest.CopyFrom(digest)
tree_request.page_size = __MAX_REQUEST_COUNT
tree_request.page_token = tree_response.next_page_token
else:
tree_fetched = True
assert digest.hash in directories
except grpc.RpcError as e:
status_code = e.code()
if status_code == grpc.StatusCode.UNIMPLEMENTED:
CallCache.mark_unimplemented(self.channel, 'BatchUpdateBlobs')
elif status_code == grpc.StatusCode.NOT_FOUND:
raise NotFoundError("Requested directory does not exist on the remote.")
else:
assert False
directory = directories[digest.hash]
self._write_directory(digest.hash, directory_path,
directories=directories, root_barrier=directory_path)
directory_fetched = True
# TODO: Try with BatchReadBlobs().
# Fallback to Read() if no GetTree():
if not directory_fetched:
directory = remote_execution_pb2.Directory()
directory.ParseFromString(self._fetch_blob(digest))
self._write_directory(directory, directory_path,
root_barrier=directory_path)
def _write_directory(self, root_directory, root_path, directories=None, root_barrier=None):
"""Generates a local directory structure"""
for file_node in root_directory.files:
file_path = os.path.join(root_path, file_node.name)
self._queue_file(file_node.digest, file_path)
for directory_node in root_directory.directories:
directory_path = os.path.join(root_path, directory_node.name)
if directories and directory_node.digest.hash in directories:
directory = directories[directory_node.digest.hash]
else:
directory = remote_execution_pb2.Directory()
directory.ParseFromString(self._fetch_blob(digest))
os.makedirs(directory_path, exist_ok=True)
self._write_directory(directory, directory_path,
directories=directories, root_barrier=root_barrier)
for symlink_node in root_directory.symlinks:
symlink_path = os.path.join(root_path, symlink_node.name)
if not os.path.isabs(symlink_node.target):
target_path = os.path.join(root_path, symlink_node.target)
else:
target_path = symlink_node.target
target_path = os.path.normpath(target_path)
# Do not create links pointing outside the barrier:
if root_barrier is not None:
common_path = os.path.commonprefix([root_barrier, target_path])
if not common_path.startswith(root_barrier):
continue
os.symlink(symlink_path, target_path)
@contextmanager
......@@ -39,15 +479,8 @@ class Uploader:
with upload(channel, instance='build') as cas:
cas.upload_file('/path/to/local/file')
Attributes:
FILE_SIZE_THRESHOLD (int): maximum size for a queueable file.
MAX_REQUEST_SIZE (int): maximum size for a single gRPC request.
"""
FILE_SIZE_THRESHOLD = 1 * 1024 * 1024
MAX_REQUEST_SIZE = 2 * 1024 * 1024
def __init__(self, channel, instance=None, u_uid=None):
"""Initializes a new :class:`Uploader` instance.
......@@ -95,7 +528,7 @@ class Uploader:
with open(file_path, 'rb') as bytes_steam:
file_bytes = bytes_steam.read()
if not queue or len(file_bytes) > Uploader.FILE_SIZE_THRESHOLD:
if not queue or len(file_bytes) > __FILE_SIZE_THRESHOLD:
blob_digest = self._send_blob(file_bytes)
else:
blob_digest = self._queue_blob(file_bytes)
......@@ -148,7 +581,7 @@ class Uploader:
blob_digest.hash = HASH(blob).hexdigest()
blob_digest.size_bytes = len(blob)
if self.__request_size + len(blob) > Uploader.MAX_REQUEST_SIZE:
if self.__request_size + len(blob) > __MAX_REQUEST_SIZE:
self._send_batch()
update_request = remote_execution_pb2.BatchUpdateBlobsRequest.Request()
......@@ -156,7 +589,7 @@ class Uploader:
update_request.data = blob
update_request_size = update_request.ByteSize()
if self.__request_size + update_request_size > Uploader.MAX_REQUEST_SIZE:
if self.__request_size + update_request_size > __MAX_REQUEST_SIZE:
self._send_batch()
self.__requests[update_request.digest.hash] = update_request
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment