Commit 130fdcb8 authored by Szilárd Pfeiffer's avatar Szilárd Pfeiffer
Browse files

common: Implement parser class for protocol message variants

parent 9dd5dcc3
Loading
Loading
Loading
Loading
Loading
+6 −0
Original line number Diff line number Diff line
@@ -155,6 +155,12 @@ class ParserBinary(object):
        except ValueError as e:
            raise InvalidValue(e.args[0], item_base_class)

    def parse_variant(self, name, variant):
        parsed_object, value_length = variant.parse(self._parsable[self._parsed_length:])

        self._parsed_values[name] = parsed_object
        self._parsed_length += value_length


class ComposerBinary(object):
    _INT_FORMATER_BY_SIZE = {
+12 −32
Original line number Diff line number Diff line
@@ -3,12 +3,11 @@

import abc

import cryptoparser.common.utils as utils

from cryptoparser.common.parse import ParsableBase, ParserBinary, ComposerBinary
from cryptoparser.common.exception import NotEnoughData, InvalidValue, InvalidType
from cryptoparser.common.exception import NotEnoughData, InvalidValue
from cryptoparser.tls.version import TlsVersion, TlsProtocolVersionBase, TlsProtocolVersionFinal, SslVersion
from cryptoparser.tls.subprotocol import TlsSubprotocolMessageBase, TlsContentType, SslMessageBase, SslMessageType
from cryptoparser.tls.subprotocol import TlsSubprotocolMessageBase, TlsSubprotocolMessageParser, TlsContentType
from cryptoparser.tls.subprotocol import SslMessageBase, SslMessageType, SslSubprotocolMessageParser


class RecordBase(ParsableBase):
@@ -52,24 +51,15 @@ class TlsRecord(RecordBase):
        if parser.unparsed_length < parser['record_length']:
            raise NotEnoughData(parser['record_length'] - parser.unparsed_length)

        header_size = parser.parsed_length

        messages = []
        while parser.parsed_length < parser['record_length'] + header_size:
            for subclass in utils.get_leaf_classes(TlsSubprotocolMessageBase):
                if subclass.get_content_type() != parser['content_type']:
                    continue

                try:
                    parser.parse_parsable('message', subclass)
        while parser.parsed_length < len(parsable):
            parser.parse_variant('message', TlsSubprotocolMessageParser(parser['content_type']))
            messages.append(parser['message'])
                    break
                except InvalidType:
                    continue
            else:
                raise InvalidValue(parser['content_type'], TlsRecord, 'content type')

        return TlsRecord(messages=messages, protocol_version=parser['protocol_version']), parser.parsed_length
        return TlsRecord(
            messages=messages,
            protocol_version=parser['protocol_version']
        ), parser.parsed_length

    def compose(self):
        body_composer = ComposerBinary()
@@ -133,17 +123,7 @@ class SslRecord(RecordBase):
        except InvalidValue as e:
            raise InvalidValue(e.value, SslMessageType)

        for subclass in utils.get_leaf_classes(SslMessageBase):
            if subclass.get_message_type() != parser['message_type']:
                continue

            try:
                parser.parse_parsable('message', subclass)
                break
            except InvalidValue:
                continue
        else:
            raise InvalidValue(parser['message_type'], SslRecord, 'message type')
        parser.parse_variant('message', SslSubprotocolMessageParser(parser['message_type']))
        parser.parse_bytes('padding', padding_length)

        return SslRecord(message=parser['message']), parser.parsed_length
+98 −0
Original line number Diff line number Diff line
@@ -24,6 +24,62 @@ class TlsContentType(enum.IntEnum):
    HEARTBEAT = 0x18


class SubprotocolParser(object):
    def __init__(self, subprotocol_type):
        self._subprotocol_type = subprotocol_type

    @classmethod
    @abc.abstractmethod
    def _get_subprotocol_parsers(cls):
        raise NotImplementedError()

    @classmethod
    def register_subprotocol_parser(cls, subprotocol_type, parsable_class):
        subprotocol_parsers = cls._get_subprotocol_parsers()
        subprotocol_parsers[subprotocol_type] = parsable_class

    def parse(self, parsable):
        subprotocol_parsers = self._get_subprotocol_parsers()

        if self._subprotocol_type in subprotocol_parsers:
            parsed_object, unparsed_bytes = subprotocol_parsers[self._subprotocol_type].parse_immutable(parsable)
            return parsed_object, len(parsable) - len(unparsed_bytes)

        raise InvalidValue(self._subprotocol_type, TlsSubprotocolMessageBase)


class VariantParsable(ParsableBase):
    def __init__(self, variant):
        self._variant = variant

    @classmethod
    @abc.abstractmethod
    def _get_variants(cls):
        raise NotImplementedError()

    @classmethod
    def register_variant_parser(cls, variant_tag, parsable_class):
        variants = cls._get_variants()
        variants[variant_tag] = parsable_class

    @classmethod
    def _parse(cls, parsable):
        for variant_parser in cls._get_variants().values():
            try:
                parsed_object, unparsed_bytes = variant_parser.parse_immutable(parsable)
                return cls(parsed_object), len(parsable) - len(unparsed_bytes)
            except InvalidType:
                continue

        raise InvalidValue(parsable, cls)

    def compose(self):
        return self._variant.compose()

    def __getattr__(self, name):
        return getattr(self._variant, name)


class TlsSubprotocolMessageBase(ParsableBase):
    @classmethod
    @abc.abstractmethod
@@ -127,6 +183,9 @@ class TlsAlertMessage(TlsSubprotocolMessageBase):
        return self.level == other.level and self.description == other.description


TlsSubprotocolMessageBase.register(TlsAlertMessage)


class TlsChangeCipherSpecType(enum.IntEnum):
    CHANGE_CIPHER_SPEC = 0x01

@@ -785,3 +844,42 @@ class SslHandshakeServerHello(SslMessageBase):
        composer.compose_bytes(self.connection_id)

        return composer.composed_bytes


class TlsHandshakeMessageVariant(VariantParsable):
    _VARIANTS = {
        TlsHandshakeType.CLIENT_HELLO: TlsHandshakeClientHello,
        TlsHandshakeType.SERVER_HELLO: TlsHandshakeServerHello,
        TlsHandshakeType.CERTIFICATE: TlsHandshakeCertificate,
        TlsHandshakeType.SERVER_KEY_EXCHANGE: TlsHandshakeServerKeyExchange,
        TlsHandshakeType.SERVER_HELLO_DONE: TlsHandshakeServerHelloDone,
    }

    @classmethod
    def _get_variants(cls):
        return cls._VARIANTS


class TlsSubprotocolMessageParser(SubprotocolParser):
    _SUBPROTOCOL_PARSERS = {
        TlsContentType.CHANGE_CIPHER_SPEC: TlsChangeCipherSpecMessage,
        TlsContentType.ALERT: TlsAlertMessage,
        TlsContentType.HANDSHAKE: TlsHandshakeMessageVariant,
        TlsContentType.APPLICATION_DATA: TlsApplicationDataMessage,
    }

    @classmethod
    def _get_subprotocol_parsers(cls):
        return cls._SUBPROTOCOL_PARSERS


class SslSubprotocolMessageParser(SubprotocolParser):
    _SUBPROTOCOL_PARSERS = {
        SslMessageType.ERROR: SslError,
        SslMessageType.CLIENT_HELLO: SslHandshakeClientHello,
        SslMessageType.SERVER_HELLO: SslHandshakeServerHello,
    }

    @classmethod
    def _get_subprotocol_parsers(cls):
        return cls._SUBPROTOCOL_PARSERS

tests/tls/classes.py

0 → 100644
+23 −0
Original line number Diff line number Diff line
#!/usr/bin/env python
# -*- coding: utf-8 -*-

from cryptoparser.tls.subprotocol import TlsSubprotocolMessageBase, TlsHandshakeMessage, TlsHandshakeType


class TestMessage(TlsSubprotocolMessageBase):
    @classmethod
    def get_handshake_type(cls):
        raise NotImplementedError


class TestVariantMessage(TlsHandshakeMessage):
    @classmethod
    def get_handshake_type(cls):
        return TlsHandshakeType.SERVER_HELLO_DONE

    @classmethod
    def _parse(cls, parsable):
        raise NotImplementedError

    def compose(self):
        raise NotImplementedError
+65 −1
Original line number Diff line number Diff line
@@ -10,15 +10,79 @@ from cryptoparser.common.exception import InvalidValue, InvalidType, NotEnoughDa

from cryptoparser.tls.ciphersuite import TlsCipherSuite, SslCipherKind
from cryptoparser.tls.extension import TlsExtensionSupportedVersions
from cryptoparser.tls.subprotocol import TlsSubprotocolMessageParser, TlsHandshakeMessageVariant
from cryptoparser.tls.subprotocol import TlsHandshakeClientHello, TlsHandshakeServerHello, TlsHandshakeHelloRandom
from cryptoparser.tls.subprotocol import TlsCipherSuiteVector, TlsCompressionMethodVector, TlsCompressionMethod
from cryptoparser.tls.subprotocol import TlsSessionIdVector, TlsExtensions, TlsContentType, TlsHandshakeType
from cryptoparser.tls.subprotocol import TlsHandshakeCertificate, TlsCertificates, TlsCertificate
from cryptoparser.tls.subprotocol import TlsHandshakeServerHelloDone, TlsHandshakeServerKeyExchange
from cryptoparser.tls.subprotocol import TlsHandshakeServerHelloDone, TlsHandshakeServerKeyExchange, TlsAlertMessage
from cryptoparser.tls.subprotocol import SslMessageType, SslHandshakeClientHello, SslHandshakeServerHello
from cryptoparser.tls.record import TlsRecord
from cryptoparser.tls.version import TlsVersion, TlsProtocolVersionFinal

from tests.tls.classes import TestMessage, TestVariantMessage


class TestSubprotocolParser(unittest.TestCase):
    def test_registered_parser(self):
        tls_message_bytes = bytes(
            b'\x02' +      # level = FATAL
            b'\x28' +      # description = HANDSHAKE_FAILURE
            b''
        )
        tls_parser = TlsSubprotocolMessageParser(TlsContentType.ALERT)
        tls_parser.parse(tls_message_bytes)

        tls_parser.register_subprotocol_parser(TlsContentType.ALERT, TestMessage)
        with self.assertRaises(NotImplementedError):
            tls_parser.parse(tls_message_bytes)

        tls_parser.register_subprotocol_parser(TlsContentType.ALERT, TlsAlertMessage)
        parsed_object, _ = tls_parser.parse(tls_message_bytes)
        self.assertEqual(parsed_object.compose(), tls_message_bytes)


class TestVariantParsable(unittest.TestCase):
    def setUp(self):
        self.server_hello_done_bytes = bytes(
            b'\x0e' +                              # handshake_type = SERVER_HELLO_DONE
            b'\x00\x00\x00' +                      # length = 0x00
            b''
        )
        self.server_hello_done = TlsHandshakeServerHelloDone()

    def test_error(self):
        invalid_tls_message_bytes = bytes(
            b'\x16' +      # type = HANDSHAKE
            b'\x03\x01' +  # version = TLS1_0
            b'\x00\x01' +  # length = 2
            b'\xff' +
            b''
        )

        with six.assertRaisesRegex(self, InvalidValue, 'is not a valid TlsHandshakeMessageVariant'):
            TlsHandshakeMessageVariant.parse_exact_size(invalid_tls_message_bytes)

    def test_compose(self):
        self.assertEqual(TlsHandshakeMessageVariant(self.server_hello_done).compose(), self.server_hello_done_bytes)

    def test_registered_parser(self):
        message = TlsHandshakeMessageVariant.parse_exact_size(self.server_hello_done_bytes)
        self.assertEqual(message.compose(), self.server_hello_done_bytes)

        TlsHandshakeMessageVariant.register_variant_parser(TlsHandshakeType.SERVER_HELLO_DONE, TestVariantMessage)
        with self.assertRaises(NotImplementedError):
            TlsHandshakeMessageVariant.parse_exact_size(self.server_hello_done_bytes)

        TlsHandshakeMessageVariant.register_variant_parser(
            TlsHandshakeType.SERVER_HELLO_DONE,
            TlsHandshakeServerHelloDone
        )
        parsed_object = TlsHandshakeMessageVariant.parse_exact_size(self.server_hello_done_bytes)
        self.assertEqual(parsed_object.compose(), self.server_hello_done_bytes)
        self.assertEqual(parsed_object.get_content_type(), TlsContentType.HANDSHAKE)
        self.assertEqual(parsed_object.get_handshake_type(), TlsHandshakeType.SERVER_HELLO_DONE)


class TestTlsHandshake(unittest.TestCase):
    def setUp(self):