From d5483e72ee5d27dbea7ef04758df506fce7ef181 Mon Sep 17 00:00:00 2001
From: finn <finn.ball@codethink.com>
Date: Tue, 31 Jul 2018 15:44:33 +0100
Subject: [PATCH] Added WaitExecution method for Remote Execution

---
 app/commands/cmd_execute.py                   | 38 ++++++++++---------
 .../server/execution/execution_service.py     | 29 +++++++++++++-
 2 files changed, 48 insertions(+), 19 deletions(-)

diff --git a/app/commands/cmd_execute.py b/app/commands/cmd_execute.py
index 196ea9d46..9bb92dfc7 100644
--- a/app/commands/cmd_execute.py
+++ b/app/commands/cmd_execute.py
@@ -61,24 +61,16 @@ def request(context, number, instance_name, wait_for_completion):
     request = remote_execution_pb2.ExecuteRequest(instance_name = instance_name,
                                                   action_digest = action_digest,
                                                   skip_cache_lookup = True)
+    responses = []
     for i in range(0, number):
-        response = stub.Execute(request)
-        for r in response:
-            context.logger.info(r)
-
-    try:
-        while wait_for_completion:
-            request = operations_pb2.ListOperationsRequest()
-            context.logger.debug('Querying to see if jobs are complete.')
-            stub = operations_pb2_grpc.OperationsStub(context.channel)
-            response = stub.ListOperations(request)
-            if all(operation.done for operation in response.operations):
-                context.logger.info('Jobs complete')
-                break
-            time.sleep(1)
-
-    except KeyboardInterrupt:
-        pass
+        responses.append(stub.Execute(request))
+
+    for response in responses:
+        if wait_for_completion:
+            for stream in response:
+                context.logger.info(stream)
+        else:
+            context.logger.info(next(response))
 
 @cli.command('status', short_help='Get the status of an operation')
 @click.argument('operation-name')
@@ -108,3 +100,15 @@ def list_operations(context):
 
     for op in response.operations:
         context.logger.info(op)
+
+@cli.command('wait', short_help='Streams an operation until it is complete')
+@click.argument('operation-name')
+@pass_context
+def wait_execution(context, operation_name):
+    stub = remote_execution_pb2_grpc.ExecutionStub(context.channel)
+    request = remote_execution_pb2.WaitExecutionRequest(name=operation_name)
+
+    response = stub.WaitExecution(request)
+
+    for stream in response:
+        context.logger.info(stream)
diff --git a/buildgrid/server/execution/execution_service.py b/buildgrid/server/execution/execution_service.py
index af00a72bb..effeda7b0 100644
--- a/buildgrid/server/execution/execution_service.py
+++ b/buildgrid/server/execution/execution_service.py
@@ -22,6 +22,7 @@ ExecutionService
 Serves remote execution requests.
 """
 
+import copy
 import grpc
 import logging
 
@@ -40,8 +41,16 @@ class ExecutionService(remote_execution_pb2_grpc.ExecutionServicer):
         # Ignore request.instance_name for now
         # Have only one instance
         try:
-            yield self._instance.execute(request.action_digest,
-                                         request.skip_cache_lookup)
+            operation = self._instance.execute(request.action_digest,
+                                               request.skip_cache_lookup)
+
+            stream_previous = None
+            while True:
+                stream = self._instance.get_operation(operation.name)
+                if stream != stream_previous:
+                    yield stream
+                    if stream.done == True: break
+                    stream_previous = copy.deepcopy(stream)
 
         except InvalidArgumentError as e:
             self.logger.error(e)
@@ -54,3 +63,19 @@ class ExecutionService(remote_execution_pb2_grpc.ExecutionServicer):
             context.set_details(str(e))
             context.set_code(grpc.StatusCode.UNIMPLEMENTED)
             yield operations_pb2.Operation()
+
+    def WaitExecution(self, request, context):
+        try:
+            stream_previous = None
+            while True:
+                stream = self._instance.get_operation(request.name)
+                if stream != stream_previous:
+                    yield stream
+                    if stream.done == True: break
+                    stream_previous = copy.deepcopy(stream)
+
+        except InvalidArgumentError as e:
+            self.logger.error(e)
+            context.set_details(str(e))
+            context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
+            yield operations_pb2.Operation()
-- 
GitLab