Commit 92bc19e5 authored by Szilárd Pfeiffer's avatar Szilárd Pfeiffer
Browse files

tls: Handle the case when SSL handshake replied by TLS alert message

parent 323a1d5e
Loading
Loading
Loading
Loading
Loading
+70 −41
Original line number Diff line number Diff line
@@ -16,7 +16,8 @@ from cryptoparser.common.utils import get_leaf_classes

from cryptoparser.tls.ciphersuite import TlsCipherSuite, SslCipherKind
from cryptoparser.tls.subprotocol import SslMessageType, SslHandshakeClientHello
from cryptoparser.tls.subprotocol import TlsHandshakeClientHello, TlsCipherSuiteVector, TlsContentType, TlsHandshakeType
from cryptoparser.tls.subprotocol import TlsHandshakeClientHello
from cryptoparser.tls.subprotocol import TlsCipherSuiteVector, TlsContentType, TlsHandshakeType
from cryptoparser.tls.subprotocol import TlsAlertLevel, TlsAlertDescription
from cryptoparser.tls.extension import TlsExtensionServerName
from cryptoparser.tls.extension import TlsExtensionSignatureAlgorithms, TlsSignatureAndHashAlgorithm
@@ -139,6 +140,7 @@ class L7Client(object):
        self._host = host
        self._port = port
        self._socket = None
        self._buffer = bytearray()

    def _do_handshake(
            self,
@@ -193,20 +195,17 @@ class L7Client(object):
            total_sent_byte_num = total_sent_byte_num + actual_sent_byte_num

    def receive(self, receivable_byte_num):
        total_received_bytes = bytearray()

        while len(total_received_bytes) < receivable_byte_num:
        total_received_byte_num = 0
        while total_received_byte_num < receivable_byte_num:
            try:
                actual_received_bytes = self._socket.recv(min(receivable_byte_num - len(total_received_bytes), 1024))
                actual_received_bytes = self._socket.recv(min(receivable_byte_num - total_received_byte_num, 1024))
                self._buffer += actual_received_bytes
                total_received_byte_num += len(actual_received_bytes)
            except socket.error:
                actual_received_bytes = None

            if not actual_received_bytes:
                raise NotEnoughData(receivable_byte_num - len(total_received_bytes))

            total_received_bytes += actual_received_bytes

        return total_received_bytes
                raise NotEnoughData(receivable_byte_num - total_received_byte_num)

    @property
    def host(self):
@@ -216,6 +215,16 @@ class L7Client(object):
    def port(self):
        return self._port

    @property
    def buffer(self):
        return bytearray(self._buffer)

    def flush_buffer(self, byte_num=None):
        if byte_num is None:
            byte_num = len(self._buffer)

        self._buffer = self._buffer[byte_num:]

    @classmethod
    def from_scheme(cls, scheme, host, port=None):
        for client_class in get_leaf_classes(L7Client):
@@ -377,6 +386,8 @@ class TlsAlert(ValueError):
class TlsClient(object):
    def __init__(self, l4_client):
        self._l4_client = l4_client
        self._last_processed_message_type = None
        self.server_messages = {}

    @abc.abstractmethod
    def do_handshake(self, hello_message, protocol_version, last_handshake_message_type):
@@ -384,20 +395,31 @@ class TlsClient(object):


class TlsClientHandshake(TlsClient):
    def _process_message(self, handshake_message, protocol_version):
        handshake_type = handshake_message.get_handshake_type()
        if handshake_type in self.server_messages:
            raise TlsAlert(TlsAlertDescription.UNEXPECTED_MESSAGE)
        if (handshake_type == TlsHandshakeType.SERVER_HELLO and
                handshake_message.protocol_version != protocol_version):
            raise TlsAlert(TlsAlertDescription.PROTOCOL_VERSION)

    def do_handshake(
            self,
            hello_message,
            protocol_version=TlsProtocolVersionFinal(TlsVersion.TLS1_0),
            last_handshake_message_type=TlsHandshakeType.SERVER_HELLO_DONE
    ):
        self.server_messages = {}
        self._last_processed_message_type = None

        tls_record = TlsRecord([hello_message, ], protocol_version)
        self._l4_client.send(tls_record.compose())

        server_messages = {}
        received_bytes = bytearray()
        while True:
            try:
                record = TlsRecord.parse_mutable(received_bytes)
                record = TlsRecord.parse_exact_size(self._l4_client.buffer)
                self._l4_client.flush_buffer()

                if record.content_type == TlsContentType.ALERT:
                    if record.messages[0].level == TlsAlertLevel.FATAL:
                        raise TlsAlert(record.messages[0].description)
@@ -407,30 +429,24 @@ class TlsClientHandshake(TlsClient):
                    raise TlsAlert(TlsAlertDescription.UNEXPECTED_MESSAGE)

                for handshake_message in record.messages:
                    handshake_type = handshake_message.get_handshake_type()
                    if handshake_type in server_messages:
                        raise TlsAlert(TlsAlertDescription.UNEXPECTED_MESSAGE)
                    if (handshake_type == TlsHandshakeType.SERVER_HELLO and
                            handshake_message.protocol_version != protocol_version):
                        raise TlsAlert(TlsAlertDescription.PROTOCOL_VERSION)

                    server_messages[handshake_message.get_handshake_type()] = handshake_message
                    self._process_message(handshake_message, protocol_version)
                    self._last_processed_message_type = handshake_message.get_handshake_type()
                    self.server_messages[self._last_processed_message_type] = handshake_message

                    if handshake_message.get_handshake_type() == last_handshake_message_type:
                        return server_messages
                    if self._last_processed_message_type == last_handshake_message_type:
                        return

                receivable_byte_num = 0
            except NotEnoughData as e:
                receivable_byte_num = e.bytes_needed

            try:
                actual_received_bytes = self._l4_client.receive(receivable_byte_num)
                self._l4_client.receive(receivable_byte_num)
            except NotEnoughData:
                if received_bytes:
                if self._l4_client.buffer:
                    raise NetworkError(NetworkErrorType.NO_CONNECTION)

                raise NetworkError(NetworkErrorType.NO_RESPONSE)
            received_bytes += actual_received_bytes


class SslError(ValueError):
@@ -457,28 +473,41 @@ class SslClientHandshake(TlsClient):
        ssl_record = SslRecord(hello_message)
        self._l4_client.send(ssl_record.compose())

        server_messages = {}
        received_bytes = bytearray()
        self.server_messages = {}
        while True:
            try:
                record = SslRecord.parse_mutable(received_bytes)
                message = record.message
                if message.get_message_type() == SslMessageType.ERROR:
                    raise SslError(message.get_message_type())
                record = SslRecord.parse_exact_size(self._l4_client.buffer)
                self._l4_client.flush_buffer()
                if record.message.get_message_type() == SslMessageType.ERROR:
                    raise SslError(record.message.error_type)

                server_messages[message.get_message_type()] = message
                if message.get_message_type() == last_handshake_message_type:
                    return server_messages
                self._last_processed_message_type = record.message.get_message_type()
                self.server_messages[self._last_processed_message_type] = record.message
                if self._last_processed_message_type == last_handshake_message_type:
                    break

                receivable_byte_num = 0
            except NotEnoughData as e:
                receivable_byte_num = e.bytes_needed

            try:
                actual_received_bytes = self._l4_client.receive(receivable_byte_num)
                self._l4_client.receive(receivable_byte_num)
            except NotEnoughData:
                if received_bytes:
                if self._l4_client.buffer:
                    try:
                        tls_record = TlsRecord.parse_exact_size(self._l4_client.buffer)
                        self._l4_client.flush_buffer()
                    except ValueError:
                        raise NetworkError(NetworkErrorType.NO_CONNECTION)
                    else:
                        if (tls_record.content_type == TlsContentType.ALERT and
                                (tls_record.messages[0].description in [
                                    TlsAlertDescription.PROTOCOL_VERSION,
                                    TlsAlertDescription.HANDSHAKE_FAILURE,
                                    TlsAlertDescription.INTERNAL_ERROR,
                                ])):
                            raise NetworkError(NetworkErrorType.NO_RESPONSE)

                        raise NetworkError(NetworkErrorType.NO_CONNECTION)
                else:
                    raise NetworkError(NetworkErrorType.NO_RESPONSE)
            received_bytes += actual_received_bytes