From 9f16fe87028f806d3eba5f956de062d449e2f26a Mon Sep 17 00:00:00 2001
From: Martin Blanchard <martin.blanchard@codethink.co.uk>
Date: Tue, 27 Nov 2018 15:43:54 +0000
Subject: [PATCH] server/_authentication.py: New JWT based gRPC interceptor

https://gitlab.com/BuildGrid/buildgrid/issues/144
---
 buildgrid/server/_authentication.py | 214 ++++++++++++++++++++++++++++
 1 file changed, 214 insertions(+)
 create mode 100644 buildgrid/server/_authentication.py

diff --git a/buildgrid/server/_authentication.py b/buildgrid/server/_authentication.py
new file mode 100644
index 000000000..b603f1893
--- /dev/null
+++ b/buildgrid/server/_authentication.py
@@ -0,0 +1,214 @@
+# Copyright (C) 2018 Bloomberg LP
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#  <http://www.apache.org/licenses/LICENSE-2.0>
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from datetime import datetime
+from enum import Enum
+import logging
+
+import grpc
+
+from buildgrid._exceptions import InvalidArgumentError
+
+
+try:
+    import jwt
+except ImportError:
+    HAVE_JWT = False
+else:
+    HAVE_JWT = True
+
+
+class AuthMetadataMethod(Enum):
+    # No authentication:
+    NONE = 'none'
+    # JWT based authentication:
+    JWT = 'JWT'
+
+
+class AuthMetadataAlgorithm(Enum):
+    # No encryption involved:
+    NONE = 'none'
+    # JWT related algorithms:
+    JWT_ES256 = 'ES256'  # ECDSA signature algorithm using SHA-256 hash algorithm
+    JWT_ES384 = 'ES384'  # ECDSA signature algorithm using SHA-384 hash algorithm
+    JWT_ES512 = 'ES512'  # ECDSA signature algorithm using SHA-512 hash algorithm
+    JWT_HS256 = 'HS256'  # HMAC using SHA-256 hash algorithm
+    JWT_HS384 = 'HS384'  # HMAC using SHA-384 hash algorithm
+    JWT_HS512 = 'HS512'  # HMAC using SHA-512 hash algorithm
+    JWT_PS256 = 'PS256'  # RSASSA-PSS using SHA-256 and MGF1 padding with SHA-256
+    JWT_PS384 = 'PS384'  # RSASSA-PSS signature using SHA-384 and MGF1 padding with SHA-384
+    JWT_PS512 = 'PS512'  # RSASSA-PSS signature using SHA-512 and MGF1 padding with SHA-512
+    JWT_RS256 = 'RS256'  # RSASSA-PKCS1-v1_5 signature algorithm using SHA-256 hash algorithm
+    JWT_RS384 = 'RS384'  # RSASSA-PKCS1-v1_5 signature algorithm using SHA-384 hash algorithm
+    JWT_RS512 = 'RS512'  # RSASSA-PKCS1-v1_5 signature algorithm using SHA-512 hash algorithm
+
+
+class _InvalidTokenError(Exception):
+    pass
+
+
+class _ExpiredTokenError(Exception):
+    pass
+
+
+class _UnboundedTokenError(Exception):
+    pass
+
+
+class AuthMetadataServerInterceptor(grpc.ServerInterceptor):
+
+    __auth_errors = {
+        'missing-bearer': 'Missing authentication header field',
+        'invalid-bearer': 'Invalid authentication header field',
+        'invalid-token': 'Invalid authentication token',
+        'expired-token': 'Expired authentication token',
+        'unbounded-token': 'Unbounded authentication token',
+    }
+
+    def __init__(self, method, secret=None, algorithm=AuthMetadataAlgorithm.NONE):
+        """Initialises a new :class:`AuthMetadataServerInterceptor`.
+
+        Args:
+            method (AuthMetadataMethod): Type of authorization method.
+            secret (str): The secret or key to be used for validating request,
+                depending on `method`. Defaults to ``None``.
+            algorithm (AuthMetadataAlgorithm): The crytographic algorithm used
+                to encode `secret`. Defaults to ``AuthMetadataAlgorithm.NONE``.
+
+        Raises:
+            InvalidArgumentError: If `method` is not supported or if `algorithm`
+                is not supported for the given `method`.
+        """
+        self.__logger = logging.getLogger(__name__)
+
+        self.__bearer_cache = {}
+        self.__terminators = {}
+        self.__validator = None
+        self.__secret = secret
+
+        self._method = method
+        self._algorithm = algorithm
+
+        if self._method == AuthMetadataMethod.JWT:
+            if not HAVE_JWT:
+                raise InvalidArgumentError("JWT authorization method requires PyJWT")
+
+            try:
+                jwt.register_algorithm(self._algorithm.value, None)
+            except TypeError:
+                raise InvalidArgumentError("Algorithm not supported for JWT decoding: [{}]"
+                                           .format(self._algorithm))
+            except ValueError:
+                pass
+
+            self.__validator = self._validate_jwt_token
+
+        for code, message in self.__auth_errors.items():
+            self.__terminators[code] = _unary_unary_rpc_terminator(message)
+
+    # --- Public API ---
+
+    @property
+    def method(self):
+        return self._method
+
+    @property
+    def algorithm(self):
+        return self._algorithm
+
+    def intercept_service(self, continuation, handler_call_details):
+        try:
+            # Reject requests not carrying a token:
+            bearer = dict(handler_call_details.invocation_metadata)['authorization']
+
+        except KeyError:
+            self.__logger.error("Rejecting '{}' request: {}"
+                                .format(handler_call_details.method.split('/')[-1],
+                                        self.__auth_errors['missing-bearer']))
+            return self.__terminators['missing-bearer']
+
+        # Reject requests with malformated bearer:
+        if not bearer.startswith('Bearer '):
+            self.__logger.error("Rejecting '{}' request: {}"
+                                .format(handler_call_details.method.split('/')[-1],
+                                        self.__auth_errors['invalid-bearer']))
+            return self.__terminators['invalid-bearer']
+
+        try:
+            # Hit the cache for already validated token:
+            expiration_time = self.__bearer_cache[bearer]
+
+            # Accept request if cached token hasn't expired yet:
+            if expiration_time < datetime.utcnow():
+                return continuation(handler_call_details)  # Accepted
+
+        except KeyError:
+            pass
+
+        assert self.__validator is not None
+
+        try:
+            # Decode and validate the new token:
+            expiration_time = self.__validator(bearer[7:])
+
+        except _InvalidTokenError as e:
+            self.__logger.error("Rejecting '{}' request: {}; {}"
+                                .format(handler_call_details.method.split('/')[-1],
+                                        self.__auth_errors['invalid-token'], str(e)))
+            return self.__terminators['invalid-token']
+
+        except _ExpiredTokenError as e:
+            self.__logger.error("Rejecting '{}' request: {}; {}"
+                                .format(handler_call_details.method.split('/')[-1],
+                                        self.__auth_errors['expired-token'], str(e)))
+            return self.__terminators['expired-token']
+
+        except _UnboundedTokenError as e:
+            self.__logger.error("Rejecting '{}' request: {}; {}"
+                                .format(handler_call_details.method.split('/')[-1],
+                                        self.__auth_errors['unbounded-token'], str(e)))
+            return self.__terminators['unbounded-token']
+
+        # Cache the validated token and store expiration time:
+        self.__bearer_cache[bearer] = expiration_time
+
+        return continuation(handler_call_details)  # Accepted
+
+    # --- Private API ---
+
+    def _validate_jwt_token(self, token):
+        """Validates a JWT token and returns its expiry date."""
+        try:
+            payload = jwt.decode(
+                token, self.__secret, algorithms=[self._algorithm.value])
+
+        except jwt.exceptions.ExpiredSignatureError as e:
+            raise _ExpiredTokenError(e)
+
+        except jwt.exceptions.InvalidTokenError as e:
+            raise _InvalidTokenError(e)
+
+        if 'exp' not in payload or not isinstance(payload['exp'], int):
+            raise _UnboundedTokenError("Missing 'exp' in payload")
+
+        return datetime.fromtimestamp(payload['exp'])
+
+
+def _unary_unary_rpc_terminator(details):
+
+    def terminate(ignored_request, context):
+        context.abort(grpc.StatusCode.UNAUTHENTICATED, details)
+
+    return grpc.unary_unary_rpc_method_handler(terminate)
-- 
GitLab