Commit ae2b4f6c authored by finn's avatar finn
Browse files

Added WaitExecution method for Remote Execution

parent 371e9a49
Loading
Loading
Loading
Loading
Loading
+21 −18
Original line number Diff line number Diff line
@@ -53,7 +53,6 @@ def cli(context, host, port):
@pass_context
def request(context, number, instance_name, wait_for_completion):
    action_digest = remote_execution_pb2.Digest()
    action_digest.hash = 'zhora'

    context.logger.info("Sending execution request...\n")
    stub = remote_execution_pb2_grpc.ExecutionStub(context.channel)
@@ -61,24 +60,16 @@ def request(context, number, instance_name, wait_for_completion):
    request = remote_execution_pb2.ExecuteRequest(instance_name = instance_name,
                                                  action_digest = action_digest,
                                                  skip_cache_lookup = True)
    responses = []
    for i in range(0, number):
        response = stub.Execute(request)
        for r in response:
            context.logger.info(r)
        responses.append(stub.Execute(request))

    try:
        while wait_for_completion:
            request = operations_pb2.ListOperationsRequest()
            context.logger.debug('Querying to see if jobs are complete.')
            stub = operations_pb2_grpc.OperationsStub(context.channel)
            response = stub.ListOperations(request)
            if all(operation.done for operation in response.operations):
                context.logger.info('Jobs complete')
                break
            time.sleep(1)

    except KeyboardInterrupt:
        pass
    for response in responses:
        if wait_for_completion:
            for stream in response:
                context.logger.info(stream)
        else:
            context.logger.info(next(response))

@cli.command('status', short_help='Get the status of an operation')
@click.argument('operation-name')
@@ -108,3 +99,15 @@ def list_operations(context):

    for op in response.operations:
        context.logger.info(op)

@cli.command('wait', short_help='Streams an operation until it is complete')
@click.argument('operation-name')
@pass_context
def wait_execution(context, operation_name):
    stub = remote_execution_pb2_grpc.ExecutionStub(context.channel)
    request = remote_execution_pb2.WaitExecutionRequest(name=operation_name)

    response = stub.WaitExecution(request)

    for stream in response:
        context.logger.info(stream)
+0 −3
Original line number Diff line number Diff line
@@ -50,7 +50,6 @@ class ExecutionInstance():
        return job.get_operation()

    def get_operation(self, name):
        self.logger.debug("Getting operation: {}".format(name))
        operation = self._scheduler.jobs.get(name)
        if operation is None:
            raise InvalidArgumentError("Operation name does not exist: {}".format(name))
@@ -60,11 +59,9 @@ class ExecutionInstance():
    def list_operations(self, name, list_filter, page_size, page_token):
        # TODO: Pages
        # Spec says number of pages and length of a page are optional
        self.logger.debug("Listing operations")
        return self._scheduler.get_operations()

    def delete_operation(self, name):
        self.logger.debug("Deleting operation {}".format(name))
        try:
            self._scheduler.jobs.pop(name)
        except KeyError:
+25 −4
Original line number Diff line number Diff line
@@ -22,8 +22,10 @@ ExecutionService
Serves remote execution requests.
"""

import copy
import grpc
import logging
import time

from buildgrid._protos.build.bazel.remote.execution.v2 import remote_execution_pb2, remote_execution_pb2_grpc
from buildgrid._protos.google.longrunning import operations_pb2_grpc, operations_pb2
@@ -40,17 +42,36 @@ class ExecutionService(remote_execution_pb2_grpc.ExecutionServicer):
        # Ignore request.instance_name for now
        # Have only one instance
        try:
            yield self._instance.execute(request.action_digest,
            operation = self._instance.execute(request.action_digest,
                                               request.skip_cache_lookup)

            yield from self._stream_operation_updates(operation.name)

        except InvalidArgumentError as e:
            self.logger.error(e)
            context.set_details(str(e))
            context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
            yield operations_pb2.Operation()

        except NotImplementedError as e:
            self.logger.error(e)
            context.set_details(str(e))
            context.set_code(grpc.StatusCode.UNIMPLEMENTED)
            yield operations_pb2.Operation()

    def WaitExecution(self, request, context):
        try:
            yield from self._stream_operation_updates(request.name)

        except InvalidArgumentError as e:
            self.logger.error(e)
            context.set_details(str(e))
            context.set_code(grpc.StatusCode.INVALID_ARGUMENT)

    def _stream_operation_updates(self, name):
        stream_previous = None
        while True:
            stream = self._instance.get_operation(name)
            if stream != stream_previous:
                yield stream
                if stream.done == True: break
                stream_previous = copy.deepcopy(stream)
            time.sleep(1)
+25 −10
Original line number Diff line number Diff line
@@ -56,15 +56,30 @@ def test_execute(skip_cache_lookup, instance, context):
                                                  action_digest = action_digest,
                                                  skip_cache_lookup = skip_cache_lookup)
    response = instance.Execute(request, context)

    for result in response:
        assert isinstance(result, operations_pb2.Operation)

    if skip_cache_lookup is False:
        [r for r in response]
        context.set_code.assert_called_once_with(grpc.StatusCode.UNIMPLEMENTED)
    else:
        result = next(response)
        assert isinstance(result, operations_pb2.Operation)
        metadata = remote_execution_pb2.ExecuteOperationMetadata()
        result.metadata.Unpack(metadata)
        assert metadata.stage == job.ExecuteStage.QUEUED.value
        assert uuid.UUID(result.name, version=4)
        assert result.done is False

def test_wait_execution(instance, context):
    action_digest = remote_execution_pb2.Digest()
    action_digest.hash = 'zhora'

    execution_request = remote_execution_pb2.ExecuteRequest(instance_name = '',
                                                            action_digest = action_digest,
                                                            skip_cache_lookup = True)
    execution_response = next(instance.Execute(execution_request, context))


    request = remote_execution_pb2.WaitExecutionRequest(name=execution_response.name)

    response = next(instance.WaitExecution(request, context))

    assert response == execution_response