Commit d93c9acd authored by Szilárd Pfeiffer's avatar Szilárd Pfeiffer
Browse files

Merge branch '40-tls-extension-parser'

Closes: #40
parents cd571f51 fcb11882
Loading
Loading
Loading
Loading
Loading
+15 −0
Original line number Diff line number Diff line
@@ -2,6 +2,21 @@
Changelog
=========

.. _v0-7-0:

* TLS (``tls``)

  * Extensions (``extensions``)

    * add `application-layer protocol negotiation <https://www.rfc-editor.org/rfc/rfc5077.html>`_ extension related
      messages (#40)
    * add `encrypt-then-MAC <https://www.rfc-editor.org/rfc/rfc7366.html>`_ extension related messages (#40)
    * add `extended master secret <https://www.rfc-editor.org/rfc/rfc7627.html>`_ extension related messages (#40)
    * add `next protocol negotiation <https://tools.ietf.org/id/draft-agl-tls-nextprotoneg-03.html>`_ extension related
      messages (#40)
    * add `renegotiation indication <https://www.rfc-editor.org/rfc/rfc5746.html>`_ extension related messages (#40)
    * add `session ticket <https://www.rfc-editor.org/rfc/rfc5077.html>`_ extension related messages (#40)

.. _v0-6-0:

0.6.0 - 2021-05-27
+52 −6
Original line number Diff line number Diff line
@@ -391,8 +391,11 @@ class ArrayBase(ParsableBase, MutableSequence, Serializable):
    def append(self, value):
        self.insert(len(self._items), value)

    def _asdict(self):
        return self._items

    def _as_markdown(self, level):
        return self._markdown_result(self._items, level)
        return self._markdown_result(self._asdict(), level)


class Vector(ArrayBase):
@@ -526,13 +529,12 @@ class VectorParsableDerived(ArrayBase):
        return header_composer.composed_bytes + body_composer.composed_bytes


@attr.s(init=False)
class Opaque(ArrayBase):
    def __init__(self, items):
        if isinstance(items, (bytes, bytearray)):
            items = [ord(items[i:i + 1]) for i in range(len(items))]
    def __attrs_post_init__(self):
        if isinstance(self._items, (bytes, bytearray)):
            self._items = [ord(self._items[i:i + 1]) for i in range(len(self._items))]

        super(Opaque, self).__init__(items)
        super(Opaque, self).__attrs_post_init__()

    @classmethod
    def _parse(cls, parsable):
@@ -769,3 +771,47 @@ class ListParsable(ArrayBase):
            composer.compose_raw(separator)

        return composer.composed_bytes


class OpaqueEnumParsable(Vector):
    @classmethod
    def _parse(cls, parsable):
        opaque, parsed_length = super(OpaqueEnumParsable, cls)._parse(parsable)
        code = bytearray(opaque).decode(cls.get_encoding())

        try:
            parsed_object = next(iter([
                enum_item
                for enum_item in cls.get_enum_class()
                if enum_item.value.code == code
            ]))
        except StopIteration as e:
            six.raise_from(InvalidValue(code.encode(cls.get_encoding()), cls), e)

        return parsed_object, parsed_length

    @classmethod
    @abc.abstractmethod
    def get_enum_class(cls):
        raise NotImplementedError()

    @classmethod
    def get_encoding(cls):
        return 'utf-8'


class OpaqueEnumComposer(enum.Enum):
    def __repr__(self):
        return self.__class__.__name__ + '.' + self.name

    def compose(self):
        composer = ComposerBinary()
        value = self.value.code.encode(self.get_encoding())  # pylint: disable=no-member

        composer.compose_bytes(value, 1)

        return composer.composed_bytes

    @classmethod
    def get_encoding(cls):
        return 'utf-8'
+106 −3
Original line number Diff line number Diff line
# -*- coding: utf-8 -*-

import abc
import enum

import attr

from cryptoparser.common.algorithm import Authentication, Hash, NamedGroup
from cryptoparser.common.base import OneByteEnumParsable, OneByteEnumComposer, TwoByteEnumComposer, TwoByteEnumParsable
from cryptoparser.common.base import (
    OneByteEnumComposer,
    OneByteEnumParsable,
    OpaqueEnumComposer,
    TwoByteEnumComposer,
    TwoByteEnumParsable
)


class TlsNamedCurveFactory(TwoByteEnumParsable):
@@ -450,7 +455,105 @@ class TlsECPointFormatParams(object):
    code = attr.ib(validator=attr.validators.instance_of(int))


class TlsECPointFormat(OneByteEnumComposer, enum.Enum):
class TlsECPointFormat(OneByteEnumComposer):
    UNCOMPRESSED = TlsECPointFormatParams(code=0x00)
    ANSIX962_COMPRESSED_PRIME = TlsECPointFormatParams(code=0x01)
    ANSIX962_COMPRESSED_CHAR2 = TlsECPointFormatParams(code=0x02)


@attr.s(frozen=True)
class TlsProtocolNameParams(object):
    code = attr.ib(validator=attr.validators.instance_of(str))


class TlsProtocolName(OpaqueEnumComposer):
    C_WEBRTC = TlsProtocolNameParams(
        code='c-webrtc',
    )
    COAP = TlsProtocolNameParams(
        code='coap',
    )
    FTP = TlsProtocolNameParams(
        code='ftp',
    )
    H2 = TlsProtocolNameParams(
        code='h2',
    )
    H2_14 = TlsProtocolNameParams(
        code='h2-14',
    )
    H2_15 = TlsProtocolNameParams(
        code='h2-15',
    )
    H2_16 = TlsProtocolNameParams(
        code='h2-16',
    )
    H2C = TlsProtocolNameParams(
        code='h2c',
    )
    HTTP_0_9 = TlsProtocolNameParams(
        code='http/0.9',
    )
    HTTP_1_0 = TlsProtocolNameParams(
        code='http/1.0',
    )
    HTTP_1_1 = TlsProtocolNameParams(
        code='http/1.1',
    )
    IMAP = TlsProtocolNameParams(
        code='imap',
    )
    MANAGESIEVE = TlsProtocolNameParams(
        code='managesieve',
    )
    POP3 = TlsProtocolNameParams(
        code='pop3',
    )
    SPDY_1 = TlsProtocolNameParams(
        code='spdy/1',
    )
    SPDY_2 = TlsProtocolNameParams(
        code='spdy/2',
    )
    SPDY_3 = TlsProtocolNameParams(
        code='spdy/3',
    )
    SPDY_3_1 = TlsProtocolNameParams(
        code='spdy/3.1',
    )
    STUN_NAT_DISCOVERY = TlsProtocolNameParams(
        code='stun.nat-discovery',
    )
    STUN_TURN = TlsProtocolNameParams(
        code='stun.turn',
    )
    WEBRTC = TlsProtocolNameParams(
        code='webrtc',
    )
    XMPP_CLIENT = TlsProtocolNameParams(
        code='xmpp-client',
    )
    XMPP_SERVER = TlsProtocolNameParams(
        code='xmpp-server',
    )


class TlsNextProtocolName(OpaqueEnumComposer):
    HTTP_1_1 = TlsProtocolNameParams(
        code='http/1.1',
    )
    SPDY_1 = TlsProtocolNameParams(
        code='spdy/1',
    )
    SPDY_2 = TlsProtocolNameParams(
        code='spdy/2',
    )
    SPDY_3 = TlsProtocolNameParams(
        code='spdy/3',
    )
    SPDY_3_1 = TlsProtocolNameParams(
        code='spdy/3.1',
    )
    SPDY_4_A_2 = TlsProtocolNameParams(
        code='spdy/4a2',
    )
+232 −8
Original line number Diff line number Diff line
@@ -7,10 +7,12 @@ import enum
import six
import attr

from cryptoparser.tls.algorithm import TlsNextProtocolName, TlsProtocolName
from cryptoparser.common.base import (
    Opaque,
    OpaqueParam,
    TwoByteEnumComposer,
    OpaqueEnumParsable,
    TwoByteEnumParsable,
    VariantParsable,
    Vector,
@@ -19,7 +21,7 @@ from cryptoparser.common.base import (
    VectorParsable,
    VectorParsableDerived,
)
from cryptoparser.common.exception import NotEnoughData, InvalidType
from cryptoparser.common.exception import NotEnoughData, InvalidType, InvalidValue
from cryptoparser.common.parse import ParsableBase, ParserBinary, ComposerBinary
from cryptoparser.tls.algorithm import (
    TlsNamedCurve,
@@ -301,7 +303,7 @@ class TlsExtensionUnparsed(TlsExtensionBase):

    @classmethod
    def _parse(cls, parsable):
        parser = super(TlsExtensionUnparsed, cls)._check_header(parsable)
        parser = cls._check_header(parsable)

        parser.parse_raw('extension_data', parser['extension_length'])

@@ -311,12 +313,7 @@ class TlsExtensionUnparsed(TlsExtensionBase):
        composer.compose_parsable(self.extension_type)

    def compose(self):
        payload_composer = ComposerBinary()
        payload_composer.compose_raw(self.extension_data)

        header_bytes = self._compose_header(payload_composer.composed_length)

        return header_bytes + payload_composer.composed_bytes
        return self._compose_header(len(self.extension_data)) + self.extension_data


@attr.s
@@ -348,6 +345,27 @@ class TlsExtensionParsed(TlsExtensionBase):
        return parser


class TlsExtensionUnusedData(TlsExtensionParsed):
    @classmethod
    def _parse(cls, parsable):
        parser = cls._parse_header(parsable)

        parser.parse_raw('extension_data', parser['extension_length'])

        if parser['extension_data']:
            raise InvalidValue(parser['extension_data'], cls)

        return cls(), parser.parsed_length

    def compose(self):
        return self._compose_header(0)

    @classmethod
    @abc.abstractmethod
    def get_extension_type(cls):
        raise NotImplementedError()


class TlsServerNameType(enum.IntEnum):
    HOST_NAME = 0x00

@@ -823,6 +841,197 @@ class TlsExtensionCertificateStatusRequest(TlsExtensionParsed):
        return header_bytes + payload_composer.composed_bytes


class TlsRenegotiatedConnection(Opaque):
    @classmethod
    def get_param(cls):
        return OpaqueParam(
            min_byte_num=0,
            max_byte_num=2 ** 8 - 1,
        )


@attr.s
class TlsExtensionRenegotiationInfo(TlsExtensionParsed):
    renegotiated_connection = attr.ib(
        default=TlsRenegotiatedConnection([]),
        validator=attr.validators.instance_of(TlsRenegotiatedConnection)
    )

    @classmethod
    def get_extension_type(cls):
        return TlsExtensionType.RENEGOTIATION_INFO

    @classmethod
    def _parse(cls, parsable):
        parser = super(TlsExtensionRenegotiationInfo, cls)._parse_header(parsable)

        parser.parse_parsable('renegotiated_connection', TlsRenegotiatedConnection)

        return TlsExtensionRenegotiationInfo(parser['renegotiated_connection']), parser.parsed_length

    def compose(self):
        payload_composer = ComposerBinary()

        payload_composer.compose_parsable(self.renegotiated_connection)

        header_bytes = self._compose_header(payload_composer.composed_length)

        return header_bytes + payload_composer.composed_bytes


@attr.s
class TlsExtensionSessionTicket(TlsExtensionParsed):
    session_ticket = attr.ib(
        default=bytearray([]),
        validator=attr.validators.instance_of((bytes, bytearray))
    )

    @classmethod
    def get_extension_type(cls):
        return TlsExtensionType.SESSION_TICKET

    @classmethod
    def _parse(cls, parsable):
        parser = super(TlsExtensionSessionTicket, cls)._parse_header(parsable)

        parser.parse_raw('session_ticket', parser['extension_length'])

        return TlsExtensionSessionTicket(parser['session_ticket']), parser.parsed_length

    def compose(self):
        payload_composer = ComposerBinary()

        payload_composer.compose_raw(self.session_ticket)

        header_bytes = self._compose_header(payload_composer.composed_length)

        return header_bytes + payload_composer.composed_bytes


class TlsProtocolNameFactory(OpaqueEnumParsable):
    @classmethod
    def get_enum_class(cls):
        return TlsProtocolName

    @classmethod
    def get_param(cls):
        return OpaqueParam(
            min_byte_num=1, max_byte_num=2 ** 8 - 1
        )


class TlsProtocolNameList(VectorParsable):
    @classmethod
    def get_param(cls):
        return VectorParamParsable(
            item_class=TlsProtocolNameFactory,
            fallback_class=None,
            min_byte_num=2, max_byte_num=2 ** 16 - 1
        )


@attr.s
class TlsExtensionApplicationLayerProtocolNegotiation(TlsExtensionParsed):
    protocol_names = attr.ib(
        converter=TlsProtocolNameList,
        validator=attr.validators.instance_of(TlsProtocolNameList),
    )

    @classmethod
    def get_extension_type(cls):
        return TlsExtensionType.APPLICATION_LAYER_PROTOCOL_NEGOTIATION

    @classmethod
    def _parse(cls, parsable):
        parser = super(TlsExtensionApplicationLayerProtocolNegotiation, cls)._parse_header(parsable)

        parser.parse_parsable('protocol_names', TlsProtocolNameList)

        return TlsExtensionApplicationLayerProtocolNegotiation(parser['protocol_names']), parser.parsed_length

    def compose(self):
        payload_composer = ComposerBinary()

        payload_composer.compose_parsable(self.protocol_names)

        header_bytes = self._compose_header(payload_composer.composed_length)

        return header_bytes + payload_composer.composed_bytes


class TlsExtensionNextProtocolNegotiationClient(TlsExtensionUnusedData):
    @classmethod
    def get_extension_type(cls):
        return TlsExtensionType.NEXT_PROTOCOL_NEGOTIATION


class TlsNextProtocolNameFactory(OpaqueEnumParsable):
    @classmethod
    def get_enum_class(cls):
        return TlsNextProtocolName

    @classmethod
    def get_param(cls):
        return OpaqueParam(
            min_byte_num=1, max_byte_num=2 ** 8 - 1
        )


class TlsNextProtocolNameList(VectorParsable):
    @classmethod
    def get_param(cls):
        return VectorParamParsable(
            item_class=TlsNextProtocolNameFactory,
            fallback_class=None,
            min_byte_num=1, max_byte_num=2 ** 16 - 1
        )


@attr.s
class TlsExtensionNextProtocolNegotiationServer(TlsExtensionParsed):
    protocol_names = attr.ib(
        converter=TlsNextProtocolNameList,
        validator=attr.validators.instance_of(TlsNextProtocolNameList),
    )

    @classmethod
    def get_extension_type(cls):
        return TlsExtensionType.NEXT_PROTOCOL_NEGOTIATION

    @classmethod
    def _parse(cls, parsable):
        parser = ParserBinary(parsable)

        cls._parse_type(parser, 'extension_type')
        if parser['extension_type'] != cls.get_extension_type():
            raise InvalidType()

        parser.parse_parsable('protocol_names', TlsNextProtocolNameList)

        return cls(parser['protocol_names']), parser.parsed_length

    def compose(self):
        payload_composer = ComposerBinary()
        payload_composer.compose_parsable(self.protocol_names)

        header_composer = ComposerBinary()
        self._compose_type(header_composer)

        return header_composer.composed_bytes + payload_composer.composed_bytes


class TlsExtensionEncryptThenMAC(TlsExtensionUnusedData):
    @classmethod
    def get_extension_type(cls):
        return TlsExtensionType.ENCRYPT_THEN_MAC


class TlsExtensionExtendedMasterSecret(TlsExtensionUnusedData):
    @classmethod
    def get_extension_type(cls):
        return TlsExtensionType.EXTENDED_MASTER_SECRET


class TlsExtensionVariantBase(VariantParsable):
    @classmethod
    @abc.abstractmethod
@@ -846,7 +1055,14 @@ class TlsExtensionVariantClient(TlsExtensionVariantBase):
    @classmethod
    def _get_parsed_extensions(cls):
        return collections.OrderedDict([
            (TlsExtensionType.APPLICATION_LAYER_PROTOCOL_NEGOTIATION,
                [TlsExtensionApplicationLayerProtocolNegotiation, ]),
            (TlsExtensionType.ENCRYPT_THEN_MAC, [TlsExtensionEncryptThenMAC, ]),
            (TlsExtensionType.EXTENDED_MASTER_SECRET, [TlsExtensionExtendedMasterSecret, ]),
            (TlsExtensionType.RENEGOTIATION_INFO, [TlsExtensionRenegotiationInfo, ]),
            (TlsExtensionType.NEXT_PROTOCOL_NEGOTIATION, [TlsExtensionNextProtocolNegotiationClient, ]),
            (TlsExtensionType.SERVER_NAME, [TlsExtensionServerName, ]),
            (TlsExtensionType.SESSION_TICKET, [TlsExtensionSessionTicket, ]),
            (TlsExtensionType.SUPPORTED_GROUPS, [TlsExtensionEllipticCurves, ]),
            (TlsExtensionType.EC_POINT_FORMATS, [TlsExtensionECPointFormats, ]),
            (TlsExtensionType.KEY_SHARE, [TlsExtensionKeyShareClient, ]),
@@ -859,6 +1075,14 @@ class TlsExtensionVariantServer(TlsExtensionVariantBase):
    @classmethod
    def _get_parsed_extensions(cls):
        return collections.OrderedDict([
            (TlsExtensionType.APPLICATION_LAYER_PROTOCOL_NEGOTIATION,
                [TlsExtensionApplicationLayerProtocolNegotiation, ]),
            (TlsExtensionType.EC_POINT_FORMATS, [TlsExtensionECPointFormats, ]),
            (TlsExtensionType.ENCRYPT_THEN_MAC, [TlsExtensionEncryptThenMAC, ]),
            (TlsExtensionType.EXTENDED_MASTER_SECRET, [TlsExtensionExtendedMasterSecret, ]),
            (TlsExtensionType.KEY_SHARE, [TlsExtensionKeyShareClientHelloRetry, TlsExtensionKeyShareServer]),
            (TlsExtensionType.NEXT_PROTOCOL_NEGOTIATION, [TlsExtensionNextProtocolNegotiationServer, ]),
            (TlsExtensionType.RENEGOTIATION_INFO, [TlsExtensionRenegotiationInfo, ]),
            (TlsExtensionType.SESSION_TICKET, [TlsExtensionSessionTicket, ]),
            (TlsExtensionType.SUPPORTED_VERSIONS, [TlsExtensionSupportedVersionsServer, ]),
        ])
+30 −10
Original line number Diff line number Diff line
@@ -12,6 +12,8 @@ import attr
import six

from cryptoparser.common.base import (
    OneByteEnumComposer,
    OneByteEnumParsable,
    Opaque,
    OpaqueParam,
    VariantParsable,
@@ -384,18 +386,35 @@ class TlsCipherSuiteVector(VectorParsable):
        )


class TlsCompressionMethod(enum.IntEnum):
    NULL = 0
class TlsCompressionMethodFactory(OneByteEnumParsable):
    @classmethod
    def get_enum_class(cls):
        return TlsCompressionMethod

    @abc.abstractmethod
    def compose(self):
        raise NotImplementedError()


@attr.s(frozen=True)
class TlsCompressionMethodParams(object):
    code = attr.ib(validator=attr.validators.instance_of(int))


class TlsCompressionMethodVector(Vector):
class TlsCompressionMethod(OneByteEnumComposer, enum.Enum):
    NULL = TlsCompressionMethodParams(code=0x00)
    DEFLATE = TlsCompressionMethodParams(code=0x01)
    LZS = TlsCompressionMethodParams(code=0x40)


class TlsCompressionMethodVector(VectorParsable):
    @classmethod
    def get_param(cls):
        return VectorParamNumeric(
            item_size=1,
        return VectorParamParsable(
            item_class=TlsCompressionMethodFactory,
            fallback_class=TlsInvalidTypeOneByte,
            min_byte_num=1,
            max_byte_num=2 ** 8 - 1,
            numeric_class=TlsCompressionMethod
        )


@@ -552,6 +571,7 @@ class TlsHandshakeServerHello(TlsHandshakeHello):
    )
    compression_method = attr.ib(
        default=TlsCompressionMethod.NULL,
        converter=TlsCompressionMethod,
        validator=attr.validators.in_(TlsCompressionMethod),
    )
    cipher_suite = attr.ib(default=None, validator=attr.validators.in_(TlsCipherSuite))
@@ -572,7 +592,7 @@ class TlsHandshakeServerHello(TlsHandshakeHello):
        parser = cls._parse_hello_header(handshake_header_parser['payload'])

        parser.parse_parsable('cipher_suite', TlsCipherSuiteFactory)
        parser.parse_numeric('compression_method', 1, TlsCompressionMethod)
        parser.parse_parsable('compression_method', TlsCompressionMethodFactory)

        extension_parser = cls._parse_extensions(handshake_header_parser, parser, TlsExtensionsServer)

@@ -592,7 +612,7 @@ class TlsHandshakeServerHello(TlsHandshakeHello):
        payload_composer.compose_parsable(self.random)
        payload_composer.compose_parsable(self.session_id)
        payload_composer.compose_parsable(self.cipher_suite)
        payload_composer.compose_numeric(self.compression_method.value, 1)
        payload_composer.compose_parsable(self.compression_method)

        extension_bytes = self._compose_extensions(self.extensions)

@@ -886,7 +906,7 @@ class TlsHandshakeHelloRetryRequest(TlsHandshakeHello):

        parser.parse_parsable('cipher_suite', TlsCipherSuiteFactory)

        parser.parse_numeric('compression_method', 1)
        parser.parse_parsable('compression_method', TlsCompressionMethodFactory)
        compression_method = parser['compression_method']
        session_id = parser['session_id']

@@ -908,7 +928,7 @@ class TlsHandshakeHelloRetryRequest(TlsHandshakeHello):
        payload_composer.compose_parsable(self.random_bytes)
        payload_composer.compose_parsable(self.session_id)
        payload_composer.compose_parsable(self.cipher_suite)
        payload_composer.compose_numeric(self.compression_method.value, 1)
        payload_composer.compose_parsable(self.compression_method)

        extension_bytes = self._compose_extensions(self.extensions)

Loading