scientist.py 9.43 KB
Newer Older
1
import base64
2 3 4
import json
import logging
import os
Gero Vermaas's avatar
Gero Vermaas committed
5
import random
6
import requests
7
import time
8

9
from datetime import datetime
10
from datetime import timedelta
11
from threading import Thread
12

Gero Vermaas's avatar
Gero Vermaas committed
13
import boto3
Gero Vermaas's avatar
Gero Vermaas committed
14
import botocore
15 16
import yaml

Gero Vermaas's avatar
Gero Vermaas committed
17
from dynamodb_json import json_util as dyndb_json_util
18
from metrics_publisher import MetricsPubisher
19
import scientist_utils
20

21 22
from result_store import RunResultsStore

23
logging.basicConfig()
24
LOGGER = logging.getLogger(__name__)
25 26
LOGGER.setLevel(logging.DEBUG)

27
AWSREGION = os.environ.get("AWSREGION")
28

29 30 31 32 33
EXPERIMENTS_REFRESH_MINUTES = int(os.environ.get("EXPERIMENTS_REFRESH_MINUTES", "1"))
experiments_update_ts = datetime.now() - timedelta(
    minutes=2 * EXPERIMENTS_REFRESH_MINUTES
)
experiments = {}
34

35 36 37
EXPERIMENTS_BUCKET = os.environ.get("EXPERIMENTS_BUCKET")
EXPERIMENTOR_ARN = os.environ.get("EXPERIMENTOR_ARN")
RESULT_COLLECTOR_ARN = os.environ.get("RESULT_COLLECTOR_ARN")
38
RESPONSE_TIME = "response_time"
39
SCIENTIST_ADDED_RESPONSE_TIME = "scientists_added_response_time"
40

41 42 43 44
# Initialize AWS Clients once here and not in the lambda_handler
# to reduced the overhead time. Initializing
# the AWS Clients takes between 3-45ms.
LAMBDA_CLIENT = boto3.client("lambda")
Gero Vermaas's avatar
Gero Vermaas committed
45 46 47
DYNAMODB_CLIENT = boto3.client("dynamodb")
COUNTERS_TABLE = os.environ["COUNTERS_TABLE"]

48

49
def list_experiment_files(bucket_name, s3_client):
50
    """
51
        Retrieve list of experiments YAML files.
52 53 54
    """

    object_details = []
55 56 57 58 59 60

    paginator = s3_client.get_paginator("list_objects_v2")
    page_iterator = paginator.paginate(Bucket=bucket_name)

    for page in page_iterator:
        object_details.extend(page["Contents"])
61 62 63 64 65 66

    return [
        s3_object["Key"]
        for s3_object in object_details
        if not s3_object["Key"].endswith("/")
    ]
67 68


69
def retrieve_experiments(aws_account_id):
70

71 72 73
    LOGGER.info("Retrieving experiments from S3")

    s3_client = boto3.client("s3")
74

75
    experiment_files = list_experiment_files(EXPERIMENTS_BUCKET, s3_client)
76
    tmp_file = "/tmp/experiments.yaml"
77 78 79

    experiments = {}
    for file in experiment_files:
80
        s3_client.download_file(EXPERIMENTS_BUCKET, file, tmp_file)
81

82
        with open(tmp_file, "r") as experiments_file:
83
            expriments_str = experiments_file.read()
Gero Vermaas's avatar
Gero Vermaas committed
84
            expriments_str = expriments_str.replace(
85
                "{AWSACCOUNT_ID}", aws_account_id
Gero Vermaas's avatar
Gero Vermaas committed
86
            ).replace("{AWSREGION}", AWSREGION)
87
            experiment_details = yaml.load(expriments_str)
88
            experiments.update(experiment_details["experiments"])
89

90
    LOGGER.info("Retrieved experiments from S3: %s", experiments)
Gero Vermaas's avatar
Gero Vermaas committed
91

92 93
    return experiments

94

95
def find_experiment(experiments, path):
96

97 98 99
    for experiment, details in experiments.items():
        if details["path"] == path:
            return experiment, details
100 101 102

    raise ValueError("No experiment found for path: {}".format(path))

103

104
def execute_control(experiment_run, control_arn, experiment_details, experiment_name):
105
    LOGGER.info("invoking control")
106
    # Run the control
107
    start_control = time.time()
108

109 110 111 112
    response = LAMBDA_CLIENT.invoke(
        FunctionName=control_arn,
        InvocationType="RequestResponse",
        Payload=json.dumps(experiment_run["payload"]),
113
        LogType="Tail",
114
    )
115
    control_duration = int((time.time() - start_control) * 1000)
116

117 118
    metrics = scientist_utils.extract_metrics_from_log(response["LogResult"])

119
    LOGGER.info("Execution of control took %s ms", control_duration)
120
    metrics[RESPONSE_TIME] = control_duration
121

122
    response_body = response["Payload"].read().decode("utf-8")
123

124
    control_response = json.loads(response_body)
125

126
    # Report results of control
127 128 129 130 131 132 133 134 135 136
    run_result = {
        "run_type": "control",
        "implementation_name": experiment_details["control"]["name"],
        "arn": control_arn,
        "experiment_name": experiment_name,
        # Original Payload value is a stream, replace it with the real contents
        "received_response": control_response,
        "metrics": metrics,
        "comparators": experiment_run["comparators"],
    }
137 138 139 140 141

    return run_result


def invoke_experimentor(experiment_run):
142
    # Hand over to Experimentor to run candidates (asynchonously)
143
    LOGGER.info("invoking experimentor")
144
    start_expirimentor = time.time()
145

146 147 148 149 150 151
    LAMBDA_CLIENT.invoke(
        FunctionName=EXPERIMENTOR_ARN,
        InvocationType="Event",
        Payload=json.dumps(experiment_run),
    )
    LOGGER.info(
152
        "Posting to experimentor took %s ms",
153 154 155
        int((time.time() - start_expirimentor) * 1000),
    )

156

Gero Vermaas's avatar
Gero Vermaas committed
157
def get_run_id():
Gero Vermaas's avatar
Gero Vermaas committed
158 159

    LOGGER.info("Generating new run_id")
Gero Vermaas's avatar
Gero Vermaas committed
160 161 162 163 164 165 166 167 168 169 170 171 172 173 174
    start = time.time()

    try:
        response = DYNAMODB_CLIENT.update_item(
            TableName=COUNTERS_TABLE,
            Key={"counter_id": {"S": "run_id"}},
            UpdateExpression="SET last_run_id = last_run_id + :incr",
            ExpressionAttributeValues={":incr": {"N": "1"}},
            ReturnValues="UPDATED_NEW",
        )

        response = dyndb_json_util.loads(response)
        run_id = response["Attributes"]["last_run_id"]

        duration = int((time.time() - start) * 1000)
175
        LOGGER.info("Getting next run_id (%s) took %sms", run_id, duration)
Gero Vermaas's avatar
Gero Vermaas committed
176

Gero Vermaas's avatar
Gero Vermaas committed
177
        LOGGER.info("Generated new run_id %s", run_id)
Gero Vermaas's avatar
Gero Vermaas committed
178 179 180 181 182 183 184 185
        return run_id
    except botocore.exceptions.ClientError:
        response = DYNAMODB_CLIENT.update_item(
            TableName=COUNTERS_TABLE, Key={"counter_id": {"S": "run_id"}}
        )
        response = dyndb_json_util.loads(response)
        if not "Item" in response:
            start_run_id = 100000
186
            LOGGER.info("Initializing run_id counter")
Gero Vermaas's avatar
Gero Vermaas committed
187 188 189 190 191 192 193
            DYNAMODB_CLIENT.put_item(
                TableName=COUNTERS_TABLE,
                Item={
                    "counter_id": {"S": "run_id"},
                    "last_run_id": {"N": str(start_run_id)},
                },
            )
194
            LOGGER.info("Generated new start_run_id %s", start_run_id)
195
            return start_run_id
Gero Vermaas's avatar
Gero Vermaas committed
196 197


Gero Vermaas's avatar
Gero Vermaas committed
198 199 200 201 202 203 204 205 206
def base64decode_event_body(event):
    if "isBase64Encoded" in event and event["isBase64Encoded"]:
        if "body" in event and event["body"]:
            event["body"] = base64.b64decode(event["body"]).decode("utf-8")
            event["isBase64Encoded"] = False

    return event


207
def lambda_handler(event, response_url):
208 209
    global experiments_update_ts
    global experiments
210 211
    LOGGER.info("event: %s", event)
    LOGGER.info("response_url: %s", response_url)
Gero Vermaas's avatar
Gero Vermaas committed
212
    if "keephot" in event and event["keephot"]:
213 214 215 216 217
        requests.post(
            response_url,
            data=json.dumps({"statusCode": 200, "body": "I'm still warm."}),
        )
        return
218

219 220
    start = time.time()

221 222 223
    if experiments_update_ts < datetime.now() - timedelta(
        minutes=EXPERIMENTS_REFRESH_MINUTES
    ):
224
        aws_account_id = event["requestContext"]["accountId"]
225 226 227 228 229 230 231 232 233 234 235 236 237 238 239
        experiments = retrieve_experiments(aws_account_id)
        experiments_update_ts = datetime.now()

    experiment_path_part = event["path"].replace("/scientist", "")
    if len(experiment_path_part) > 1:
        experiment_path_part = experiment_path_part[1:]
    if not experiment_path_part:
        # scientist invoked without experiment mentioned
        # Just return list of available experiments
        host = event["headers"]["Host"]
        body_text = "The following experiments are available:\n"
        for experiment in experiments:
            experiment_path = experiments[experiment]["path"]
            body_text = f"{body_text}https://{host}/v1/scientist/{experiment_path}\n"

240 241 242 243 244 245 246 247 248 249 250
        requests.post(
            response_url,
            data=json.dumps(
                {
                    "statusCode": 200,
                    "body": body_text,
                    "headers": {"Content-Type": "text/plain"},
                }
            ),
        )
        return
251

252
    experiment_name, experiment_details = find_experiment(
253
        experiments, experiment_path_part
254 255
    )
    control_arn = experiment_details["control"]["arn"]
256

257 258 259 260 261 262 263 264
    experiment_run = {
        "candidates": experiment_details["candidates"],
        "payload": base64decode_event_body(event),
        # Because of binary asset setting in API Gateway, body comes encoded in base 64.
        # So, if that's the case, decode the body first.
        "experiment_name": experiment_name,
        "comparators": experiment_details["comparators"],
    }
265

266 267
    control_run_result = execute_control(
        experiment_run, control_arn, experiment_details, experiment_name
268
    )
269

270 271 272 273 274 275 276
    requests.post(
        response_url, data=json.dumps(control_run_result["received_response"])
    )
    LOGGER.info(
        "Posting of response done after %sms: %s",
        int((time.time() - start) * 1000),
        control_run_result["received_response"],
277
    )
278

279 280 281 282
    # run_id is determined here such that it is done after the control
    # has been invoked and does not add latency.
    experiment_run["run_id"] = get_run_id()
    control_run_result["run_id"] = experiment_run["run_id"]
283
    control_run_result["request_payload"] = experiment_run["payload"]
284

285
    scientist_duration = int((time.time() - start) * 1000)
286 287
    metrics_publisher = MetricsPubisher()

288
    overhead = scientist_duration - control_run_result["metrics"][RESPONSE_TIME]
289 290 291 292 293 294
    metrics_publisher.publish_value(
        SCIENTIST_ADDED_RESPONSE_TIME,
        experiment_name,
        "scientist-overhead",
        "",
        overhead,
295
    )
296
    LOGGER.warning("Overhead of scientist %s ms", overhead)
297

298 299
    results_store = RunResultsStore()
    results_store.store_result(control_run_result)
300 301
    invoke_experimentor(experiment_run)
    LOGGER.info("Returning after %s ms", (time.time() - start) * 1000)