Compare commits

...

7 Commits

3 changed files with 318 additions and 78 deletions

View File

@ -15,8 +15,11 @@ import time
from . import message, shared, structure
class Connection(threading.Thread):
"""The connection object"""
class ConnectionBase(threading.Thread):
"""
Common code for the connection thread
with minimum command handlers to reuse
"""
def __init__(
self, host, port, s=None, network='ip', server=False,
i2p_remote_dest=b''
@ -334,87 +337,13 @@ class Connection(threading.Thread):
break
def _process_message(self, m):
if m.command == b'version':
version = message.Version.from_bytes(m.to_bytes())
logging.debug('%s:%s -> %s', self.host_print, self.port, version)
if (
version.protocol_version != shared.protocol_version
or version.nonce == shared.nonce
):
self.status = 'disconnecting'
self.send_queue.put(None)
else:
logging.info(
'%s:%s claims to be %s',
self.host_print, self.port, version.user_agent)
self.send_queue.put(message.Message(b'verack', b''))
self.verack_sent = True
self.remote_version = version
if not self.server:
self.send_queue.put('fully_established')
if self.network == 'ip':
shared.address_advertise_queue.put(structure.NetAddr(
version.services, self.host, self.port))
shared.node_pool.add((self.host, self.port))
elif self.network == 'i2p':
shared.i2p_node_pool.add((self.host, 'i2p'))
if self.network == 'ip':
shared.address_advertise_queue.put(structure.NetAddr(
shared.services, version.host, shared.listening_port))
if self.server:
if self.network == 'ip':
self.send_queue.put(
message.Version(self.host, self.port))
else:
self.send_queue.put(message.Version('127.0.0.1', 7656))
elif m.command == b'verack':
if m.command == b'verack':
self.verack_received = True
logging.debug(
'%s:%s -> %s', self.host_print, self.port, 'verack')
if self.server:
self.send_queue.put('fully_established')
elif m.command == b'inv':
inv = message.Inv.from_message(m)
logging.debug('%s:%s -> %s', self.host_print, self.port, inv)
to_get = inv.vectors.copy()
to_get.difference_update(shared.objects.keys())
self.vectors_to_get.update(to_get)
# Do not send objects they already have.
self.vectors_to_send.difference_update(inv.vectors)
elif m.command == b'object':
obj = structure.Object.from_message(m)
logging.debug('%s:%s -> %s', self.host_print, self.port, obj)
self.vectors_requested.pop(obj.vector, None)
self.vectors_to_get.discard(obj.vector)
if obj.is_valid() and obj.vector not in shared.objects:
with shared.objects_lock:
shared.objects[obj.vector] = obj
if (
obj.object_type == shared.i2p_dest_obj_type
and obj.version == shared.i2p_dest_obj_version
):
dest = base64.b64encode(obj.object_payload, altchars=b'-~')
logging.debug(
'Received I2P destination object,'
' adding to i2p_unchecked_node_pool')
logging.debug(dest)
shared.i2p_unchecked_node_pool.add((dest, 'i2p'))
shared.vector_advertise_queue.put(obj.vector)
elif m.command == b'getdata':
getdata = message.GetData.from_message(m)
logging.debug('%s:%s -> %s', self.host_print, self.port, getdata)
self.vectors_to_send.update(getdata.vectors)
elif m.command == b'addr':
addr = message.Addr.from_message(m)
logging.debug('%s:%s -> %s', self.host_print, self.port, addr)
for a in addr.addresses:
shared.unchecked_node_pool.add((a.host, a.port))
elif m.command == b'ping':
logging.debug('%s:%s -> ping', self.host_print, self.port)
self.send_queue.put(message.Message(b'pong', b''))
@ -424,7 +353,49 @@ class Connection(threading.Thread):
'%s:%s -> error: %s', self.host_print, self.port, m.payload)
else:
logging.debug('%s:%s -> %s', self.host_print, self.port, m)
try:
getattr(self, '_process_msg_{}'.format(m.command.decode()))(m)
except (AttributeError, UnicodeDecodeError):
logging.debug('%s:%s -> %s', self.host_print, self.port, m)
def _process_msg_version(self, m):
version = message.Version.from_bytes(m.to_bytes())
logging.debug('%s:%s -> %s', self.host_print, self.port, version)
if (
version.protocol_version != shared.protocol_version
or version.nonce == shared.nonce
):
self.status = 'disconnecting'
self.send_queue.put(None)
else:
logging.info(
'%s:%s claims to be %s',
self.host_print, self.port, version.user_agent)
self.send_queue.put(message.Message(b'verack', b''))
self.verack_sent = True
self.remote_version = version
if not self.server:
self.send_queue.put('fully_established')
if self.network == 'ip':
shared.address_advertise_queue.put(structure.NetAddr(
version.services, self.host, self.port))
shared.node_pool.add((self.host, self.port))
elif self.network == 'i2p':
shared.i2p_node_pool.add((self.host, 'i2p'))
if self.network == 'ip':
shared.address_advertise_queue.put(structure.NetAddr(
shared.services, version.host, shared.listening_port))
if self.server:
if self.network == 'ip':
self.send_queue.put(message.Version(self.host, self.port))
else:
self.send_queue.put(message.Version('127.0.0.1', 7656))
def _process_msg_addr(self, m):
addr = message.Addr.from_message(m)
logging.debug('%s:%s -> %s', self.host_print, self.port, addr)
for a in addr.addresses:
shared.unchecked_node_pool.add((a.host, a.port))
def _request_objects(self):
if self.vectors_to_get and len(self.vectors_requested) < 100:
@ -483,4 +454,61 @@ class Connection(threading.Thread):
message.Message(b'object', obj.to_bytes()))
class Connection(ConnectionBase):
"""The connection with all commands implementation"""
def _process_msg_inv(self, m):
inv = message.Inv.from_message(m)
logging.debug('%s:%s -> %s', self.host_print, self.port, inv)
to_get = inv.vectors.copy()
to_get.difference_update(shared.objects.keys())
self.vectors_to_get.update(to_get)
# Do not send objects they already have.
self.vectors_to_send.difference_update(inv.vectors)
def _process_msg_object(self, m):
obj = structure.Object.from_message(m)
logging.debug('%s:%s -> %s', self.host_print, self.port, obj)
self.vectors_requested.pop(obj.vector, None)
self.vectors_to_get.discard(obj.vector)
if obj.is_valid() and obj.vector not in shared.objects:
with shared.objects_lock:
shared.objects[obj.vector] = obj
if (
obj.object_type == shared.i2p_dest_obj_type
and obj.version == shared.i2p_dest_obj_version
):
dest = base64.b64encode(obj.object_payload, altchars=b'-~')
logging.debug(
'Received I2P destination object,'
' adding to i2p_unchecked_node_pool')
logging.debug(dest)
shared.i2p_unchecked_node_pool.add((dest, 'i2p'))
shared.vector_advertise_queue.put(obj.vector)
def _process_msg_getdata(self, m):
getdata = message.GetData.from_message(m)
logging.debug('%s:%s -> %s', self.host_print, self.port, getdata)
self.vectors_to_send.update(getdata.vectors)
class TrackingConnection(Connection):
"""A helper to track number of servers in the network"""
def __init__(self, *args):
super().__init__(*args)
self.connected_time = None
self.active = False
def _on_connection_fully_established(self):
self.connected_time = time.time()
super()._on_connection_fully_established()
def _process_message(self, m):
super()._process_message(m)
if not self.active and m.command in (b'object', b'getdata'):
self.active = True
# def run(self):
# super().run()
shared.connection = Connection

View File

@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
"""The main thread, managing connections, nodes and objects"""
import base64
import csv
import logging
import os
import pickle
@ -26,6 +27,10 @@ class Manager(threading.Thread):
# Publish destination 5-15 minutes after start
self.last_published_i2p_destination = \
time.time() - 50 * 60 + random.uniform(-1, 1) * 300 # nosec
self.last_dumped_counts = time.time()
self.connection_counts = []
self.connections_avg = 0
self.connections_max = 0
def run(self):
self.clean_objects()
@ -40,6 +45,7 @@ class Manager(threading.Thread):
self.last_cleaned_objects = now
if now - self.last_cleaned_connections > 2:
self.manage_connections()
self.check_connections_count()
self.last_cleaned_connections = now
if now - self.last_pickled_objects > 100:
self.pickle_objects()
@ -47,9 +53,50 @@ class Manager(threading.Thread):
if now - self.last_pickled_nodes > 60:
self.pickle_nodes()
self.last_pickled_nodes = now
if now - self.last_dumped_counts > 300:
self.dump_counts()
self.last_dumped_counts = now
if now - self.last_published_i2p_destination > 3600:
self.publish_i2p_destination()
self.last_published_i2p_destination = now
logging.info(
'Connections number max: %s, average: %.2f',
self.connections_max, self.connections_avg)
def dump_counts(self):
"""Write connection counts into a csv file"""
total = 0
with open(
os.path.join(shared.data_directory, 'connection_counts.csv'),
'a', newline='', encoding='ascii'
) as dst:
writer = csv.writer(dst)
for timestamp, number in self.connection_counts:
total += number
if number > self.connections_max:
self.connections_max = number
writer.writerow((time.ctime(timestamp), number))
avg = total / len(self.connection_counts)
self.connections_avg = (
(self.connections_avg + avg) / 2 if self.connections_avg else avg)
self.connection_counts = self.connection_counts[-1:]
def check_connections_count(self):
"""Make a record for changed connections count"""
try:
previous_count = self.connection_counts[-1]
except IndexError:
previous_count = 0
count = sum((
1 for c in shared.connections.copy()
if c.status == 'fully_established'))
if count != previous_count:
self.connection_counts.append((time.time(), count))
if count == shared.outgoing_connections:
logging.info('Max outgoing connections reached!')
@staticmethod
def clean_objects():

View File

@ -0,0 +1,165 @@
"""Tests for network connections"""
import logging
import os
import random
import unittest
import tempfile
import time
from minode import connection, main, message, proofofwork, shared, structure
logging.basicConfig(
level=logging.INFO,
format='[%(asctime)s] [%(levelname)s] %(message)s')
class PushConnection(connection.ConnectionBase):
"""A connection pushing useless objects"""
def __init__(self, *args):
super().__init__(*args)
self.objects_setup = 0
self.started = time.time()
def _request_objects(self):
"""Make objects to send"""
while not shared.vector_advertise_queue.empty():
vector = shared.vector_advertise_queue.get()
self.vectors_to_send.add(vector)
if self.objects_setup < 10:
if self.objects_setup == 0:
self.started = time.time()
obj = structure.Object(
b'\x00' * 8, int(time.time() + 1800), 42, 1, 1,
b'HELLO %i' % self.objects_setup)
proofofwork.do_pow_and_publish(obj)
self.objects_setup += 1
elif (
len(shared.objects) == 10 and len(self.vectors_to_send) == 0
or time.time() - self.started > 120
):
self.status = 'disconnecting'
class SearchConnection(connection.Connection):
"""A connection searching for objects"""
def __init__(self, *args):
super().__init__(*args)
self.started = time.time()
self.vectors_to_search = {v for v in shared.objects.keys()}
shared.objects.clear()
logging.info(
'Have %s objects to search for', len(self.vectors_to_search))
def _process_msg_inv(self, m):
inv = message.Inv.from_message(m)
to_get = inv.vectors.intersection(self.vectors_to_search)
self.vectors_to_get.update(to_get)
self.vectors_to_search.difference_update(to_get)
logging.info('Found %s vectors', len(to_get))
def _send_objects(self):
if (
len(self.vectors_to_search) == len(self.vectors_to_get) == 0
or time.time() - self.started > 300
):
self.status = 'disconnecting'
logging.info('disconnecting!')
class TestNetwork(unittest.TestCase):
"""Test case starting connections"""
@classmethod
def setUpClass(cls):
shared.data_directory = tempfile.gettempdir()
def setUp(self):
shared.core_nodes.clear()
shared.unchecked_node_pool.clear()
shared.objects = {}
try:
os.remove(os.path.join(shared.data_directory, 'objects.pickle'))
except FileNotFoundError:
pass
def _make_initial_nodes(self):
main.load_data()
self.assertGreaterEqual(len(shared.core_nodes), 3)
main.bootstrap_from_dns()
self.assertGreaterEqual(len(shared.unchecked_node_pool), 3)
def test_connection(self):
"""Check a normal connection - should receive objects"""
self._make_initial_nodes()
started = time.time()
nodes = list(shared.core_nodes.union(shared.unchecked_node_pool))
random.shuffle(nodes)
for node in nodes:
# unknown = node not in shared.node_pool
# self.assertTrue(unknown)
unknown = True
shared.node_pool.discard(node)
c = connection.Connection(*node)
c.start()
connection_started = time.time()
while c.status not in ('disconnecting', 'disconnected', 'failed'):
# The addr of established connection is added to nodes pool
if unknown and c.status == 'fully_established':
unknown = False
self.assertIn(node, shared.node_pool)
if shared.objects or time.time() - connection_started > 90:
c.status = 'disconnecting'
if time.time() - started > 300:
c.status = 'disconnecting'
self.fail('Failed to receive an object in %s sec' % 300)
time.sleep(0.2)
if shared.objects: # got some objects
break
else:
self.fail('Failed to establish a proper connection')
@unittest.skip('Not ready yet')
def test_push(self):
"""Make and push useless objects"""
self._make_initial_nodes()
nodes = shared.core_nodes.union(shared.unchecked_node_pool)
# logging.info('Nodes: %s', nodes)
sent_to_nodes = set()
for node in random.sample(nodes, 5):
c = PushConnection(*node)
c.start()
c.join()
sent_to_nodes.add(node)
if shared.objects:
break
else:
return
nodes.difference_update(sent_to_nodes)
self.assertGreaterEqual(len(nodes), 1)
objects = shared.objects.copy()
nodes = list(nodes)
random.shuffle(nodes)
# logging.info('Nodes: %s', nodes)
for node in nodes:
c = SearchConnection(*node)
c.start()
c.join()
if len(shared.objects) == len(objects):
break
else:
logging.info('Got %s objects', len(shared.objects))
shared.objects = objects.copy()
else:
return
self.assertNotEqual(shared.objects, objects)