Skip to content
Snippets Groups Projects
Commit ae4bf5f2 authored by Martin Blanchard's avatar Martin Blanchard
Browse files

client/cas.py: Introduce CAS downloader helper class

#79
parent c40ede7a
No related branches found
No related tags found
Loading
...@@ -19,12 +19,23 @@ import os ...@@ -19,12 +19,23 @@ import os
import grpc import grpc
from buildgrid._exceptions import NotFoundError
from buildgrid._protos.build.bazel.remote.execution.v2 import remote_execution_pb2, remote_execution_pb2_grpc from buildgrid._protos.build.bazel.remote.execution.v2 import remote_execution_pb2, remote_execution_pb2_grpc
from buildgrid._protos.google.bytestream import bytestream_pb2, bytestream_pb2_grpc from buildgrid._protos.google.bytestream import bytestream_pb2, bytestream_pb2_grpc
from buildgrid._protos.google.rpc import code_pb2 from buildgrid._protos.google.rpc import code_pb2
from buildgrid.settings import HASH from buildgrid.settings import HASH
# Maximum size for a queueable file:
FILE_SIZE_THRESHOLD = 1 * 1024 * 1024
# Maximum size for a single gRPC request:
MAX_REQUEST_SIZE = 2 * 1024 * 1024
# Maximum number of elements per gRPC request:
MAX_REQUEST_COUNT = 500
class CallCache: class CallCache:
"""Per remote grpc.StatusCode.UNIMPLEMENTED call cache.""" """Per remote grpc.StatusCode.UNIMPLEMENTED call cache."""
__calls = {} __calls = {}
...@@ -42,6 +53,399 @@ class CallCache: ...@@ -42,6 +53,399 @@ class CallCache:
return name in cls.__calls[channel] return name in cls.__calls[channel]
@contextmanager
def download(channel, instance=None, u_uid=None):
downloader = Downloader(channel, instance=instance)
try:
yield downloader
finally:
downloader.close()
class Downloader:
"""Remote CAS files, directories and messages download helper.
The :class:`Downloader` class comes with a generator factory function that
can be used together with the `with` statement for context management::
with download(channel, instance='build') as cas:
cas.get_message(message_digest)
"""
def __init__(self, channel, instance=None):
"""Initializes a new :class:`Downloader` instance.
Args:
channel (grpc.Channel): A gRPC channel to the CAS endpoint.
instance (str, optional): the targeted instance's name.
"""
self.channel = channel
self.instance_name = instance
self.__bytestream_stub = bytestream_pb2_grpc.ByteStreamStub(self.channel)
self.__cas_stub = remote_execution_pb2_grpc.ContentAddressableStorageStub(self.channel)
self.__file_requests = {}
self.__file_request_count = 0
self.__file_request_size = 0
self.__file_response_size = 0
# --- Public API ---
def get_blob(self, digest):
"""Retrieves a blob from the remote CAS server.
Args:
digest (:obj:`Digest`): the blob's digest to fetch.
Returns:
bytearray: the fetched blob data or None if not found.
"""
try:
blob = self._fetch_blob(digest)
except NotFoundError:
return None
return blob
def get_blobs(self, digests):
"""Retrieves a list of blobs from the remote CAS server.
Args:
digests (list): list of :obj:`Digest`s for the blobs to fetch.
Returns:
list: the fetched blob data list.
"""
return self._fetch_blob_batch(digests)
def get_message(self, digest, message):
"""Retrieves a :obj:`Message` from the remote CAS server.
Args:
digest (:obj:`Digest`): the message's digest to fetch.
message (:obj:`Message`): an empty message to fill.
Returns:
:obj:`Message`: `message` filled or emptied if not found.
"""
try:
message_blob = self._fetch_blob(digest)
except NotFoundError:
message_blob = None
if message_blob is not None:
message.ParseFromString(message_blob)
else:
message.Clear()
return message
def get_messages(self, digests, messages):
"""Retrieves a list of :obj:`Message`s from the remote CAS server.
Note:
The `digests` and `messages` list **must** contain the same number
of elements.
Args:
digests (list): list of :obj:`Digest`s for the messages to fetch.
messages (list): list of empty :obj:`Message`s to fill.
Returns:
list: the fetched and filled message list.
"""
assert len(digests) == len(messages)
message_blobs = self._fetch_blob_batch(digests)
assert len(message_blobs) == len(messages)
for message, message_blob in zip(messages, message_blobs):
message.ParseFromString(message_blob)
return messages
def download_file(self, digest, file_path, queue=True):
"""Retrieves a file from the remote CAS server.
If queuing is allowed (`queue=True`), the download request **may** be
defer. An explicit call to :method:`flush` can force the request to be
send immediately (along with the rest of the queued batch).
Args:
digest (:obj:`Digest`): the file's digest to fetch.
file_path (str): absolute or relative path to the local file to write.
queue (bool, optional): whether or not the download request may be
queued and submitted as part of a batch upload request. Defaults
to True.
Raises:
NotFoundError: if `digest` is not present in the remote CAS server.
OSError: if `file_path` does not exist or is not readable.
"""
if not os.path.isabs(file_path):
file_path = os.path.abspath(file_path)
if not queue or digest.size_bytes > FILE_SIZE_THRESHOLD:
self._fetch_file(digest, file_path)
else:
self._queue_file(digest, file_path)
def download_directory(self, digest, directory_path):
"""Retrieves a :obj:`Directory` from the remote CAS server.
Args:
digest (:obj:`Digest`): the directory's digest to fetch.
Returns:
:obj:`Digest`: The digest of the :obj:`Directory`.
directory_path (str): absolute or relative path to the local
directory to write.
Raises:
NotFoundError: if `digest` is not present in the remote CAS server.
FileExistsError: if `directory_path` already contains parts of their
fetched directory's content.
"""
if not os.path.isabs(directory_path):
directory_path = os.path.abspath(directory_path)
# We want to start fresh here, the rest is very synchronous...
self.flush()
self._fetch_directory(digest, directory_path)
def flush(self):
"""Ensures any queued request gets sent."""
if self.__file_requests:
self._fetch_file_batch(self.__file_requests)
self.__file_requests.clear()
self.__file_request_count = 0
self.__file_request_size = 0
self.__file_response_size = 0
def close(self):
"""Closes the underlying connection stubs.
Note:
This will always send pending requests before closing connections,
if any.
"""
self.flush()
self.__bytestream_stub = None
self.__cas_stub = None
# --- Private API ---
def _fetch_blob(self, digest):
"""Fetches a blob using ByteStream.Read()"""
read_blob = bytearray()
if self.instance_name is not None:
resource_name = '/'.join([self.instance_name, 'blobs',
digest.hash, str(digest.size_bytes)])
else:
resource_name = '/'.join(['blobs', digest.hash, str(digest.size_bytes)])
read_request = bytestream_pb2.ReadRequest()
read_request.resource_name = resource_name
read_request.read_offset = 0
try:
# TODO: Handle connection loss/recovery
for read_response in self.__bytestream_stub.Read(read_request):
read_blob += read_response.data
assert len(read_blob) == digest.size_bytes
except grpc.RpcError as e:
status_code = e.code()
if status_code == grpc.StatusCode.NOT_FOUND:
raise NotFoundError("Requested data does not exist on the remote.")
else:
assert False
return read_blob
def _fetch_blob_batch(self, digests):
"""Fetches blobs using ContentAddressableStorage.BatchReadBlobs()"""
batch_fetched = False
read_blobs = []
# First, try BatchReadBlobs(), if not already known not being implemented:
if not CallCache.unimplemented(self.channel, 'BatchReadBlobs'):
batch_request = remote_execution_pb2.BatchReadBlobsRequest()
batch_request.digests.extend(digests)
if self.instance_name is not None:
batch_request.instance_name = self.instance_name
try:
batch_response = self.__cas_stub.BatchReadBlobs(batch_request)
for response in batch_response.responses:
assert response.digest.hash in digests
read_blobs.append(response.data)
if response.status.code != code_pb2.OK:
assert False
batch_fetched = True
except grpc.RpcError as e:
status_code = e.code()
if status_code == grpc.StatusCode.UNIMPLEMENTED:
CallCache.mark_unimplemented(self.channel, 'BatchReadBlobs')
else:
assert False
# Fallback to Read() if no BatchReadBlobs():
if not batch_fetched:
for digest in digests:
read_blobs.append(self._fetch_blob(digest))
return read_blobs
def _fetch_file(self, digest, file_path):
"""Fetches a file using ByteStream.Read()"""
if self.instance_name is not None:
resource_name = '/'.join([self.instance_name, 'blobs',
digest.hash, str(digest.size_bytes)])
else:
resource_name = '/'.join(['blobs', digest.hash, str(digest.size_bytes)])
read_request = bytestream_pb2.ReadRequest()
read_request.resource_name = resource_name
read_request.read_offset = 0
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, 'wb') as byte_file:
# TODO: Handle connection loss/recovery
for read_response in self.__bytestream_stub.Read(read_request):
byte_file.write(read_response.data)
assert byte_file.tell() == digest.size_bytes
def _queue_file(self, digest, file_path):
"""Queues a file for later batch download"""
if self.__file_request_size + digest.ByteSize() > MAX_REQUEST_SIZE:
self.flush()
elif self.__file_response_size + digest.size_bytes > MAX_REQUEST_SIZE:
self.flush()
elif self.__file_request_count >= MAX_REQUEST_COUNT:
self.flush()
self.__file_requests[digest.hash] = (digest, file_path)
self.__file_request_count += 1
self.__file_request_size += digest.ByteSize()
self.__file_response_size += digest.size_bytes
def _fetch_file_batch(self, batch):
"""Sends queued data using ContentAddressableStorage.BatchReadBlobs()"""
batch_digests = [digest for digest, _ in batch.values()]
batch_blobs = self._fetch_blob_batch(batch_digests)
for (_, file_path), file_blob in zip(batch.values(), batch_blobs):
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, 'wb') as byte_file:
byte_file.write(file_blob)
def _fetch_directory(self, digest, directory_path):
"""Fetches a file using ByteStream.GetTree()"""
# Better fail early if the local root path cannot be created:
os.makedirs(directory_path, exist_ok=True)
directories = {}
directory_fetched = False
# First, try GetTree() if not known to be unimplemented yet:
if not CallCache.unimplemented(self.channel, 'GetTree'):
tree_request = remote_execution_pb2.GetTreeRequest()
tree_request.root_digest.CopyFrom(digest)
tree_request.page_size = MAX_REQUEST_COUNT
if self.instance_name is not None:
tree_request.instance_name = self.instance_name
try:
for tree_response in self.__cas_stub.GetTree(tree_request):
for directory in tree_response.directories:
directory_blob = directory.SerializeToString()
directory_hash = HASH(directory_blob).hexdigest()
directories[directory_hash] = directory
assert digest.hash in directories
directory = directories[digest.hash]
self._write_directory(digest.hash, directory_path,
directories=directories, root_barrier=directory_path)
directory_fetched = True
except grpc.RpcError as e:
status_code = e.code()
if status_code == grpc.StatusCode.UNIMPLEMENTED:
CallCache.mark_unimplemented(self.channel, 'BatchUpdateBlobs')
elif status_code == grpc.StatusCode.NOT_FOUND:
raise NotFoundError("Requested directory does not exist on the remote.")
else:
assert False
# TODO: Try with BatchReadBlobs().
# Fallback to Read() if no GetTree():
if not directory_fetched:
directory = remote_execution_pb2.Directory()
directory.ParseFromString(self._fetch_blob(digest))
self._write_directory(directory, directory_path,
root_barrier=directory_path)
def _write_directory(self, root_directory, root_path, directories=None, root_barrier=None):
"""Generates a local directory structure"""
for file_node in root_directory.files:
file_path = os.path.join(root_path, file_node.name)
self._queue_file(file_node.digest, file_path)
for directory_node in root_directory.directories:
directory_path = os.path.join(root_path, directory_node.name)
if directories and directory_node.digest.hash in directories:
directory = directories[directory_node.digest.hash]
else:
directory = remote_execution_pb2.Directory()
directory.ParseFromString(self._fetch_blob(directory_node.digest))
os.makedirs(directory_path, exist_ok=True)
self._write_directory(directory, directory_path,
directories=directories, root_barrier=root_barrier)
for symlink_node in root_directory.symlinks:
symlink_path = os.path.join(root_path, symlink_node.name)
if not os.path.isabs(symlink_node.target):
target_path = os.path.join(root_path, symlink_node.target)
else:
target_path = symlink_node.target
target_path = os.path.normpath(target_path)
# Do not create links pointing outside the barrier:
if root_barrier is not None:
common_path = os.path.commonprefix([root_barrier, target_path])
if not common_path.startswith(root_barrier):
continue
os.symlink(symlink_path, target_path)
@contextmanager @contextmanager
def upload(channel, instance=None, u_uid=None): def upload(channel, instance=None, u_uid=None):
uploader = Uploader(channel, instance=instance, u_uid=u_uid) uploader = Uploader(channel, instance=instance, u_uid=u_uid)
...@@ -59,16 +463,8 @@ class Uploader: ...@@ -59,16 +463,8 @@ class Uploader:
with upload(channel, instance='build') as cas: with upload(channel, instance='build') as cas:
cas.upload_file('/path/to/local/file') cas.upload_file('/path/to/local/file')
Attributes:
FILE_SIZE_THRESHOLD (int): maximum size for a queueable file.
MAX_REQUEST_SIZE (int): maximum size for a single gRPC request.
""" """
FILE_SIZE_THRESHOLD = 1 * 1024 * 1024
MAX_REQUEST_SIZE = 2 * 1024 * 1024
MAX_REQUEST_COUNT = 500
def __init__(self, channel, instance=None, u_uid=None): def __init__(self, channel, instance=None, u_uid=None):
"""Initializes a new :class:`Uploader` instance. """Initializes a new :class:`Uploader` instance.
...@@ -111,7 +507,7 @@ class Uploader: ...@@ -111,7 +507,7 @@ class Uploader:
Returns: Returns:
:obj:`Digest`: the sent blob's digest. :obj:`Digest`: the sent blob's digest.
""" """
if not queue or len(blob) > Uploader.FILE_SIZE_THRESHOLD: if not queue or len(blob) > FILE_SIZE_THRESHOLD:
blob_digest = self._send_blob(blob, digest=digest) blob_digest = self._send_blob(blob, digest=digest)
else: else:
blob_digest = self._queue_blob(blob, digest=digest) blob_digest = self._queue_blob(blob, digest=digest)
...@@ -137,7 +533,7 @@ class Uploader: ...@@ -137,7 +533,7 @@ class Uploader:
""" """
message_blob = message.SerializeToString() message_blob = message.SerializeToString()
if not queue or len(message_blob) > Uploader.FILE_SIZE_THRESHOLD: if not queue or len(message_blob) > FILE_SIZE_THRESHOLD:
message_digest = self._send_blob(message_blob, digest=digest) message_digest = self._send_blob(message_blob, digest=digest)
else: else:
message_digest = self._queue_blob(message_blob, digest=digest) message_digest = self._queue_blob(message_blob, digest=digest)
...@@ -169,7 +565,7 @@ class Uploader: ...@@ -169,7 +565,7 @@ class Uploader:
with open(file_path, 'rb') as bytes_steam: with open(file_path, 'rb') as bytes_steam:
file_bytes = bytes_steam.read() file_bytes = bytes_steam.read()
if not queue or len(file_bytes) > Uploader.FILE_SIZE_THRESHOLD: if not queue or len(file_bytes) > FILE_SIZE_THRESHOLD:
file_digest = self._send_blob(file_bytes) file_digest = self._send_blob(file_bytes)
else: else:
file_digest = self._queue_blob(file_bytes) file_digest = self._queue_blob(file_bytes)
...@@ -274,9 +670,9 @@ class Uploader: ...@@ -274,9 +670,9 @@ class Uploader:
blob_digest.hash = HASH(blob).hexdigest() blob_digest.hash = HASH(blob).hexdigest()
blob_digest.size_bytes = len(blob) blob_digest.size_bytes = len(blob)
if self.__request_size + blob_digest.size_bytes > Uploader.MAX_REQUEST_SIZE: if self.__request_size + blob_digest.size_bytes > MAX_REQUEST_SIZE:
self.flush() self.flush()
elif self.__request_count >= Uploader.MAX_REQUEST_COUNT: elif self.__request_count >= MAX_REQUEST_COUNT:
self.flush() self.flush()
self.__requests[blob_digest.hash] = (blob, blob_digest) self.__requests[blob_digest.hash] = (blob, blob_digest)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment