Skip to content
Snippets Groups Projects
Commit 7a532ac6 authored by Martin Blanchard's avatar Martin Blanchard
Browse files

Stop exposing Job objects outside of Scheduler

#75
parent d29de94e
No related branches found
No related tags found
Loading
......@@ -44,7 +44,7 @@ class ExecutionInstance:
def hash_type(self):
return get_hash_type()
def execute(self, action_digest, skip_cache_lookup, peer=None, message_queue=None):
def execute(self, action_digest, skip_cache_lookup):
""" Sends a job for execution.
Queues an action and creates an Operation instance to be associated with
this action.
......@@ -54,33 +54,36 @@ class ExecutionInstance:
if not action:
raise FailedPreconditionError("Could not get action from storage.")
job = self._scheduler.queue_job(action, action_digest,
skip_cache_lookup=skip_cache_lookup)
return self._scheduler.queue_job(action, action_digest,
skip_cache_lookup=skip_cache_lookup)
if peer is not None and message_queue is not None:
job.register_operation_peer(peer, message_queue)
return job.operation
def register_operation_peer(self, job_name, peer, message_queue):
def register_operation_peer(self, operation_name, peer, message_queue):
try:
self._scheduler.register_operation_peer(job_name, peer, message_queue)
return self._scheduler.register_operation_peer(operation_name,
peer, message_queue)
except NotFoundError:
raise InvalidArgumentError("Operation name does not exist: [{}]".format(job_name))
raise InvalidArgumentError("Operation name does not exist: [{}]"
.format(operation_name))
def unregister_operation_peer(self, job_name, peer):
def unregister_operation_peer(self, operation_name, peer):
try:
self._scheduler.unregister_operation_peer(job_name, peer)
self._scheduler.unregister_operation_peer(operation_name, peer)
except NotFoundError:
raise InvalidArgumentError("Operation name does not exist: [{}]".format(job_name))
raise InvalidArgumentError("Operation name does not exist: [{}]"
.format(operation_name))
def stream_operation_updates(self, message_queue):
error, operation = message_queue.get()
if error is not None:
raise error
while not operation.done:
yield operation
def stream_operation_updates(self, message_queue, operation_name):
job = message_queue.get()
while not job.operation.done:
yield job.operation
job = message_queue.get()
job.check_operation_status()
error, operation = message_queue.get()
if error is not None:
raise error
yield job.operation
yield operation
......@@ -99,13 +99,14 @@ class ExecutionService(remote_execution_pb2_grpc.ExecutionServicer):
try:
instance = self._get_instance(instance_name)
operation = instance.execute(request.action_digest,
request.skip_cache_lookup,
peer=peer,
message_queue=message_queue)
job_name = instance.execute(request.action_digest,
request.skip_cache_lookup)
operation_name = instance.register_operation_peer(job_name,
peer, message_queue)
context.add_callback(partial(self._rpc_termination_callback,
peer, instance_name, operation.name))
peer, instance_name, operation_name))
if self._is_instrumented:
if peer not in self.__peers:
......@@ -114,16 +115,13 @@ class ExecutionService(remote_execution_pb2_grpc.ExecutionServicer):
else:
self.__peers[peer] += 1
instanced_op_name = "{}/{}".format(instance_name, operation.name)
operation_full_name = "{}/{}".format(instance_name, operation_name)
self.__logger.info("Operation name: [%s]", instanced_op_name)
self.__logger.info("Operation name: [%s]", operation_full_name)
for operation in instance.stream_operation_updates(message_queue,
operation.name):
op = operations_pb2.Operation()
op.CopyFrom(operation)
op.name = instanced_op_name
yield op
for operation in instance.stream_operation_updates(message_queue):
operation.name = operation_full_name
yield operation
except InvalidArgumentError as e:
self.__logger.error(e)
......@@ -162,8 +160,8 @@ class ExecutionService(remote_execution_pb2_grpc.ExecutionServicer):
try:
instance = self._get_instance(instance_name)
instance.register_operation_peer(operation_name,
peer, message_queue)
operation_name = instance.register_operation_peer(operation_name,
peer, message_queue)
context.add_callback(partial(self._rpc_termination_callback,
peer, instance_name, operation_name))
......@@ -175,12 +173,11 @@ class ExecutionService(remote_execution_pb2_grpc.ExecutionServicer):
else:
self.__peers[peer] += 1
for operation in instance.stream_operation_updates(message_queue,
operation_name):
op = operations_pb2.Operation()
op.CopyFrom(operation)
op.name = request.name
yield op
operation_full_name = "{}/{}".format(instance_name, operation_name)
for operation in instance.stream_operation_updates(message_queue):
operation.name = operation_full_name
yield operation
except InvalidArgumentError as e:
self.__logger.error(e)
......@@ -215,10 +212,10 @@ class ExecutionService(remote_execution_pb2_grpc.ExecutionServicer):
# --- Private API ---
def _rpc_termination_callback(self, peer, instance_name, job_name):
def _rpc_termination_callback(self, peer, instance_name, operation_name):
instance = self._get_instance(instance_name)
instance.unregister_operation_peer(job_name, peer)
instance.unregister_operation_peer(operation_name, peer)
if self._is_instrumented:
if self.__peers[peer] > 1:
......
......@@ -166,11 +166,18 @@ class Job:
Args:
peer (str): a unique string identifying the client.
message_queue (queue.Queue): the event queue to register.
Returns:
str: The name of the subscribed :class:`Operation`.
"""
if peer not in self.__operation_message_queues:
self.__operation_message_queues[peer] = message_queue
message_queue.put(self)
message = (None, self._copy_operation(self._operation),)
message_queue.put(message)
return self._operation.name
def unregister_operation_peer(self, peer):
"""Unsubscribes to the job's :class:`Operation` stage change.
......@@ -297,17 +304,14 @@ class Job:
self._operation.metadata.Pack(self.__operation_metadata)
for message_queue in self.__operation_message_queues.values():
message_queue.put(self)
def check_operation_status(self):
"""Reports errors on unexpected job's :class:Operation state.
if not self.__operation_cancelled:
message = (None, self._copy_operation(self._operation),)
else:
message = (CancelledError(self.__execute_response.status.message),
self._copy_operation(self._operation),)
Raises:
CancelledError: if the job's :class:Operation was cancelled.
"""
if self.__operation_cancelled:
raise CancelledError(self.__execute_response.status.message)
for message_queue in self.__operation_message_queues.values():
message_queue.put(message)
def cancel_operation(self):
"""Triggers a job's :class:Operation cancellation.
......@@ -331,3 +335,12 @@ class Job:
def query_n_retries(self):
return self._n_tries - 1 if self._n_tries > 0 else 0
# --- Private API ---
def _copy_operation(self, operation):
new_operation = operations_pb2.Operation()
new_operation.CopyFrom(operation)
return new_operation
......@@ -52,11 +52,11 @@ class OperationsInstance:
# TODO: Pages
# Spec says number of pages and length of a page are optional
response = operations_pb2.ListOperationsResponse()
operations = []
for job in self._scheduler.list_jobs():
op = operations_pb2.Operation()
op.CopyFrom(job.operation)
operations.append(op)
for job_name in self._scheduler.list_jobs():
operation = self._scheduler.get_job_operation(job_name)
operations.append(operation)
response.operations.extend(operations)
......
......@@ -56,40 +56,45 @@ class Scheduler:
# --- Public API ---
def register_operation_peer(self, job_name, peer, message_queue):
def register_operation_peer(self, operation_name, peer, message_queue):
"""Subscribes to one of the job's :class:`Operation` stage changes.
Args:
job_name (str): name of the job subscribe to.
operation_name (str): name of the operation to subscribe to.
peer (str): a unique string identifying the client.
message_queue (queue.Queue): the event queue to register.
Returns:
str: The name of the subscribed :class:`Operation`.
Raises:
NotFoundError: If no job with `job_name` exists.
NotFoundError: If no operation with `operation_name` exists.
"""
try:
job = self.__jobs_by_name[job_name]
job = self.__jobs_by_name[operation_name]
except KeyError:
raise NotFoundError("Job name does not exist: [{}]".format(job_name))
raise NotFoundError("Operation name does not exist: [{}]"
.format(operation_name))
job.register_operation_peer(peer, message_queue)
return job.register_operation_peer(peer, message_queue)
def unregister_operation_peer(self, job_name, peer):
def unregister_operation_peer(self, operation_name, peer):
"""Unsubscribes to one of the job's :class:`Operation` stage change.
Args:
job_name (str): name of the job to unsubscribe from.
operation_name (str): name of the operation to unsubscribe from.
peer (str): a unique string identifying the client.
Raises:
NotFoundError: If no job with `job_name` exists.
NotFoundError: If no operation with `operation_name` exists.
"""
try:
job = self.__jobs_by_name[job_name]
job = self.__jobs_by_name[operation_name]
except KeyError:
raise NotFoundError("Job name does not exist: [{}]".format(job_name))
raise NotFoundError("Operation name does not exist: [{}]"
.format(operation_name))
job.unregister_operation_peer(peer)
......@@ -105,6 +110,9 @@ class Scheduler:
priority (int): the execution job's priority.
skip_cache_lookup (bool): whether or not to look for pre-computed
result for the given action.
Returns:
str: the newly created operation's name.
"""
if action_digest.hash in self.__jobs_by_action:
job = self.__jobs_by_action[action_digest.hash]
......@@ -116,7 +124,7 @@ class Scheduler:
if job.operation_stage == OperationStage.QUEUED:
self._queue_job(job.name)
return job
return job.name
job = Job(action, action_digest, priority=priority)
......@@ -146,7 +154,7 @@ class Scheduler:
self._update_job_operation_stage(job.name, operation_stage)
return job
return job.name
def retry_job(self, job_name):
try:
......@@ -170,7 +178,7 @@ class Scheduler:
self._update_job_operation_stage(job.name, operation_stage)
def list_jobs(self):
return self.__jobs_by_name.values()
return self.__jobs_by_name.keys()
def request_job_leases(self, worker_capabilities):
"""Generates a list of the highest priority leases to be run.
......
......@@ -20,11 +20,11 @@
import uuid
from unittest import mock
from google.protobuf import any_pb2
import grpc
from grpc._server import _Context
import pytest
from buildgrid._enums import OperationStage
from buildgrid._protos.build.bazel.remote.execution.v2 import remote_execution_pb2
from buildgrid._protos.google.longrunning import operations_pb2
......@@ -82,7 +82,7 @@ def test_execute(skip_cache_lookup, instance, context):
assert isinstance(result, operations_pb2.Operation)
metadata = remote_execution_pb2.ExecuteOperationMetadata()
result.metadata.Unpack(metadata)
assert metadata.stage == job.OperationStage.QUEUED.value
assert metadata.stage == OperationStage.QUEUED.value
operation_uuid = result.name.split('/')[-1]
assert uuid.UUID(operation_uuid, version=4)
assert result.done is False
......@@ -106,18 +106,14 @@ def test_no_action_digest_in_storage(instance, context):
def test_wait_execution(instance, controller, context):
j = controller.execution_instance._scheduler.queue_job(action,
action_digest,
skip_cache_lookup=True)
j._operation.done = True
job_name = controller.execution_instance._scheduler.queue_job(action,
action_digest,
skip_cache_lookup=True)
request = remote_execution_pb2.WaitExecutionRequest(name=j.name)
controller.execution_instance._scheduler._update_job_operation_stage(job_name,
OperationStage.COMPLETED)
action_result_any = any_pb2.Any()
action_result = remote_execution_pb2.ActionResult()
action_result_any.Pack(action_result)
j.update_operation_stage(job.OperationStage.COMPLETED)
request = remote_execution_pb2.WaitExecutionRequest(name=job_name)
response = instance.WaitExecution(request, context)
......@@ -127,7 +123,6 @@ def test_wait_execution(instance, controller, context):
metadata = remote_execution_pb2.ExecuteOperationMetadata()
result.metadata.Unpack(metadata)
assert metadata.stage == job.OperationStage.COMPLETED.value
assert uuid.UUID(result.name, version=4)
assert result.done is True
......
......@@ -86,8 +86,8 @@ def blank_instance(controller):
# Queue an execution, get operation corresponding to that request
def test_get_operation(instance, controller, execute_request, context):
response_execute = controller.execution_instance.execute(execute_request.action_digest,
execute_request.skip_cache_lookup)
job_name = controller.execution_instance.execute(execute_request.action_digest,
execute_request.skip_cache_lookup)
request = operations_pb2.GetOperationRequest()
......@@ -95,25 +95,23 @@ def test_get_operation(instance, controller, execute_request, context):
# we're manually creating the instance here, it doesn't get a name.
# Therefore we need to manually add the instance name to the operation
# name in the GetOperation request.
request.name = "{}/{}".format(instance_name, response_execute.name)
request.name = "{}/{}".format(instance_name, job_name)
response = instance.GetOperation(request, context)
assert response.name == "{}/{}".format(instance_name, response_execute.name)
assert response.done == response_execute.done
assert response.name == "{}/{}".format(instance_name, job_name)
# Queue an execution, get operation corresponding to that request
def test_get_operation_blank(blank_instance, controller, execute_request, context):
response_execute = controller.execution_instance.execute(execute_request.action_digest,
execute_request.skip_cache_lookup)
job_name = controller.execution_instance.execute(execute_request.action_digest,
execute_request.skip_cache_lookup)
request = operations_pb2.GetOperationRequest()
request.name = response_execute.name
request.name = job_name
response = blank_instance.GetOperation(request, context)
assert response.name == response_execute.name
assert response.done == response_execute.done
assert response.name == job_name
def test_get_operation_fail(instance, context):
......@@ -133,25 +131,25 @@ def test_get_operation_instance_fail(instance, context):
def test_list_operations(instance, controller, execute_request, context):
response_execute = controller.execution_instance.execute(execute_request.action_digest,
execute_request.skip_cache_lookup)
job_name = controller.execution_instance.execute(execute_request.action_digest,
execute_request.skip_cache_lookup)
request = operations_pb2.ListOperationsRequest(name=instance_name)
response = instance.ListOperations(request, context)
names = response.operations[0].name.split('/')
assert names[0] == instance_name
assert names[1] == response_execute.name
assert names[1] == job_name
def test_list_operations_blank(blank_instance, controller, execute_request, context):
response_execute = controller.execution_instance.execute(execute_request.action_digest,
execute_request.skip_cache_lookup)
job_name = controller.execution_instance.execute(execute_request.action_digest,
execute_request.skip_cache_lookup)
request = operations_pb2.ListOperationsRequest(name='')
response = blank_instance.ListOperations(request, context)
assert response.operations[0].name.split('/')[-1] == response_execute.name
assert response.operations[0].name.split('/')[-1] == job_name
def test_list_operations_instance_fail(instance, controller, execute_request, context):
......@@ -174,14 +172,14 @@ def test_list_operations_empty(instance, context):
# Send execution off, delete, try to find operation should fail
def test_delete_operation(instance, controller, execute_request, context):
response_execute = controller.execution_instance.execute(execute_request.action_digest,
execute_request.skip_cache_lookup)
job_name = controller.execution_instance.execute(execute_request.action_digest,
execute_request.skip_cache_lookup)
request = operations_pb2.DeleteOperationRequest()
request.name = response_execute.name
request.name = job_name
instance.DeleteOperation(request, context)
request_name = "{}/{}".format(instance_name, response_execute.name)
request_name = "{}/{}".format(instance_name, job_name)
with pytest.raises(InvalidArgumentError):
controller.operations_instance.get_operation(request_name)
......@@ -189,14 +187,14 @@ def test_delete_operation(instance, controller, execute_request, context):
# Send execution off, delete, try to find operation should fail
def test_delete_operation_blank(blank_instance, controller, execute_request, context):
response_execute = controller.execution_instance.execute(execute_request.action_digest,
execute_request.skip_cache_lookup)
job_name = controller.execution_instance.execute(execute_request.action_digest,
execute_request.skip_cache_lookup)
request = operations_pb2.DeleteOperationRequest()
request.name = response_execute.name
request.name = job_name
blank_instance.DeleteOperation(request, context)
request_name = response_execute.name
request_name = job_name
with pytest.raises(InvalidArgumentError):
controller.operations_instance.get_operation(request_name)
......@@ -211,11 +209,11 @@ def test_delete_operation_fail(instance, context):
def test_cancel_operation(instance, controller, execute_request, context):
response_execute = controller.execution_instance.execute(execute_request.action_digest,
execute_request.skip_cache_lookup)
job_name = controller.execution_instance.execute(execute_request.action_digest,
execute_request.skip_cache_lookup)
request = operations_pb2.CancelOperationRequest()
request.name = "{}/{}".format(instance_name, response_execute.name)
request.name = "{}/{}".format(instance_name, job_name)
instance.CancelOperation(request, context)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment