From 9fd1006f2dda9296c0bf55be5b44c05f6f30ca0f Mon Sep 17 00:00:00 2001 From: TheKysek Date: Wed, 3 Aug 2016 19:05:01 +0200 Subject: [PATCH] Use non-blocking sockets --- src/connection.py | 87 +++++++++++++++++++++++++++++------------------ 1 file changed, 54 insertions(+), 33 deletions(-) diff --git a/src/connection.py b/src/connection.py index 7113a85..02d3e69 100644 --- a/src/connection.py +++ b/src/connection.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +import errno import logging import os import random @@ -41,9 +42,12 @@ class Connection(threading.Thread): if self.server: self.status = 'connected' - self.buffer = b'' + self.buffer_receive = b'' + self.buffer_send = b'' + self.next_message_size = shared.header_length self.next_header = True + self.on_connection_fully_established_scheduled = False self.last_message_received = time.time() self.last_message_sent = time.time() @@ -53,29 +57,44 @@ class Connection(threading.Thread): self._connect() if self.status != 'connected': return - self.s.settimeout(0.5) + self.s.settimeout(0) if not self.server: self.send_queue.put(message.Version(self.host, self.port)) while True: + time.sleep(0.3) data = True try: - data = self.s.recv(1024) - self.buffer += data - except socket.timeout or ssl.SSLError: - if time.time() - self.last_message_received > shared.timeout: - logging.debug('Disconnecting from {};{}. Reason: time.time() - self.last_message_received > shared.timeout'.format(self.host, self.port)) - data = None - if time.time() - self.last_message_received > 30 and self.status != 'fully_established': - logging.debug('Disconnecting from {};{}. Reason: time.time() - self.last_message_received > 30 and self.status != \'fully_established\''.format(self.host, self.port)) - data = None - if time.time() - self.last_message_sent > 300 and self.status == 'fully_established': - self.send_queue.put(message.Message(b'pong', b'')) + data = self.s.recv(self.next_message_size - len(self.buffer_receive)) + self.buffer_receive += data + except ssl.SSLWantReadError: + pass + except socket.error as e: + err = e.args[0] + if err == errno.EAGAIN or err == errno.EWOULDBLOCK: + pass + else: + raise except ConnectionResetError: - logging.debug('Disconnecting from {};{}. Reason: ConnectionResetError'.format(self.host, self.port)) + logging.debug('Disconnecting from {}:{}. Reason: ConnectionResetError'.format(self.host, self.port)) data = None - self._process_buffer() + self._process_buffer_receive() self._request_objects() self._process_queue() + self._send_data() + if time.time() - self.last_message_received > shared.timeout: + logging.debug( + 'Disconnecting from {}:{}. Reason: time.time() - self.last_message_received > shared.timeout'.format( + self.host, self.port)) + data = None + if time.time() - self.last_message_received > 30 and self.status != 'fully_established': + logging.debug( + 'Disconnecting from {}:{}. Reason: time.time() - self.last_message_received > 30 and self.status != \'fully_established\''.format( + self.host, self.port)) + data = None + if time.time() - self.last_message_sent > 300 and self.status == 'fully_established': + self.send_queue.put(message.Message(b'pong', b'')) + if self.on_connection_fully_established_scheduled and not (self.buffer_send or self.buffer_receive): + self._on_connection_fully_established() if self.status == 'disconnecting': data = None if not data: @@ -97,28 +116,32 @@ class Connection(threading.Thread): self.status = 'failed' + def _send_data(self): + try: + amount = self.s.send(self.buffer_send[:1000]) + self.buffer_send = self.buffer_send[amount:] + except (BlockingIOError, ssl.SSLWantWriteError): + pass + def _do_tls_handshake(self): - self.s.settimeout(30) logging.debug('Initializing TLS connection with {}:{}'.format(self.host, self.port)) self.s = ssl.wrap_socket(self.s, keyfile=os.path.join(shared.source_directory, 'tls', 'key.pem'), certfile=os.path.join(shared.source_directory, 'tls', 'cert.pem'), server_side=self.server, ssl_version=ssl.PROTOCOL_TLSv1, do_handshake_on_connect=False, - ciphers='AECDH-AES256-SHA') + ciphers='AECDH-AES256-SHA', suppress_ragged_eofs=True) if hasattr(self.s, "context"): self.s.context.set_ecdh_curve("secp256k1") while True: try: self.s.do_handshake() break - except ssl.SSLError as e: - if e.errno == 2: - select.select([self.s], [self.s], []) - else: - break + except ssl.SSLWantReadError: + select.select([self.s], [], []) + except ssl.SSLWantWriteError: + select.select([], [self.s], []) except Exception as e: logging.error(e) break - self.s.settimeout(0.5) self.tls = True logging.debug('Established TLS connection with {}:{}'.format(self.host, self.port)) @@ -127,15 +150,13 @@ class Connection(threading.Thread): logging.debug('{}:{} <- {}'.format(self.host, self.port, structure.Object.from_message(m))) else: logging.debug('{}:{} <- {}'.format(self.host, self.port, m)) - self.s.settimeout(60) - self.s.sendall(m.to_bytes()) - self.s.settimeout(0.5) + self.buffer_send += m.to_bytes() def _on_connection_fully_established(self): logging.info('Established Bitmessage protocol connection to {}:{}'.format(self.host, self.port)) + self.on_connection_fully_established_scheduled = False if self.remote_version.services & 2: # NODE_SSL self._do_tls_handshake() - time.sleep(5) with shared.objects_lock: if len(shared.objects) > 0: self.send_queue.put(message.Inv({vector for vector in shared.objects.keys() if shared.objects[vector].expires_time > time.time()})) @@ -149,7 +170,7 @@ class Connection(threading.Thread): m = self.send_queue.get() if m: if m == 'fully_established': - self._on_connection_fully_established() + self.on_connection_fully_established_scheduled = True else: self._send_message(m) self.last_message_sent = time.time() @@ -157,17 +178,17 @@ class Connection(threading.Thread): self.status = 'disconnecting' break - def _process_buffer(self): - while len(self.buffer) >= self.next_message_size: + def _process_buffer_receive(self): + while len(self.buffer_receive) >= self.next_message_size: if self.next_header: self.next_header = False - h = message.Header.from_bytes(self.buffer[:shared.header_length]) + h = message.Header.from_bytes(self.buffer_receive[:shared.header_length]) self.next_message_size += h.payload_length else: - m = message.Message.from_bytes(self.buffer[:self.next_message_size]) + m = message.Message.from_bytes(self.buffer_receive[:self.next_message_size]) self.next_header = True - self.buffer = self.buffer[self.next_message_size:] + self.buffer_receive = self.buffer_receive[self.next_message_size:] self.next_message_size = shared.header_length self.last_message_received = time.time() self._process_message(m)