WIP: Implementing sqlite objects storage #13

Draft
lee.miller wants to merge 11 commits from lee.miller/MiNode:sqlite into v0.3
4 changed files with 82 additions and 19 deletions
Showing only changes of commit f4f10fe299 - Show all commits

View File

@ -473,7 +473,7 @@ class Connection(ConnectionBase):
logging.debug('%s:%s -> %s', self.host_print, self.port, obj) logging.debug('%s:%s -> %s', self.host_print, self.port, obj)
self.vectors_requested.pop(obj.vector, None) self.vectors_requested.pop(obj.vector, None)
self.vectors_to_get.discard(obj.vector) self.vectors_to_get.discard(obj.vector)
if obj.is_valid() and obj.vector not in shared.objects: if obj.is_valid():
shared.objects[obj.vector] = obj shared.objects[obj.vector] = obj
if ( if (
obj.object_type == shared.i2p_dest_obj_type obj.object_type == shared.i2p_dest_obj_type

View File

@ -42,6 +42,7 @@ class Manager(threading.Thread):
now = time.time() now = time.time()
if shared.shutting_down: if shared.shutting_down:
logging.debug('Shutting down Manager') logging.debug('Shutting down Manager')
shared.objects.flush()
break break
if now - self.last_cleaned_objects > 90: if now - self.last_cleaned_objects > 90:
shared.objects.cleanup() shared.objects.cleanup()

View File

@ -1,5 +1,6 @@
"""Inventory implementation using sqlite""" """Inventory implementation using sqlite"""
import base64
import logging import logging
import os import os
import sqlite3 import sqlite3
@ -16,6 +17,7 @@ class Inventory():
def __init__(self): def __init__(self):
self._lock = threading.Lock() self._lock = threading.Lock()
self._deleted = 0 self._deleted = 0
self._last = {}
self._pending = set() self._pending = set()
self._db = sqlite3.connect( self._db = sqlite3.connect(
os.path.join(shared.data_directory, 'objects.dat'), os.path.join(shared.data_directory, 'objects.dat'),
@ -66,6 +68,16 @@ class Inventory():
logging.info( logging.info(
'Not cleaning up, %s objects pending', len(self._pending)) 'Not cleaning up, %s objects pending', len(self._pending))
return return
for vector in set(self._last):
if self._last[vector].is_expired():
logging.debug(
'Deleted expired object: %s',
base64.b16encode(vector).decode())
with self._lock:
del self._last[vector]
if len(self._last) > 1000:
self.flush()
return
with self._lock: with self._lock:
now = int(time.time()) now = int(time.time())
cur = self._db.execute( cur = self._db.execute(
@ -86,8 +98,31 @@ class Inventory():
self._deleted = 0 self._deleted = 0
self.lastvacuumtime = now self.lastvacuumtime = now
def flush(self):
"""Write cached objects to the database"""
with self._lock:
cur = self._db.executemany(
'INSERT INTO objects VALUES (?,?,?,?,?,?,?,?)',
((obj.vector, obj.expires_time, obj.object_type,
obj.version, obj.stream_number, obj.tag, obj.data,
obj.offset) for obj in self._last.values()))
self._db.commit()
self.rowid = cur.lastrowid
self._last = {}
def filter(self, stream=None, object_type=None, tag=None): def filter(self, stream=None, object_type=None, tag=None):
"""Generator of objects with the given parameters""" """Generator of objects with the given parameters"""
def fits(obj):
if stream and obj.stream_number != stream:
return False
if object_type is not None and obj.object_type != object_type:
return False
if tag and obj.tag != tag:
return False
return True
yield from filter(fits, self._last.values())
clauses = [] clauses = []
if stream: if stream:
clauses.append(('stream = ?', stream)) clauses.append(('stream = ?', stream))
@ -98,16 +133,16 @@ class Inventory():
clauses, params = zip(*clauses) clauses, params = zip(*clauses)
cur = self._db.execute( yield from self._db.execute(
'SELECT * FROM objects WHERE ' # nosec B608 'SELECT * FROM objects WHERE ' # nosec B608
+ ' AND '.join(clauses), params) + ' AND '.join(clauses), params)
return cur
def select(self, vectors): def select(self, vectors):
"""Select new vectors from the given set""" """Select new vectors from the given set"""
chunk_size = 999 chunk_size = 999
keys = tuple(vectors)
with self._lock: with self._lock:
vectors.difference_update(self._last)
keys = tuple(vectors)
for i in range(0, len(vectors), chunk_size): for i in range(0, len(vectors), chunk_size):
chunk = keys[i:i+chunk_size] chunk = keys[i:i+chunk_size]
cur = self._db.execute( cur = self._db.execute(
@ -116,7 +151,7 @@ class Inventory():
chunk) chunk)
for v, in cur: for v, in cur:
vectors.remove(v) vectors.remove(v)
self._pending.update(vectors) self._pending.update(vectors)
return vectors return vectors
def vectors_to_send(self, chunk_size=10000, stream=None): def vectors_to_send(self, chunk_size=10000, stream=None):
@ -131,6 +166,14 @@ class Inventory():
while True: while True:
vectors = cur.fetchmany() vectors = cur.fetchmany()
if not vectors: if not vectors:
# TODO: append to the last short result,
# check that _last is shorter than the chunk_size
# (should be < 1000)
if self._last:
yield [
obj.vector for obj in self._last.values()
if obj.stream_number == stream
and obj.expires_time > now]
return return
yield [v for v, in vectors] yield [v for v, in vectors]
@ -141,13 +184,19 @@ class Inventory():
return default return default
def keys(self): def keys(self):
cur = self._db.execute('SELECT vector FROM objects') yield from self._last
return (v for v, in cur) for vector, in self._db.execute('SELECT vector FROM objects'):
yield vector
def values(self): def values(self):
return self._db.execute('SELECT * FROM objects') yield from self._last.values()
yield from self._db.execute('SELECT * FROM objects')
def popitem(self): def popitem(self):
try:
return self._last.popitem()
except KeyError:
pass
if not self.rowid: if not self.rowid:
raise KeyError('empty') raise KeyError('empty')
cur = self._db.execute( cur = self._db.execute(
@ -158,11 +207,17 @@ class Inventory():
return (vector, obj) return (vector, obj)
def __contains__(self, vector): def __contains__(self, vector):
cur = self._db.execute( if vector in self._last:
'SELECT vector FROM objects WHERE vector = ?', (vector,)) return True
return cur.fetchone() return self._db.execute(
'SELECT vector FROM objects WHERE vector = ?', (vector,)
).fetchone() is not None
def __getitem__(self, vector): def __getitem__(self, vector):
try:
return self._last[vector]
except KeyError:
pass
item = self._db.execute( item = self._db.execute(
'SELECT * FROM objects WHERE vector = ?', (vector,)).fetchone() 'SELECT * FROM objects WHERE vector = ?', (vector,)).fetchone()
if item is None: if item is None:
@ -170,28 +225,32 @@ class Inventory():
return item return item
def __delitem__(self, vector): def __delitem__(self, vector):
try:
del self._last[vector]
return
except KeyError:
pass
with self._lock: # KeyError with self._lock: # KeyError
self._db.execute('DELETE FROM objects WHERE vector = ?', (vector,)) self._db.execute('DELETE FROM objects WHERE vector = ?', (vector,))
self._db.commit() self._db.commit()
self.rowid = len(self) self.rowid = len(self)
def __setitem__(self, vector, obj): def __setitem__(self, vector, obj):
if vector in self:
return
with self._lock: with self._lock:
cur = self._db.execute( self._last[vector] = obj
'INSERT INTO objects VALUES (?,?,?,?,?,?,?,?)', (
vector, obj.expires_time, obj.object_type, obj.version,
obj.stream_number, obj.tag, obj.data, obj.offset
))
self._db.commit()
self.rowid = cur.lastrowid
def __bool__(self): def __bool__(self):
if self._last:
return True
return self._db.execute( return self._db.execute(
'SELECT vector from objects LIMIT 1').fetchone() is not None 'SELECT vector from objects LIMIT 1').fetchone() is not None
def __len__(self): def __len__(self):
cur = self._db.execute('SELECT count(*) FROM objects') cur = self._db.execute('SELECT count(*) FROM objects')
return cur.fetchone()[0] return cur.fetchone()[0] + len(self._last)
def __del__(self): def __del__(self):
self.flush()
self._db.close() self._db.close()

View File

@ -72,6 +72,9 @@ class TestNetwork(unittest.TestCase):
pass pass
shared.objects = sql.Inventory() shared.objects = sql.Inventory()
def tearDown(self):
shared.objects.flush()
def _make_initial_nodes(self): def _make_initial_nodes(self):
Manager.load_data() Manager.load_data()
core_nodes_len = len(shared.core_nodes) core_nodes_len = len(shared.core_nodes)