Loading cryptoparser/tls/client.py +70 −41 Original line number Diff line number Diff line Loading @@ -16,7 +16,8 @@ from cryptoparser.common.utils import get_leaf_classes from cryptoparser.tls.ciphersuite import TlsCipherSuite, SslCipherKind from cryptoparser.tls.subprotocol import SslMessageType, SslHandshakeClientHello from cryptoparser.tls.subprotocol import TlsHandshakeClientHello, TlsCipherSuiteVector, TlsContentType, TlsHandshakeType from cryptoparser.tls.subprotocol import TlsHandshakeClientHello from cryptoparser.tls.subprotocol import TlsCipherSuiteVector, TlsContentType, TlsHandshakeType from cryptoparser.tls.subprotocol import TlsAlertLevel, TlsAlertDescription from cryptoparser.tls.extension import TlsExtensionServerName from cryptoparser.tls.extension import TlsExtensionSignatureAlgorithms, TlsSignatureAndHashAlgorithm Loading Loading @@ -139,6 +140,7 @@ class L7Client(object): self._host = host self._port = port self._socket = None self._buffer = bytearray() def _do_handshake( self, Loading Loading @@ -193,20 +195,17 @@ class L7Client(object): total_sent_byte_num = total_sent_byte_num + actual_sent_byte_num def receive(self, receivable_byte_num): total_received_bytes = bytearray() while len(total_received_bytes) < receivable_byte_num: total_received_byte_num = 0 while total_received_byte_num < receivable_byte_num: try: actual_received_bytes = self._socket.recv(min(receivable_byte_num - len(total_received_bytes), 1024)) actual_received_bytes = self._socket.recv(min(receivable_byte_num - total_received_byte_num, 1024)) self._buffer += actual_received_bytes total_received_byte_num += len(actual_received_bytes) except socket.error: actual_received_bytes = None if not actual_received_bytes: raise NotEnoughData(receivable_byte_num - len(total_received_bytes)) total_received_bytes += actual_received_bytes return total_received_bytes raise NotEnoughData(receivable_byte_num - total_received_byte_num) @property def host(self): Loading @@ -216,6 +215,16 @@ class L7Client(object): def port(self): return self._port @property def buffer(self): return bytearray(self._buffer) def flush_buffer(self, byte_num=None): if byte_num is None: byte_num = len(self._buffer) self._buffer = self._buffer[byte_num:] @classmethod def from_scheme(cls, scheme, host, port=None): for client_class in get_leaf_classes(L7Client): Loading Loading @@ -377,6 +386,8 @@ class TlsAlert(ValueError): class TlsClient(object): def __init__(self, l4_client): self._l4_client = l4_client self._last_processed_message_type = None self.server_messages = {} @abc.abstractmethod def do_handshake(self, hello_message, protocol_version, last_handshake_message_type): Loading @@ -384,20 +395,31 @@ class TlsClient(object): class TlsClientHandshake(TlsClient): def _process_message(self, handshake_message, protocol_version): handshake_type = handshake_message.get_handshake_type() if handshake_type in self.server_messages: raise TlsAlert(TlsAlertDescription.UNEXPECTED_MESSAGE) if (handshake_type == TlsHandshakeType.SERVER_HELLO and handshake_message.protocol_version != protocol_version): raise TlsAlert(TlsAlertDescription.PROTOCOL_VERSION) def do_handshake( self, hello_message, protocol_version=TlsProtocolVersionFinal(TlsVersion.TLS1_0), last_handshake_message_type=TlsHandshakeType.SERVER_HELLO_DONE ): self.server_messages = {} self._last_processed_message_type = None tls_record = TlsRecord([hello_message, ], protocol_version) self._l4_client.send(tls_record.compose()) server_messages = {} received_bytes = bytearray() while True: try: record = TlsRecord.parse_mutable(received_bytes) record = TlsRecord.parse_exact_size(self._l4_client.buffer) self._l4_client.flush_buffer() if record.content_type == TlsContentType.ALERT: if record.messages[0].level == TlsAlertLevel.FATAL: raise TlsAlert(record.messages[0].description) Loading @@ -407,30 +429,24 @@ class TlsClientHandshake(TlsClient): raise TlsAlert(TlsAlertDescription.UNEXPECTED_MESSAGE) for handshake_message in record.messages: handshake_type = handshake_message.get_handshake_type() if handshake_type in server_messages: raise TlsAlert(TlsAlertDescription.UNEXPECTED_MESSAGE) if (handshake_type == TlsHandshakeType.SERVER_HELLO and handshake_message.protocol_version != protocol_version): raise TlsAlert(TlsAlertDescription.PROTOCOL_VERSION) server_messages[handshake_message.get_handshake_type()] = handshake_message self._process_message(handshake_message, protocol_version) self._last_processed_message_type = handshake_message.get_handshake_type() self.server_messages[self._last_processed_message_type] = handshake_message if handshake_message.get_handshake_type() == last_handshake_message_type: return server_messages if self._last_processed_message_type == last_handshake_message_type: return receivable_byte_num = 0 except NotEnoughData as e: receivable_byte_num = e.bytes_needed try: actual_received_bytes = self._l4_client.receive(receivable_byte_num) self._l4_client.receive(receivable_byte_num) except NotEnoughData: if received_bytes: if self._l4_client.buffer: raise NetworkError(NetworkErrorType.NO_CONNECTION) raise NetworkError(NetworkErrorType.NO_RESPONSE) received_bytes += actual_received_bytes class SslError(ValueError): Loading @@ -457,28 +473,41 @@ class SslClientHandshake(TlsClient): ssl_record = SslRecord(hello_message) self._l4_client.send(ssl_record.compose()) server_messages = {} received_bytes = bytearray() self.server_messages = {} while True: try: record = SslRecord.parse_mutable(received_bytes) message = record.message if message.get_message_type() == SslMessageType.ERROR: raise SslError(message.get_message_type()) record = SslRecord.parse_exact_size(self._l4_client.buffer) self._l4_client.flush_buffer() if record.message.get_message_type() == SslMessageType.ERROR: raise SslError(record.message.error_type) server_messages[message.get_message_type()] = message if message.get_message_type() == last_handshake_message_type: return server_messages self._last_processed_message_type = record.message.get_message_type() self.server_messages[self._last_processed_message_type] = record.message if self._last_processed_message_type == last_handshake_message_type: break receivable_byte_num = 0 except NotEnoughData as e: receivable_byte_num = e.bytes_needed try: actual_received_bytes = self._l4_client.receive(receivable_byte_num) self._l4_client.receive(receivable_byte_num) except NotEnoughData: if received_bytes: if self._l4_client.buffer: try: tls_record = TlsRecord.parse_exact_size(self._l4_client.buffer) self._l4_client.flush_buffer() except ValueError: raise NetworkError(NetworkErrorType.NO_CONNECTION) else: if (tls_record.content_type == TlsContentType.ALERT and (tls_record.messages[0].description in [ TlsAlertDescription.PROTOCOL_VERSION, TlsAlertDescription.HANDSHAKE_FAILURE, TlsAlertDescription.INTERNAL_ERROR, ])): raise NetworkError(NetworkErrorType.NO_RESPONSE) raise NetworkError(NetworkErrorType.NO_CONNECTION) else: raise NetworkError(NetworkErrorType.NO_RESPONSE) received_bytes += actual_received_bytes Loading
cryptoparser/tls/client.py +70 −41 Original line number Diff line number Diff line Loading @@ -16,7 +16,8 @@ from cryptoparser.common.utils import get_leaf_classes from cryptoparser.tls.ciphersuite import TlsCipherSuite, SslCipherKind from cryptoparser.tls.subprotocol import SslMessageType, SslHandshakeClientHello from cryptoparser.tls.subprotocol import TlsHandshakeClientHello, TlsCipherSuiteVector, TlsContentType, TlsHandshakeType from cryptoparser.tls.subprotocol import TlsHandshakeClientHello from cryptoparser.tls.subprotocol import TlsCipherSuiteVector, TlsContentType, TlsHandshakeType from cryptoparser.tls.subprotocol import TlsAlertLevel, TlsAlertDescription from cryptoparser.tls.extension import TlsExtensionServerName from cryptoparser.tls.extension import TlsExtensionSignatureAlgorithms, TlsSignatureAndHashAlgorithm Loading Loading @@ -139,6 +140,7 @@ class L7Client(object): self._host = host self._port = port self._socket = None self._buffer = bytearray() def _do_handshake( self, Loading Loading @@ -193,20 +195,17 @@ class L7Client(object): total_sent_byte_num = total_sent_byte_num + actual_sent_byte_num def receive(self, receivable_byte_num): total_received_bytes = bytearray() while len(total_received_bytes) < receivable_byte_num: total_received_byte_num = 0 while total_received_byte_num < receivable_byte_num: try: actual_received_bytes = self._socket.recv(min(receivable_byte_num - len(total_received_bytes), 1024)) actual_received_bytes = self._socket.recv(min(receivable_byte_num - total_received_byte_num, 1024)) self._buffer += actual_received_bytes total_received_byte_num += len(actual_received_bytes) except socket.error: actual_received_bytes = None if not actual_received_bytes: raise NotEnoughData(receivable_byte_num - len(total_received_bytes)) total_received_bytes += actual_received_bytes return total_received_bytes raise NotEnoughData(receivable_byte_num - total_received_byte_num) @property def host(self): Loading @@ -216,6 +215,16 @@ class L7Client(object): def port(self): return self._port @property def buffer(self): return bytearray(self._buffer) def flush_buffer(self, byte_num=None): if byte_num is None: byte_num = len(self._buffer) self._buffer = self._buffer[byte_num:] @classmethod def from_scheme(cls, scheme, host, port=None): for client_class in get_leaf_classes(L7Client): Loading Loading @@ -377,6 +386,8 @@ class TlsAlert(ValueError): class TlsClient(object): def __init__(self, l4_client): self._l4_client = l4_client self._last_processed_message_type = None self.server_messages = {} @abc.abstractmethod def do_handshake(self, hello_message, protocol_version, last_handshake_message_type): Loading @@ -384,20 +395,31 @@ class TlsClient(object): class TlsClientHandshake(TlsClient): def _process_message(self, handshake_message, protocol_version): handshake_type = handshake_message.get_handshake_type() if handshake_type in self.server_messages: raise TlsAlert(TlsAlertDescription.UNEXPECTED_MESSAGE) if (handshake_type == TlsHandshakeType.SERVER_HELLO and handshake_message.protocol_version != protocol_version): raise TlsAlert(TlsAlertDescription.PROTOCOL_VERSION) def do_handshake( self, hello_message, protocol_version=TlsProtocolVersionFinal(TlsVersion.TLS1_0), last_handshake_message_type=TlsHandshakeType.SERVER_HELLO_DONE ): self.server_messages = {} self._last_processed_message_type = None tls_record = TlsRecord([hello_message, ], protocol_version) self._l4_client.send(tls_record.compose()) server_messages = {} received_bytes = bytearray() while True: try: record = TlsRecord.parse_mutable(received_bytes) record = TlsRecord.parse_exact_size(self._l4_client.buffer) self._l4_client.flush_buffer() if record.content_type == TlsContentType.ALERT: if record.messages[0].level == TlsAlertLevel.FATAL: raise TlsAlert(record.messages[0].description) Loading @@ -407,30 +429,24 @@ class TlsClientHandshake(TlsClient): raise TlsAlert(TlsAlertDescription.UNEXPECTED_MESSAGE) for handshake_message in record.messages: handshake_type = handshake_message.get_handshake_type() if handshake_type in server_messages: raise TlsAlert(TlsAlertDescription.UNEXPECTED_MESSAGE) if (handshake_type == TlsHandshakeType.SERVER_HELLO and handshake_message.protocol_version != protocol_version): raise TlsAlert(TlsAlertDescription.PROTOCOL_VERSION) server_messages[handshake_message.get_handshake_type()] = handshake_message self._process_message(handshake_message, protocol_version) self._last_processed_message_type = handshake_message.get_handshake_type() self.server_messages[self._last_processed_message_type] = handshake_message if handshake_message.get_handshake_type() == last_handshake_message_type: return server_messages if self._last_processed_message_type == last_handshake_message_type: return receivable_byte_num = 0 except NotEnoughData as e: receivable_byte_num = e.bytes_needed try: actual_received_bytes = self._l4_client.receive(receivable_byte_num) self._l4_client.receive(receivable_byte_num) except NotEnoughData: if received_bytes: if self._l4_client.buffer: raise NetworkError(NetworkErrorType.NO_CONNECTION) raise NetworkError(NetworkErrorType.NO_RESPONSE) received_bytes += actual_received_bytes class SslError(ValueError): Loading @@ -457,28 +473,41 @@ class SslClientHandshake(TlsClient): ssl_record = SslRecord(hello_message) self._l4_client.send(ssl_record.compose()) server_messages = {} received_bytes = bytearray() self.server_messages = {} while True: try: record = SslRecord.parse_mutable(received_bytes) message = record.message if message.get_message_type() == SslMessageType.ERROR: raise SslError(message.get_message_type()) record = SslRecord.parse_exact_size(self._l4_client.buffer) self._l4_client.flush_buffer() if record.message.get_message_type() == SslMessageType.ERROR: raise SslError(record.message.error_type) server_messages[message.get_message_type()] = message if message.get_message_type() == last_handshake_message_type: return server_messages self._last_processed_message_type = record.message.get_message_type() self.server_messages[self._last_processed_message_type] = record.message if self._last_processed_message_type == last_handshake_message_type: break receivable_byte_num = 0 except NotEnoughData as e: receivable_byte_num = e.bytes_needed try: actual_received_bytes = self._l4_client.receive(receivable_byte_num) self._l4_client.receive(receivable_byte_num) except NotEnoughData: if received_bytes: if self._l4_client.buffer: try: tls_record = TlsRecord.parse_exact_size(self._l4_client.buffer) self._l4_client.flush_buffer() except ValueError: raise NetworkError(NetworkErrorType.NO_CONNECTION) else: if (tls_record.content_type == TlsContentType.ALERT and (tls_record.messages[0].description in [ TlsAlertDescription.PROTOCOL_VERSION, TlsAlertDescription.HANDSHAKE_FAILURE, TlsAlertDescription.INTERNAL_ERROR, ])): raise NetworkError(NetworkErrorType.NO_RESPONSE) raise NetworkError(NetworkErrorType.NO_CONNECTION) else: raise NetworkError(NetworkErrorType.NO_RESPONSE) received_bytes += actual_received_bytes