Commit cb1644cd authored by Szilárd Pfeiffer's avatar Szilárd Pfeiffer
Browse files

Merge branch 'tls_pubkeys_order'

parents 5894bd18 592a035d
Loading
Loading
Loading
Loading
+44 −0
Original line number Diff line number Diff line
@@ -9,7 +9,11 @@ import ssl
from collections import OrderedDict
from six import iteritems

import cryptography
import cryptography.x509 as cryptography_x509  # pylint: disable=import-error
import cryptography.hazmat.primitives.asymmetric.rsa as cryptography_rsa
import cryptography.hazmat.primitives.asymmetric.ec as cryptography_ec
from cryptography.hazmat.primitives.asymmetric import padding as cryptography_padding
from cryptography.hazmat.primitives import hashes as cryptography_hashes  # pylint: disable=import-error
from cryptography.hazmat.primitives import serialization as cryptography_serialization  # pylint: disable=import-error
from cryptography.hazmat.backends import default_backend as cryptography_default_backend  # pylint: disable=import-error
@@ -132,6 +136,18 @@ class PublicKeyX509(PublicKey):
    def __init__(self, certificate):
        self._certificate = certificate

    def __eq__(self, other):
        self_in_der_format = self._certificate.public_key().public_bytes(
            encoding=cryptography_serialization.Encoding.DER,
            format=cryptography_serialization.PublicFormat.SubjectPublicKeyInfo
        )
        other_in_der_format = other._certificate.public_key().public_bytes(  # pylint: disable=protected-access
            encoding=cryptography_serialization.Encoding.DER,
            format=cryptography_serialization.PublicFormat.SubjectPublicKeyInfo
        )

        return self_in_der_format == other_in_der_format

    @property
    def valid_not_before(self):
        return self._certificate.not_valid_before
@@ -278,6 +294,34 @@ class PublicKeyX509(PublicKey):

        return extension.value.ca

    @property
    def is_self_signed(self):
        return self._certificate.subject and self._certificate.subject == self._certificate.issuer

    def verify(self, public_key):
        verify_args = {
            'signature': public_key._certificate.signature,  # pylint: disable=protected-access
            'data': public_key._certificate.tbs_certificate_bytes,  # pylint: disable=protected-access
        }
        public_key_signature_hash_algorithm = \
            public_key._certificate.signature_hash_algorithm  # pylint: disable=protected-access
        if isinstance(self._certificate.public_key(), cryptography_rsa.RSAPublicKey):
            verify_args['padding'] = cryptography_padding.PKCS1v15()
            verify_args['algorithm'] = public_key_signature_hash_algorithm
        if isinstance(self._certificate.public_key(), cryptography_ec.EllipticCurvePublicKey):
            verify_args['signature_algorithm'] = cryptography_ec.ECDSA(
                public_key_signature_hash_algorithm
            )
        else:
            verify_args['algorithm'] = public_key_signature_hash_algorithm

        try:
            self._certificate.public_key().verify(**verify_args)
        except cryptography.exceptions.InvalidSignature:
            return False

        return True

    def _asdict(self):
        return OrderedDict([
            ('serial_number', str(self._certificate.serial_number)),
+55 −3
Original line number Diff line number Diff line
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import cryptography.x509 as cryptography_x509  # pylint: disable=import-error
from cryptography.hazmat.backends import default_backend as cryptography_default_backend  # pylint: disable=import-error
from collections import OrderedDict

import cryptography.x509 as cryptography_x509

from cryptography.hazmat.backends import default_backend as cryptography_default_backend

from cryptoparser.common.base import Serializable
from cryptoparser.tls.subprotocol import TlsHandshakeType, TlsAlertDescription
@@ -23,6 +26,47 @@ class TlsCertificateChain(Serializable): # pylint: disable=too-few-public-metho
    def __init__(self, certificate_bytes, certificate_chain):
        self._certificate_bytes = certificate_bytes
        self.items = certificate_chain
        self.verified = None

        ordered_certificate_chain = [cert for cert in certificate_chain if not cert.is_ca]

        while True:
            try:
                next_certificate = self._get_issuer(ordered_certificate_chain[-1])
                ordered_certificate_chain.append(next_certificate)
            except StopIteration:
                break

        if len(ordered_certificate_chain) > 1:
            self.ordered = self.items == ordered_certificate_chain
            self.items = ordered_certificate_chain

            for chain_index in range(len(self.items) - 1):
                issuer_public_key = self.items[chain_index + 1]
                cert_to_check = self.items[chain_index]

                if not issuer_public_key.verify(cert_to_check):
                    break
            else:
                self.verified = True
        else:
            self.ordered = None
            self.verified = None

    def _get_issuer(self, certificate):
        issuer_certificates = [
            issuer_certificate
            for issuer_certificate in self.items
            if issuer_certificate.is_ca and issuer_certificate.subject == certificate.issuer
        ]
        if len(issuer_certificates) == 1 and certificate != issuer_certificates[0]:
            return issuer_certificates[0]

        raise StopIteration()

    @property
    def contains_anchor(self):
        return any([cert.is_self_signed for cert in self.items])

    def __hash__(self):
        return hash(tuple([bytes(certificate_byte) for certificate_byte in self._certificate_bytes]))
@@ -30,6 +74,14 @@ class TlsCertificateChain(Serializable): # pylint: disable=too-few-public-metho
    def __eq__(self, other):
        return hash(self) == hash(other)

    def _asdict(self):
        return OrderedDict([
            ('items_chain', self.items),
            ('ordered', self.ordered),
            ('verified', self.verified),
            ('contains_anchor', self.contains_anchor),
        ])


class TlsPublicKey(Serializable):
    def __init__(self, sni_sent, subject_matches, tls_certificate_chain):
@@ -103,7 +155,7 @@ class AnalyzerPublicKeys(AnalyzerTlsBase):
            else:
                sni_sent = not isinstance(client_hello, TlsHandshakeClientHelloBasic)
                certificate_chain = self._get_tls_certificate_chain(server_messages)
                leaf_certificate = [cert for cert in certificate_chain.items if not cert.is_ca][0]
                leaf_certificate = certificate_chain.items[0]
                subject_matches = cryptolyzer.common.x509.is_subject_matches(
                    leaf_certificate.common_names,
                    leaf_certificate.subject_alternative_names,
+41 −0
Original line number Diff line number Diff line
@@ -27,3 +27,44 @@ class TestTlsPubKeys(unittest.TestCase):
    def test_fallback_certificate(self):
        result = self.get_result('unexisting-hostname-to-get-wildcard-certificate-without-sni.badssl.com', 443)
        self.assertEqual(len(result.pubkeys), 1)

    def test_certificate_chain(self):
        result = self.get_result('badssl.com', 443)
        self.assertEqual(len(result.pubkeys), 1)

        trusted_root_chain = result.pubkeys[0].certificate_chain
        self.assertEqual(len(trusted_root_chain.items), 2)
        self.assertFalse(trusted_root_chain.contains_anchor)
        self.assertTrue(trusted_root_chain.ordered)
        self.assertTrue(trusted_root_chain.verified)

        result = self.get_result('self-signed.badssl.com', 443)
        self.assertEqual(len(result.pubkeys), 1)

        self_signed_chain = result.pubkeys[0].certificate_chain
        self.assertEqual(len(self_signed_chain.items), 1)
        self.assertTrue(self_signed_chain.contains_anchor)
        self.assertEqual(self_signed_chain.ordered, None)
        self.assertEqual(self_signed_chain.verified, None)

        result = self.get_result('untrusted-root.badssl.com', 443)
        self.assertEqual(len(result.pubkeys), 1)

        untrusted_root_chain = result.pubkeys[0].certificate_chain
        self.assertEqual(len(untrusted_root_chain.items), 2)
        self.assertTrue(untrusted_root_chain.contains_anchor)
        self.assertTrue(untrusted_root_chain.ordered)
        self.assertTrue(untrusted_root_chain.verified)

        self.assertNotEqual(self_signed_chain.items[0], untrusted_root_chain.items[1])

        result = self.get_result('incomplete-chain.badssl.com', 443)
        self.assertEqual(len(result.pubkeys), 1)

        incomplete_chain = result.pubkeys[0].certificate_chain
        self.assertEqual(len(incomplete_chain.items), 1)
        self.assertFalse(incomplete_chain.contains_anchor)
        self.assertEqual(incomplete_chain.ordered, None)
        self.assertEqual(incomplete_chain.verified, None)

        self.assertEqual(trusted_root_chain.items[0], incomplete_chain.items[0])