Improving the network communication #4

Merged
PeterSurda merged 12 commits from lee.miller/MiNode:network into v0.3 2024-07-15 07:58:57 +02:00
10 changed files with 387 additions and 104 deletions

View File

@ -15,8 +15,11 @@ import time
from . import message, shared, structure from . import message, shared, structure
class Connection(threading.Thread): class ConnectionBase(threading.Thread):
"""The connection object""" """
Common code for the connection thread
with minimum command handlers to reuse
"""
def __init__( def __init__(
self, host, port, s=None, network='ip', server=False, self, host, port, s=None, network='ip', server=False,
i2p_remote_dest=b'' i2p_remote_dest=b''
@ -344,7 +347,32 @@ class Connection(threading.Thread):
break break
def _process_message(self, m): def _process_message(self, m):
if m.command == b'version': if m.command == b'verack':
self.verack_received = True
logging.debug(
'%s:%s -> %s', self.host_print, self.port, 'verack')
if self.server:
self.send_queue.put('fully_established')
elif m.command == b'ping':
logging.debug('%s:%s -> ping', self.host_print, self.port)
self.send_queue.put(message.Message(b'pong', b''))
elif m.command == b'error':
error = message.Error.from_message(m)
logging.warning(
'%s:%s -> %s', self.host_print, self.port, error)
if error.fatal == 2:
# reduce probability to connect soon
shared.unchecked_node_pool.discard((self.host, self.port))
else:
try:
getattr(self, '_process_msg_{}'.format(m.command.decode()))(m)
except (AttributeError, UnicodeDecodeError):
logging.debug('%s:%s -> %s', self.host_print, self.port, m)
def _process_msg_version(self, m):
version = message.Version.from_message(m) version = message.Version.from_message(m)
if shared.stream not in version.streams: if shared.stream not in version.streams:
raise ValueError('message not for stream %i' % shared.stream) raise ValueError('message not for stream %i' % shared.stream)
@ -375,73 +403,16 @@ class Connection(threading.Thread):
shared.services, version.host, shared.listening_port)) shared.services, version.host, shared.listening_port))
if self.server: if self.server:
if self.network == 'ip': if self.network == 'ip':
self.send_queue.put( self.send_queue.put(message.Version(self.host, self.port))
message.Version(self.host, self.port))
else: else:
self.send_queue.put(message.Version('127.0.0.1', 7656)) self.send_queue.put(message.Version('127.0.0.1', 7656))
elif m.command == b'verack': def _process_msg_addr(self, m):
self.verack_received = True
logging.debug(
'%s:%s -> %s', self.host_print, self.port, 'verack')
if self.server:
self.send_queue.put('fully_established')
elif m.command == b'inv':
inv = message.Inv.from_message(m)
logging.debug('%s:%s -> %s', self.host_print, self.port, inv)
to_get = inv.vectors.copy()
to_get.difference_update(shared.objects.keys())
self.vectors_to_get.update(to_get)
# Do not send objects they already have.
self.vectors_to_send.difference_update(inv.vectors)
elif m.command == b'object':
obj = structure.Object.from_message(m)
logging.debug('%s:%s -> %s', self.host_print, self.port, obj)
self.vectors_requested.pop(obj.vector, None)
self.vectors_to_get.discard(obj.vector)
if obj.is_valid() and obj.vector not in shared.objects:
with shared.objects_lock:
shared.objects[obj.vector] = obj
if (
obj.object_type == shared.i2p_dest_obj_type
and obj.version == shared.i2p_dest_obj_version
):
dest = base64.b64encode(obj.object_payload, altchars=b'-~')
logging.debug(
'Received I2P destination object,'
' adding to i2p_unchecked_node_pool')
logging.debug(dest)
shared.i2p_unchecked_node_pool.add((dest, 'i2p'))
shared.vector_advertise_queue.put(obj.vector)
elif m.command == b'getdata':
getdata = message.GetData.from_message(m)
logging.debug('%s:%s -> %s', self.host_print, self.port, getdata)
self.vectors_to_send.update(getdata.vectors)
elif m.command == b'addr':
addr = message.Addr.from_message(m) addr = message.Addr.from_message(m)
logging.debug('%s:%s -> %s', self.host_print, self.port, addr) logging.debug('%s:%s -> %s', self.host_print, self.port, addr)
for a in addr.addresses: for a in addr.addresses:
shared.unchecked_node_pool.add((a.host, a.port)) shared.unchecked_node_pool.add((a.host, a.port))
elif m.command == b'ping':
logging.debug('%s:%s -> ping', self.host_print, self.port)
self.send_queue.put(message.Message(b'pong', b''))
elif m.command == b'error':
error = message.Error.from_message(m)
logging.warning(
'%s:%s -> %s', self.host_print, self.port, error)
if error.fatal == 2:
# reduce probability to connect soon
shared.unchecked_node_pool.discard((self.host, self.port))
else:
logging.debug('%s:%s -> %s', self.host_print, self.port, m)
def _request_objects(self): def _request_objects(self):
if self.vectors_to_get and len(self.vectors_requested) < 100: if self.vectors_to_get and len(self.vectors_requested) < 100:
self.vectors_to_get.difference_update(shared.objects.keys()) self.vectors_to_get.difference_update(shared.objects.keys())
@ -499,4 +470,41 @@ class Connection(threading.Thread):
message.Message(b'object', obj.to_bytes())) message.Message(b'object', obj.to_bytes()))
class Connection(ConnectionBase):
"""The connection with all commands implementation"""
def _process_msg_inv(self, m):
inv = message.Inv.from_message(m)
logging.debug('%s:%s -> %s', self.host_print, self.port, inv)
to_get = inv.vectors.copy()
to_get.difference_update(shared.objects.keys())
self.vectors_to_get.update(to_get)
# Do not send objects they already have.
self.vectors_to_send.difference_update(inv.vectors)
def _process_msg_object(self, m):
obj = structure.Object.from_message(m)
logging.debug('%s:%s -> %s', self.host_print, self.port, obj)
self.vectors_requested.pop(obj.vector, None)
self.vectors_to_get.discard(obj.vector)
if obj.is_valid() and obj.vector not in shared.objects:
with shared.objects_lock:
shared.objects[obj.vector] = obj
if (
obj.object_type == shared.i2p_dest_obj_type
and obj.version == shared.i2p_dest_obj_version
):
dest = base64.b64encode(obj.object_payload, altchars=b'-~')
logging.debug(
'Received I2P destination object,'
' adding to i2p_unchecked_node_pool')
logging.debug(dest)
shared.i2p_unchecked_node_pool.add((dest, 'i2p'))
shared.vector_advertise_queue.put(obj.vector)
def _process_msg_getdata(self, m):
getdata = message.GetData.from_message(m)
logging.debug('%s:%s -> %s', self.host_print, self.port, getdata)
self.vectors_to_send.update(getdata.vectors)
shared.connection = Connection shared.connection = Connection

View File

@ -28,6 +28,9 @@ class Listener(threading.Thread):
break break
try: try:
conn, addr = self.s.accept() conn, addr = self.s.accept()
except socket.timeout:
continue
logging.info('Incoming connection from: %s:%i', *addr[:2]) logging.info('Incoming connection from: %s:%i', *addr[:2])
with shared.connections_lock: with shared.connections_lock:
if len(shared.connections) > shared.connection_limit: if len(shared.connections) > shared.connection_limit:
@ -36,5 +39,3 @@ class Listener(threading.Thread):
c = Connection(*addr[:2], conn, server=True) c = Connection(*addr[:2], conn, server=True)
c.start() c.start()
shared.connections.add(c) shared.connections.add(c)
except socket.timeout:
pass

View File

@ -77,7 +77,7 @@ class Manager(threading.Thread):
with shared.connections_lock: with shared.connections_lock:
shared.connections.remove(c) shared.connections.remove(c)
else: else:
hosts.add(c.host) hosts.add(structure.NetAddrNoPrefix.network_group(c.host))
if not c.server: if not c.server:
outgoing_connections += 1 outgoing_connections += 1
@ -121,15 +121,16 @@ class Manager(threading.Thread):
else: else:
to_connect.update(shared.i2p_node_pool) to_connect.update(shared.i2p_node_pool)
for addr in to_connect: for host, port in to_connect:
if addr[0] in hosts: group = structure.NetAddrNoPrefix.network_group(host)
if group in hosts:
continue continue
if addr[1] == 'i2p' and shared.i2p_enabled: if port == 'i2p' and shared.i2p_enabled:
if shared.i2p_session_nick and addr[0] != shared.i2p_dest_pub: if shared.i2p_session_nick and host != shared.i2p_dest_pub:
try: try:
d = I2PDialer( d = I2PDialer(
shared, shared,
addr[0], shared.i2p_session_nick, host, shared.i2p_session_nick,
shared.i2p_sam_host, shared.i2p_sam_port) shared.i2p_sam_host, shared.i2p_sam_port)
d.start() d.start()
hosts.add(d.destination) hosts.add(d.destination)
@ -141,9 +142,9 @@ class Manager(threading.Thread):
else: else:
continue continue
else: else:
c = Connection(addr[0], addr[1]) c = Connection(host, port)
c.start() c.start()
hosts.add(c.host) hosts.add(group)
with shared.connections_lock: with shared.connections_lock:
shared.connections.add(c) shared.connections.add(c)
shared.hosts = hosts shared.hosts = hosts

View File

@ -145,10 +145,13 @@ class Version():
def from_message(cls, m): def from_message(cls, m):
payload = m.payload payload = m.payload
( # unused: timestamp, net_addr_local ( # unused: net_addr_local
protocol_version, services, _, net_addr_remote, _, nonce protocol_version, services, timestamp, net_addr_remote, _, nonce
) = struct.unpack('>IQQ26s26s8s', payload[:80]) ) = struct.unpack('>IQQ26s26s8s', payload[:80])
if abs(time.time() - timestamp) > 3600:
raise ValueError('remote time offset is too large')
net_addr_remote = structure.NetAddrNoPrefix.from_bytes(net_addr_remote) net_addr_remote = structure.NetAddrNoPrefix.from_bytes(net_addr_remote)
host = net_addr_remote.host host = net_addr_remote.host

View File

@ -21,7 +21,7 @@ protocol_version = 3
services = 3 # NODE_NETWORK, NODE_SSL services = 3 # NODE_NETWORK, NODE_SSL
stream = 1 stream = 1
nonce = os.urandom(8) nonce = os.urandom(8)
user_agent = b'/MiNode:0.3.2/' user_agent = b'/MiNode:0.3.3/'
timeout = 600 timeout = 600
header_length = 24 header_length = 24
i2p_dest_obj_type = 0x493250 i2p_dest_obj_type = 0x493250

View File

@ -173,6 +173,21 @@ class NetAddrNoPrefix():
b += struct.pack('>H', int(self.port)) b += struct.pack('>H', int(self.port))
return b return b
@staticmethod
def network_group(host):
"""A simplified network group identifier from pybitmessage protocol"""
try:
host = socket.inet_pton(socket.AF_INET, host)
return host[:2]
except socket.error:
try:
host = socket.inet_pton(socket.AF_INET6, host)
return host[:12]
except OSError:
return host
except TypeError:
return host
@classmethod @classmethod
def from_bytes(cls, b): def from_bytes(cls, b):
services, host, port = struct.unpack('>Q16sH', b) services, host, port = struct.unpack('>Q16sH', b)

View File

@ -1,4 +1,6 @@
"""Tests for messages""" """Tests for messages"""
import struct
import time
import unittest import unittest
from binascii import unhexlify from binascii import unhexlify
@ -75,6 +77,13 @@ class TestMessage(unittest.TestCase):
"""Test version message""" """Test version message"""
msg = message.Message.from_bytes(sample_version_msg) msg = message.Message.from_bytes(sample_version_msg)
self.assertEqual(msg.command, b'version') self.assertEqual(msg.command, b'version')
with self.assertRaises(ValueError):
# large time offset
version_packet = message.Version.from_message(msg)
msg.payload = (
msg.payload[:12] + struct.pack('>Q', int(time.time()))
+ msg.payload[20:])
version_packet = message.Version.from_message(msg) version_packet = message.Version.from_message(msg)
self.assertEqual(version_packet.host, '127.0.0.1') self.assertEqual(version_packet.host, '127.0.0.1')
self.assertEqual(version_packet.port, 8444) self.assertEqual(version_packet.port, 8444)

View File

@ -0,0 +1,217 @@
"""Tests for network connections"""
import logging
import os
import random
import unittest
import tempfile
import time
from contextlib import contextmanager
from minode import connection, main, shared
from minode.listener import Listener
from minode.manager import Manager
from .test_process import TestProcessProto
logging.basicConfig(
level=logging.INFO,
format='[%(asctime)s] [%(levelname)s] %(message)s')
@contextmanager
def time_offset(offset):
"""
Replace time.time() by a mock returning a constant value
with given offset from current time.
"""
started = time.time()
time_call = time.time
try:
time.time = lambda: started + offset
yield time_call
finally:
time.time = time_call
@contextmanager
def run_listener(host='localhost', port=8444):
"""
Run the Listener with zero connection limit and
reset variables in shared after its stop.
"""
connection_limit = shared.connection_limit
shared.connection_limit = 0
try:
listener = Listener(host, port)
listener.start()
yield listener
except OSError:
yield
finally:
shared.connection_limit = connection_limit
shared.connections.clear()
shared.shutting_down = True
time.sleep(1)
class TestNetwork(unittest.TestCase):
"""Test case starting connections"""
@classmethod
def setUpClass(cls):
shared.data_directory = tempfile.gettempdir()
def setUp(self):
shared.core_nodes.clear()
shared.unchecked_node_pool.clear()
shared.objects = {}
try:
os.remove(os.path.join(shared.data_directory, 'objects.pickle'))
except FileNotFoundError:
pass
def _make_initial_nodes(self):
Manager.load_data()
self.assertGreaterEqual(len(shared.core_nodes), 3)
main.bootstrap_from_dns()
self.assertGreaterEqual(len(shared.unchecked_node_pool), 3)
def test_connection(self):
"""Check a normal connection - should receive objects"""
self._make_initial_nodes()
started = time.time()
nodes = list(shared.core_nodes.union(shared.unchecked_node_pool))
random.shuffle(nodes)
for node in nodes:
# unknown = node not in shared.node_pool
# self.assertTrue(unknown)
unknown = True
shared.node_pool.discard(node)
c = connection.Connection(*node)
c.start()
connection_started = time.time()
while c.status not in ('disconnected', 'failed'):
# The addr of established connection is added to nodes pool
if unknown and c.status == 'fully_established':
unknown = False
self.assertIn(node, shared.node_pool)
if shared.objects or time.time() - connection_started > 90:
c.status = 'disconnecting'
if time.time() - started > 300:
c.status = 'disconnecting'
self.fail('Failed to receive an object in %s sec' % 300)
time.sleep(0.2)
if shared.objects: # got some objects
break
else:
self.fail('Failed to establish a proper connection')
def test_time_offset(self):
"""Assert the network bans for large time offset"""
def try_connect(nodes, timeout, call):
started = call()
for node in nodes:
c = connection.Connection(*node)
c.start()
while call() < started + timeout:
if c.status == 'fully_established':
return 'Established a connection'
if c.status in ('disconnected', 'failed'):
break
time.sleep(0.2)
else:
return 'Spent too much time trying to connect'
def time_offset_connections(nodes, offset):
"""Spoof time.time and open connections with given time offset"""
with time_offset(offset) as time_call:
result = try_connect(nodes, 200, time_call)
if result:
self.fail(result)
self._make_initial_nodes()
nodes = random.sample(
tuple(shared.core_nodes.union(shared.unchecked_node_pool)), 5)
time_offset_connections(nodes, 4000)
time_offset_connections(nodes, -4000)
class TestListener(TestProcessProto):
"""A separate test case for Listener with a process with --trusted-peer"""
_process_cmd = ['minode', '--trusted-peer', '127.0.0.1']
def setUp(self):
shared.shutting_down = False
@classmethod
def tearDownClass(cls):
super().tearDownClass()
shared.shutting_down = False
def test_listener(self):
"""Start Listener and try to connect"""
with run_listener() as listener:
if not listener:
self.fail('Failed to start listener')
c = connection.Connection('127.0.0.1', 8444)
shared.connections.add(c)
for _ in range(30):
if len(shared.connections) > 1:
self.fail('The listener ignored connection limit')
time.sleep(0.5)
shared.connection_limit = 2
c.start()
started = time.time()
while c.status not in ('disconnected', 'failed'):
if c.status == 'fully_established':
self.fail('Connected to itself')
if time.time() - started > 90:
c.status = 'disconnecting'
time.sleep(0.2)
server = None
started = time.time()
while not server:
time.sleep(0.2)
if time.time() - started > 90:
self.fail('Failed to establish the connection')
for c in shared.connections:
if c.status == 'fully_established':
server = c
self.assertTrue(server.server)
while not self.process.connections():
time.sleep(0.2)
if time.time() - started > 90:
self.fail('Failed to connect to listener')
client = self.process.connections()[0]
self.assertEqual(client.raddr[0], '127.0.0.1')
self.assertEqual(client.raddr[1], 8444)
self.assertEqual(server.host, client.laddr[0])
# self.assertEqual(server.port, client.laddr[1])
server.status = 'disconnecting'
self.assertFalse(listener.is_alive())
def test_listener_timeoffset(self):
"""Run listener with a large time offset - shouldn't connect"""
with time_offset(4000):
with run_listener() as listener:
if not listener:
self.fail('Failed to start listener')
shared.connection_limit = 2
for _ in range(30):
for c in shared.connections:
if c.status == 'fully_established':
self.fail('Established a connection')
time.sleep(0.5)

View File

@ -9,6 +9,8 @@ import time
import psutil import psutil
from minode.structure import NetAddrNoPrefix
try: try:
socket.socket().bind(('127.0.0.1', 7656)) socket.socket().bind(('127.0.0.1', 7656))
i2p_port_free = True i2p_port_free = True
@ -19,7 +21,7 @@ except (OSError, socket.error):
class TestProcessProto(unittest.TestCase): class TestProcessProto(unittest.TestCase):
"""Test process attributes, common flow""" """Test process attributes, common flow"""
_process_cmd = ['minode'] _process_cmd = ['minode']
_connection_limit = 4 if sys.platform.startswith('win') else 10 _connection_limit = 4 if sys.platform.startswith('win') else 8
_listen = False _listen = False
_listening_port = None _listening_port = None
@ -82,7 +84,7 @@ class TestProcessShutdown(TestProcessProto):
class TestProcess(TestProcessProto): class TestProcess(TestProcessProto):
"""The test case for minode process""" """The test case for minode process"""
_wait_time = 120 _wait_time = 180
_check_limit = False _check_limit = False
def test_connections(self): def test_connections(self):
@ -102,13 +104,19 @@ class TestProcess(TestProcessProto):
time.sleep(1) time.sleep(1)
for _ in range(self._wait_time * 2): for _ in range(self._wait_time * 2):
if len(self.connections()) > self._connection_limit / 2: if len(self.connections()) >= self._connection_limit / 2:
_time_to_connect = round(time.time() - _started) _time_to_connect = round(time.time() - _started)
break break
if '--i2p' not in self._process_cmd:
groups = []
for c in self.connections():
group = NetAddrNoPrefix.network_group(c.raddr[0])
self.assertNotIn(group, groups)
groups.append(group)
time.sleep(0.5) time.sleep(0.5)
else: else:
self.fail( self.fail(
'Failed establish at least %i connections in %s sec' 'Failed to establish at least %i connections in %s sec'
% (int(self._connection_limit / 2), self._wait_time)) % (int(self._connection_limit / 2), self._wait_time))
if self._check_limit: if self._check_limit:

View File

@ -104,6 +104,27 @@ class TestStructure(unittest.TestCase):
addr = structure.NetAddr(1, '2607:5300:201:3000::57ae', 8080, 1) addr = structure.NetAddr(1, '2607:5300:201:3000::57ae', 8080, 1)
self.assertEqual(addr.to_bytes()[8:], sample_addr_data[8:]) self.assertEqual(addr.to_bytes()[8:], sample_addr_data[8:])
def test_network_group(self):
"""Test various types of network groups"""
test_ip = '1.2.3.4'
self.assertEqual(
b'\x01\x02', structure.NetAddrNoPrefix.network_group(test_ip))
self.assertEqual(
structure.NetAddrNoPrefix.network_group('8.8.8.8'),
structure.NetAddrNoPrefix.network_group('8.8.4.4'))
self.assertNotEqual(
structure.NetAddrNoPrefix.network_group('1.1.1.1'),
structure.NetAddrNoPrefix.network_group('8.8.8.8'))
test_ip = '0102:0304:0506:0708:090A:0B0C:0D0E:0F10'
self.assertEqual(
b'\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0A\x0B\x0C',
structure.NetAddrNoPrefix.network_group(test_ip))
for test_ip in (
'bootstrap8444.bitmessage.org', 'quzwelsuziwqgpt2.onion', None
):
self.assertEqual(
test_ip, structure.NetAddrNoPrefix.network_group(test_ip))
def test_object(self): def test_object(self):
"""Create and check objects""" """Create and check objects"""
obj = structure.Object.from_message( obj = structure.Object.from_message(