WIP: Implementing sqlite objects storage #13

Draft
lee.miller wants to merge 11 commits from lee.miller/MiNode:sqlite into v0.3
8 changed files with 222 additions and 107 deletions
Showing only changes of commit d62fb01fa1 - Show all commits

View File

@ -277,22 +277,15 @@ class ConnectionBase(threading.Thread):
if len(addr) != 0: if len(addr) != 0:
self.send_queue.put(message.Addr(addr)) self.send_queue.put(message.Addr(addr))
with shared.objects_lock: if shared.objects:
if len(shared.objects) > 0: to_send = shared.objects.vectors_to_send()
to_send = { offset = 0
vector for vector in shared.objects.keys() while offset < len(to_send):
if shared.objects[vector].expires_time > time.time()}
while len(to_send) > 0:
if len(to_send) > 10000:
# We limit size of inv messaged to 10000 entries # We limit size of inv messaged to 10000 entries
# because they might time out # because they might time out
# in very slow networks (I2P) # in very slow networks (I2P)
pack = random.sample(tuple(to_send), 10000) self.send_queue.put(message.Inv(to_send[offset:offset+10000]))
self.send_queue.put(message.Inv(pack)) offset += 10000
to_send.difference_update(pack)
else:
self.send_queue.put(message.Inv(to_send))
to_send.clear()
self.status = 'fully_established' self.status = 'fully_established'
def _process_queue(self): def _process_queue(self):
@ -463,12 +456,10 @@ class ConnectionBase(threading.Thread):
else: else:
to_send = self.vectors_to_send.copy() to_send = self.vectors_to_send.copy()
self.vectors_to_send.clear() self.vectors_to_send.clear()
with shared.objects_lock:
for vector in to_send: for vector in to_send:
obj = shared.objects.get(vector, None) obj = shared.objects.get(vector)
if obj: if obj:
self.send_queue.put( self.send_queue.put(message.Message(b'object', obj.data))
message.Message(b'object', obj.to_bytes()))
class Connection(ConnectionBase): class Connection(ConnectionBase):
@ -488,7 +479,6 @@ class Connection(ConnectionBase):
self.vectors_requested.pop(obj.vector, None) self.vectors_requested.pop(obj.vector, None)
self.vectors_to_get.discard(obj.vector) self.vectors_to_get.discard(obj.vector)
if obj.is_valid() and obj.vector not in shared.objects: if obj.is_valid() and obj.vector not in shared.objects:
with shared.objects_lock:
shared.objects[obj.vector] = obj shared.objects[obj.vector] = obj
if ( if (
obj.object_type == shared.i2p_dest_obj_type obj.object_type == shared.i2p_dest_obj_type

View File

@ -8,7 +8,7 @@ import os
import signal import signal
import socket import socket
from . import i2p, shared from . import i2p, shared, sql
from .advertiser import Advertiser from .advertiser import Advertiser
from .manager import Manager from .manager import Manager
from .listener import Listener from .listener import Listener
@ -161,8 +161,7 @@ def start_ip_listener():
def start_i2p_listener(): def start_i2p_listener():
"""Starts I2P threads""" """Starts I2P threads"""
# Grab I2P destinations from old object file # Grab I2P destinations from old object file
for obj in shared.objects.values(): for obj in shared.objects.filter(object_type=shared.i2p_dest_obj_type):
if obj.object_type == shared.i2p_dest_obj_type:
shared.i2p_unchecked_node_pool.add(( shared.i2p_unchecked_node_pool.add((
base64.b64encode(obj.object_payload, altchars=b'-~'), 'i2p')) base64.b64encode(obj.object_payload, altchars=b'-~'), 'i2p'))
@ -245,6 +244,8 @@ def main():
if shared.ip_enabled and not shared.trusted_peer: if shared.ip_enabled and not shared.trusted_peer:
bootstrap_from_dns() bootstrap_from_dns()
shared.objects = sql.Inventory()
if shared.i2p_enabled: if shared.i2p_enabled:
# We are starting it before cleaning expired objects # We are starting it before cleaning expired objects
# so we can collect I2P destination objects # so we can collect I2P destination objects

View File

@ -23,7 +23,6 @@ class Manager(threading.Thread):
self.bootstrap_pool = [] self.bootstrap_pool = []
self.last_cleaned_objects = time.time() self.last_cleaned_objects = time.time()
self.last_cleaned_connections = time.time() self.last_cleaned_connections = time.time()
self.last_pickled_objects = time.time()
self.last_pickled_nodes = time.time() self.last_pickled_nodes = time.time()
# Publish destination 5-15 minutes after start # Publish destination 5-15 minutes after start
self.last_published_i2p_destination = \ self.last_published_i2p_destination = \
@ -36,7 +35,7 @@ class Manager(threading.Thread):
def run(self): def run(self):
self.load_data() self.load_data()
self.clean_objects() shared.objects.cleanup()
self.fill_bootstrap_pool() self.fill_bootstrap_pool()
while True: while True:
time.sleep(0.8) time.sleep(0.8)
@ -45,14 +44,11 @@ class Manager(threading.Thread):
logging.debug('Shutting down Manager') logging.debug('Shutting down Manager')
break break
if now - self.last_cleaned_objects > 90: if now - self.last_cleaned_objects > 90:
self.clean_objects() shared.objects.cleanup()
self.last_cleaned_objects = now self.last_cleaned_objects = now
if now - self.last_cleaned_connections > 2: if now - self.last_cleaned_connections > 2:
self.manage_connections() self.manage_connections()
self.last_cleaned_connections = now self.last_cleaned_connections = now
if now - self.last_pickled_objects > 100:
self.pickle_objects()
self.last_pickled_objects = now
if now - self.last_pickled_nodes > 60: if now - self.last_pickled_nodes > 60:
self.pickle_nodes() self.pickle_nodes()
self.last_pickled_nodes = now self.last_pickled_nodes = now
@ -60,17 +56,6 @@ class Manager(threading.Thread):
self.publish_i2p_destination() self.publish_i2p_destination()
self.last_published_i2p_destination = now self.last_published_i2p_destination = now
@staticmethod
def clean_objects():
for vector in set(shared.objects):
# FIXME: no need to check is_valid() here
if shared.objects[vector].is_expired():
logging.debug(
'Deleted expired object: %s',
base64.b16encode(vector).decode())
with shared.objects_lock:
del shared.objects[vector]
def manage_connections(self): def manage_connections(self):
"""Open new connections if needed, remove closed ones""" """Open new connections if needed, remove closed ones"""
hosts = set() hosts = set()
@ -177,17 +162,6 @@ class Manager(threading.Thread):
@staticmethod @staticmethod
def load_data(): def load_data():
"""Loads initial nodes and data, stored in files between sessions""" """Loads initial nodes and data, stored in files between sessions"""
try:
with open(
os.path.join(shared.data_directory, 'objects.pickle'), 'br'
) as src:
shared.objects = pickle.load(src)
except FileNotFoundError:
pass # first start
except Exception:
logging.warning(
'Error while loading objects from disk.', exc_info=True)
try: try:
with open( with open(
os.path.join(shared.data_directory, 'nodes.pickle'), 'br' os.path.join(shared.data_directory, 'nodes.pickle'), 'br'
@ -227,18 +201,6 @@ class Manager(threading.Thread):
(row[0].encode(), 'i2p') for row in reader} (row[0].encode(), 'i2p') for row in reader}
shared.i2p_node_pool.update(shared.i2p_core_nodes) shared.i2p_node_pool.update(shared.i2p_core_nodes)
@staticmethod
def pickle_objects():
try:
with open(
os.path.join(shared.data_directory, 'objects.pickle'), 'bw'
) as dst:
with shared.objects_lock:
pickle.dump(shared.objects, dst, protocol=3)
logging.debug('Saved objects')
except Exception:
logging.warning('Error while saving objects', exc_info=True)
@staticmethod @staticmethod
def pickle_nodes(): def pickle_nodes():
if len(shared.node_pool) > 10000: if len(shared.node_pool) > 10000:
@ -275,7 +237,7 @@ class Manager(threading.Thread):
dest_pub_raw = base64.b64decode( dest_pub_raw = base64.b64decode(
shared.i2p_dest_pub, altchars=b'-~') shared.i2p_dest_pub, altchars=b'-~')
obj = structure.Object( obj = structure.Object(
b'\x00' * 8, int(time.time() + 2 * 3600), int(time.time() + 2 * 3600),
shared.i2p_dest_obj_type, shared.i2p_dest_obj_version, shared.i2p_dest_obj_type, shared.i2p_dest_obj_version,
shared.stream, dest_pub_raw) shared.stream, object_payload=dest_pub_raw)
proofofwork.do_pow_and_publish(obj) proofofwork.do_pow_and_publish(obj)

View File

@ -39,12 +39,12 @@ def _worker(obj):
logging.debug( logging.debug(
'Finished doing POW, nonce: %s, time: %ss', nonce, time.time() - t) 'Finished doing POW, nonce: %s, time: %ss', nonce, time.time() - t)
obj = structure.Object( obj = structure.Object(
nonce, obj.expires_time, obj.object_type, obj.version, obj.expires_time, obj.object_type, obj.version, obj.stream_number,
obj.stream_number, obj.object_payload) object_payload=obj.object_payload, nonce=nonce
)
logging.debug( logging.debug(
'Object vector is %s', base64.b16encode(obj.vector).decode()) 'Object vector is %s', base64.b16encode(obj.vector).decode())
with shared.objects_lock:
shared.objects[obj.vector] = obj shared.objects[obj.vector] = obj
shared.vector_advertise_queue.put(obj.vector) shared.vector_advertise_queue.put(obj.vector)

153
minode/sql.py Normal file
View File

@ -0,0 +1,153 @@
"""Inventory implementation using sqlite"""
import os
import sqlite3
import time
from . import shared, structure
sqlite3.threadsafety = 3
class Inventory():
"""sqlite inventory"""
def __init__(self):
self._db = sqlite3.connect(
os.path.join(shared.data_directory, 'objects.dat'),
check_same_thread=False
)
self._db.executescript("""
BEGIN;
CREATE TABLE IF NOT EXISTS status
(key text, value integer, UNIQUE(key) ON CONFLICT REPLACE);
INSERT INTO status VALUES ('version', 1);
CREATE TABLE IF NOT EXISTS objects
(vector unique, expires integer, type integer, version integer,
stream integer, tag, data, offset integer);
COMMIT;
""")
self.rowid = len(self) or None
cur = self._db.cursor()
cur.execute("SELECT value FROM status WHERE key='lastvacuumtime'")
now = int(time.time())
try:
vacuumed = cur.fetchone()[0]
except TypeError:
pass
else:
if vacuumed < now - 86400: # 24 hours
cur.execute('VACUUM')
cur.execute("INSERT INTO status VALUES ('lastvacuumtime', ?)", (now,))
self._db.commit()
def __objects(self, cur):
return (
structure.Object(
expires, obj_type, version, stream, data, offset,
tag=tag, vector=vector)
for (vector, expires, obj_type, version, stream, tag, data, offset)
in cur.fetchall()
)
def cleanup(self):
"""Remove expired objects"""
with shared.objects_lock:
self._db.execute(
'DELETE FROM objects WHERE expires < ?',
(int(time.time()) - 3 * 3600,)
)
self._db.commit()
# conditional vacuum and validity check
def filter(self, stream=None, object_type=None, tag=None):
"""Generator of objects with the given parameters"""
clauses = []
if stream:
clauses.append(('stream = ?', stream))
if object_type is not None:
clauses.append(('type = ?', object_type))
if tag:
clauses.append(('tag = ?', tag))
clauses, params = zip(*clauses)
cur = self._db.execute(
'SELECT * FROM objects WHERE ' # nosec B608
+ ' AND '.join(clauses), params)
return self.__objects(cur)
def vectors_to_send(self, stream=None):
cur = self._db.execute(
'SELECT vector FROM objects WHERE expires > ? AND stream = ?'
' ORDER BY random()',
(int(time.time()), stream or shared.stream)
)
return [v for v, in cur.fetchall()]
def get(self, vector, default=None):
try:
return self[vector]
except KeyError:
return default
def keys(self):
cur = self._db.execute('SELECT vector FROM objects')
return (v for v, in cur.fetchall())
def values(self):
cur = self._db.execute('SELECT * FROM objects')
return self.__objects(cur)
def popitem(self):
if not self.rowid:
raise KeyError('empty')
cur = self._db.execute(
'SELECT vector FROM objects WHERE ROWID = ?', (self.rowid,))
vector = cur.fetchone()[0]
obj = self.get(vector)
del self[vector]
return (vector, obj)
def __contains__(self, vector):
cur = self._db.execute(
'SELECT vector FROM objects WHERE vector = ?', (vector,))
return cur.fetchone()
def __getitem__(self, vector):
cur = self._db.execute(
'SELECT * FROM objects WHERE vector = ?', (vector,))
item = cur.fetchone()
if item is None:
raise KeyError(vector)
vector, expires, obj_type, version, stream, tag, data, offset = item
return structure.Object(
expires, obj_type, version, stream, data, offset,
tag=tag, vector=vector
)
def __delitem__(self, vector):
with shared.objects_lock: # KeyError
self._db.execute('DELETE FROM objects WHERE vector = ?', (vector,))
self._db.commit()
self.rowid = len(self)
def __setitem__(self, vector, obj):
with shared.objects_lock:
cur = self._db.execute(
'INSERT INTO objects VALUES (?,?,?,?,?,?,?,?)', (
vector, obj.expires_time, obj.object_type, obj.version,
obj.stream_number, obj.tag, obj.data, obj.offset
))
self._db.commit()
self.rowid = cur.lastrowid
def __bool__(self):
return self._db.execute(
'SELECT vector from objects LIMIT 1').fetchone() is not None
def __len__(self):
cur = self._db.execute('SELECT count(*) FROM objects')
return cur.fetchone()[0]
def __del__(self):
self._db.close()

View File

@ -48,21 +48,26 @@ class VarInt():
class Object(): class Object():
"""The 'object' message payload""" """The 'object' message payload"""
def __init__( def __init__(
self, nonce, expires_time, object_type, version, self, expires_time, object_type, version, stream_number,
stream_number, object_payload data=None, offset=None, object_payload=None,
tag=None, nonce=b'\x00' * 8, vector=None
): ):
self.nonce = nonce self.nonce = nonce
self.expires_time = expires_time self.expires_time = expires_time
self.object_type = object_type self.object_type = object_type
self.version = version self.version = version
self.stream_number = stream_number self.stream_number = stream_number
self.object_payload = object_payload if not data:
self.vector = hashlib.sha512(hashlib.sha512( data, offset = self.to_bytes(object_payload)
self.to_bytes()).digest()).digest()[:32] self.data = data
self.offset = offset
self.vector = vector or hashlib.sha512(hashlib.sha512(
self.data).digest()).digest()[:32]
self.tag = ( self.tag = tag or (
# broadcast from version 5 and pubkey/getpukey from version 4 # broadcast from version 5 and pubkey/getpukey from version 4
self.object_payload[:32] if object_type == 3 and version == 5 (object_payload or self.object_payload)[:32]
if object_type == 3 and version == 5
or (object_type in (0, 1) and version == 4) or (object_type in (0, 1) and version == 4)
else None) else None)
@ -73,20 +78,25 @@ class Object():
@classmethod @classmethod
def from_message(cls, m): def from_message(cls, m):
"""Decode message payload""" """Decode message payload"""
payload = m.payload data = m.payload
nonce, expires_time, object_type = struct.unpack('>8sQL', payload[:20]) nonce, expires_time, object_type = struct.unpack('>8sQL', data[:20])
payload = payload[20:] version_varint_length = VarInt.length(data[20])
version_varint_length = VarInt.length(payload[0]) offset = 20 + version_varint_length
version = VarInt.from_bytes(payload[:version_varint_length]).n version = VarInt.from_bytes(data[20:offset]).n
payload = payload[version_varint_length:] stream_number_varint_length = VarInt.length(data[offset])
stream_number_varint_length = VarInt.length(payload[0])
stream_number = VarInt.from_bytes( stream_number = VarInt.from_bytes(
payload[:stream_number_varint_length]).n data[offset:offset+stream_number_varint_length]).n
payload = payload[stream_number_varint_length:] offset += stream_number_varint_length
return cls( return cls(
nonce, expires_time, object_type, version, stream_number, payload) expires_time, object_type, version, stream_number,
data, offset, nonce=nonce
)
def to_bytes(self): @property
def object_payload(self):
return self.data[self.offset:]
def to_bytes(self, object_payload):
"""Serialize to bytes""" """Serialize to bytes"""
payload = b'' payload = b''
payload += self.nonce payload += self.nonce
@ -94,8 +104,7 @@ class Object():
payload += ( payload += (
VarInt(self.version).to_bytes() VarInt(self.version).to_bytes()
+ VarInt(self.stream_number).to_bytes()) + VarInt(self.stream_number).to_bytes())
payload += self.object_payload return payload + object_payload, len(payload)
return payload
def is_expired(self): def is_expired(self):
"""Check if object's TTL is expired""" """Check if object's TTL is expired"""
@ -138,7 +147,7 @@ class Object():
def pow_target(self): def pow_target(self):
"""Compute PoW target""" """Compute PoW target"""
data = self.to_bytes()[8:] data = self.data[8:]
length = len(data) + 8 + shared.payload_length_extra_bytes length = len(data) + 8 + shared.payload_length_extra_bytes
dt = max(self.expires_time - time.time(), 0) dt = max(self.expires_time - time.time(), 0)
return int( return int(
@ -148,7 +157,7 @@ class Object():
def pow_initial_hash(self): def pow_initial_hash(self):
"""Compute the initial hash for PoW""" """Compute the initial hash for PoW"""
return hashlib.sha512(self.to_bytes()[8:]).digest() return hashlib.sha512(self.data[8:]).digest()
class NetAddrNoPrefix(): class NetAddrNoPrefix():

View File

@ -8,7 +8,7 @@ import tempfile
import time import time
from contextlib import contextmanager from contextlib import contextmanager
from minode import connection, main, shared from minode import connection, main, shared, sql
from minode.listener import Listener from minode.listener import Listener
from minode.manager import Manager from minode.manager import Manager
@ -66,11 +66,11 @@ class TestNetwork(unittest.TestCase):
def setUp(self): def setUp(self):
shared.core_nodes.clear() shared.core_nodes.clear()
shared.unchecked_node_pool.clear() shared.unchecked_node_pool.clear()
shared.objects = {}
try: try:
os.remove(os.path.join(shared.data_directory, 'objects.pickle')) os.remove(os.path.join(shared.data_directory, 'objects.dat'))
except FileNotFoundError: except FileNotFoundError:
pass pass
shared.objects = sql.Inventory()
def _make_initial_nodes(self): def _make_initial_nodes(self):
Manager.load_data() Manager.load_data()

View File

@ -135,13 +135,13 @@ class TestStructure(unittest.TestCase):
self.assertEqual(obj.object_payload, b'HELLO') self.assertEqual(obj.object_payload, b'HELLO')
obj = structure.Object( obj = structure.Object(
b'\x00' * 8, int(time.time() + 3000000), 42, 1, 1, b'HELLO') int(time.time() + 3000000), 42, 1, 1, object_payload=b'HELLO')
self.assertFalse(obj.is_valid()) self.assertFalse(obj.is_valid())
obj.expires_time = int(time.time() - 11000) obj.expires_time = int(time.time() - 11000)
self.assertFalse(obj.is_valid()) self.assertFalse(obj.is_valid())
obj = structure.Object( obj = structure.Object(
b'\x00' * 8, int(time.time() + 300), 42, 1, 2, b'HELLO') int(time.time() + 300), 42, 1, 2, object_payload=b'HELLO')
vector = obj.vector vector = obj.vector
proofofwork._worker(obj) # pylint: disable=protected-access proofofwork._worker(obj) # pylint: disable=protected-access
obj = shared.objects.popitem()[1] obj = shared.objects.popitem()[1]
@ -159,8 +159,8 @@ class TestStructure(unittest.TestCase):
"""Check the main proofofwork call and worker""" """Check the main proofofwork call and worker"""
shared.vector_advertise_queue = queue.Queue() shared.vector_advertise_queue = queue.Queue()
obj = structure.Object( obj = structure.Object(
b'\x00' * 8, int(time.time() + 300), 42, 1, int(time.time() + 300), 42, 1, shared.stream,
shared.stream, b'HELLO') object_payload=b'HELLO')
start_time = time.time() start_time = time.time()
proofofwork.do_pow_and_publish(obj) proofofwork.do_pow_and_publish(obj)
try: try:
@ -189,6 +189,6 @@ class TestStructure(unittest.TestCase):
self.fail("No nonce found in the queue") self.fail("No nonce found in the queue")
obj = structure.Object( obj = structure.Object(
nonce, obj.expires_time, obj.object_type, obj.version, obj.expires_time, obj.object_type, obj.version, obj.stream_number,
obj.stream_number, obj.object_payload) object_payload=obj.object_payload, nonce=nonce)
self.assertTrue(obj.is_valid()) self.assertTrue(obj.is_valid())