WIP: Implementing sqlite objects storage #13

Draft
lee.miller wants to merge 12 commits from lee.miller/MiNode:sqlite into v0.3
9 changed files with 296 additions and 131 deletions

View File

@ -277,22 +277,15 @@ class ConnectionBase(threading.Thread):
if len(addr) != 0:
self.send_queue.put(message.Addr(addr))
with shared.objects_lock:
if len(shared.objects) > 0:
to_send = {
vector for vector in shared.objects.keys()
if shared.objects[vector].expires_time > time.time()}
while len(to_send) > 0:
if len(to_send) > 10000:
# We limit size of inv messaged to 10000 entries
# because they might time out
# in very slow networks (I2P)
pack = random.sample(tuple(to_send), 10000)
self.send_queue.put(message.Inv(pack))
to_send.difference_update(pack)
else:
self.send_queue.put(message.Inv(to_send))
to_send.clear()
if shared.objects:
to_send = shared.objects.vectors_to_send()
offset = 0
while offset < len(to_send):
# We limit size of inv messaged to 10000 entries
# because they might time out
# in very slow networks (I2P)
self.send_queue.put(message.Inv(to_send[offset:offset+10000]))
offset += 10000
self.status = 'fully_established'
def _process_queue(self):
@ -415,7 +408,7 @@ class ConnectionBase(threading.Thread):
def _request_objects(self):
if self.vectors_to_get and len(self.vectors_requested) < 100:
self.vectors_to_get.difference_update(shared.objects.keys())
self.vectors_to_get = shared.objects.select(self.vectors_to_get)
if not self.wait_until:
nodes_count = (
len(shared.node_pool) + len(shared.unchecked_node_pool))
@ -462,12 +455,10 @@ class ConnectionBase(threading.Thread):
else:
to_send = self.vectors_to_send.copy()
self.vectors_to_send.clear()
with shared.objects_lock:
for vector in to_send:
obj = shared.objects.get(vector, None)
if obj:
self.send_queue.put(
message.Message(b'object', obj.to_bytes()))
for vector in to_send:
obj = shared.objects.get(vector)
if obj:
self.send_queue.put(message.Message(b'object', obj.data))
class Connection(ConnectionBase):
@ -475,9 +466,7 @@ class Connection(ConnectionBase):
def _process_msg_inv(self, m):
inv = message.Inv.from_message(m)
logging.debug('%s:%s -> %s', self.host_print, self.port, inv)
to_get = inv.vectors.copy()
to_get.difference_update(shared.objects.keys())
self.vectors_to_get.update(to_get)
self.vectors_to_get.update(shared.objects.select(inv.vectors))
# Do not send objects they already have.
self.vectors_to_send.difference_update(inv.vectors)
@ -486,9 +475,8 @@ class Connection(ConnectionBase):
logging.debug('%s:%s -> %s', self.host_print, self.port, obj)
self.vectors_requested.pop(obj.vector, None)
self.vectors_to_get.discard(obj.vector)
if obj.is_valid() and obj.vector not in shared.objects:
with shared.objects_lock:
shared.objects[obj.vector] = obj
if obj.is_valid():
shared.objects[obj.vector] = obj
if (
obj.object_type == shared.i2p_dest_obj_type
and obj.version == shared.i2p_dest_obj_version

View File

@ -8,7 +8,7 @@ import os
import signal
import socket
from . import i2p, shared
from . import i2p, shared, sql
from .advertiser import Advertiser
from .manager import Manager
from .listener import Listener
@ -157,10 +157,9 @@ def start_ip_listener():
def start_i2p_listener():
"""Starts I2P threads"""
# Grab I2P destinations from old object file
for obj in shared.objects.values():
if obj.object_type == shared.i2p_dest_obj_type:
shared.i2p_unchecked_node_pool.add((
base64.b64encode(obj.object_payload, altchars=b'-~'), 'i2p'))
for obj in shared.objects.filter(object_type=shared.i2p_dest_obj_type):
shared.i2p_unchecked_node_pool.add((
base64.b64encode(obj.object_payload, altchars=b'-~'), 'i2p'))
dest_priv = b''
@ -241,6 +240,8 @@ def main():
if shared.ip_enabled and not shared.trusted_peer:
bootstrap_from_dns()
shared.objects = sql.Inventory()
if shared.i2p_enabled:
# We are starting it before cleaning expired objects
# so we can collect I2P destination objects

View File

@ -22,7 +22,7 @@ class Manager(threading.Thread):
self.q = queue.Queue()
self.last_cleaned_objects = time.time()
self.last_cleaned_connections = time.time()
self.last_pickled_objects = time.time()
# self.last_pickled_objects = time.time()
self.last_pickled_nodes = time.time()
# Publish destination 5-15 minutes after start
self.last_published_i2p_destination = \
@ -30,7 +30,7 @@ class Manager(threading.Thread):
def run(self):
self.load_data()
self.clean_objects()
shared.objects.cleanup()
while True:
time.sleep(0.8)
now = time.time()
@ -38,14 +38,14 @@ class Manager(threading.Thread):
logging.debug('Shutting down Manager')
break
if now - self.last_cleaned_objects > 90:
self.clean_objects()
shared.objects.cleanup()
self.last_cleaned_objects = now
if now - self.last_cleaned_connections > 2:
self.manage_connections()
self.last_cleaned_connections = now
if now - self.last_pickled_objects > 100:
self.pickle_objects()
self.last_pickled_objects = now
# if now - self.last_pickled_objects > 100:
# self.pickle_objects()
# self.last_pickled_objects = now
if now - self.last_pickled_nodes > 60:
self.pickle_nodes()
self.last_pickled_nodes = now
@ -53,21 +53,6 @@ class Manager(threading.Thread):
self.publish_i2p_destination()
self.last_published_i2p_destination = now
@staticmethod
def clean_objects():
for vector in set(shared.objects):
if not shared.objects[vector].is_valid():
if shared.objects[vector].is_expired():
logging.debug(
'Deleted expired object: %s',
base64.b16encode(vector).decode())
else:
logging.warning(
'Deleted invalid object: %s',
base64.b16encode(vector).decode())
with shared.objects_lock:
del shared.objects[vector]
@staticmethod
def manage_connections():
hosts = set()
@ -152,17 +137,6 @@ class Manager(threading.Thread):
@staticmethod
def load_data():
"""Loads initial nodes and data, stored in files between sessions"""
try:
with open(
os.path.join(shared.data_directory, 'objects.pickle'), 'br'
) as src:
shared.objects = pickle.load(src)
except FileNotFoundError:
pass # first start
except Exception:
logging.warning(
'Error while loading objects from disk.', exc_info=True)
try:
with open(
os.path.join(shared.data_directory, 'nodes.pickle'), 'br'
@ -202,17 +176,17 @@ class Manager(threading.Thread):
(row[0].encode(), 'i2p') for row in reader}
shared.i2p_node_pool.update(shared.i2p_core_nodes)
@staticmethod
def pickle_objects():
try:
with open(
os.path.join(shared.data_directory, 'objects.pickle'), 'bw'
) as dst:
with shared.objects_lock:
pickle.dump(shared.objects, dst, protocol=3)
logging.debug('Saved objects')
except Exception:
logging.warning('Error while saving objects', exc_info=True)
# @staticmethod
# def pickle_objects():
# try:
# with open(
# os.path.join(shared.data_directory, 'objects.pickle'), 'bw'
# ) as dst:
# with shared.objects_lock:
# pickle.dump(shared.objects, dst, protocol=3)
# logging.debug('Saved objects')
# except Exception:
# logging.warning('Error while saving objects', exc_info=True)
@staticmethod
def pickle_nodes():
@ -250,7 +224,7 @@ class Manager(threading.Thread):
dest_pub_raw = base64.b64decode(
shared.i2p_dest_pub, altchars=b'-~')
obj = structure.Object(
b'\x00' * 8, int(time.time() + 2 * 3600),
int(time.time() + 2 * 3600),
shared.i2p_dest_obj_type, shared.i2p_dest_obj_version,
shared.stream, dest_pub_raw)
shared.stream, object_payload=dest_pub_raw)
proofofwork.do_pow_and_publish(obj)

View File

@ -39,14 +39,14 @@ def _worker(obj):
logging.debug(
'Finished doing POW, nonce: %s, time: %ss', nonce, time.time() - t)
obj = structure.Object(
nonce, obj.expires_time, obj.object_type, obj.version,
obj.stream_number, obj.object_payload)
obj.expires_time, obj.object_type, obj.version, obj.stream_number,
object_payload=obj.object_payload, nonce=nonce
)
logging.debug(
'Object vector is %s', base64.b16encode(obj.vector).decode())
with shared.objects_lock:
shared.objects[obj.vector] = obj
shared.vector_advertise_queue.put(obj.vector)
shared.objects[obj.vector] = obj
shared.vector_advertise_queue.put(obj.vector)
def do_pow_and_publish(obj):

View File

@ -21,7 +21,7 @@ protocol_version = 3
services = 3 # NODE_NETWORK, NODE_SSL
stream = 1
nonce = os.urandom(8)
user_agent = b'/MiNode:0.3.3/'
user_agent = b'/MiNode:0.3.5/'
timeout = 600
header_length = 24
i2p_dest_obj_type = 0x493250
@ -63,4 +63,3 @@ outgoing_connections = 8
connection_limit = 250
objects = {}
objects_lock = threading.Lock()

187
minode/sql.py Normal file
View File

@ -0,0 +1,187 @@
"""Inventory implementation using sqlite"""
import logging
import os
import sqlite3
import threading
import time
from . import shared, structure
sqlite3.threadsafety = 3
class Inventory():
"""sqlite inventory"""
def __init__(self):
self._lock = threading.Lock()
self._deleted = 0
self._pending = set()
self._db = sqlite3.connect(
os.path.join(shared.data_directory, 'objects.dat'),
check_same_thread=False
)
self._db.executescript("""
BEGIN;
CREATE TABLE IF NOT EXISTS status
(key text, value integer, UNIQUE(key) ON CONFLICT REPLACE);
INSERT INTO status VALUES ('version', 1);
CREATE TABLE IF NOT EXISTS objects
(vector unique, expires integer, type integer, version integer,
stream integer, tag, data, offset integer);
COMMIT;
""")
self.rowid = len(self) or None
cur = self._db.cursor()
cur.execute("SELECT value FROM status WHERE key='lastvacuumtime'")
now = int(time.time())
try:
vacuumed = cur.fetchone()[0]
except TypeError:
pass
else:
if vacuumed < now - 86400: # 24 hours
cur.execute('VACUUM')
cur.execute("INSERT INTO status VALUES ('lastvacuumtime', ?)", (now,))
self._db.commit()
self._db.row_factory = self.__object
@staticmethod
def __object(cursor, row):
if len(cursor.description) != 8:
return row
vector, expires, obj_type, version, stream, tag, data, offset = row
return structure.Object(
expires, obj_type, version, stream, data, offset,
tag=tag, vector=vector)
def cleanup(self):
if len(self._pending) > 100:
logging.warning(
'Not cleaning up, %s objects pending', len(self._pending))
return
with self._lock:
cur = self._db.execute(
'DELETE FROM objects WHERE expires < ?',
(int(time.time()) - 3 * 3600,)
)
self._db.commit()
self._deleted += cur.rowcount
# conditional vacuum and validity check
if self._deleted > 10000:
logging.info('Doing VACUUM for objects')
cur.execute('VACUUM')
cur.execute(
"INSERT INTO status VALUES ('lastvacuumtime', ?)",
(int(time.time()),))
self._db.commit()
self._deleted = 0
logging.info(
'Deleted %s expired objects, %s pending',
cur.rowcount, len(self._pending))
def filter(self, stream=None, object_type=None, tag=None):
clauses = []
if stream:
clauses.append(('stream = ?', stream))
if object_type:
clauses.append(('type = ?', object_type))
if tag:
clauses.append(('tag = ?', tag))
clauses, params = zip(*clauses)
cur = self._db.execute(
'SELECT * FROM objects WHERE ' + ' AND '.join(clauses), params)
return cur
def select(self, vectors):
chunk_size = 999
keys = tuple(vectors)
with self._lock:
for i in range(0, len(vectors), chunk_size):
chunk = keys[i:i+chunk_size]
cur = self._db.execute(
'SELECT vector FROM objects WHERE vector IN ({})'.format(
','.join('?' * len(chunk))), chunk)
for v, in cur:
vectors.remove(v)
self._pending.update(vectors)
return vectors
def vectors_to_send(self, stream=None):
cur = self._db.execute(
'SELECT vector FROM objects WHERE expires > ? AND stream = ?'
' ORDER BY random()',
(int(time.time()), stream or shared.stream)
)
return [v for v, in cur]
def get(self, vector, default=None):
try:
return self[vector]
except KeyError:
return default
def keys(self):
cur = self._db.execute('SELECT vector FROM objects')
return (v for v, in cur)
def values(self):
return self._db.execute('SELECT * FROM objects')
def popitem(self):
if not self.rowid:
raise KeyError('empty')
cur = self._db.execute(
'SELECT vector FROM objects WHERE ROWID = ?', (self.rowid,))
vector = cur.fetchone()[0]
obj = self.get(vector)
del self[vector]
return (vector, obj)
def __contains__(self, vector):
cur = self._db.execute(
'SELECT vector FROM objects WHERE vector = ?', (vector,))
return cur.fetchone()
def __getitem__(self, vector):
cur = self._db.execute(
'SELECT * FROM objects WHERE vector = ?', (vector,))
cur.row_factory = self.__object
item = cur.fetchone()
if item is None:
raise KeyError(vector)
return item
def __delitem__(self, vector):
with self._lock: # KeyError
self._db.execute(
'DELETE FROM objects WHERE vector = ?', (vector,))
self._db.commit()
self.rowid = len(self)
def __setitem__(self, vector, obj):
with self._lock:
try:
cur = self._db.execute(
'INSERT INTO objects VALUES (?,?,?,?,?,?,?,?)', (
vector, obj.expires_time, obj.object_type, obj.version,
obj.stream_number, obj.tag, obj.data, obj.offset
))
except (sqlite3.DatabaseError, sqlite3.IntegrityError):
return
self._db.commit()
self.rowid = cur.lastrowid
self._pending.discard(vector)
def __bool__(self):
return self._db.execute(
'SELECT vector from objects LIMIT 1').fetchone() is not None
def __len__(self):
cur = self._db.execute('SELECT count(*) FROM objects')
return cur.fetchone()[0]
def __del__(self):
self._db.close()

View File

@ -48,21 +48,26 @@ class VarInt():
class Object():
"""The 'object' message payload"""
def __init__(
self, nonce, expires_time, object_type, version,
stream_number, object_payload
self, expires_time, object_type, version, stream_number,
data=None, offset=None, object_payload=None,
tag=None, nonce=b'\x00' * 8, vector=None
):
self.nonce = nonce
self.expires_time = expires_time
self.object_type = object_type
self.version = version
self.stream_number = stream_number
self.object_payload = object_payload
self.vector = hashlib.sha512(hashlib.sha512(
self.to_bytes()).digest()).digest()[:32]
if not data:
data, offset = self.to_bytes(object_payload)
self.data = data
self.offset = offset
self.vector = vector or hashlib.sha512(hashlib.sha512(
self.data).digest()).digest()[:32]
self.tag = (
self.tag = tag or (
# broadcast from version 5 and pubkey/getpukey from version 4
self.object_payload[:32] if object_type == 3 and version == 5
(object_payload or self.object_payload)[:32]
if object_type == 3 and version == 5
or (object_type in (0, 1) and version == 4)
else None)
@ -73,20 +78,25 @@ class Object():
@classmethod
def from_message(cls, m):
"""Decode message payload"""
payload = m.payload
nonce, expires_time, object_type = struct.unpack('>8sQL', payload[:20])
payload = payload[20:]
version_varint_length = VarInt.length(payload[0])
version = VarInt.from_bytes(payload[:version_varint_length]).n
payload = payload[version_varint_length:]
stream_number_varint_length = VarInt.length(payload[0])
data = m.payload
nonce, expires_time, object_type = struct.unpack('>8sQL', data[:20])
version_varint_length = VarInt.length(data[20])
offset = 20 + version_varint_length
version = VarInt.from_bytes(data[20:offset]).n
stream_number_varint_length = VarInt.length(data[offset])
stream_number = VarInt.from_bytes(
payload[:stream_number_varint_length]).n
payload = payload[stream_number_varint_length:]
data[offset:offset+stream_number_varint_length]).n
offset += stream_number_varint_length
return cls(
nonce, expires_time, object_type, version, stream_number, payload)
expires_time, object_type, version, stream_number,
data, offset, nonce=nonce
)
def to_bytes(self):
@property
def object_payload(self):
return self.data[self.offset:]
def to_bytes(self, object_payload):
"""Serialize to bytes"""
payload = b''
payload += self.nonce
@ -94,8 +104,7 @@ class Object():
payload += (
VarInt(self.version).to_bytes()
+ VarInt(self.stream_number).to_bytes())
payload += self.object_payload
return payload
return payload + object_payload, len(payload)
def is_expired(self):
"""Check if object's TTL is expired"""
@ -138,7 +147,7 @@ class Object():
def pow_target(self):
"""Compute PoW target"""
data = self.to_bytes()[8:]
data = self.data[8:]
length = len(data) + 8 + shared.payload_length_extra_bytes
dt = max(self.expires_time - time.time(), 0)
return int(
@ -148,7 +157,7 @@ class Object():
def pow_initial_hash(self):
"""Compute the initial hash for PoW"""
return hashlib.sha512(self.to_bytes()[8:]).digest()
return hashlib.sha512(self.data[8:]).digest()
class NetAddrNoPrefix():
@ -162,14 +171,25 @@ class NetAddrNoPrefix():
return 'net_addr_no_prefix, services: {}, host: {}, port {}'.format(
self.services, self.host, self.port)
@staticmethod
def decode_host(host):
if host.startswith(
b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xFF\xFF'):
return socket.inet_ntop(socket.AF_INET, host[-4:])
return socket.inet_ntop(socket.AF_INET6, host)
@staticmethod
def encode_host(host):
try:
host = socket.inet_pton(socket.AF_INET, host)
return b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xFF\xFF' + host
except socket.error:
return socket.inet_pton(socket.AF_INET6, host)
def to_bytes(self):
b = b''
b += struct.pack('>Q', self.services)
try:
host = socket.inet_pton(socket.AF_INET, self.host)
b += b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xFF\xFF' + host
except socket.error:
b += socket.inet_pton(socket.AF_INET6, self.host)
b += self.encode_host(self.host)
b += struct.pack('>H', int(self.port))
return b
@ -191,12 +211,7 @@ class NetAddrNoPrefix():
@classmethod
def from_bytes(cls, b):
services, host, port = struct.unpack('>Q16sH', b)
if host.startswith(
b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xFF\xFF'):
host = socket.inet_ntop(socket.AF_INET, host[-4:])
else:
host = socket.inet_ntop(socket.AF_INET6, host)
return cls(services, host, port)
return cls(services, cls.decode_host(host), port)
class NetAddr():

View File

@ -7,7 +7,7 @@ import tempfile
import time
from contextlib import contextmanager
from minode import connection, main, shared
from minode import connection, main, shared, sql
from minode.listener import Listener
from minode.manager import Manager
@ -65,11 +65,11 @@ class TestNetwork(unittest.TestCase):
def setUp(self):
shared.core_nodes.clear()
shared.unchecked_node_pool.clear()
shared.objects = {}
try:
os.remove(os.path.join(shared.data_directory, 'objects.pickle'))
os.remove(os.path.join(shared.data_directory, 'objects.dat'))
except FileNotFoundError:
pass
shared.objects = sql.Inventory()
def _make_initial_nodes(self):
Manager.load_data()

View File

@ -135,13 +135,13 @@ class TestStructure(unittest.TestCase):
self.assertEqual(obj.object_payload, b'HELLO')
obj = structure.Object(
b'\x00' * 8, int(time.time() + 3000000), 42, 1, 1, b'HELLO')
int(time.time() + 3000000), 42, 1, 1, object_payload=b'HELLO')
self.assertFalse(obj.is_valid())
obj.expires_time = int(time.time() - 11000)
self.assertFalse(obj.is_valid())
obj = structure.Object(
b'\x00' * 8, int(time.time() + 300), 42, 1, 2, b'HELLO')
int(time.time() + 300), 42, 1, 2, object_payload=b'HELLO')
vector = obj.vector
proofofwork._worker(obj) # pylint: disable=protected-access
obj = shared.objects.popitem()[1]
@ -151,16 +151,17 @@ class TestStructure(unittest.TestCase):
shared.stream = 2
self.assertTrue(obj.is_valid())
obj.object_payload = \
b'TIGER, tiger, burning bright. In the forests of the night'
self.assertFalse(obj.is_valid())
# obj.data = struct.pack(...
# obj.object_payload = \
# b'TIGER, tiger, burning bright. In the forests of the night'
# self.assertFalse(obj.is_valid())
def test_proofofwork(self):
"""Check the main proofofwork call and worker"""
shared.vector_advertise_queue = queue.Queue()
obj = structure.Object(
b'\x00' * 8, int(time.time() + 300), 42, 1,
shared.stream, b'HELLO')
int(time.time() + 300), 42, 1, shared.stream,
object_payload=b'HELLO')
start_time = time.time()
proofofwork.do_pow_and_publish(obj)
try:
@ -189,6 +190,6 @@ class TestStructure(unittest.TestCase):
self.fail("No nonce found in the queue")
obj = structure.Object(
nonce, obj.expires_time, obj.object_type, obj.version,
obj.stream_number, obj.object_payload)
obj.expires_time, obj.object_type, obj.version, obj.stream_number,
object_payload=obj.object_payload, nonce=nonce)
self.assertTrue(obj.is_valid())