diff --git a/minode/connection.py b/minode/connection.py index c74af5d..b7f9a75 100644 --- a/minode/connection.py +++ b/minode/connection.py @@ -15,8 +15,11 @@ import time from . import message, shared, structure -class Connection(threading.Thread): - """The connection object""" +class ConnectionBase(threading.Thread): + """ + Common code for the connection thread + with minimum command handlers to reuse + """ def __init__( self, host, port, s=None, network='ip', server=False, i2p_remote_dest=b'' @@ -344,89 +347,13 @@ class Connection(threading.Thread): break def _process_message(self, m): - if m.command == b'version': - version = message.Version.from_message(m) - if shared.stream not in version.streams: - raise ValueError('message not for stream %i' % shared.stream) - logging.debug('%s:%s -> %s', self.host_print, self.port, version) - if ( - version.protocol_version != shared.protocol_version - or version.nonce == shared.nonce - ): - self.status = 'disconnecting' - self.send_queue.put(None) - else: - logging.info( - '%s:%s claims to be %s', - self.host_print, self.port, version.user_agent) - self.send_queue.put(message.Message(b'verack', b'')) - self.verack_sent = True - self.remote_version = version - if not self.server: - self.send_queue.put('fully_established') - if self.network == 'ip': - shared.address_advertise_queue.put(structure.NetAddr( - version.services, self.host, self.port)) - shared.node_pool.add((self.host, self.port)) - elif self.network == 'i2p': - shared.i2p_node_pool.add((self.host, 'i2p')) - if self.network == 'ip': - shared.address_advertise_queue.put(structure.NetAddr( - shared.services, version.host, shared.listening_port)) - if self.server: - if self.network == 'ip': - self.send_queue.put( - message.Version(self.host, self.port)) - else: - self.send_queue.put(message.Version('127.0.0.1', 7656)) - - elif m.command == b'verack': + if m.command == b'verack': self.verack_received = True logging.debug( '%s:%s -> %s', self.host_print, self.port, 'verack') if self.server: self.send_queue.put('fully_established') - elif m.command == b'inv': - inv = message.Inv.from_message(m) - logging.debug('%s:%s -> %s', self.host_print, self.port, inv) - to_get = inv.vectors.copy() - to_get.difference_update(shared.objects.keys()) - self.vectors_to_get.update(to_get) - # Do not send objects they already have. - self.vectors_to_send.difference_update(inv.vectors) - - elif m.command == b'object': - obj = structure.Object.from_message(m) - logging.debug('%s:%s -> %s', self.host_print, self.port, obj) - self.vectors_requested.pop(obj.vector, None) - self.vectors_to_get.discard(obj.vector) - if obj.is_valid() and obj.vector not in shared.objects: - with shared.objects_lock: - shared.objects[obj.vector] = obj - if ( - obj.object_type == shared.i2p_dest_obj_type - and obj.version == shared.i2p_dest_obj_version - ): - dest = base64.b64encode(obj.object_payload, altchars=b'-~') - logging.debug( - 'Received I2P destination object,' - ' adding to i2p_unchecked_node_pool') - logging.debug(dest) - shared.i2p_unchecked_node_pool.add((dest, 'i2p')) - shared.vector_advertise_queue.put(obj.vector) - - elif m.command == b'getdata': - getdata = message.GetData.from_message(m) - logging.debug('%s:%s -> %s', self.host_print, self.port, getdata) - self.vectors_to_send.update(getdata.vectors) - - elif m.command == b'addr': - addr = message.Addr.from_message(m) - logging.debug('%s:%s -> %s', self.host_print, self.port, addr) - for a in addr.addresses: - shared.unchecked_node_pool.add((a.host, a.port)) - elif m.command == b'ping': logging.debug('%s:%s -> ping', self.host_print, self.port) self.send_queue.put(message.Message(b'pong', b'')) @@ -440,7 +367,51 @@ class Connection(threading.Thread): shared.unchecked_node_pool.discard((self.host, self.port)) else: - logging.debug('%s:%s -> %s', self.host_print, self.port, m) + try: + getattr(self, '_process_msg_{}'.format(m.command.decode()))(m) + except (AttributeError, UnicodeDecodeError): + logging.debug('%s:%s -> %s', self.host_print, self.port, m) + + def _process_msg_version(self, m): + version = message.Version.from_message(m) + if shared.stream not in version.streams: + raise ValueError('message not for stream %i' % shared.stream) + logging.debug('%s:%s -> %s', self.host_print, self.port, version) + if ( + version.protocol_version != shared.protocol_version + or version.nonce == shared.nonce + ): + self.status = 'disconnecting' + self.send_queue.put(None) + else: + logging.info( + '%s:%s claims to be %s', + self.host_print, self.port, version.user_agent) + self.send_queue.put(message.Message(b'verack', b'')) + self.verack_sent = True + self.remote_version = version + if not self.server: + self.send_queue.put('fully_established') + if self.network == 'ip': + shared.address_advertise_queue.put(structure.NetAddr( + version.services, self.host, self.port)) + shared.node_pool.add((self.host, self.port)) + elif self.network == 'i2p': + shared.i2p_node_pool.add((self.host, 'i2p')) + if self.network == 'ip': + shared.address_advertise_queue.put(structure.NetAddr( + shared.services, version.host, shared.listening_port)) + if self.server: + if self.network == 'ip': + self.send_queue.put(message.Version(self.host, self.port)) + else: + self.send_queue.put(message.Version('127.0.0.1', 7656)) + + def _process_msg_addr(self, m): + addr = message.Addr.from_message(m) + logging.debug('%s:%s -> %s', self.host_print, self.port, addr) + for a in addr.addresses: + shared.unchecked_node_pool.add((a.host, a.port)) def _request_objects(self): if self.vectors_to_get and len(self.vectors_requested) < 100: @@ -499,4 +470,41 @@ class Connection(threading.Thread): message.Message(b'object', obj.to_bytes())) +class Connection(ConnectionBase): + """The connection with all commands implementation""" + def _process_msg_inv(self, m): + inv = message.Inv.from_message(m) + logging.debug('%s:%s -> %s', self.host_print, self.port, inv) + to_get = inv.vectors.copy() + to_get.difference_update(shared.objects.keys()) + self.vectors_to_get.update(to_get) + # Do not send objects they already have. + self.vectors_to_send.difference_update(inv.vectors) + + def _process_msg_object(self, m): + obj = structure.Object.from_message(m) + logging.debug('%s:%s -> %s', self.host_print, self.port, obj) + self.vectors_requested.pop(obj.vector, None) + self.vectors_to_get.discard(obj.vector) + if obj.is_valid() and obj.vector not in shared.objects: + with shared.objects_lock: + shared.objects[obj.vector] = obj + if ( + obj.object_type == shared.i2p_dest_obj_type + and obj.version == shared.i2p_dest_obj_version + ): + dest = base64.b64encode(obj.object_payload, altchars=b'-~') + logging.debug( + 'Received I2P destination object,' + ' adding to i2p_unchecked_node_pool') + logging.debug(dest) + shared.i2p_unchecked_node_pool.add((dest, 'i2p')) + shared.vector_advertise_queue.put(obj.vector) + + def _process_msg_getdata(self, m): + getdata = message.GetData.from_message(m) + logging.debug('%s:%s -> %s', self.host_print, self.port, getdata) + self.vectors_to_send.update(getdata.vectors) + + shared.connection = Connection diff --git a/minode/listener.py b/minode/listener.py index 88b26db..bec5365 100644 --- a/minode/listener.py +++ b/minode/listener.py @@ -28,13 +28,14 @@ class Listener(threading.Thread): break try: conn, addr = self.s.accept() - logging.info('Incoming connection from: %s:%i', *addr[:2]) - with shared.connections_lock: - if len(shared.connections) > shared.connection_limit: - conn.close() - else: - c = Connection(*addr[:2], conn, server=True) - c.start() - shared.connections.add(c) except socket.timeout: - pass + continue + + logging.info('Incoming connection from: %s:%i', *addr[:2]) + with shared.connections_lock: + if len(shared.connections) > shared.connection_limit: + conn.close() + else: + c = Connection(*addr[:2], conn, server=True) + c.start() + shared.connections.add(c) diff --git a/minode/manager.py b/minode/manager.py index 93a3f30..a03f9d6 100644 --- a/minode/manager.py +++ b/minode/manager.py @@ -77,7 +77,7 @@ class Manager(threading.Thread): with shared.connections_lock: shared.connections.remove(c) else: - hosts.add(c.host) + hosts.add(structure.NetAddrNoPrefix.network_group(c.host)) if not c.server: outgoing_connections += 1 @@ -121,15 +121,16 @@ class Manager(threading.Thread): else: to_connect.update(shared.i2p_node_pool) - for addr in to_connect: - if addr[0] in hosts: + for host, port in to_connect: + group = structure.NetAddrNoPrefix.network_group(host) + if group in hosts: continue - if addr[1] == 'i2p' and shared.i2p_enabled: - if shared.i2p_session_nick and addr[0] != shared.i2p_dest_pub: + if port == 'i2p' and shared.i2p_enabled: + if shared.i2p_session_nick and host != shared.i2p_dest_pub: try: d = I2PDialer( shared, - addr[0], shared.i2p_session_nick, + host, shared.i2p_session_nick, shared.i2p_sam_host, shared.i2p_sam_port) d.start() hosts.add(d.destination) @@ -141,9 +142,9 @@ class Manager(threading.Thread): else: continue else: - c = Connection(addr[0], addr[1]) + c = Connection(host, port) c.start() - hosts.add(c.host) + hosts.add(group) with shared.connections_lock: shared.connections.add(c) shared.hosts = hosts diff --git a/minode/message.py b/minode/message.py index eee5e3f..ee01ea8 100644 --- a/minode/message.py +++ b/minode/message.py @@ -145,10 +145,13 @@ class Version(): def from_message(cls, m): payload = m.payload - ( # unused: timestamp, net_addr_local - protocol_version, services, _, net_addr_remote, _, nonce + ( # unused: net_addr_local + protocol_version, services, timestamp, net_addr_remote, _, nonce ) = struct.unpack('>IQQ26s26s8s', payload[:80]) + if abs(time.time() - timestamp) > 3600: + raise ValueError('remote time offset is too large') + net_addr_remote = structure.NetAddrNoPrefix.from_bytes(net_addr_remote) host = net_addr_remote.host diff --git a/minode/shared.py b/minode/shared.py index d49786c..72864ec 100644 --- a/minode/shared.py +++ b/minode/shared.py @@ -21,7 +21,7 @@ protocol_version = 3 services = 3 # NODE_NETWORK, NODE_SSL stream = 1 nonce = os.urandom(8) -user_agent = b'/MiNode:0.3.2/' +user_agent = b'/MiNode:0.3.3/' timeout = 600 header_length = 24 i2p_dest_obj_type = 0x493250 diff --git a/minode/structure.py b/minode/structure.py index 53785b8..405da1a 100644 --- a/minode/structure.py +++ b/minode/structure.py @@ -173,6 +173,21 @@ class NetAddrNoPrefix(): b += struct.pack('>H', int(self.port)) return b + @staticmethod + def network_group(host): + """A simplified network group identifier from pybitmessage protocol""" + try: + host = socket.inet_pton(socket.AF_INET, host) + return host[:2] + except socket.error: + try: + host = socket.inet_pton(socket.AF_INET6, host) + return host[:12] + except OSError: + return host + except TypeError: + return host + @classmethod def from_bytes(cls, b): services, host, port = struct.unpack('>Q16sH', b) diff --git a/minode/tests/test_message.py b/minode/tests/test_message.py index 1a6c089..7c0511e 100644 --- a/minode/tests/test_message.py +++ b/minode/tests/test_message.py @@ -1,4 +1,6 @@ """Tests for messages""" +import struct +import time import unittest from binascii import unhexlify @@ -75,6 +77,13 @@ class TestMessage(unittest.TestCase): """Test version message""" msg = message.Message.from_bytes(sample_version_msg) self.assertEqual(msg.command, b'version') + with self.assertRaises(ValueError): + # large time offset + version_packet = message.Version.from_message(msg) + msg.payload = ( + msg.payload[:12] + struct.pack('>Q', int(time.time())) + + msg.payload[20:]) + version_packet = message.Version.from_message(msg) self.assertEqual(version_packet.host, '127.0.0.1') self.assertEqual(version_packet.port, 8444) diff --git a/minode/tests/test_network.py b/minode/tests/test_network.py new file mode 100644 index 0000000..c8cdd31 --- /dev/null +++ b/minode/tests/test_network.py @@ -0,0 +1,217 @@ +"""Tests for network connections""" +import logging +import os +import random +import unittest +import tempfile +import time +from contextlib import contextmanager + +from minode import connection, main, shared +from minode.listener import Listener +from minode.manager import Manager + +from .test_process import TestProcessProto + + +logging.basicConfig( + level=logging.INFO, + format='[%(asctime)s] [%(levelname)s] %(message)s') + + +@contextmanager +def time_offset(offset): + """ + Replace time.time() by a mock returning a constant value + with given offset from current time. + """ + started = time.time() + time_call = time.time + try: + time.time = lambda: started + offset + yield time_call + finally: + time.time = time_call + + +@contextmanager +def run_listener(host='localhost', port=8444): + """ + Run the Listener with zero connection limit and + reset variables in shared after its stop. + """ + connection_limit = shared.connection_limit + shared.connection_limit = 0 + try: + listener = Listener(host, port) + listener.start() + yield listener + except OSError: + yield + finally: + shared.connection_limit = connection_limit + shared.connections.clear() + shared.shutting_down = True + time.sleep(1) + + +class TestNetwork(unittest.TestCase): + """Test case starting connections""" + + @classmethod + def setUpClass(cls): + shared.data_directory = tempfile.gettempdir() + + def setUp(self): + shared.core_nodes.clear() + shared.unchecked_node_pool.clear() + shared.objects = {} + try: + os.remove(os.path.join(shared.data_directory, 'objects.pickle')) + except FileNotFoundError: + pass + + def _make_initial_nodes(self): + Manager.load_data() + self.assertGreaterEqual(len(shared.core_nodes), 3) + + main.bootstrap_from_dns() + self.assertGreaterEqual(len(shared.unchecked_node_pool), 3) + + def test_connection(self): + """Check a normal connection - should receive objects""" + self._make_initial_nodes() + + started = time.time() + nodes = list(shared.core_nodes.union(shared.unchecked_node_pool)) + random.shuffle(nodes) + + for node in nodes: + # unknown = node not in shared.node_pool + # self.assertTrue(unknown) + unknown = True + shared.node_pool.discard(node) + + c = connection.Connection(*node) + c.start() + connection_started = time.time() + while c.status not in ('disconnected', 'failed'): + # The addr of established connection is added to nodes pool + if unknown and c.status == 'fully_established': + unknown = False + self.assertIn(node, shared.node_pool) + if shared.objects or time.time() - connection_started > 90: + c.status = 'disconnecting' + if time.time() - started > 300: + c.status = 'disconnecting' + self.fail('Failed to receive an object in %s sec' % 300) + time.sleep(0.2) + if shared.objects: # got some objects + break + else: + self.fail('Failed to establish a proper connection') + + def test_time_offset(self): + """Assert the network bans for large time offset""" + def try_connect(nodes, timeout, call): + started = call() + for node in nodes: + c = connection.Connection(*node) + c.start() + while call() < started + timeout: + if c.status == 'fully_established': + return 'Established a connection' + if c.status in ('disconnected', 'failed'): + break + time.sleep(0.2) + else: + return 'Spent too much time trying to connect' + + def time_offset_connections(nodes, offset): + """Spoof time.time and open connections with given time offset""" + with time_offset(offset) as time_call: + result = try_connect(nodes, 200, time_call) + if result: + self.fail(result) + + self._make_initial_nodes() + nodes = random.sample( + tuple(shared.core_nodes.union(shared.unchecked_node_pool)), 5) + + time_offset_connections(nodes, 4000) + time_offset_connections(nodes, -4000) + + +class TestListener(TestProcessProto): + """A separate test case for Listener with a process with --trusted-peer""" + _process_cmd = ['minode', '--trusted-peer', '127.0.0.1'] + + def setUp(self): + shared.shutting_down = False + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + shared.shutting_down = False + + def test_listener(self): + """Start Listener and try to connect""" + with run_listener() as listener: + if not listener: + self.fail('Failed to start listener') + + c = connection.Connection('127.0.0.1', 8444) + shared.connections.add(c) + + for _ in range(30): + if len(shared.connections) > 1: + self.fail('The listener ignored connection limit') + time.sleep(0.5) + + shared.connection_limit = 2 + c.start() + started = time.time() + while c.status not in ('disconnected', 'failed'): + if c.status == 'fully_established': + self.fail('Connected to itself') + if time.time() - started > 90: + c.status = 'disconnecting' + time.sleep(0.2) + + server = None + started = time.time() + while not server: + time.sleep(0.2) + if time.time() - started > 90: + self.fail('Failed to establish the connection') + for c in shared.connections: + if c.status == 'fully_established': + server = c + self.assertTrue(server.server) + + while not self.process.connections(): + time.sleep(0.2) + if time.time() - started > 90: + self.fail('Failed to connect to listener') + + client = self.process.connections()[0] + self.assertEqual(client.raddr[0], '127.0.0.1') + self.assertEqual(client.raddr[1], 8444) + self.assertEqual(server.host, client.laddr[0]) + # self.assertEqual(server.port, client.laddr[1]) + server.status = 'disconnecting' + + self.assertFalse(listener.is_alive()) + + def test_listener_timeoffset(self): + """Run listener with a large time offset - shouldn't connect""" + with time_offset(4000): + with run_listener() as listener: + if not listener: + self.fail('Failed to start listener') + shared.connection_limit = 2 + for _ in range(30): + for c in shared.connections: + if c.status == 'fully_established': + self.fail('Established a connection') + time.sleep(0.5) diff --git a/minode/tests/test_process.py b/minode/tests/test_process.py index b6c0b7e..e6a391c 100644 --- a/minode/tests/test_process.py +++ b/minode/tests/test_process.py @@ -9,6 +9,8 @@ import time import psutil +from minode.structure import NetAddrNoPrefix + try: socket.socket().bind(('127.0.0.1', 7656)) i2p_port_free = True @@ -19,7 +21,7 @@ except (OSError, socket.error): class TestProcessProto(unittest.TestCase): """Test process attributes, common flow""" _process_cmd = ['minode'] - _connection_limit = 4 if sys.platform.startswith('win') else 10 + _connection_limit = 4 if sys.platform.startswith('win') else 8 _listen = False _listening_port = None @@ -82,7 +84,7 @@ class TestProcessShutdown(TestProcessProto): class TestProcess(TestProcessProto): """The test case for minode process""" - _wait_time = 120 + _wait_time = 180 _check_limit = False def test_connections(self): @@ -102,13 +104,19 @@ class TestProcess(TestProcessProto): time.sleep(1) for _ in range(self._wait_time * 2): - if len(self.connections()) > self._connection_limit / 2: + if len(self.connections()) >= self._connection_limit / 2: _time_to_connect = round(time.time() - _started) break + if '--i2p' not in self._process_cmd: + groups = [] + for c in self.connections(): + group = NetAddrNoPrefix.network_group(c.raddr[0]) + self.assertNotIn(group, groups) + groups.append(group) time.sleep(0.5) else: self.fail( - 'Failed establish at least %i connections in %s sec' + 'Failed to establish at least %i connections in %s sec' % (int(self._connection_limit / 2), self._wait_time)) if self._check_limit: diff --git a/minode/tests/test_structure.py b/minode/tests/test_structure.py index 16e41fe..970c152 100644 --- a/minode/tests/test_structure.py +++ b/minode/tests/test_structure.py @@ -104,6 +104,27 @@ class TestStructure(unittest.TestCase): addr = structure.NetAddr(1, '2607:5300:201:3000::57ae', 8080, 1) self.assertEqual(addr.to_bytes()[8:], sample_addr_data[8:]) + def test_network_group(self): + """Test various types of network groups""" + test_ip = '1.2.3.4' + self.assertEqual( + b'\x01\x02', structure.NetAddrNoPrefix.network_group(test_ip)) + self.assertEqual( + structure.NetAddrNoPrefix.network_group('8.8.8.8'), + structure.NetAddrNoPrefix.network_group('8.8.4.4')) + self.assertNotEqual( + structure.NetAddrNoPrefix.network_group('1.1.1.1'), + structure.NetAddrNoPrefix.network_group('8.8.8.8')) + test_ip = '0102:0304:0506:0708:090A:0B0C:0D0E:0F10' + self.assertEqual( + b'\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0A\x0B\x0C', + structure.NetAddrNoPrefix.network_group(test_ip)) + for test_ip in ( + 'bootstrap8444.bitmessage.org', 'quzwelsuziwqgpt2.onion', None + ): + self.assertEqual( + test_ip, structure.NetAddrNoPrefix.network_group(test_ip)) + def test_object(self): """Create and check objects""" obj = structure.Object.from_message(