From a1211a6dfd8190ac0e2b2c1cd3988780655a4833 Mon Sep 17 00:00:00 2001
From: Martin Blanchard <martin.blanchard@codethink.co.uk>
Date: Wed, 7 Nov 2018 10:59:28 +0000
Subject: [PATCH] execution/service.py: Calculate client counts

https://gitlab.com/BuildGrid/buildgrid/issues/132
---
 buildgrid/server/execution/service.py | 113 +++++++++++++++++++++-----
 buildgrid/server/instance.py          |   3 +-
 2 files changed, 96 insertions(+), 20 deletions(-)

diff --git a/buildgrid/server/execution/service.py b/buildgrid/server/execution/service.py
index 7d2b2f068..5bf1ac720 100644
--- a/buildgrid/server/execution/service.py
+++ b/buildgrid/server/execution/service.py
@@ -33,30 +33,62 @@ from buildgrid._protos.google.longrunning import operations_pb2
 
 class ExecutionService(remote_execution_pb2_grpc.ExecutionServicer):
 
-    def __init__(self, server):
+    def __init__(self, server, monitor=False):
         self.__logger = logging.getLogger(__name__)
 
+        self.__peers_by_instance = None
+        self.__peers = None
+
         self._instances = {}
+
         remote_execution_pb2_grpc.add_ExecutionServicer_to_server(self, server)
 
-    def add_instance(self, name, instance):
-        self._instances[name] = instance
+        self._is_instrumented = monitor
+
+        if self._is_instrumented:
+            self.__peers_by_instance = {}
+            self.__peers = {}
+
+    # --- Public API ---
+
+    def add_instance(self, instance_name, instance):
+        self._instances[instance_name] = instance
+
+        if self._is_instrumented:
+            self.__peers_by_instance[instance_name] = set()
+
+    # --- Public API: Servicer ---
 
     def Execute(self, request, context):
+        """Handles ExecuteRequest messages.
+
+        Args:
+            request (ExecuteRequest): The incoming RPC request.
+            context (grpc.ServicerContext): Context for the RPC call.
+        """
         self.__logger.debug("Execute request from [%s]", context.peer())
 
+        instance_name = request.instance_name
+        message_queue = queue.Queue()
+        peer = context.peer()
+
         try:
-            message_queue = queue.Queue()
-            instance = self._get_instance(request.instance_name)
+            instance = self._get_instance(instance_name)
             operation = instance.execute(request.action_digest,
                                          request.skip_cache_lookup,
                                          message_queue)
 
-            context.add_callback(partial(instance.unregister_message_client,
-                                         operation.name, message_queue))
+            context.add_callback(partial(self._rpc_termination_callback,
+                                         peer, instance_name, operation.name, message_queue))
 
-            instanced_op_name = "{}/{}".format(request.instance_name,
-                                               operation.name)
+            if self._is_instrumented:
+                if peer not in self.__peers:
+                    self.__peers_by_instance[instance_name].add(peer)
+                    self.__peers[peer] = 1
+                else:
+                    self.__peers[peer] += 1
+
+            instanced_op_name = "{}/{}".format(instance_name, operation.name)
 
             self.__logger.info("Operation name: [%s]", instanced_op_name)
 
@@ -86,23 +118,33 @@ class ExecutionService(remote_execution_pb2_grpc.ExecutionServicer):
             yield operations_pb2.Operation()
 
     def WaitExecution(self, request, context):
-        self.__logger.debug("WaitExecution request from [%s]", context.peer())
+        """Handles WaitExecutionRequest messages.
 
-        try:
-            names = request.name.split("/")
+        Args:
+            request (WaitExecutionRequest): The incoming RPC request.
+            context (grpc.ServicerContext): Context for the RPC call.
+        """
+        self.__logger.debug("WaitExecution request from [%s]", context.peer())
 
-            # Operation name should be in format:
-            # {instance/name}/{operation_id}
-            instance_name = ''.join(names[0:-1])
+        names = request.name.split('/')
+        instance_name = '/'.join(names[:-1])
+        operation_name = names[-1]
+        message_queue = queue.Queue()
+        peer = context.peer()
 
-            message_queue = queue.Queue()
-            operation_name = names[-1]
+        try:
             instance = self._get_instance(instance_name)
 
             instance.register_message_client(operation_name, message_queue)
+            context.add_callback(partial(self._rpc_termination_callback,
+                                         peer, instance_name, operation_name, message_queue))
 
-            context.add_callback(partial(instance.unregister_message_client,
-                                         operation_name, message_queue))
+            if self._is_instrumented:
+                if peer not in self.__peers:
+                    self.__peers_by_instance[instance_name].add(peer)
+                    self.__peers[peer] = 1
+                else:
+                    self.__peers[peer] += 1
 
             for operation in instance.stream_operation_updates(message_queue,
                                                                operation_name):
@@ -123,6 +165,39 @@ class ExecutionService(remote_execution_pb2_grpc.ExecutionServicer):
             context.set_code(grpc.StatusCode.CANCELLED)
             yield operations_pb2.Operation()
 
+    # --- Public API: Monitoring ---
+
+    @property
+    def is_instrumented(self):
+        return self._is_instrumented
+
+    def query_n_clients(self):
+        if self.__peers is not None:
+            return len(self.__peers)
+        return 0
+
+    def query_n_clients_for_instance(self, instance_name):
+        try:
+            if self.__peers_by_instance is not None:
+                return len(self.__peers_by_instance[instance_name])
+        except KeyError:
+            pass
+        return 0
+
+    # --- Private API ---
+
+    def _rpc_termination_callback(self, peer, instance_name, job_name, message_queue):
+        instance = self._get_instance(instance_name)
+
+        instance.unregister_message_client(job_name, message_queue)
+
+        if self._is_instrumented:
+            if self.__peers[peer] > 1:
+                self.__peers[peer] -= 1
+            else:
+                self.__peers_by_instance[instance_name].remove(peer)
+                del self.__peers[peer]
+
     def _get_instance(self, name):
         try:
             return self._instances[name]
diff --git a/buildgrid/server/instance.py b/buildgrid/server/instance.py
index ca5afaa6b..4d1f45815 100644
--- a/buildgrid/server/instance.py
+++ b/buildgrid/server/instance.py
@@ -130,7 +130,8 @@ class BuildGridServer:
             instance_name (str): Instance name.
         """
         if self._execution_service is None:
-            self._execution_service = ExecutionService(self.__grpc_server)
+            self._execution_service = ExecutionService(
+                self.__grpc_server, monitor=self._is_instrumented)
 
         self._execution_service.add_instance(instance_name, instance)
         self._add_capabilities_instance(instance_name, execution_instance=instance)
-- 
GitLab