diff --git a/minode/connection.py b/minode/connection.py index b7f9a75..5d35017 100644 --- a/minode/connection.py +++ b/minode/connection.py @@ -277,22 +277,15 @@ class ConnectionBase(threading.Thread): if len(addr) != 0: self.send_queue.put(message.Addr(addr)) - with shared.objects_lock: - if len(shared.objects) > 0: - to_send = { - vector for vector in shared.objects.keys() - if shared.objects[vector].expires_time > time.time()} - while len(to_send) > 0: - if len(to_send) > 10000: - # We limit size of inv messaged to 10000 entries - # because they might time out - # in very slow networks (I2P) - pack = random.sample(tuple(to_send), 10000) - self.send_queue.put(message.Inv(pack)) - to_send.difference_update(pack) - else: - self.send_queue.put(message.Inv(to_send)) - to_send.clear() + if shared.objects: + to_send = shared.objects.vectors_to_send() + offset = 0 + while offset < len(to_send): + # We limit size of inv messaged to 10000 entries + # because they might time out + # in very slow networks (I2P) + self.send_queue.put(message.Inv(to_send[offset:offset+10000])) + offset += 10000 self.status = 'fully_established' def _process_queue(self): @@ -415,7 +408,7 @@ class ConnectionBase(threading.Thread): def _request_objects(self): if self.vectors_to_get and len(self.vectors_requested) < 100: - self.vectors_to_get.difference_update(shared.objects.keys()) + self.vectors_to_get = shared.objects.select(self.vectors_to_get) if not self.wait_until: nodes_count = ( len(shared.node_pool) + len(shared.unchecked_node_pool)) @@ -462,12 +455,10 @@ class ConnectionBase(threading.Thread): else: to_send = self.vectors_to_send.copy() self.vectors_to_send.clear() - with shared.objects_lock: - for vector in to_send: - obj = shared.objects.get(vector, None) - if obj: - self.send_queue.put( - message.Message(b'object', obj.to_bytes())) + for vector in to_send: + obj = shared.objects.get(vector) + if obj: + self.send_queue.put(message.Message(b'object', obj.data)) class Connection(ConnectionBase): @@ -475,9 +466,7 @@ class Connection(ConnectionBase): def _process_msg_inv(self, m): inv = message.Inv.from_message(m) logging.debug('%s:%s -> %s', self.host_print, self.port, inv) - to_get = inv.vectors.copy() - to_get.difference_update(shared.objects.keys()) - self.vectors_to_get.update(to_get) + self.vectors_to_get.update(shared.objects.select(inv.vectors)) # Do not send objects they already have. self.vectors_to_send.difference_update(inv.vectors) @@ -486,9 +475,8 @@ class Connection(ConnectionBase): logging.debug('%s:%s -> %s', self.host_print, self.port, obj) self.vectors_requested.pop(obj.vector, None) self.vectors_to_get.discard(obj.vector) - if obj.is_valid() and obj.vector not in shared.objects: - with shared.objects_lock: - shared.objects[obj.vector] = obj + if obj.is_valid(): + shared.objects[obj.vector] = obj if ( obj.object_type == shared.i2p_dest_obj_type and obj.version == shared.i2p_dest_obj_version diff --git a/minode/main.py b/minode/main.py index 7b52796..22b517a 100644 --- a/minode/main.py +++ b/minode/main.py @@ -8,7 +8,7 @@ import os import signal import socket -from . import i2p, shared +from . import i2p, shared, sql from .advertiser import Advertiser from .manager import Manager from .listener import Listener @@ -157,10 +157,9 @@ def start_ip_listener(): def start_i2p_listener(): """Starts I2P threads""" # Grab I2P destinations from old object file - for obj in shared.objects.values(): - if obj.object_type == shared.i2p_dest_obj_type: - shared.i2p_unchecked_node_pool.add(( - base64.b64encode(obj.object_payload, altchars=b'-~'), 'i2p')) + for obj in shared.objects.filter(object_type=shared.i2p_dest_obj_type): + shared.i2p_unchecked_node_pool.add(( + base64.b64encode(obj.object_payload, altchars=b'-~'), 'i2p')) dest_priv = b'' @@ -241,6 +240,8 @@ def main(): if shared.ip_enabled and not shared.trusted_peer: bootstrap_from_dns() + shared.objects = sql.Inventory() + if shared.i2p_enabled: # We are starting it before cleaning expired objects # so we can collect I2P destination objects diff --git a/minode/manager.py b/minode/manager.py index a03f9d6..b483129 100644 --- a/minode/manager.py +++ b/minode/manager.py @@ -22,7 +22,7 @@ class Manager(threading.Thread): self.q = queue.Queue() self.last_cleaned_objects = time.time() self.last_cleaned_connections = time.time() - self.last_pickled_objects = time.time() + # self.last_pickled_objects = time.time() self.last_pickled_nodes = time.time() # Publish destination 5-15 minutes after start self.last_published_i2p_destination = \ @@ -30,7 +30,7 @@ class Manager(threading.Thread): def run(self): self.load_data() - self.clean_objects() + shared.objects.cleanup() while True: time.sleep(0.8) now = time.time() @@ -38,14 +38,14 @@ class Manager(threading.Thread): logging.debug('Shutting down Manager') break if now - self.last_cleaned_objects > 90: - self.clean_objects() + shared.objects.cleanup() self.last_cleaned_objects = now if now - self.last_cleaned_connections > 2: self.manage_connections() self.last_cleaned_connections = now - if now - self.last_pickled_objects > 100: - self.pickle_objects() - self.last_pickled_objects = now + # if now - self.last_pickled_objects > 100: + # self.pickle_objects() + # self.last_pickled_objects = now if now - self.last_pickled_nodes > 60: self.pickle_nodes() self.last_pickled_nodes = now @@ -53,21 +53,6 @@ class Manager(threading.Thread): self.publish_i2p_destination() self.last_published_i2p_destination = now - @staticmethod - def clean_objects(): - for vector in set(shared.objects): - if not shared.objects[vector].is_valid(): - if shared.objects[vector].is_expired(): - logging.debug( - 'Deleted expired object: %s', - base64.b16encode(vector).decode()) - else: - logging.warning( - 'Deleted invalid object: %s', - base64.b16encode(vector).decode()) - with shared.objects_lock: - del shared.objects[vector] - @staticmethod def manage_connections(): hosts = set() @@ -152,17 +137,6 @@ class Manager(threading.Thread): @staticmethod def load_data(): """Loads initial nodes and data, stored in files between sessions""" - try: - with open( - os.path.join(shared.data_directory, 'objects.pickle'), 'br' - ) as src: - shared.objects = pickle.load(src) - except FileNotFoundError: - pass # first start - except Exception: - logging.warning( - 'Error while loading objects from disk.', exc_info=True) - try: with open( os.path.join(shared.data_directory, 'nodes.pickle'), 'br' @@ -202,17 +176,17 @@ class Manager(threading.Thread): (row[0].encode(), 'i2p') for row in reader} shared.i2p_node_pool.update(shared.i2p_core_nodes) - @staticmethod - def pickle_objects(): - try: - with open( - os.path.join(shared.data_directory, 'objects.pickle'), 'bw' - ) as dst: - with shared.objects_lock: - pickle.dump(shared.objects, dst, protocol=3) - logging.debug('Saved objects') - except Exception: - logging.warning('Error while saving objects', exc_info=True) + # @staticmethod + # def pickle_objects(): + # try: + # with open( + # os.path.join(shared.data_directory, 'objects.pickle'), 'bw' + # ) as dst: + # with shared.objects_lock: + # pickle.dump(shared.objects, dst, protocol=3) + # logging.debug('Saved objects') + # except Exception: + # logging.warning('Error while saving objects', exc_info=True) @staticmethod def pickle_nodes(): @@ -250,7 +224,7 @@ class Manager(threading.Thread): dest_pub_raw = base64.b64decode( shared.i2p_dest_pub, altchars=b'-~') obj = structure.Object( - b'\x00' * 8, int(time.time() + 2 * 3600), + int(time.time() + 2 * 3600), shared.i2p_dest_obj_type, shared.i2p_dest_obj_version, - shared.stream, dest_pub_raw) + shared.stream, object_payload=dest_pub_raw) proofofwork.do_pow_and_publish(obj) diff --git a/minode/proofofwork.py b/minode/proofofwork.py index f8aa69f..812fe5d 100644 --- a/minode/proofofwork.py +++ b/minode/proofofwork.py @@ -39,14 +39,14 @@ def _worker(obj): logging.debug( 'Finished doing POW, nonce: %s, time: %ss', nonce, time.time() - t) obj = structure.Object( - nonce, obj.expires_time, obj.object_type, obj.version, - obj.stream_number, obj.object_payload) + obj.expires_time, obj.object_type, obj.version, obj.stream_number, + object_payload=obj.object_payload, nonce=nonce + ) logging.debug( 'Object vector is %s', base64.b16encode(obj.vector).decode()) - with shared.objects_lock: - shared.objects[obj.vector] = obj - shared.vector_advertise_queue.put(obj.vector) + shared.objects[obj.vector] = obj + shared.vector_advertise_queue.put(obj.vector) def do_pow_and_publish(obj): diff --git a/minode/shared.py b/minode/shared.py index 72864ec..cc892d4 100644 --- a/minode/shared.py +++ b/minode/shared.py @@ -21,7 +21,7 @@ protocol_version = 3 services = 3 # NODE_NETWORK, NODE_SSL stream = 1 nonce = os.urandom(8) -user_agent = b'/MiNode:0.3.3/' +user_agent = b'/MiNode:0.3.5/' timeout = 600 header_length = 24 i2p_dest_obj_type = 0x493250 @@ -63,4 +63,3 @@ outgoing_connections = 8 connection_limit = 250 objects = {} -objects_lock = threading.Lock() diff --git a/minode/sql.py b/minode/sql.py new file mode 100644 index 0000000..f8fd79d --- /dev/null +++ b/minode/sql.py @@ -0,0 +1,187 @@ +"""Inventory implementation using sqlite""" + +import logging +import os +import sqlite3 +import threading +import time + +from . import shared, structure + +sqlite3.threadsafety = 3 + + +class Inventory(): + """sqlite inventory""" + def __init__(self): + self._lock = threading.Lock() + self._deleted = 0 + self._pending = set() + self._db = sqlite3.connect( + os.path.join(shared.data_directory, 'objects.dat'), + check_same_thread=False + ) + self._db.executescript(""" + BEGIN; + CREATE TABLE IF NOT EXISTS status + (key text, value integer, UNIQUE(key) ON CONFLICT REPLACE); + INSERT INTO status VALUES ('version', 1); + CREATE TABLE IF NOT EXISTS objects + (vector unique, expires integer, type integer, version integer, + stream integer, tag, data, offset integer); + COMMIT; + """) + self.rowid = len(self) or None + cur = self._db.cursor() + cur.execute("SELECT value FROM status WHERE key='lastvacuumtime'") + now = int(time.time()) + try: + vacuumed = cur.fetchone()[0] + except TypeError: + pass + else: + if vacuumed < now - 86400: # 24 hours + cur.execute('VACUUM') + cur.execute("INSERT INTO status VALUES ('lastvacuumtime', ?)", (now,)) + self._db.commit() + self._db.row_factory = self.__object + + @staticmethod + def __object(cursor, row): + if len(cursor.description) != 8: + return row + vector, expires, obj_type, version, stream, tag, data, offset = row + return structure.Object( + expires, obj_type, version, stream, data, offset, + tag=tag, vector=vector) + + def cleanup(self): + if len(self._pending) > 100: + logging.warning( + 'Not cleaning up, %s objects pending', len(self._pending)) + return + with self._lock: + cur = self._db.execute( + 'DELETE FROM objects WHERE expires < ?', + (int(time.time()) - 3 * 3600,) + ) + self._db.commit() + self._deleted += cur.rowcount + # conditional vacuum and validity check + if self._deleted > 10000: + logging.info('Doing VACUUM for objects') + cur.execute('VACUUM') + cur.execute( + "INSERT INTO status VALUES ('lastvacuumtime', ?)", + (int(time.time()),)) + self._db.commit() + self._deleted = 0 + logging.info( + 'Deleted %s expired objects, %s pending', + cur.rowcount, len(self._pending)) + + def filter(self, stream=None, object_type=None, tag=None): + clauses = [] + if stream: + clauses.append(('stream = ?', stream)) + if object_type: + clauses.append(('type = ?', object_type)) + if tag: + clauses.append(('tag = ?', tag)) + + clauses, params = zip(*clauses) + + cur = self._db.execute( + 'SELECT * FROM objects WHERE ' + ' AND '.join(clauses), params) + return cur + + def select(self, vectors): + chunk_size = 999 + keys = tuple(vectors) + with self._lock: + for i in range(0, len(vectors), chunk_size): + chunk = keys[i:i+chunk_size] + cur = self._db.execute( + 'SELECT vector FROM objects WHERE vector IN ({})'.format( + ','.join('?' * len(chunk))), chunk) + for v, in cur: + vectors.remove(v) + self._pending.update(vectors) + return vectors + + def vectors_to_send(self, stream=None): + cur = self._db.execute( + 'SELECT vector FROM objects WHERE expires > ? AND stream = ?' + ' ORDER BY random()', + (int(time.time()), stream or shared.stream) + ) + return [v for v, in cur] + + def get(self, vector, default=None): + try: + return self[vector] + except KeyError: + return default + + def keys(self): + cur = self._db.execute('SELECT vector FROM objects') + return (v for v, in cur) + + def values(self): + return self._db.execute('SELECT * FROM objects') + + def popitem(self): + if not self.rowid: + raise KeyError('empty') + cur = self._db.execute( + 'SELECT vector FROM objects WHERE ROWID = ?', (self.rowid,)) + vector = cur.fetchone()[0] + obj = self.get(vector) + del self[vector] + return (vector, obj) + + def __contains__(self, vector): + cur = self._db.execute( + 'SELECT vector FROM objects WHERE vector = ?', (vector,)) + return cur.fetchone() + + def __getitem__(self, vector): + cur = self._db.execute( + 'SELECT * FROM objects WHERE vector = ?', (vector,)) + cur.row_factory = self.__object + item = cur.fetchone() + if item is None: + raise KeyError(vector) + return item + + def __delitem__(self, vector): + with self._lock: # KeyError + self._db.execute( + 'DELETE FROM objects WHERE vector = ?', (vector,)) + self._db.commit() + self.rowid = len(self) + + def __setitem__(self, vector, obj): + with self._lock: + try: + cur = self._db.execute( + 'INSERT INTO objects VALUES (?,?,?,?,?,?,?,?)', ( + vector, obj.expires_time, obj.object_type, obj.version, + obj.stream_number, obj.tag, obj.data, obj.offset + )) + except (sqlite3.DatabaseError, sqlite3.IntegrityError): + return + self._db.commit() + self.rowid = cur.lastrowid + self._pending.discard(vector) + + def __bool__(self): + return self._db.execute( + 'SELECT vector from objects LIMIT 1').fetchone() is not None + + def __len__(self): + cur = self._db.execute('SELECT count(*) FROM objects') + return cur.fetchone()[0] + + def __del__(self): + self._db.close() diff --git a/minode/structure.py b/minode/structure.py index 405da1a..986fccf 100644 --- a/minode/structure.py +++ b/minode/structure.py @@ -48,21 +48,26 @@ class VarInt(): class Object(): """The 'object' message payload""" def __init__( - self, nonce, expires_time, object_type, version, - stream_number, object_payload + self, expires_time, object_type, version, stream_number, + data=None, offset=None, object_payload=None, + tag=None, nonce=b'\x00' * 8, vector=None ): self.nonce = nonce self.expires_time = expires_time self.object_type = object_type self.version = version self.stream_number = stream_number - self.object_payload = object_payload - self.vector = hashlib.sha512(hashlib.sha512( - self.to_bytes()).digest()).digest()[:32] + if not data: + data, offset = self.to_bytes(object_payload) + self.data = data + self.offset = offset + self.vector = vector or hashlib.sha512(hashlib.sha512( + self.data).digest()).digest()[:32] - self.tag = ( + self.tag = tag or ( # broadcast from version 5 and pubkey/getpukey from version 4 - self.object_payload[:32] if object_type == 3 and version == 5 + (object_payload or self.object_payload)[:32] + if object_type == 3 and version == 5 or (object_type in (0, 1) and version == 4) else None) @@ -73,20 +78,25 @@ class Object(): @classmethod def from_message(cls, m): """Decode message payload""" - payload = m.payload - nonce, expires_time, object_type = struct.unpack('>8sQL', payload[:20]) - payload = payload[20:] - version_varint_length = VarInt.length(payload[0]) - version = VarInt.from_bytes(payload[:version_varint_length]).n - payload = payload[version_varint_length:] - stream_number_varint_length = VarInt.length(payload[0]) + data = m.payload + nonce, expires_time, object_type = struct.unpack('>8sQL', data[:20]) + version_varint_length = VarInt.length(data[20]) + offset = 20 + version_varint_length + version = VarInt.from_bytes(data[20:offset]).n + stream_number_varint_length = VarInt.length(data[offset]) stream_number = VarInt.from_bytes( - payload[:stream_number_varint_length]).n - payload = payload[stream_number_varint_length:] + data[offset:offset+stream_number_varint_length]).n + offset += stream_number_varint_length return cls( - nonce, expires_time, object_type, version, stream_number, payload) + expires_time, object_type, version, stream_number, + data, offset, nonce=nonce + ) - def to_bytes(self): + @property + def object_payload(self): + return self.data[self.offset:] + + def to_bytes(self, object_payload): """Serialize to bytes""" payload = b'' payload += self.nonce @@ -94,8 +104,7 @@ class Object(): payload += ( VarInt(self.version).to_bytes() + VarInt(self.stream_number).to_bytes()) - payload += self.object_payload - return payload + return payload + object_payload, len(payload) def is_expired(self): """Check if object's TTL is expired""" @@ -138,7 +147,7 @@ class Object(): def pow_target(self): """Compute PoW target""" - data = self.to_bytes()[8:] + data = self.data[8:] length = len(data) + 8 + shared.payload_length_extra_bytes dt = max(self.expires_time - time.time(), 0) return int( @@ -148,7 +157,7 @@ class Object(): def pow_initial_hash(self): """Compute the initial hash for PoW""" - return hashlib.sha512(self.to_bytes()[8:]).digest() + return hashlib.sha512(self.data[8:]).digest() class NetAddrNoPrefix(): @@ -162,14 +171,25 @@ class NetAddrNoPrefix(): return 'net_addr_no_prefix, services: {}, host: {}, port {}'.format( self.services, self.host, self.port) + @staticmethod + def decode_host(host): + if host.startswith( + b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xFF\xFF'): + return socket.inet_ntop(socket.AF_INET, host[-4:]) + return socket.inet_ntop(socket.AF_INET6, host) + + @staticmethod + def encode_host(host): + try: + host = socket.inet_pton(socket.AF_INET, host) + return b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xFF\xFF' + host + except socket.error: + return socket.inet_pton(socket.AF_INET6, host) + def to_bytes(self): b = b'' b += struct.pack('>Q', self.services) - try: - host = socket.inet_pton(socket.AF_INET, self.host) - b += b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xFF\xFF' + host - except socket.error: - b += socket.inet_pton(socket.AF_INET6, self.host) + b += self.encode_host(self.host) b += struct.pack('>H', int(self.port)) return b @@ -191,12 +211,7 @@ class NetAddrNoPrefix(): @classmethod def from_bytes(cls, b): services, host, port = struct.unpack('>Q16sH', b) - if host.startswith( - b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xFF\xFF'): - host = socket.inet_ntop(socket.AF_INET, host[-4:]) - else: - host = socket.inet_ntop(socket.AF_INET6, host) - return cls(services, host, port) + return cls(services, cls.decode_host(host), port) class NetAddr(): diff --git a/minode/tests/test_network.py b/minode/tests/test_network.py index c8cdd31..bbcac4d 100644 --- a/minode/tests/test_network.py +++ b/minode/tests/test_network.py @@ -7,7 +7,7 @@ import tempfile import time from contextlib import contextmanager -from minode import connection, main, shared +from minode import connection, main, shared, sql from minode.listener import Listener from minode.manager import Manager @@ -65,11 +65,11 @@ class TestNetwork(unittest.TestCase): def setUp(self): shared.core_nodes.clear() shared.unchecked_node_pool.clear() - shared.objects = {} try: - os.remove(os.path.join(shared.data_directory, 'objects.pickle')) + os.remove(os.path.join(shared.data_directory, 'objects.dat')) except FileNotFoundError: pass + shared.objects = sql.Inventory() def _make_initial_nodes(self): Manager.load_data() diff --git a/minode/tests/test_structure.py b/minode/tests/test_structure.py index 970c152..c6d7eb2 100644 --- a/minode/tests/test_structure.py +++ b/minode/tests/test_structure.py @@ -135,13 +135,13 @@ class TestStructure(unittest.TestCase): self.assertEqual(obj.object_payload, b'HELLO') obj = structure.Object( - b'\x00' * 8, int(time.time() + 3000000), 42, 1, 1, b'HELLO') + int(time.time() + 3000000), 42, 1, 1, object_payload=b'HELLO') self.assertFalse(obj.is_valid()) obj.expires_time = int(time.time() - 11000) self.assertFalse(obj.is_valid()) obj = structure.Object( - b'\x00' * 8, int(time.time() + 300), 42, 1, 2, b'HELLO') + int(time.time() + 300), 42, 1, 2, object_payload=b'HELLO') vector = obj.vector proofofwork._worker(obj) # pylint: disable=protected-access obj = shared.objects.popitem()[1] @@ -151,16 +151,17 @@ class TestStructure(unittest.TestCase): shared.stream = 2 self.assertTrue(obj.is_valid()) - obj.object_payload = \ - b'TIGER, tiger, burning bright. In the forests of the night' - self.assertFalse(obj.is_valid()) + # obj.data = struct.pack(... + # obj.object_payload = \ + # b'TIGER, tiger, burning bright. In the forests of the night' + # self.assertFalse(obj.is_valid()) def test_proofofwork(self): """Check the main proofofwork call and worker""" shared.vector_advertise_queue = queue.Queue() obj = structure.Object( - b'\x00' * 8, int(time.time() + 300), 42, 1, - shared.stream, b'HELLO') + int(time.time() + 300), 42, 1, shared.stream, + object_payload=b'HELLO') start_time = time.time() proofofwork.do_pow_and_publish(obj) try: @@ -189,6 +190,6 @@ class TestStructure(unittest.TestCase): self.fail("No nonce found in the queue") obj = structure.Object( - nonce, obj.expires_time, obj.object_type, obj.version, - obj.stream_number, obj.object_payload) + obj.expires_time, obj.object_type, obj.version, obj.stream_number, + object_payload=obj.object_payload, nonce=nonce) self.assertTrue(obj.is_valid())