A rough implementation of the objects stored in sqlite db

This commit is contained in:
Lee Miller 2024-09-08 06:48:28 +03:00
parent fba7aa635a
commit 7261103c87
Signed by: lee.miller
GPG Key ID: 4F97A5EA88F4AB63
8 changed files with 200 additions and 77 deletions

View File

@ -278,11 +278,7 @@ class ConnectionBase(threading.Thread):
self.send_queue.put(message.Addr(addr)) self.send_queue.put(message.Addr(addr))
if len(shared.objects) > 0: if len(shared.objects) > 0:
with shared.objects_lock: to_send = shared.objects.vectors_to_send()
to_send = [
vector for vector in shared.objects.keys()
if shared.objects[vector].expires_time > time.time()]
random.shuffle(to_send)
offset = 0 offset = 0
while offset < len(to_send): while offset < len(to_send):
# We limit size of inv messaged to 10000 entries # We limit size of inv messaged to 10000 entries
@ -459,12 +455,10 @@ class ConnectionBase(threading.Thread):
else: else:
to_send = self.vectors_to_send.copy() to_send = self.vectors_to_send.copy()
self.vectors_to_send.clear() self.vectors_to_send.clear()
with shared.objects_lock: for vector in to_send:
for vector in to_send: obj = shared.objects.get(vector)
obj = shared.objects.get(vector, None) if obj:
if obj: self.send_queue.put(message.Message(b'object', obj.data))
self.send_queue.put(
message.Message(b'object', obj.to_bytes()))
class Connection(ConnectionBase): class Connection(ConnectionBase):
@ -484,8 +478,7 @@ class Connection(ConnectionBase):
self.vectors_requested.pop(obj.vector, None) self.vectors_requested.pop(obj.vector, None)
self.vectors_to_get.discard(obj.vector) self.vectors_to_get.discard(obj.vector)
if obj.is_valid() and obj.vector not in shared.objects: if obj.is_valid() and obj.vector not in shared.objects:
with shared.objects_lock: shared.objects[obj.vector] = obj
shared.objects[obj.vector] = obj
if ( if (
obj.object_type == shared.i2p_dest_obj_type obj.object_type == shared.i2p_dest_obj_type
and obj.version == shared.i2p_dest_obj_version and obj.version == shared.i2p_dest_obj_version

View File

@ -8,7 +8,7 @@ import os
import signal import signal
import socket import socket
from . import i2p, shared from . import i2p, shared, sql
from .advertiser import Advertiser from .advertiser import Advertiser
from .manager import Manager from .manager import Manager
from .listener import Listener from .listener import Listener
@ -157,10 +157,9 @@ def start_ip_listener():
def start_i2p_listener(): def start_i2p_listener():
"""Starts I2P threads""" """Starts I2P threads"""
# Grab I2P destinations from old object file # Grab I2P destinations from old object file
for obj in shared.objects.values(): for obj in shared.objects.filter(object_type=shared.i2p_dest_obj_type):
if obj.object_type == shared.i2p_dest_obj_type: shared.i2p_unchecked_node_pool.add((
shared.i2p_unchecked_node_pool.add(( base64.b64encode(obj.object_payload, altchars=b'-~'), 'i2p'))
base64.b64encode(obj.object_payload, altchars=b'-~'), 'i2p'))
dest_priv = b'' dest_priv = b''
@ -241,6 +240,8 @@ def main():
if shared.ip_enabled and not shared.trusted_peer: if shared.ip_enabled and not shared.trusted_peer:
bootstrap_from_dns() bootstrap_from_dns()
shared.objects = sql.Inventory()
if shared.i2p_enabled: if shared.i2p_enabled:
# We are starting it before cleaning expired objects # We are starting it before cleaning expired objects
# so we can collect I2P destination objects # so we can collect I2P destination objects

View File

@ -30,7 +30,7 @@ class Manager(threading.Thread):
def run(self): def run(self):
self.load_data() self.load_data()
self.clean_objects() shared.objects.cleanup()
while True: while True:
time.sleep(0.8) time.sleep(0.8)
now = time.time() now = time.time()
@ -38,14 +38,14 @@ class Manager(threading.Thread):
logging.debug('Shutting down Manager') logging.debug('Shutting down Manager')
break break
if now - self.last_cleaned_objects > 90: if now - self.last_cleaned_objects > 90:
self.clean_objects() shared.objects.cleanup()
self.last_cleaned_objects = now self.last_cleaned_objects = now
if now - self.last_cleaned_connections > 2: if now - self.last_cleaned_connections > 2:
self.manage_connections() self.manage_connections()
self.last_cleaned_connections = now self.last_cleaned_connections = now
if now - self.last_pickled_objects > 100: # if now - self.last_pickled_objects > 100:
self.pickle_objects() # self.pickle_objects()
self.last_pickled_objects = now # self.last_pickled_objects = now
if now - self.last_pickled_nodes > 60: if now - self.last_pickled_nodes > 60:
self.pickle_nodes() self.pickle_nodes()
self.last_pickled_nodes = now self.last_pickled_nodes = now
@ -53,21 +53,6 @@ class Manager(threading.Thread):
self.publish_i2p_destination() self.publish_i2p_destination()
self.last_published_i2p_destination = now self.last_published_i2p_destination = now
@staticmethod
def clean_objects():
for vector in set(shared.objects):
if not shared.objects[vector].is_valid():
if shared.objects[vector].is_expired():
logging.debug(
'Deleted expired object: %s',
base64.b16encode(vector).decode())
else:
logging.warning(
'Deleted invalid object: %s',
base64.b16encode(vector).decode())
with shared.objects_lock:
del shared.objects[vector]
@staticmethod @staticmethod
def manage_connections(): def manage_connections():
hosts = set() hosts = set()
@ -156,7 +141,8 @@ class Manager(threading.Thread):
with open( with open(
os.path.join(shared.data_directory, 'objects.pickle'), 'br' os.path.join(shared.data_directory, 'objects.pickle'), 'br'
) as src: ) as src:
shared.objects = pickle.load(src) for obj in pickle.load(src):
shared.objects[obj.vector] = obj
except FileNotFoundError: except FileNotFoundError:
pass # first start pass # first start
except Exception: except Exception:
@ -250,7 +236,7 @@ class Manager(threading.Thread):
dest_pub_raw = base64.b64decode( dest_pub_raw = base64.b64decode(
shared.i2p_dest_pub, altchars=b'-~') shared.i2p_dest_pub, altchars=b'-~')
obj = structure.Object( obj = structure.Object(
b'\x00' * 8, int(time.time() + 2 * 3600), int(time.time() + 2 * 3600),
shared.i2p_dest_obj_type, shared.i2p_dest_obj_version, shared.i2p_dest_obj_type, shared.i2p_dest_obj_version,
shared.stream, dest_pub_raw) shared.stream, object_payload=dest_pub_raw)
proofofwork.do_pow_and_publish(obj) proofofwork.do_pow_and_publish(obj)

View File

@ -39,14 +39,14 @@ def _worker(obj):
logging.debug( logging.debug(
'Finished doing POW, nonce: %s, time: %ss', nonce, time.time() - t) 'Finished doing POW, nonce: %s, time: %ss', nonce, time.time() - t)
obj = structure.Object( obj = structure.Object(
nonce, obj.expires_time, obj.object_type, obj.version, obj.expires_time, obj.object_type, obj.version, obj.stream_number,
obj.stream_number, obj.object_payload) object_payload=obj.object_payload, nonce=nonce
)
logging.debug( logging.debug(
'Object vector is %s', base64.b16encode(obj.vector).decode()) 'Object vector is %s', base64.b16encode(obj.vector).decode())
with shared.objects_lock: shared.objects[obj.vector] = obj
shared.objects[obj.vector] = obj shared.vector_advertise_queue.put(obj.vector)
shared.vector_advertise_queue.put(obj.vector)
def do_pow_and_publish(obj): def do_pow_and_publish(obj):

134
minode/sql.py Normal file
View File

@ -0,0 +1,134 @@
"""Inventory implementation using sqlite"""
import os
import sqlite3
import time
from . import shared, structure
sqlite3.threadsafety = 3
class Inventory():
"""sqlite inventory"""
def __init__(self):
self._db = sqlite3.connect(
os.path.join(shared.data_directory, 'objects.dat'),
check_same_thread=False
)
self._db.executescript("""
BEGIN;
CREATE TABLE IF NOT EXISTS status (vacuumtime integer);
CREATE TABLE IF NOT EXISTS objects
(vector unique, expires integer, type integer, version integer,
stream integer, tag, data, offset integer);
COMMIT;
""")
self.rowid = len(self) or None
# self._cur.execute('INSERT INTO status VALUES (?)', int(time.time()))
def __objects(self, cur):
return (
structure.Object(
expires, obj_type, version, stream, data, offset,
tag=tag, vector=vector)
for (vector, expires, obj_type, version, stream, tag, data, offset)
in cur.fetchall()
)
def cleanup(self):
with shared.objects_lock:
self._db.execute(
'DELETE FROM objects WHERE expires < ?',
(int(time.time()) - 3 * 3600,)
)
self._db.commit()
# conditional vacuum and validity check
def filter(self, stream=None, object_type=None, tag=None):
clauses = []
if stream:
clauses.append(('stream = ?', stream))
if object_type:
clauses.append(('type = ?', object_type))
if tag:
clauses.append(('tag = ?', tag))
clauses, params = zip(*clauses)
cur = self._db.execute(
'SELECT * FROM objects WHERE ' + ' AND '.join(clauses), params)
return self.__objects(cur)
def vectors_to_send(self, stream=None):
cur = self._db.execute(
'SELECT vector FROM objects WHERE expires > ? AND stream = ?'
' ORDER BY random()',
(int(time.time()), stream or shared.stream)
)
return [v for v, in cur.fetchall()]
def get(self, vector, default=None):
try:
return self[vector]
except KeyError:
return default
def keys(self):
cur = self._db.execute('SELECT vector FROM objects')
return (v for v, in cur.fetchall())
def values(self):
cur = self._db.execute('SELECT * FROM objects')
return self.__objects(cur)
def popitem(self):
if not self.rowid:
raise KeyError('empty')
cur = self._db.execute(
'SELECT vector FROM objects WHERE ROWID = ?', (self.rowid,))
vector = cur.fetchone()[0]
obj = self.get(vector)
del self[vector]
return (vector, obj)
def __contains__(self, vector):
cur = self._db.execute(
'SELECT vector FROM objects WHERE vector = ?', (vector,))
return cur.fetchone()
def __getitem__(self, vector):
cur = self._db.execute(
'SELECT * FROM objects WHERE vector = ?', (vector,))
item = cur.fetchone()
if item is None:
raise KeyError(vector)
vector, expires, obj_type, version, stream, tag, data, offset = item
return structure.Object(
expires, obj_type, version, stream, data, offset,
tag=tag, vector=vector
)
def __delitem__(self, vector):
with shared.objects_lock: # KeyError
self._db.execute(
'DELETE FROM objects WHERE vector = ?', (vector,))
self._db.commit()
self.rowid = len(self)
def __setitem__(self, vector, obj):
with shared.objects_lock:
cur = self._db.execute(
'INSERT INTO objects VALUES (?,?,?,?,?,?,?,?)', (
vector, obj.expires_time, obj.object_type, obj.version,
obj.stream_number, obj.tag, obj.data, obj.offset
))
self._db.commit()
self.rowid = cur.lastrowid
def __len__(self):
cur = self._db.execute('SELECT count(*) FROM objects')
return cur.fetchone()[0]
def __del__(self):
self._db.close()

View File

@ -48,21 +48,26 @@ class VarInt():
class Object(): class Object():
"""The 'object' message payload""" """The 'object' message payload"""
def __init__( def __init__(
self, nonce, expires_time, object_type, version, self, expires_time, object_type, version, stream_number,
stream_number, object_payload data=None, offset=None, object_payload=None,
tag=None, nonce=b'\x00' * 8, vector=None
): ):
self.nonce = nonce self.nonce = nonce
self.expires_time = expires_time self.expires_time = expires_time
self.object_type = object_type self.object_type = object_type
self.version = version self.version = version
self.stream_number = stream_number self.stream_number = stream_number
self.object_payload = object_payload if not data:
self.vector = hashlib.sha512(hashlib.sha512( data, offset = self.to_bytes(object_payload)
self.to_bytes()).digest()).digest()[:32] self.data = data
self.offset = offset
self.vector = vector or hashlib.sha512(hashlib.sha512(
self.data).digest()).digest()[:32]
self.tag = ( self.tag = tag or (
# broadcast from version 5 and pubkey/getpukey from version 4 # broadcast from version 5 and pubkey/getpukey from version 4
self.object_payload[:32] if object_type == 3 and version == 5 (object_payload or self.object_payload)[:32]
if object_type == 3 and version == 5
or (object_type in (0, 1) and version == 4) or (object_type in (0, 1) and version == 4)
else None) else None)
@ -73,20 +78,25 @@ class Object():
@classmethod @classmethod
def from_message(cls, m): def from_message(cls, m):
"""Decode message payload""" """Decode message payload"""
payload = m.payload data = m.payload
nonce, expires_time, object_type = struct.unpack('>8sQL', payload[:20]) nonce, expires_time, object_type = struct.unpack('>8sQL', data[:20])
payload = payload[20:] version_varint_length = VarInt.length(data[20])
version_varint_length = VarInt.length(payload[0]) offset = 20 + version_varint_length
version = VarInt.from_bytes(payload[:version_varint_length]).n version = VarInt.from_bytes(data[20:offset]).n
payload = payload[version_varint_length:] stream_number_varint_length = VarInt.length(data[offset])
stream_number_varint_length = VarInt.length(payload[0])
stream_number = VarInt.from_bytes( stream_number = VarInt.from_bytes(
payload[:stream_number_varint_length]).n data[offset:offset+stream_number_varint_length]).n
payload = payload[stream_number_varint_length:] offset += stream_number_varint_length
return cls( return cls(
nonce, expires_time, object_type, version, stream_number, payload) expires_time, object_type, version, stream_number,
data, offset, nonce=nonce
)
def to_bytes(self): @property
def object_payload(self):
return self.data[self.offset:]
def to_bytes(self, object_payload):
"""Serialize to bytes""" """Serialize to bytes"""
payload = b'' payload = b''
payload += self.nonce payload += self.nonce
@ -94,8 +104,7 @@ class Object():
payload += ( payload += (
VarInt(self.version).to_bytes() VarInt(self.version).to_bytes()
+ VarInt(self.stream_number).to_bytes()) + VarInt(self.stream_number).to_bytes())
payload += self.object_payload return payload + object_payload, len(payload)
return payload
def is_expired(self): def is_expired(self):
"""Check if object's TTL is expired""" """Check if object's TTL is expired"""
@ -138,7 +147,7 @@ class Object():
def pow_target(self): def pow_target(self):
"""Compute PoW target""" """Compute PoW target"""
data = self.to_bytes()[8:] data = self.data[8:]
length = len(data) + 8 + shared.payload_length_extra_bytes length = len(data) + 8 + shared.payload_length_extra_bytes
dt = max(self.expires_time - time.time(), 0) dt = max(self.expires_time - time.time(), 0)
return int( return int(
@ -148,7 +157,7 @@ class Object():
def pow_initial_hash(self): def pow_initial_hash(self):
"""Compute the initial hash for PoW""" """Compute the initial hash for PoW"""
return hashlib.sha512(self.to_bytes()[8:]).digest() return hashlib.sha512(self.data[8:]).digest()
class NetAddrNoPrefix(): class NetAddrNoPrefix():

View File

@ -7,7 +7,7 @@ import tempfile
import time import time
from contextlib import contextmanager from contextlib import contextmanager
from minode import connection, main, shared from minode import connection, main, shared, sql
from minode.listener import Listener from minode.listener import Listener
from minode.manager import Manager from minode.manager import Manager
@ -65,11 +65,11 @@ class TestNetwork(unittest.TestCase):
def setUp(self): def setUp(self):
shared.core_nodes.clear() shared.core_nodes.clear()
shared.unchecked_node_pool.clear() shared.unchecked_node_pool.clear()
shared.objects = {}
try: try:
os.remove(os.path.join(shared.data_directory, 'objects.pickle')) os.remove(os.path.join(shared.data_directory, 'objects.dat'))
except FileNotFoundError: except FileNotFoundError:
pass pass
shared.objects = sql.Inventory()
def _make_initial_nodes(self): def _make_initial_nodes(self):
Manager.load_data() Manager.load_data()

View File

@ -135,13 +135,13 @@ class TestStructure(unittest.TestCase):
self.assertEqual(obj.object_payload, b'HELLO') self.assertEqual(obj.object_payload, b'HELLO')
obj = structure.Object( obj = structure.Object(
b'\x00' * 8, int(time.time() + 3000000), 42, 1, 1, b'HELLO') int(time.time() + 3000000), 42, 1, 1, object_payload=b'HELLO')
self.assertFalse(obj.is_valid()) self.assertFalse(obj.is_valid())
obj.expires_time = int(time.time() - 11000) obj.expires_time = int(time.time() - 11000)
self.assertFalse(obj.is_valid()) self.assertFalse(obj.is_valid())
obj = structure.Object( obj = structure.Object(
b'\x00' * 8, int(time.time() + 300), 42, 1, 2, b'HELLO') int(time.time() + 300), 42, 1, 2, object_payload=b'HELLO')
vector = obj.vector vector = obj.vector
proofofwork._worker(obj) # pylint: disable=protected-access proofofwork._worker(obj) # pylint: disable=protected-access
obj = shared.objects.popitem()[1] obj = shared.objects.popitem()[1]
@ -159,8 +159,8 @@ class TestStructure(unittest.TestCase):
"""Check the main proofofwork call and worker""" """Check the main proofofwork call and worker"""
shared.vector_advertise_queue = queue.Queue() shared.vector_advertise_queue = queue.Queue()
obj = structure.Object( obj = structure.Object(
b'\x00' * 8, int(time.time() + 300), 42, 1, int(time.time() + 300), 42, 1, shared.stream,
shared.stream, b'HELLO') object_payload=b'HELLO')
start_time = time.time() start_time = time.time()
proofofwork.do_pow_and_publish(obj) proofofwork.do_pow_and_publish(obj)
try: try:
@ -189,6 +189,6 @@ class TestStructure(unittest.TestCase):
self.fail("No nonce found in the queue") self.fail("No nonce found in the queue")
obj = structure.Object( obj = structure.Object(
nonce, obj.expires_time, obj.object_type, obj.version, obj.expires_time, obj.object_type, obj.version, obj.stream_number,
obj.stream_number, obj.object_payload) object_payload=obj.object_payload, nonce=nonce)
self.assertTrue(obj.is_valid()) self.assertTrue(obj.is_valid())