Use non-blocking sockets

This commit is contained in:
TheKysek 2016-08-03 19:05:01 +02:00
parent e7dcec142a
commit 9fd1006f2d

View File

@ -1,4 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import errno
import logging import logging
import os import os
import random import random
@ -41,9 +42,12 @@ class Connection(threading.Thread):
if self.server: if self.server:
self.status = 'connected' self.status = 'connected'
self.buffer = b'' self.buffer_receive = b''
self.buffer_send = b''
self.next_message_size = shared.header_length self.next_message_size = shared.header_length
self.next_header = True self.next_header = True
self.on_connection_fully_established_scheduled = False
self.last_message_received = time.time() self.last_message_received = time.time()
self.last_message_sent = time.time() self.last_message_sent = time.time()
@ -53,29 +57,44 @@ class Connection(threading.Thread):
self._connect() self._connect()
if self.status != 'connected': if self.status != 'connected':
return return
self.s.settimeout(0.5) self.s.settimeout(0)
if not self.server: if not self.server:
self.send_queue.put(message.Version(self.host, self.port)) self.send_queue.put(message.Version(self.host, self.port))
while True: while True:
time.sleep(0.3)
data = True data = True
try: try:
data = self.s.recv(1024) data = self.s.recv(self.next_message_size - len(self.buffer_receive))
self.buffer += data self.buffer_receive += data
except socket.timeout or ssl.SSLError: except ssl.SSLWantReadError:
if time.time() - self.last_message_received > shared.timeout: pass
logging.debug('Disconnecting from {};{}. Reason: time.time() - self.last_message_received > shared.timeout'.format(self.host, self.port)) except socket.error as e:
data = None err = e.args[0]
if time.time() - self.last_message_received > 30 and self.status != 'fully_established': if err == errno.EAGAIN or err == errno.EWOULDBLOCK:
logging.debug('Disconnecting from {};{}. Reason: time.time() - self.last_message_received > 30 and self.status != \'fully_established\''.format(self.host, self.port)) pass
data = None else:
if time.time() - self.last_message_sent > 300 and self.status == 'fully_established': raise
self.send_queue.put(message.Message(b'pong', b''))
except ConnectionResetError: except ConnectionResetError:
logging.debug('Disconnecting from {};{}. Reason: ConnectionResetError'.format(self.host, self.port)) logging.debug('Disconnecting from {}:{}. Reason: ConnectionResetError'.format(self.host, self.port))
data = None data = None
self._process_buffer() self._process_buffer_receive()
self._request_objects() self._request_objects()
self._process_queue() self._process_queue()
self._send_data()
if time.time() - self.last_message_received > shared.timeout:
logging.debug(
'Disconnecting from {}:{}. Reason: time.time() - self.last_message_received > shared.timeout'.format(
self.host, self.port))
data = None
if time.time() - self.last_message_received > 30 and self.status != 'fully_established':
logging.debug(
'Disconnecting from {}:{}. Reason: time.time() - self.last_message_received > 30 and self.status != \'fully_established\''.format(
self.host, self.port))
data = None
if time.time() - self.last_message_sent > 300 and self.status == 'fully_established':
self.send_queue.put(message.Message(b'pong', b''))
if self.on_connection_fully_established_scheduled and not (self.buffer_send or self.buffer_receive):
self._on_connection_fully_established()
if self.status == 'disconnecting': if self.status == 'disconnecting':
data = None data = None
if not data: if not data:
@ -97,28 +116,32 @@ class Connection(threading.Thread):
self.status = 'failed' self.status = 'failed'
def _send_data(self):
try:
amount = self.s.send(self.buffer_send[:1000])
self.buffer_send = self.buffer_send[amount:]
except (BlockingIOError, ssl.SSLWantWriteError):
pass
def _do_tls_handshake(self): def _do_tls_handshake(self):
self.s.settimeout(30)
logging.debug('Initializing TLS connection with {}:{}'.format(self.host, self.port)) logging.debug('Initializing TLS connection with {}:{}'.format(self.host, self.port))
self.s = ssl.wrap_socket(self.s, keyfile=os.path.join(shared.source_directory, 'tls', 'key.pem'), self.s = ssl.wrap_socket(self.s, keyfile=os.path.join(shared.source_directory, 'tls', 'key.pem'),
certfile=os.path.join(shared.source_directory, 'tls', 'cert.pem'), certfile=os.path.join(shared.source_directory, 'tls', 'cert.pem'),
server_side=self.server, ssl_version=ssl.PROTOCOL_TLSv1, do_handshake_on_connect=False, server_side=self.server, ssl_version=ssl.PROTOCOL_TLSv1, do_handshake_on_connect=False,
ciphers='AECDH-AES256-SHA') ciphers='AECDH-AES256-SHA', suppress_ragged_eofs=True)
if hasattr(self.s, "context"): if hasattr(self.s, "context"):
self.s.context.set_ecdh_curve("secp256k1") self.s.context.set_ecdh_curve("secp256k1")
while True: while True:
try: try:
self.s.do_handshake() self.s.do_handshake()
break break
except ssl.SSLError as e: except ssl.SSLWantReadError:
if e.errno == 2: select.select([self.s], [], [])
select.select([self.s], [self.s], []) except ssl.SSLWantWriteError:
else: select.select([], [self.s], [])
break
except Exception as e: except Exception as e:
logging.error(e) logging.error(e)
break break
self.s.settimeout(0.5)
self.tls = True self.tls = True
logging.debug('Established TLS connection with {}:{}'.format(self.host, self.port)) logging.debug('Established TLS connection with {}:{}'.format(self.host, self.port))
@ -127,15 +150,13 @@ class Connection(threading.Thread):
logging.debug('{}:{} <- {}'.format(self.host, self.port, structure.Object.from_message(m))) logging.debug('{}:{} <- {}'.format(self.host, self.port, structure.Object.from_message(m)))
else: else:
logging.debug('{}:{} <- {}'.format(self.host, self.port, m)) logging.debug('{}:{} <- {}'.format(self.host, self.port, m))
self.s.settimeout(60) self.buffer_send += m.to_bytes()
self.s.sendall(m.to_bytes())
self.s.settimeout(0.5)
def _on_connection_fully_established(self): def _on_connection_fully_established(self):
logging.info('Established Bitmessage protocol connection to {}:{}'.format(self.host, self.port)) logging.info('Established Bitmessage protocol connection to {}:{}'.format(self.host, self.port))
self.on_connection_fully_established_scheduled = False
if self.remote_version.services & 2: # NODE_SSL if self.remote_version.services & 2: # NODE_SSL
self._do_tls_handshake() self._do_tls_handshake()
time.sleep(5)
with shared.objects_lock: with shared.objects_lock:
if len(shared.objects) > 0: if len(shared.objects) > 0:
self.send_queue.put(message.Inv({vector for vector in shared.objects.keys() if shared.objects[vector].expires_time > time.time()})) self.send_queue.put(message.Inv({vector for vector in shared.objects.keys() if shared.objects[vector].expires_time > time.time()}))
@ -149,7 +170,7 @@ class Connection(threading.Thread):
m = self.send_queue.get() m = self.send_queue.get()
if m: if m:
if m == 'fully_established': if m == 'fully_established':
self._on_connection_fully_established() self.on_connection_fully_established_scheduled = True
else: else:
self._send_message(m) self._send_message(m)
self.last_message_sent = time.time() self.last_message_sent = time.time()
@ -157,17 +178,17 @@ class Connection(threading.Thread):
self.status = 'disconnecting' self.status = 'disconnecting'
break break
def _process_buffer(self): def _process_buffer_receive(self):
while len(self.buffer) >= self.next_message_size: while len(self.buffer_receive) >= self.next_message_size:
if self.next_header: if self.next_header:
self.next_header = False self.next_header = False
h = message.Header.from_bytes(self.buffer[:shared.header_length]) h = message.Header.from_bytes(self.buffer_receive[:shared.header_length])
self.next_message_size += h.payload_length self.next_message_size += h.payload_length
else: else:
m = message.Message.from_bytes(self.buffer[:self.next_message_size]) m = message.Message.from_bytes(self.buffer_receive[:self.next_message_size])
self.next_header = True self.next_header = True
self.buffer = self.buffer[self.next_message_size:] self.buffer_receive = self.buffer_receive[self.next_message_size:]
self.next_message_size = shared.header_length self.next_message_size = shared.header_length
self.last_message_received = time.time() self.last_message_received = time.time()
self._process_message(m) self._process_message(m)