Compare commits

...

7 Commits

3 changed files with 318 additions and 78 deletions

View File

@ -15,8 +15,11 @@ import time
from . import message, shared, structure from . import message, shared, structure
class Connection(threading.Thread): class ConnectionBase(threading.Thread):
"""The connection object""" """
Common code for the connection thread
with minimum command handlers to reuse
"""
def __init__( def __init__(
self, host, port, s=None, network='ip', server=False, self, host, port, s=None, network='ip', server=False,
i2p_remote_dest=b'' i2p_remote_dest=b''
@ -334,87 +337,13 @@ class Connection(threading.Thread):
break break
def _process_message(self, m): def _process_message(self, m):
if m.command == b'version': if m.command == b'verack':
version = message.Version.from_bytes(m.to_bytes())
logging.debug('%s:%s -> %s', self.host_print, self.port, version)
if (
version.protocol_version != shared.protocol_version
or version.nonce == shared.nonce
):
self.status = 'disconnecting'
self.send_queue.put(None)
else:
logging.info(
'%s:%s claims to be %s',
self.host_print, self.port, version.user_agent)
self.send_queue.put(message.Message(b'verack', b''))
self.verack_sent = True
self.remote_version = version
if not self.server:
self.send_queue.put('fully_established')
if self.network == 'ip':
shared.address_advertise_queue.put(structure.NetAddr(
version.services, self.host, self.port))
shared.node_pool.add((self.host, self.port))
elif self.network == 'i2p':
shared.i2p_node_pool.add((self.host, 'i2p'))
if self.network == 'ip':
shared.address_advertise_queue.put(structure.NetAddr(
shared.services, version.host, shared.listening_port))
if self.server:
if self.network == 'ip':
self.send_queue.put(
message.Version(self.host, self.port))
else:
self.send_queue.put(message.Version('127.0.0.1', 7656))
elif m.command == b'verack':
self.verack_received = True self.verack_received = True
logging.debug( logging.debug(
'%s:%s -> %s', self.host_print, self.port, 'verack') '%s:%s -> %s', self.host_print, self.port, 'verack')
if self.server: if self.server:
self.send_queue.put('fully_established') self.send_queue.put('fully_established')
elif m.command == b'inv':
inv = message.Inv.from_message(m)
logging.debug('%s:%s -> %s', self.host_print, self.port, inv)
to_get = inv.vectors.copy()
to_get.difference_update(shared.objects.keys())
self.vectors_to_get.update(to_get)
# Do not send objects they already have.
self.vectors_to_send.difference_update(inv.vectors)
elif m.command == b'object':
obj = structure.Object.from_message(m)
logging.debug('%s:%s -> %s', self.host_print, self.port, obj)
self.vectors_requested.pop(obj.vector, None)
self.vectors_to_get.discard(obj.vector)
if obj.is_valid() and obj.vector not in shared.objects:
with shared.objects_lock:
shared.objects[obj.vector] = obj
if (
obj.object_type == shared.i2p_dest_obj_type
and obj.version == shared.i2p_dest_obj_version
):
dest = base64.b64encode(obj.object_payload, altchars=b'-~')
logging.debug(
'Received I2P destination object,'
' adding to i2p_unchecked_node_pool')
logging.debug(dest)
shared.i2p_unchecked_node_pool.add((dest, 'i2p'))
shared.vector_advertise_queue.put(obj.vector)
elif m.command == b'getdata':
getdata = message.GetData.from_message(m)
logging.debug('%s:%s -> %s', self.host_print, self.port, getdata)
self.vectors_to_send.update(getdata.vectors)
elif m.command == b'addr':
addr = message.Addr.from_message(m)
logging.debug('%s:%s -> %s', self.host_print, self.port, addr)
for a in addr.addresses:
shared.unchecked_node_pool.add((a.host, a.port))
elif m.command == b'ping': elif m.command == b'ping':
logging.debug('%s:%s -> ping', self.host_print, self.port) logging.debug('%s:%s -> ping', self.host_print, self.port)
self.send_queue.put(message.Message(b'pong', b'')) self.send_queue.put(message.Message(b'pong', b''))
@ -424,7 +353,49 @@ class Connection(threading.Thread):
'%s:%s -> error: %s', self.host_print, self.port, m.payload) '%s:%s -> error: %s', self.host_print, self.port, m.payload)
else: else:
logging.debug('%s:%s -> %s', self.host_print, self.port, m) try:
getattr(self, '_process_msg_{}'.format(m.command.decode()))(m)
except (AttributeError, UnicodeDecodeError):
logging.debug('%s:%s -> %s', self.host_print, self.port, m)
def _process_msg_version(self, m):
version = message.Version.from_bytes(m.to_bytes())
logging.debug('%s:%s -> %s', self.host_print, self.port, version)
if (
version.protocol_version != shared.protocol_version
or version.nonce == shared.nonce
):
self.status = 'disconnecting'
self.send_queue.put(None)
else:
logging.info(
'%s:%s claims to be %s',
self.host_print, self.port, version.user_agent)
self.send_queue.put(message.Message(b'verack', b''))
self.verack_sent = True
self.remote_version = version
if not self.server:
self.send_queue.put('fully_established')
if self.network == 'ip':
shared.address_advertise_queue.put(structure.NetAddr(
version.services, self.host, self.port))
shared.node_pool.add((self.host, self.port))
elif self.network == 'i2p':
shared.i2p_node_pool.add((self.host, 'i2p'))
if self.network == 'ip':
shared.address_advertise_queue.put(structure.NetAddr(
shared.services, version.host, shared.listening_port))
if self.server:
if self.network == 'ip':
self.send_queue.put(message.Version(self.host, self.port))
else:
self.send_queue.put(message.Version('127.0.0.1', 7656))
def _process_msg_addr(self, m):
addr = message.Addr.from_message(m)
logging.debug('%s:%s -> %s', self.host_print, self.port, addr)
for a in addr.addresses:
shared.unchecked_node_pool.add((a.host, a.port))
def _request_objects(self): def _request_objects(self):
if self.vectors_to_get and len(self.vectors_requested) < 100: if self.vectors_to_get and len(self.vectors_requested) < 100:
@ -483,4 +454,61 @@ class Connection(threading.Thread):
message.Message(b'object', obj.to_bytes())) message.Message(b'object', obj.to_bytes()))
class Connection(ConnectionBase):
"""The connection with all commands implementation"""
def _process_msg_inv(self, m):
inv = message.Inv.from_message(m)
logging.debug('%s:%s -> %s', self.host_print, self.port, inv)
to_get = inv.vectors.copy()
to_get.difference_update(shared.objects.keys())
self.vectors_to_get.update(to_get)
# Do not send objects they already have.
self.vectors_to_send.difference_update(inv.vectors)
def _process_msg_object(self, m):
obj = structure.Object.from_message(m)
logging.debug('%s:%s -> %s', self.host_print, self.port, obj)
self.vectors_requested.pop(obj.vector, None)
self.vectors_to_get.discard(obj.vector)
if obj.is_valid() and obj.vector not in shared.objects:
with shared.objects_lock:
shared.objects[obj.vector] = obj
if (
obj.object_type == shared.i2p_dest_obj_type
and obj.version == shared.i2p_dest_obj_version
):
dest = base64.b64encode(obj.object_payload, altchars=b'-~')
logging.debug(
'Received I2P destination object,'
' adding to i2p_unchecked_node_pool')
logging.debug(dest)
shared.i2p_unchecked_node_pool.add((dest, 'i2p'))
shared.vector_advertise_queue.put(obj.vector)
def _process_msg_getdata(self, m):
getdata = message.GetData.from_message(m)
logging.debug('%s:%s -> %s', self.host_print, self.port, getdata)
self.vectors_to_send.update(getdata.vectors)
class TrackingConnection(Connection):
"""A helper to track number of servers in the network"""
def __init__(self, *args):
super().__init__(*args)
self.connected_time = None
self.active = False
def _on_connection_fully_established(self):
self.connected_time = time.time()
super()._on_connection_fully_established()
def _process_message(self, m):
super()._process_message(m)
if not self.active and m.command in (b'object', b'getdata'):
self.active = True
# def run(self):
# super().run()
shared.connection = Connection shared.connection = Connection

View File

@ -1,6 +1,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
"""The main thread, managing connections, nodes and objects""" """The main thread, managing connections, nodes and objects"""
import base64 import base64
import csv
import logging import logging
import os import os
import pickle import pickle
@ -26,6 +27,10 @@ class Manager(threading.Thread):
# Publish destination 5-15 minutes after start # Publish destination 5-15 minutes after start
self.last_published_i2p_destination = \ self.last_published_i2p_destination = \
time.time() - 50 * 60 + random.uniform(-1, 1) * 300 # nosec time.time() - 50 * 60 + random.uniform(-1, 1) * 300 # nosec
self.last_dumped_counts = time.time()
self.connection_counts = []
self.connections_avg = 0
self.connections_max = 0
def run(self): def run(self):
self.clean_objects() self.clean_objects()
@ -40,6 +45,7 @@ class Manager(threading.Thread):
self.last_cleaned_objects = now self.last_cleaned_objects = now
if now - self.last_cleaned_connections > 2: if now - self.last_cleaned_connections > 2:
self.manage_connections() self.manage_connections()
self.check_connections_count()
self.last_cleaned_connections = now self.last_cleaned_connections = now
if now - self.last_pickled_objects > 100: if now - self.last_pickled_objects > 100:
self.pickle_objects() self.pickle_objects()
@ -47,9 +53,50 @@ class Manager(threading.Thread):
if now - self.last_pickled_nodes > 60: if now - self.last_pickled_nodes > 60:
self.pickle_nodes() self.pickle_nodes()
self.last_pickled_nodes = now self.last_pickled_nodes = now
if now - self.last_dumped_counts > 300:
self.dump_counts()
self.last_dumped_counts = now
if now - self.last_published_i2p_destination > 3600: if now - self.last_published_i2p_destination > 3600:
self.publish_i2p_destination() self.publish_i2p_destination()
self.last_published_i2p_destination = now self.last_published_i2p_destination = now
logging.info(
'Connections number max: %s, average: %.2f',
self.connections_max, self.connections_avg)
def dump_counts(self):
"""Write connection counts into a csv file"""
total = 0
with open(
os.path.join(shared.data_directory, 'connection_counts.csv'),
'a', newline='', encoding='ascii'
) as dst:
writer = csv.writer(dst)
for timestamp, number in self.connection_counts:
total += number
if number > self.connections_max:
self.connections_max = number
writer.writerow((time.ctime(timestamp), number))
avg = total / len(self.connection_counts)
self.connections_avg = (
(self.connections_avg + avg) / 2 if self.connections_avg else avg)
self.connection_counts = self.connection_counts[-1:]
def check_connections_count(self):
"""Make a record for changed connections count"""
try:
previous_count = self.connection_counts[-1]
except IndexError:
previous_count = 0
count = sum((
1 for c in shared.connections.copy()
if c.status == 'fully_established'))
if count != previous_count:
self.connection_counts.append((time.time(), count))
if count == shared.outgoing_connections:
logging.info('Max outgoing connections reached!')
@staticmethod @staticmethod
def clean_objects(): def clean_objects():

View File

@ -0,0 +1,165 @@
"""Tests for network connections"""
import logging
import os
import random
import unittest
import tempfile
import time
from minode import connection, main, message, proofofwork, shared, structure
logging.basicConfig(
level=logging.INFO,
format='[%(asctime)s] [%(levelname)s] %(message)s')
class PushConnection(connection.ConnectionBase):
"""A connection pushing useless objects"""
def __init__(self, *args):
super().__init__(*args)
self.objects_setup = 0
self.started = time.time()
def _request_objects(self):
"""Make objects to send"""
while not shared.vector_advertise_queue.empty():
vector = shared.vector_advertise_queue.get()
self.vectors_to_send.add(vector)
if self.objects_setup < 10:
if self.objects_setup == 0:
self.started = time.time()
obj = structure.Object(
b'\x00' * 8, int(time.time() + 1800), 42, 1, 1,
b'HELLO %i' % self.objects_setup)
proofofwork.do_pow_and_publish(obj)
self.objects_setup += 1
elif (
len(shared.objects) == 10 and len(self.vectors_to_send) == 0
or time.time() - self.started > 120
):
self.status = 'disconnecting'
class SearchConnection(connection.Connection):
"""A connection searching for objects"""
def __init__(self, *args):
super().__init__(*args)
self.started = time.time()
self.vectors_to_search = {v for v in shared.objects.keys()}
shared.objects.clear()
logging.info(
'Have %s objects to search for', len(self.vectors_to_search))
def _process_msg_inv(self, m):
inv = message.Inv.from_message(m)
to_get = inv.vectors.intersection(self.vectors_to_search)
self.vectors_to_get.update(to_get)
self.vectors_to_search.difference_update(to_get)
logging.info('Found %s vectors', len(to_get))
def _send_objects(self):
if (
len(self.vectors_to_search) == len(self.vectors_to_get) == 0
or time.time() - self.started > 300
):
self.status = 'disconnecting'
logging.info('disconnecting!')
class TestNetwork(unittest.TestCase):
"""Test case starting connections"""
@classmethod
def setUpClass(cls):
shared.data_directory = tempfile.gettempdir()
def setUp(self):
shared.core_nodes.clear()
shared.unchecked_node_pool.clear()
shared.objects = {}
try:
os.remove(os.path.join(shared.data_directory, 'objects.pickle'))
except FileNotFoundError:
pass
def _make_initial_nodes(self):
main.load_data()
self.assertGreaterEqual(len(shared.core_nodes), 3)
main.bootstrap_from_dns()
self.assertGreaterEqual(len(shared.unchecked_node_pool), 3)
def test_connection(self):
"""Check a normal connection - should receive objects"""
self._make_initial_nodes()
started = time.time()
nodes = list(shared.core_nodes.union(shared.unchecked_node_pool))
random.shuffle(nodes)
for node in nodes:
# unknown = node not in shared.node_pool
# self.assertTrue(unknown)
unknown = True
shared.node_pool.discard(node)
c = connection.Connection(*node)
c.start()
connection_started = time.time()
while c.status not in ('disconnecting', 'disconnected', 'failed'):
# The addr of established connection is added to nodes pool
if unknown and c.status == 'fully_established':
unknown = False
self.assertIn(node, shared.node_pool)
if shared.objects or time.time() - connection_started > 90:
c.status = 'disconnecting'
if time.time() - started > 300:
c.status = 'disconnecting'
self.fail('Failed to receive an object in %s sec' % 300)
time.sleep(0.2)
if shared.objects: # got some objects
break
else:
self.fail('Failed to establish a proper connection')
@unittest.skip('Not ready yet')
def test_push(self):
"""Make and push useless objects"""
self._make_initial_nodes()
nodes = shared.core_nodes.union(shared.unchecked_node_pool)
# logging.info('Nodes: %s', nodes)
sent_to_nodes = set()
for node in random.sample(nodes, 5):
c = PushConnection(*node)
c.start()
c.join()
sent_to_nodes.add(node)
if shared.objects:
break
else:
return
nodes.difference_update(sent_to_nodes)
self.assertGreaterEqual(len(nodes), 1)
objects = shared.objects.copy()
nodes = list(nodes)
random.shuffle(nodes)
# logging.info('Nodes: %s', nodes)
for node in nodes:
c = SearchConnection(*node)
c.start()
c.join()
if len(shared.objects) == len(objects):
break
else:
logging.info('Got %s objects', len(shared.objects))
shared.objects = objects.copy()
else:
return
self.assertNotEqual(shared.objects, objects)