From 9f16fe87028f806d3eba5f956de062d449e2f26a Mon Sep 17 00:00:00 2001 From: Martin Blanchard <martin.blanchard@codethink.co.uk> Date: Tue, 27 Nov 2018 15:43:54 +0000 Subject: [PATCH] server/_authentication.py: New JWT based gRPC interceptor https://gitlab.com/BuildGrid/buildgrid/issues/144 --- buildgrid/server/_authentication.py | 214 ++++++++++++++++++++++++++++ 1 file changed, 214 insertions(+) create mode 100644 buildgrid/server/_authentication.py diff --git a/buildgrid/server/_authentication.py b/buildgrid/server/_authentication.py new file mode 100644 index 000000000..b603f1893 --- /dev/null +++ b/buildgrid/server/_authentication.py @@ -0,0 +1,214 @@ +# Copyright (C) 2018 Bloomberg LP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# <http://www.apache.org/licenses/LICENSE-2.0> +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from datetime import datetime +from enum import Enum +import logging + +import grpc + +from buildgrid._exceptions import InvalidArgumentError + + +try: + import jwt +except ImportError: + HAVE_JWT = False +else: + HAVE_JWT = True + + +class AuthMetadataMethod(Enum): + # No authentication: + NONE = 'none' + # JWT based authentication: + JWT = 'JWT' + + +class AuthMetadataAlgorithm(Enum): + # No encryption involved: + NONE = 'none' + # JWT related algorithms: + JWT_ES256 = 'ES256' # ECDSA signature algorithm using SHA-256 hash algorithm + JWT_ES384 = 'ES384' # ECDSA signature algorithm using SHA-384 hash algorithm + JWT_ES512 = 'ES512' # ECDSA signature algorithm using SHA-512 hash algorithm + JWT_HS256 = 'HS256' # HMAC using SHA-256 hash algorithm + JWT_HS384 = 'HS384' # HMAC using SHA-384 hash algorithm + JWT_HS512 = 'HS512' # HMAC using SHA-512 hash algorithm + JWT_PS256 = 'PS256' # RSASSA-PSS using SHA-256 and MGF1 padding with SHA-256 + JWT_PS384 = 'PS384' # RSASSA-PSS signature using SHA-384 and MGF1 padding with SHA-384 + JWT_PS512 = 'PS512' # RSASSA-PSS signature using SHA-512 and MGF1 padding with SHA-512 + JWT_RS256 = 'RS256' # RSASSA-PKCS1-v1_5 signature algorithm using SHA-256 hash algorithm + JWT_RS384 = 'RS384' # RSASSA-PKCS1-v1_5 signature algorithm using SHA-384 hash algorithm + JWT_RS512 = 'RS512' # RSASSA-PKCS1-v1_5 signature algorithm using SHA-512 hash algorithm + + +class _InvalidTokenError(Exception): + pass + + +class _ExpiredTokenError(Exception): + pass + + +class _UnboundedTokenError(Exception): + pass + + +class AuthMetadataServerInterceptor(grpc.ServerInterceptor): + + __auth_errors = { + 'missing-bearer': 'Missing authentication header field', + 'invalid-bearer': 'Invalid authentication header field', + 'invalid-token': 'Invalid authentication token', + 'expired-token': 'Expired authentication token', + 'unbounded-token': 'Unbounded authentication token', + } + + def __init__(self, method, secret=None, algorithm=AuthMetadataAlgorithm.NONE): + """Initialises a new :class:`AuthMetadataServerInterceptor`. + + Args: + method (AuthMetadataMethod): Type of authorization method. + secret (str): The secret or key to be used for validating request, + depending on `method`. Defaults to ``None``. + algorithm (AuthMetadataAlgorithm): The crytographic algorithm used + to encode `secret`. Defaults to ``AuthMetadataAlgorithm.NONE``. + + Raises: + InvalidArgumentError: If `method` is not supported or if `algorithm` + is not supported for the given `method`. + """ + self.__logger = logging.getLogger(__name__) + + self.__bearer_cache = {} + self.__terminators = {} + self.__validator = None + self.__secret = secret + + self._method = method + self._algorithm = algorithm + + if self._method == AuthMetadataMethod.JWT: + if not HAVE_JWT: + raise InvalidArgumentError("JWT authorization method requires PyJWT") + + try: + jwt.register_algorithm(self._algorithm.value, None) + except TypeError: + raise InvalidArgumentError("Algorithm not supported for JWT decoding: [{}]" + .format(self._algorithm)) + except ValueError: + pass + + self.__validator = self._validate_jwt_token + + for code, message in self.__auth_errors.items(): + self.__terminators[code] = _unary_unary_rpc_terminator(message) + + # --- Public API --- + + @property + def method(self): + return self._method + + @property + def algorithm(self): + return self._algorithm + + def intercept_service(self, continuation, handler_call_details): + try: + # Reject requests not carrying a token: + bearer = dict(handler_call_details.invocation_metadata)['authorization'] + + except KeyError: + self.__logger.error("Rejecting '{}' request: {}" + .format(handler_call_details.method.split('/')[-1], + self.__auth_errors['missing-bearer'])) + return self.__terminators['missing-bearer'] + + # Reject requests with malformated bearer: + if not bearer.startswith('Bearer '): + self.__logger.error("Rejecting '{}' request: {}" + .format(handler_call_details.method.split('/')[-1], + self.__auth_errors['invalid-bearer'])) + return self.__terminators['invalid-bearer'] + + try: + # Hit the cache for already validated token: + expiration_time = self.__bearer_cache[bearer] + + # Accept request if cached token hasn't expired yet: + if expiration_time < datetime.utcnow(): + return continuation(handler_call_details) # Accepted + + except KeyError: + pass + + assert self.__validator is not None + + try: + # Decode and validate the new token: + expiration_time = self.__validator(bearer[7:]) + + except _InvalidTokenError as e: + self.__logger.error("Rejecting '{}' request: {}; {}" + .format(handler_call_details.method.split('/')[-1], + self.__auth_errors['invalid-token'], str(e))) + return self.__terminators['invalid-token'] + + except _ExpiredTokenError as e: + self.__logger.error("Rejecting '{}' request: {}; {}" + .format(handler_call_details.method.split('/')[-1], + self.__auth_errors['expired-token'], str(e))) + return self.__terminators['expired-token'] + + except _UnboundedTokenError as e: + self.__logger.error("Rejecting '{}' request: {}; {}" + .format(handler_call_details.method.split('/')[-1], + self.__auth_errors['unbounded-token'], str(e))) + return self.__terminators['unbounded-token'] + + # Cache the validated token and store expiration time: + self.__bearer_cache[bearer] = expiration_time + + return continuation(handler_call_details) # Accepted + + # --- Private API --- + + def _validate_jwt_token(self, token): + """Validates a JWT token and returns its expiry date.""" + try: + payload = jwt.decode( + token, self.__secret, algorithms=[self._algorithm.value]) + + except jwt.exceptions.ExpiredSignatureError as e: + raise _ExpiredTokenError(e) + + except jwt.exceptions.InvalidTokenError as e: + raise _InvalidTokenError(e) + + if 'exp' not in payload or not isinstance(payload['exp'], int): + raise _UnboundedTokenError("Missing 'exp' in payload") + + return datetime.fromtimestamp(payload['exp']) + + +def _unary_unary_rpc_terminator(details): + + def terminate(ignored_request, context): + context.abort(grpc.StatusCode.UNAUTHENTICATED, details) + + return grpc.unary_unary_rpc_method_handler(terminate) -- GitLab