Skip to content
Snippets Groups Projects
Commit 0fb70b6e authored by Martin Blanchard's avatar Martin Blanchard
Browse files

server/_authentication.py: New JWT based gRPC interceptor

parent 11c9e26b
No related branches found
No related tags found
No related merge requests found
# Copyright (C) 2018 Bloomberg LP
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# <http://www.apache.org/licenses/LICENSE-2.0>
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from datetime import datetime
from enum import Enum
import logging
import grpc
from buildgrid._exceptions import InvalidArgumentError
try:
import jwt
except ImportError:
HAVE_JWT = False
else:
HAVE_JWT = True
class AuthMetadataMethod(Enum):
# No authentication:
NONE = 'none'
# JWT based authentication:
JWT = 'JWT'
class AuthMetadataAlgorithm(Enum):
# No encryption involved:
NONE = 'none'
# JWT related algorithms:
JWT_ES256 = 'ES256' # ECDSA signature algorithm using SHA-256 hash algorithm
JWT_ES384 = 'ES384' # ECDSA signature algorithm using SHA-384 hash algorithm
JWT_ES512 = 'ES512' # ECDSA signature algorithm using SHA-512 hash algorithm
JWT_HS256 = 'HS256' # HMAC using SHA-256 hash algorithm
JWT_HS384 = 'HS384' # HMAC using SHA-384 hash algorithm
JWT_HS512 = 'HS512' # HMAC using SHA-512 hash algorithm
JWT_PS256 = 'PS256' # RSASSA-PSS using SHA-256 and MGF1 padding with SHA-256
JWT_PS384 = 'PS384' # RSASSA-PSS signature using SHA-384 and MGF1 padding with SHA-384
JWT_PS512 = 'PS512' # RSASSA-PSS signature using SHA-512 and MGF1 padding with SHA-512
JWT_RS256 = 'RS256' # RSASSA-PKCS1-v1_5 signature algorithm using SHA-256 hash algorithm
JWT_RS384 = 'RS384' # RSASSA-PKCS1-v1_5 signature algorithm using SHA-384 hash algorithm
JWT_RS512 = 'RS512' # RSASSA-PKCS1-v1_5 signature algorithm using SHA-512 hash algorithm
class _InvalidTokenError(Exception):
pass
class _ExpiredTokenError(Exception):
pass
class _UnboundedTokenError(Exception):
pass
class AuthMetadataServerInterceptor(grpc.ServerInterceptor):
__auth_errors = {
'missing-bearer': 'Missing authentication header field',
'invalid-bearer': 'Invalid authentication header field',
'invalid-token': 'Invalid authentication token',
'expired-token': 'Expired authentication token',
'unbounded-token': 'Unbounded authentication token',
}
def __init__(self, method, secret=None, algorithm=AuthMetadataAlgorithm.NONE):
"""Initialises a new :class:`AuthMetadataServerInterceptor`.
Args:
method (AuthMetadataMethod): Type of authorization method.
secret (str): The secret or key to be used for validating request,
depending on `method`. Defaults to ``None``.
algorithm (AuthMetadataAlgorithm): The crytographic algorithm used
to encode `secret`. Defaults to ``AuthMetadataAlgorithm.NONE``.
Raises:
InvalidArgumentError: If `method` is not supported or if `algorithm`
is not supported for the given `method`.
"""
self.__logger = logging.getLogger(__name__)
self.__bearer_cache = {}
self.__terminators = {}
self.__validator = None
self.__secret = secret
self._method = method
self._algorithm = algorithm
if self._method == AuthMetadataMethod.JWT:
if not HAVE_JWT:
raise InvalidArgumentError("JWT authorization method requires PyJWT")
try:
jwt.register_algorithm(self._algorithm.value, None)
except TypeError:
raise InvalidArgumentError("Algorithm not supported for JWT decoding: [{}]"
.format(self._algorithm))
except ValueError:
pass
self.__validator = self._validate_jwt_token
for code, message in self.__auth_errors.items():
self.__terminators[code] = _unary_unary_rpc_terminator(message)
@property
def method(self):
return self._method
@property
def algorithm(self):
return self._algorithm
def intercept_service(self, continuation, handler_call_details):
try:
# Reject requests not carrying a token:
bearer = dict(handler_call_details.invocation_metadata)['authorization']
except KeyError:
self.__logger.error("Rejecting '{}' request: {}"
.format(handler_call_details.method.split('/')[-1],
self.__auth_errors['missing-bearer']))
return self.__terminators['missing-bearer']
# Reject requests with malformated bearer:
if not bearer.startswith('Bearer '):
self.__logger.error("Rejecting '{}' request: {}"
.format(handler_call_details.method.split('/')[-1],
self.__auth_errors['invalid-bearer']))
return self.__terminators['invalid-bearer']
try:
# Hit the cache for already validated token:
expiration_time = self.__bearer_cache[bearer]
# Accept request if cached token hasn't expired yet:
if expiration_time < datetime.utcnow():
return continuation(handler_call_details) # Accepted
except KeyError:
pass
assert self.__validator is not None
try:
# Decode and validate the new token:
expiration_time = self.__validator(bearer[7:])
except _InvalidTokenError as e:
self.__logger.error("Rejecting '{}' request: {}; {}"
.format(handler_call_details.method.split('/')[-1],
self.__auth_errors['invalid-token'], str(e)))
return self.__terminators['invalid-token']
except _ExpiredTokenError as e:
self.__logger.error("Rejecting '{}' request: {}; {}"
.format(handler_call_details.method.split('/')[-1],
self.__auth_errors['expired-token'], str(e)))
return self.__terminators['expired-token']
except _UnboundedTokenError as e:
self.__logger.error("Rejecting '{}' request: {}; {}"
.format(handler_call_details.method.split('/')[-1],
self.__auth_errors['unbounded-token'], str(e)))
return self.__terminators['unbounded-token']
# Cache the validated token and store expiration time:
self.__bearer_cache[bearer] = expiration_time
return continuation(handler_call_details) # Accepted
def _validate_jwt_token(self, token):
try:
payload = jwt.decode(
token, self.__secret, algorithms=[self._algorithm.value])
except jwt.exceptions.ExpiredSignatureError as e:
raise _ExpiredTokenError(e)
except jwt.exceptions.InvalidTokenError as e:
raise _InvalidTokenError(e)
if 'exp' not in payload or not isinstance(payload['exp'], int):
raise _UnboundedTokenError("Missing 'exp' in payload")
return datetime.fromtimestamp(payload['exp'])
def _unary_unary_rpc_terminator(details):
def terminate(ignored_request, context):
context.abort(grpc.StatusCode.UNAUTHENTICATED, details)
return grpc.unary_unary_rpc_method_handler(terminate)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment