diff --git a/minode/connection.py b/minode/connection.py index c74af5d..b7f9a75 100644 --- a/minode/connection.py +++ b/minode/connection.py @@ -15,8 +15,11 @@ import time from . import message, shared, structure -class Connection(threading.Thread): - """The connection object""" +class ConnectionBase(threading.Thread): + """ + Common code for the connection thread + with minimum command handlers to reuse + """ def __init__( self, host, port, s=None, network='ip', server=False, i2p_remote_dest=b'' @@ -344,89 +347,13 @@ class Connection(threading.Thread): break def _process_message(self, m): - if m.command == b'version': - version = message.Version.from_message(m) - if shared.stream not in version.streams: - raise ValueError('message not for stream %i' % shared.stream) - logging.debug('%s:%s -> %s', self.host_print, self.port, version) - if ( - version.protocol_version != shared.protocol_version - or version.nonce == shared.nonce - ): - self.status = 'disconnecting' - self.send_queue.put(None) - else: - logging.info( - '%s:%s claims to be %s', - self.host_print, self.port, version.user_agent) - self.send_queue.put(message.Message(b'verack', b'')) - self.verack_sent = True - self.remote_version = version - if not self.server: - self.send_queue.put('fully_established') - if self.network == 'ip': - shared.address_advertise_queue.put(structure.NetAddr( - version.services, self.host, self.port)) - shared.node_pool.add((self.host, self.port)) - elif self.network == 'i2p': - shared.i2p_node_pool.add((self.host, 'i2p')) - if self.network == 'ip': - shared.address_advertise_queue.put(structure.NetAddr( - shared.services, version.host, shared.listening_port)) - if self.server: - if self.network == 'ip': - self.send_queue.put( - message.Version(self.host, self.port)) - else: - self.send_queue.put(message.Version('127.0.0.1', 7656)) - - elif m.command == b'verack': + if m.command == b'verack': self.verack_received = True logging.debug( '%s:%s -> %s', self.host_print, self.port, 'verack') if self.server: self.send_queue.put('fully_established') - elif m.command == b'inv': - inv = message.Inv.from_message(m) - logging.debug('%s:%s -> %s', self.host_print, self.port, inv) - to_get = inv.vectors.copy() - to_get.difference_update(shared.objects.keys()) - self.vectors_to_get.update(to_get) - # Do not send objects they already have. - self.vectors_to_send.difference_update(inv.vectors) - - elif m.command == b'object': - obj = structure.Object.from_message(m) - logging.debug('%s:%s -> %s', self.host_print, self.port, obj) - self.vectors_requested.pop(obj.vector, None) - self.vectors_to_get.discard(obj.vector) - if obj.is_valid() and obj.vector not in shared.objects: - with shared.objects_lock: - shared.objects[obj.vector] = obj - if ( - obj.object_type == shared.i2p_dest_obj_type - and obj.version == shared.i2p_dest_obj_version - ): - dest = base64.b64encode(obj.object_payload, altchars=b'-~') - logging.debug( - 'Received I2P destination object,' - ' adding to i2p_unchecked_node_pool') - logging.debug(dest) - shared.i2p_unchecked_node_pool.add((dest, 'i2p')) - shared.vector_advertise_queue.put(obj.vector) - - elif m.command == b'getdata': - getdata = message.GetData.from_message(m) - logging.debug('%s:%s -> %s', self.host_print, self.port, getdata) - self.vectors_to_send.update(getdata.vectors) - - elif m.command == b'addr': - addr = message.Addr.from_message(m) - logging.debug('%s:%s -> %s', self.host_print, self.port, addr) - for a in addr.addresses: - shared.unchecked_node_pool.add((a.host, a.port)) - elif m.command == b'ping': logging.debug('%s:%s -> ping', self.host_print, self.port) self.send_queue.put(message.Message(b'pong', b'')) @@ -440,7 +367,51 @@ class Connection(threading.Thread): shared.unchecked_node_pool.discard((self.host, self.port)) else: - logging.debug('%s:%s -> %s', self.host_print, self.port, m) + try: + getattr(self, '_process_msg_{}'.format(m.command.decode()))(m) + except (AttributeError, UnicodeDecodeError): + logging.debug('%s:%s -> %s', self.host_print, self.port, m) + + def _process_msg_version(self, m): + version = message.Version.from_message(m) + if shared.stream not in version.streams: + raise ValueError('message not for stream %i' % shared.stream) + logging.debug('%s:%s -> %s', self.host_print, self.port, version) + if ( + version.protocol_version != shared.protocol_version + or version.nonce == shared.nonce + ): + self.status = 'disconnecting' + self.send_queue.put(None) + else: + logging.info( + '%s:%s claims to be %s', + self.host_print, self.port, version.user_agent) + self.send_queue.put(message.Message(b'verack', b'')) + self.verack_sent = True + self.remote_version = version + if not self.server: + self.send_queue.put('fully_established') + if self.network == 'ip': + shared.address_advertise_queue.put(structure.NetAddr( + version.services, self.host, self.port)) + shared.node_pool.add((self.host, self.port)) + elif self.network == 'i2p': + shared.i2p_node_pool.add((self.host, 'i2p')) + if self.network == 'ip': + shared.address_advertise_queue.put(structure.NetAddr( + shared.services, version.host, shared.listening_port)) + if self.server: + if self.network == 'ip': + self.send_queue.put(message.Version(self.host, self.port)) + else: + self.send_queue.put(message.Version('127.0.0.1', 7656)) + + def _process_msg_addr(self, m): + addr = message.Addr.from_message(m) + logging.debug('%s:%s -> %s', self.host_print, self.port, addr) + for a in addr.addresses: + shared.unchecked_node_pool.add((a.host, a.port)) def _request_objects(self): if self.vectors_to_get and len(self.vectors_requested) < 100: @@ -499,4 +470,41 @@ class Connection(threading.Thread): message.Message(b'object', obj.to_bytes())) +class Connection(ConnectionBase): + """The connection with all commands implementation""" + def _process_msg_inv(self, m): + inv = message.Inv.from_message(m) + logging.debug('%s:%s -> %s', self.host_print, self.port, inv) + to_get = inv.vectors.copy() + to_get.difference_update(shared.objects.keys()) + self.vectors_to_get.update(to_get) + # Do not send objects they already have. + self.vectors_to_send.difference_update(inv.vectors) + + def _process_msg_object(self, m): + obj = structure.Object.from_message(m) + logging.debug('%s:%s -> %s', self.host_print, self.port, obj) + self.vectors_requested.pop(obj.vector, None) + self.vectors_to_get.discard(obj.vector) + if obj.is_valid() and obj.vector not in shared.objects: + with shared.objects_lock: + shared.objects[obj.vector] = obj + if ( + obj.object_type == shared.i2p_dest_obj_type + and obj.version == shared.i2p_dest_obj_version + ): + dest = base64.b64encode(obj.object_payload, altchars=b'-~') + logging.debug( + 'Received I2P destination object,' + ' adding to i2p_unchecked_node_pool') + logging.debug(dest) + shared.i2p_unchecked_node_pool.add((dest, 'i2p')) + shared.vector_advertise_queue.put(obj.vector) + + def _process_msg_getdata(self, m): + getdata = message.GetData.from_message(m) + logging.debug('%s:%s -> %s', self.host_print, self.port, getdata) + self.vectors_to_send.update(getdata.vectors) + + shared.connection = Connection