diff --git a/minode/connection.py b/minode/connection.py index 60cced1..fc470b0 100644 --- a/minode/connection.py +++ b/minode/connection.py @@ -473,7 +473,7 @@ class Connection(ConnectionBase): logging.debug('%s:%s -> %s', self.host_print, self.port, obj) self.vectors_requested.pop(obj.vector, None) self.vectors_to_get.discard(obj.vector) - if obj.is_valid() and obj.vector not in shared.objects: + if obj.is_valid(): shared.objects[obj.vector] = obj if ( obj.object_type == shared.i2p_dest_obj_type diff --git a/minode/manager.py b/minode/manager.py index 7dc0800..62b5806 100644 --- a/minode/manager.py +++ b/minode/manager.py @@ -42,6 +42,7 @@ class Manager(threading.Thread): now = time.time() if shared.shutting_down: logging.debug('Shutting down Manager') + shared.objects.flush() break if now - self.last_cleaned_objects > 90: shared.objects.cleanup() diff --git a/minode/sql.py b/minode/sql.py index 27c2471..55536d0 100644 --- a/minode/sql.py +++ b/minode/sql.py @@ -1,5 +1,6 @@ """Inventory implementation using sqlite""" +import base64 import logging import os import sqlite3 @@ -16,6 +17,7 @@ class Inventory(): def __init__(self): self._lock = threading.Lock() self._deleted = 0 + self._last = {} self._pending = set() self._db = sqlite3.connect( os.path.join(shared.data_directory, 'objects.dat'), @@ -66,6 +68,16 @@ class Inventory(): logging.info( 'Not cleaning up, %s objects pending', len(self._pending)) return + for vector in set(self._last): + if self._last[vector].is_expired(): + logging.debug( + 'Deleted expired object: %s', + base64.b16encode(vector).decode()) + with self._lock: + del self._last[vector] + if len(self._last) > 1000: + self.flush() + return with self._lock: now = int(time.time()) cur = self._db.execute( @@ -86,8 +98,31 @@ class Inventory(): self._deleted = 0 self.lastvacuumtime = now + def flush(self): + """Write cached objects to the database""" + with self._lock: + cur = self._db.executemany( + 'INSERT INTO objects VALUES (?,?,?,?,?,?,?,?)', + ((obj.vector, obj.expires_time, obj.object_type, + obj.version, obj.stream_number, obj.tag, obj.data, + obj.offset) for obj in self._last.values())) + self._db.commit() + self.rowid = cur.lastrowid + self._last = {} + def filter(self, stream=None, object_type=None, tag=None): """Generator of objects with the given parameters""" + def fits(obj): + if stream and obj.stream_number != stream: + return False + if object_type is not None and obj.object_type != object_type: + return False + if tag and obj.tag != tag: + return False + return True + + yield from filter(fits, self._last.values()) + clauses = [] if stream: clauses.append(('stream = ?', stream)) @@ -98,16 +133,16 @@ class Inventory(): clauses, params = zip(*clauses) - cur = self._db.execute( + yield from self._db.execute( 'SELECT * FROM objects WHERE ' # nosec B608 + ' AND '.join(clauses), params) - return cur def select(self, vectors): """Select new vectors from the given set""" chunk_size = 999 - keys = tuple(vectors) with self._lock: + vectors.difference_update(self._last) + keys = tuple(vectors) for i in range(0, len(vectors), chunk_size): chunk = keys[i:i+chunk_size] cur = self._db.execute( @@ -116,7 +151,7 @@ class Inventory(): chunk) for v, in cur: vectors.remove(v) - self._pending.update(vectors) + self._pending.update(vectors) return vectors def vectors_to_send(self, chunk_size=10000, stream=None): @@ -131,6 +166,14 @@ class Inventory(): while True: vectors = cur.fetchmany() if not vectors: + # TODO: append to the last short result, + # check that _last is shorter than the chunk_size + # (should be < 1000) + if self._last: + yield [ + obj.vector for obj in self._last.values() + if obj.stream_number == stream + and obj.expires_time > now] return yield [v for v, in vectors] @@ -141,13 +184,19 @@ class Inventory(): return default def keys(self): - cur = self._db.execute('SELECT vector FROM objects') - return (v for v, in cur) + yield from self._last + for vector, in self._db.execute('SELECT vector FROM objects'): + yield vector def values(self): - return self._db.execute('SELECT * FROM objects') + yield from self._last.values() + yield from self._db.execute('SELECT * FROM objects') def popitem(self): + try: + return self._last.popitem() + except KeyError: + pass if not self.rowid: raise KeyError('empty') cur = self._db.execute( @@ -158,11 +207,17 @@ class Inventory(): return (vector, obj) def __contains__(self, vector): - cur = self._db.execute( - 'SELECT vector FROM objects WHERE vector = ?', (vector,)) - return cur.fetchone() + if vector in self._last: + return True + return self._db.execute( + 'SELECT vector FROM objects WHERE vector = ?', (vector,) + ).fetchone() is not None def __getitem__(self, vector): + try: + return self._last[vector] + except KeyError: + pass item = self._db.execute( 'SELECT * FROM objects WHERE vector = ?', (vector,)).fetchone() if item is None: @@ -170,28 +225,32 @@ class Inventory(): return item def __delitem__(self, vector): + try: + del self._last[vector] + return + except KeyError: + pass with self._lock: # KeyError self._db.execute('DELETE FROM objects WHERE vector = ?', (vector,)) self._db.commit() self.rowid = len(self) def __setitem__(self, vector, obj): + if vector in self: + return with self._lock: - cur = self._db.execute( - 'INSERT INTO objects VALUES (?,?,?,?,?,?,?,?)', ( - vector, obj.expires_time, obj.object_type, obj.version, - obj.stream_number, obj.tag, obj.data, obj.offset - )) - self._db.commit() - self.rowid = cur.lastrowid + self._last[vector] = obj def __bool__(self): + if self._last: + return True return self._db.execute( 'SELECT vector from objects LIMIT 1').fetchone() is not None def __len__(self): cur = self._db.execute('SELECT count(*) FROM objects') - return cur.fetchone()[0] + return cur.fetchone()[0] + len(self._last) def __del__(self): + self.flush() self._db.close() diff --git a/minode/tests/test_network.py b/minode/tests/test_network.py index d588022..1a9085e 100644 --- a/minode/tests/test_network.py +++ b/minode/tests/test_network.py @@ -72,6 +72,9 @@ class TestNetwork(unittest.TestCase): pass shared.objects = sql.Inventory() + def tearDown(self): + shared.objects.flush() + def _make_initial_nodes(self): Manager.load_data() core_nodes_len = len(shared.core_nodes)