diff --git a/minode/connection.py b/minode/connection.py index 05d552a..fba349c 100644 --- a/minode/connection.py +++ b/minode/connection.py @@ -277,22 +277,15 @@ class ConnectionBase(threading.Thread): if len(addr) != 0: self.send_queue.put(message.Addr(addr)) - with shared.objects_lock: - if len(shared.objects) > 0: - to_send = { - vector for vector in shared.objects.keys() - if shared.objects[vector].expires_time > time.time()} - while len(to_send) > 0: - if len(to_send) > 10000: - # We limit size of inv messaged to 10000 entries - # because they might time out - # in very slow networks (I2P) - pack = random.sample(tuple(to_send), 10000) - self.send_queue.put(message.Inv(pack)) - to_send.difference_update(pack) - else: - self.send_queue.put(message.Inv(to_send)) - to_send.clear() + if shared.objects: + to_send = shared.objects.vectors_to_send() + offset = 0 + while offset < len(to_send): + # We limit size of inv messaged to 10000 entries + # because they might time out + # in very slow networks (I2P) + self.send_queue.put(message.Inv(to_send[offset:offset+10000])) + offset += 10000 self.status = 'fully_established' def _process_queue(self): @@ -463,12 +456,10 @@ class ConnectionBase(threading.Thread): else: to_send = self.vectors_to_send.copy() self.vectors_to_send.clear() - with shared.objects_lock: - for vector in to_send: - obj = shared.objects.get(vector, None) - if obj: - self.send_queue.put( - message.Message(b'object', obj.to_bytes())) + for vector in to_send: + obj = shared.objects.get(vector) + if obj: + self.send_queue.put(message.Message(b'object', obj.data)) class Connection(ConnectionBase): @@ -488,8 +479,7 @@ class Connection(ConnectionBase): self.vectors_requested.pop(obj.vector, None) self.vectors_to_get.discard(obj.vector) if obj.is_valid() and obj.vector not in shared.objects: - with shared.objects_lock: - shared.objects[obj.vector] = obj + shared.objects[obj.vector] = obj if ( obj.object_type == shared.i2p_dest_obj_type and obj.version == shared.i2p_dest_obj_version diff --git a/minode/main.py b/minode/main.py index 72cebe0..868c3f0 100644 --- a/minode/main.py +++ b/minode/main.py @@ -8,7 +8,7 @@ import os import signal import socket -from . import i2p, shared +from . import i2p, shared, sql from .advertiser import Advertiser from .manager import Manager from .listener import Listener @@ -161,10 +161,9 @@ def start_ip_listener(): def start_i2p_listener(): """Starts I2P threads""" # Grab I2P destinations from old object file - for obj in shared.objects.values(): - if obj.object_type == shared.i2p_dest_obj_type: - shared.i2p_unchecked_node_pool.add(( - base64.b64encode(obj.object_payload, altchars=b'-~'), 'i2p')) + for obj in shared.objects.filter(object_type=shared.i2p_dest_obj_type): + shared.i2p_unchecked_node_pool.add(( + base64.b64encode(obj.object_payload, altchars=b'-~'), 'i2p')) dest_priv = b'' @@ -245,6 +244,8 @@ def main(): if shared.ip_enabled and not shared.trusted_peer: bootstrap_from_dns() + shared.objects = sql.Inventory() + if shared.i2p_enabled: # We are starting it before cleaning expired objects # so we can collect I2P destination objects diff --git a/minode/manager.py b/minode/manager.py index aa8bbf9..9f31d2e 100644 --- a/minode/manager.py +++ b/minode/manager.py @@ -23,7 +23,6 @@ class Manager(threading.Thread): self.bootstrap_pool = [] self.last_cleaned_objects = time.time() self.last_cleaned_connections = time.time() - self.last_pickled_objects = time.time() self.last_pickled_nodes = time.time() # Publish destination 5-15 minutes after start self.last_published_i2p_destination = \ @@ -36,7 +35,7 @@ class Manager(threading.Thread): def run(self): self.load_data() - self.clean_objects() + shared.objects.cleanup() self.fill_bootstrap_pool() while True: time.sleep(0.8) @@ -45,14 +44,11 @@ class Manager(threading.Thread): logging.debug('Shutting down Manager') break if now - self.last_cleaned_objects > 90: - self.clean_objects() + shared.objects.cleanup() self.last_cleaned_objects = now if now - self.last_cleaned_connections > 2: self.manage_connections() self.last_cleaned_connections = now - if now - self.last_pickled_objects > 100: - self.pickle_objects() - self.last_pickled_objects = now if now - self.last_pickled_nodes > 60: self.pickle_nodes() self.last_pickled_nodes = now @@ -60,17 +56,6 @@ class Manager(threading.Thread): self.publish_i2p_destination() self.last_published_i2p_destination = now - @staticmethod - def clean_objects(): - for vector in set(shared.objects): - # FIXME: no need to check is_valid() here - if shared.objects[vector].is_expired(): - logging.debug( - 'Deleted expired object: %s', - base64.b16encode(vector).decode()) - with shared.objects_lock: - del shared.objects[vector] - def manage_connections(self): """Open new connections if needed, remove closed ones""" hosts = set() @@ -177,17 +162,6 @@ class Manager(threading.Thread): @staticmethod def load_data(): """Loads initial nodes and data, stored in files between sessions""" - try: - with open( - os.path.join(shared.data_directory, 'objects.pickle'), 'br' - ) as src: - shared.objects = pickle.load(src) - except FileNotFoundError: - pass # first start - except Exception: - logging.warning( - 'Error while loading objects from disk.', exc_info=True) - try: with open( os.path.join(shared.data_directory, 'nodes.pickle'), 'br' @@ -227,18 +201,6 @@ class Manager(threading.Thread): (row[0].encode(), 'i2p') for row in reader} shared.i2p_node_pool.update(shared.i2p_core_nodes) - @staticmethod - def pickle_objects(): - try: - with open( - os.path.join(shared.data_directory, 'objects.pickle'), 'bw' - ) as dst: - with shared.objects_lock: - pickle.dump(shared.objects, dst, protocol=3) - logging.debug('Saved objects') - except Exception: - logging.warning('Error while saving objects', exc_info=True) - @staticmethod def pickle_nodes(): if len(shared.node_pool) > 10000: @@ -275,7 +237,7 @@ class Manager(threading.Thread): dest_pub_raw = base64.b64decode( shared.i2p_dest_pub, altchars=b'-~') obj = structure.Object( - b'\x00' * 8, int(time.time() + 2 * 3600), + int(time.time() + 2 * 3600), shared.i2p_dest_obj_type, shared.i2p_dest_obj_version, - shared.stream, dest_pub_raw) + shared.stream, object_payload=dest_pub_raw) proofofwork.do_pow_and_publish(obj) diff --git a/minode/proofofwork.py b/minode/proofofwork.py index f8aa69f..812fe5d 100644 --- a/minode/proofofwork.py +++ b/minode/proofofwork.py @@ -39,14 +39,14 @@ def _worker(obj): logging.debug( 'Finished doing POW, nonce: %s, time: %ss', nonce, time.time() - t) obj = structure.Object( - nonce, obj.expires_time, obj.object_type, obj.version, - obj.stream_number, obj.object_payload) + obj.expires_time, obj.object_type, obj.version, obj.stream_number, + object_payload=obj.object_payload, nonce=nonce + ) logging.debug( 'Object vector is %s', base64.b16encode(obj.vector).decode()) - with shared.objects_lock: - shared.objects[obj.vector] = obj - shared.vector_advertise_queue.put(obj.vector) + shared.objects[obj.vector] = obj + shared.vector_advertise_queue.put(obj.vector) def do_pow_and_publish(obj): diff --git a/minode/sql.py b/minode/sql.py new file mode 100644 index 0000000..4ccb6d9 --- /dev/null +++ b/minode/sql.py @@ -0,0 +1,153 @@ +"""Inventory implementation using sqlite""" + +import os +import sqlite3 +import time + +from . import shared, structure + +sqlite3.threadsafety = 3 + + +class Inventory(): + """sqlite inventory""" + def __init__(self): + self._db = sqlite3.connect( + os.path.join(shared.data_directory, 'objects.dat'), + check_same_thread=False + ) + self._db.executescript(""" + BEGIN; + CREATE TABLE IF NOT EXISTS status + (key text, value integer, UNIQUE(key) ON CONFLICT REPLACE); + INSERT INTO status VALUES ('version', 1); + CREATE TABLE IF NOT EXISTS objects + (vector unique, expires integer, type integer, version integer, + stream integer, tag, data, offset integer); + COMMIT; + """) + self.rowid = len(self) or None + cur = self._db.cursor() + cur.execute("SELECT value FROM status WHERE key='lastvacuumtime'") + now = int(time.time()) + try: + vacuumed = cur.fetchone()[0] + except TypeError: + pass + else: + if vacuumed < now - 86400: # 24 hours + cur.execute('VACUUM') + cur.execute("INSERT INTO status VALUES ('lastvacuumtime', ?)", (now,)) + self._db.commit() + + def __objects(self, cur): + return ( + structure.Object( + expires, obj_type, version, stream, data, offset, + tag=tag, vector=vector) + for (vector, expires, obj_type, version, stream, tag, data, offset) + in cur.fetchall() + ) + + def cleanup(self): + """Remove expired objects""" + with shared.objects_lock: + self._db.execute( + 'DELETE FROM objects WHERE expires < ?', + (int(time.time()) - 3 * 3600,) + ) + self._db.commit() + # conditional vacuum and validity check + + def filter(self, stream=None, object_type=None, tag=None): + """Generator of objects with the given parameters""" + clauses = [] + if stream: + clauses.append(('stream = ?', stream)) + if object_type is not None: + clauses.append(('type = ?', object_type)) + if tag: + clauses.append(('tag = ?', tag)) + + clauses, params = zip(*clauses) + + cur = self._db.execute( + 'SELECT * FROM objects WHERE ' # nosec B608 + + ' AND '.join(clauses), params) + return self.__objects(cur) + + def vectors_to_send(self, stream=None): + cur = self._db.execute( + 'SELECT vector FROM objects WHERE expires > ? AND stream = ?' + ' ORDER BY random()', + (int(time.time()), stream or shared.stream) + ) + return [v for v, in cur.fetchall()] + + def get(self, vector, default=None): + try: + return self[vector] + except KeyError: + return default + + def keys(self): + cur = self._db.execute('SELECT vector FROM objects') + return (v for v, in cur.fetchall()) + + def values(self): + cur = self._db.execute('SELECT * FROM objects') + return self.__objects(cur) + + def popitem(self): + if not self.rowid: + raise KeyError('empty') + cur = self._db.execute( + 'SELECT vector FROM objects WHERE ROWID = ?', (self.rowid,)) + vector = cur.fetchone()[0] + obj = self.get(vector) + del self[vector] + return (vector, obj) + + def __contains__(self, vector): + cur = self._db.execute( + 'SELECT vector FROM objects WHERE vector = ?', (vector,)) + return cur.fetchone() + + def __getitem__(self, vector): + cur = self._db.execute( + 'SELECT * FROM objects WHERE vector = ?', (vector,)) + item = cur.fetchone() + if item is None: + raise KeyError(vector) + vector, expires, obj_type, version, stream, tag, data, offset = item + return structure.Object( + expires, obj_type, version, stream, data, offset, + tag=tag, vector=vector + ) + + def __delitem__(self, vector): + with shared.objects_lock: # KeyError + self._db.execute('DELETE FROM objects WHERE vector = ?', (vector,)) + self._db.commit() + self.rowid = len(self) + + def __setitem__(self, vector, obj): + with shared.objects_lock: + cur = self._db.execute( + 'INSERT INTO objects VALUES (?,?,?,?,?,?,?,?)', ( + vector, obj.expires_time, obj.object_type, obj.version, + obj.stream_number, obj.tag, obj.data, obj.offset + )) + self._db.commit() + self.rowid = cur.lastrowid + + def __bool__(self): + return self._db.execute( + 'SELECT vector from objects LIMIT 1').fetchone() is not None + + def __len__(self): + cur = self._db.execute('SELECT count(*) FROM objects') + return cur.fetchone()[0] + + def __del__(self): + self._db.close() diff --git a/minode/structure.py b/minode/structure.py index 405da1a..fcd6c74 100644 --- a/minode/structure.py +++ b/minode/structure.py @@ -48,21 +48,26 @@ class VarInt(): class Object(): """The 'object' message payload""" def __init__( - self, nonce, expires_time, object_type, version, - stream_number, object_payload + self, expires_time, object_type, version, stream_number, + data=None, offset=None, object_payload=None, + tag=None, nonce=b'\x00' * 8, vector=None ): self.nonce = nonce self.expires_time = expires_time self.object_type = object_type self.version = version self.stream_number = stream_number - self.object_payload = object_payload - self.vector = hashlib.sha512(hashlib.sha512( - self.to_bytes()).digest()).digest()[:32] + if not data: + data, offset = self.to_bytes(object_payload) + self.data = data + self.offset = offset + self.vector = vector or hashlib.sha512(hashlib.sha512( + self.data).digest()).digest()[:32] - self.tag = ( + self.tag = tag or ( # broadcast from version 5 and pubkey/getpukey from version 4 - self.object_payload[:32] if object_type == 3 and version == 5 + (object_payload or self.object_payload)[:32] + if object_type == 3 and version == 5 or (object_type in (0, 1) and version == 4) else None) @@ -73,20 +78,25 @@ class Object(): @classmethod def from_message(cls, m): """Decode message payload""" - payload = m.payload - nonce, expires_time, object_type = struct.unpack('>8sQL', payload[:20]) - payload = payload[20:] - version_varint_length = VarInt.length(payload[0]) - version = VarInt.from_bytes(payload[:version_varint_length]).n - payload = payload[version_varint_length:] - stream_number_varint_length = VarInt.length(payload[0]) + data = m.payload + nonce, expires_time, object_type = struct.unpack('>8sQL', data[:20]) + version_varint_length = VarInt.length(data[20]) + offset = 20 + version_varint_length + version = VarInt.from_bytes(data[20:offset]).n + stream_number_varint_length = VarInt.length(data[offset]) stream_number = VarInt.from_bytes( - payload[:stream_number_varint_length]).n - payload = payload[stream_number_varint_length:] + data[offset:offset+stream_number_varint_length]).n + offset += stream_number_varint_length return cls( - nonce, expires_time, object_type, version, stream_number, payload) + expires_time, object_type, version, stream_number, + data, offset, nonce=nonce + ) - def to_bytes(self): + @property + def object_payload(self): + return self.data[self.offset:] + + def to_bytes(self, object_payload): """Serialize to bytes""" payload = b'' payload += self.nonce @@ -94,8 +104,7 @@ class Object(): payload += ( VarInt(self.version).to_bytes() + VarInt(self.stream_number).to_bytes()) - payload += self.object_payload - return payload + return payload + object_payload, len(payload) def is_expired(self): """Check if object's TTL is expired""" @@ -138,7 +147,7 @@ class Object(): def pow_target(self): """Compute PoW target""" - data = self.to_bytes()[8:] + data = self.data[8:] length = len(data) + 8 + shared.payload_length_extra_bytes dt = max(self.expires_time - time.time(), 0) return int( @@ -148,7 +157,7 @@ class Object(): def pow_initial_hash(self): """Compute the initial hash for PoW""" - return hashlib.sha512(self.to_bytes()[8:]).digest() + return hashlib.sha512(self.data[8:]).digest() class NetAddrNoPrefix(): diff --git a/minode/tests/test_network.py b/minode/tests/test_network.py index c683637..d588022 100644 --- a/minode/tests/test_network.py +++ b/minode/tests/test_network.py @@ -8,7 +8,7 @@ import tempfile import time from contextlib import contextmanager -from minode import connection, main, shared +from minode import connection, main, shared, sql from minode.listener import Listener from minode.manager import Manager @@ -66,11 +66,11 @@ class TestNetwork(unittest.TestCase): def setUp(self): shared.core_nodes.clear() shared.unchecked_node_pool.clear() - shared.objects = {} try: - os.remove(os.path.join(shared.data_directory, 'objects.pickle')) + os.remove(os.path.join(shared.data_directory, 'objects.dat')) except FileNotFoundError: pass + shared.objects = sql.Inventory() def _make_initial_nodes(self): Manager.load_data() diff --git a/minode/tests/test_structure.py b/minode/tests/test_structure.py index 970c152..c45d2d4 100644 --- a/minode/tests/test_structure.py +++ b/minode/tests/test_structure.py @@ -135,13 +135,13 @@ class TestStructure(unittest.TestCase): self.assertEqual(obj.object_payload, b'HELLO') obj = structure.Object( - b'\x00' * 8, int(time.time() + 3000000), 42, 1, 1, b'HELLO') + int(time.time() + 3000000), 42, 1, 1, object_payload=b'HELLO') self.assertFalse(obj.is_valid()) obj.expires_time = int(time.time() - 11000) self.assertFalse(obj.is_valid()) obj = structure.Object( - b'\x00' * 8, int(time.time() + 300), 42, 1, 2, b'HELLO') + int(time.time() + 300), 42, 1, 2, object_payload=b'HELLO') vector = obj.vector proofofwork._worker(obj) # pylint: disable=protected-access obj = shared.objects.popitem()[1] @@ -159,8 +159,8 @@ class TestStructure(unittest.TestCase): """Check the main proofofwork call and worker""" shared.vector_advertise_queue = queue.Queue() obj = structure.Object( - b'\x00' * 8, int(time.time() + 300), 42, 1, - shared.stream, b'HELLO') + int(time.time() + 300), 42, 1, shared.stream, + object_payload=b'HELLO') start_time = time.time() proofofwork.do_pow_and_publish(obj) try: @@ -189,6 +189,6 @@ class TestStructure(unittest.TestCase): self.fail("No nonce found in the queue") obj = structure.Object( - nonce, obj.expires_time, obj.object_type, obj.version, - obj.stream_number, obj.object_payload) + obj.expires_time, obj.object_type, obj.version, obj.stream_number, + object_payload=obj.object_payload, nonce=nonce) self.assertTrue(obj.is_valid())