166 lines
5.3 KiB
Python
166 lines
5.3 KiB
Python
"""Tests for network connections"""
|
|
import logging
|
|
import os
|
|
import random
|
|
import unittest
|
|
import tempfile
|
|
import time
|
|
|
|
from minode import connection, main, message, proofofwork, shared, structure
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='[%(asctime)s] [%(levelname)s] %(message)s')
|
|
|
|
|
|
class PushConnection(connection.ConnectionBase):
|
|
"""A connection pushing useless objects"""
|
|
|
|
def __init__(self, *args):
|
|
super().__init__(*args)
|
|
self.objects_setup = 0
|
|
self.started = time.time()
|
|
|
|
def _request_objects(self):
|
|
"""Make objects to send"""
|
|
while not shared.vector_advertise_queue.empty():
|
|
vector = shared.vector_advertise_queue.get()
|
|
self.vectors_to_send.add(vector)
|
|
|
|
if self.objects_setup < 10:
|
|
if self.objects_setup == 0:
|
|
self.started = time.time()
|
|
obj = structure.Object(
|
|
b'\x00' * 8, int(time.time() + 1800), 42, 1, 1,
|
|
b'HELLO %i' % self.objects_setup)
|
|
proofofwork.do_pow_and_publish(obj)
|
|
self.objects_setup += 1
|
|
elif (
|
|
len(shared.objects) == 10 and len(self.vectors_to_send) == 0
|
|
or time.time() - self.started > 120
|
|
):
|
|
self.status = 'disconnecting'
|
|
|
|
|
|
class SearchConnection(connection.Connection):
|
|
"""A connection searching for objects"""
|
|
|
|
def __init__(self, *args):
|
|
super().__init__(*args)
|
|
self.started = time.time()
|
|
self.vectors_to_search = {v for v in shared.objects.keys()}
|
|
shared.objects.clear()
|
|
logging.info(
|
|
'Have %s objects to search for', len(self.vectors_to_search))
|
|
|
|
def _process_msg_inv(self, m):
|
|
inv = message.Inv.from_message(m)
|
|
to_get = inv.vectors.intersection(self.vectors_to_search)
|
|
self.vectors_to_get.update(to_get)
|
|
self.vectors_to_search.difference_update(to_get)
|
|
logging.info('Found %s vectors', len(to_get))
|
|
|
|
def _send_objects(self):
|
|
if (
|
|
len(self.vectors_to_search) == len(self.vectors_to_get) == 0
|
|
or time.time() - self.started > 300
|
|
):
|
|
self.status = 'disconnecting'
|
|
logging.info('disconnecting!')
|
|
|
|
|
|
class TestNetwork(unittest.TestCase):
|
|
"""Test case starting connections"""
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
shared.data_directory = tempfile.gettempdir()
|
|
|
|
def setUp(self):
|
|
shared.core_nodes.clear()
|
|
shared.unchecked_node_pool.clear()
|
|
shared.objects = {}
|
|
try:
|
|
os.remove(os.path.join(shared.data_directory, 'objects.pickle'))
|
|
except FileNotFoundError:
|
|
pass
|
|
|
|
def _make_initial_nodes(self):
|
|
main.load_data()
|
|
self.assertGreaterEqual(len(shared.core_nodes), 3)
|
|
|
|
main.bootstrap_from_dns()
|
|
self.assertGreaterEqual(len(shared.unchecked_node_pool), 3)
|
|
|
|
def test_connection(self):
|
|
"""Check a normal connection - should receive objects"""
|
|
self._make_initial_nodes()
|
|
|
|
started = time.time()
|
|
nodes = list(shared.core_nodes.union(shared.unchecked_node_pool))
|
|
random.shuffle(nodes)
|
|
|
|
for node in nodes:
|
|
# unknown = node not in shared.node_pool
|
|
# self.assertTrue(unknown)
|
|
unknown = True
|
|
shared.node_pool.discard(node)
|
|
|
|
c = connection.Connection(*node)
|
|
c.start()
|
|
connection_started = time.time()
|
|
while c.status not in ('disconnecting', 'disconnected', 'failed'):
|
|
# The addr of established connection is added to nodes pool
|
|
if unknown and c.status == 'fully_established':
|
|
unknown = False
|
|
self.assertIn(node, shared.node_pool)
|
|
if shared.objects or time.time() - connection_started > 90:
|
|
c.status = 'disconnecting'
|
|
if time.time() - started > 300:
|
|
c.status = 'disconnecting'
|
|
self.fail('Failed to receive an object in %s sec' % 300)
|
|
time.sleep(0.2)
|
|
if shared.objects: # got some objects
|
|
break
|
|
else:
|
|
self.fail('Failed to establish a proper connection')
|
|
|
|
@unittest.skip('Not ready yet')
|
|
def test_push(self):
|
|
"""Make and push useless objects"""
|
|
self._make_initial_nodes()
|
|
|
|
nodes = shared.core_nodes.union(shared.unchecked_node_pool)
|
|
# logging.info('Nodes: %s', nodes)
|
|
sent_to_nodes = set()
|
|
for node in random.sample(nodes, 5):
|
|
c = PushConnection(*node)
|
|
c.start()
|
|
c.join()
|
|
sent_to_nodes.add(node)
|
|
if shared.objects:
|
|
break
|
|
else:
|
|
return
|
|
|
|
nodes.difference_update(sent_to_nodes)
|
|
self.assertGreaterEqual(len(nodes), 1)
|
|
objects = shared.objects.copy()
|
|
nodes = list(nodes)
|
|
random.shuffle(nodes)
|
|
# logging.info('Nodes: %s', nodes)
|
|
|
|
for node in nodes:
|
|
c = SearchConnection(*node)
|
|
c.start()
|
|
c.join()
|
|
if len(shared.objects) == len(objects):
|
|
break
|
|
else:
|
|
logging.info('Got %s objects', len(shared.objects))
|
|
shared.objects = objects.copy()
|
|
else:
|
|
return
|
|
|
|
self.assertNotEqual(shared.objects, objects)
|