WIP: Implementing sqlite objects storage #13

Draft
lee.miller wants to merge 11 commits from lee.miller/MiNode:sqlite into v0.3
10 changed files with 513 additions and 120 deletions

View File

@ -277,22 +277,11 @@ class ConnectionBase(threading.Thread):
if len(addr) != 0: if len(addr) != 0:
self.send_queue.put(message.Addr(addr)) self.send_queue.put(message.Addr(addr))
with shared.objects_lock: if shared.objects:
if len(shared.objects) > 0: for chunk in shared.objects.biginv_chunks(10000):
to_send = { # We limit size of inv messages to 10000 entries
vector for vector in shared.objects.keys() # because they might time out in very slow networks (I2P)
if shared.objects[vector].expires_time > time.time()} self.send_queue.put(message.Inv(chunk))
while len(to_send) > 0:
if len(to_send) > 10000:
# We limit size of inv messaged to 10000 entries
# because they might time out
# in very slow networks (I2P)
pack = random.sample(tuple(to_send), 10000)
self.send_queue.put(message.Inv(pack))
to_send.difference_update(pack)
else:
self.send_queue.put(message.Inv(to_send))
to_send.clear()
self.status = 'fully_established' self.status = 'fully_established'
def _process_queue(self): def _process_queue(self):
@ -416,7 +405,7 @@ class ConnectionBase(threading.Thread):
def _request_objects(self): def _request_objects(self):
if self.vectors_to_get and len(self.vectors_requested) < 100: if self.vectors_to_get and len(self.vectors_requested) < 100:
self.vectors_to_get.difference_update(shared.objects.keys()) self.vectors_to_get = shared.objects.select(self.vectors_to_get)
if not self.wait_until: if not self.wait_until:
nodes_count = ( nodes_count = (
len(shared.node_pool) + len(shared.unchecked_node_pool)) len(shared.node_pool) + len(shared.unchecked_node_pool))
@ -463,12 +452,10 @@ class ConnectionBase(threading.Thread):
else: else:
to_send = self.vectors_to_send.copy() to_send = self.vectors_to_send.copy()
self.vectors_to_send.clear() self.vectors_to_send.clear()
with shared.objects_lock:
for vector in to_send: for vector in to_send:
obj = shared.objects.get(vector, None) obj = shared.objects.get(vector)
if obj: if obj:
self.send_queue.put( self.send_queue.put(message.Message(b'object', obj.data))
message.Message(b'object', obj.to_bytes()))
class Connection(ConnectionBase): class Connection(ConnectionBase):
@ -476,9 +463,7 @@ class Connection(ConnectionBase):
def _process_msg_inv(self, m): def _process_msg_inv(self, m):
inv = message.Inv.from_message(m) inv = message.Inv.from_message(m)
logging.debug('%s:%s -> %s', self.host_print, self.port, inv) logging.debug('%s:%s -> %s', self.host_print, self.port, inv)
to_get = inv.vectors.copy() self.vectors_to_get.update(shared.objects.select(inv.vectors))
to_get.difference_update(shared.objects.keys())
self.vectors_to_get.update(to_get)
# Do not send objects they already have. # Do not send objects they already have.
self.vectors_to_send.difference_update(inv.vectors) self.vectors_to_send.difference_update(inv.vectors)
@ -487,8 +472,7 @@ class Connection(ConnectionBase):
logging.debug('%s:%s -> %s', self.host_print, self.port, obj) logging.debug('%s:%s -> %s', self.host_print, self.port, obj)
self.vectors_requested.pop(obj.vector, None) self.vectors_requested.pop(obj.vector, None)
self.vectors_to_get.discard(obj.vector) self.vectors_to_get.discard(obj.vector)
if obj.is_valid() and obj.vector not in shared.objects: if obj.is_valid():
with shared.objects_lock:
shared.objects[obj.vector] = obj shared.objects[obj.vector] = obj
if ( if (
obj.object_type == shared.i2p_dest_obj_type obj.object_type == shared.i2p_dest_obj_type
@ -501,6 +485,7 @@ class Connection(ConnectionBase):
logging.debug(dest) logging.debug(dest)
shared.i2p_unchecked_node_pool.add((dest, 'i2p')) shared.i2p_unchecked_node_pool.add((dest, 'i2p'))
shared.vector_advertise_queue.put(obj.vector) shared.vector_advertise_queue.put(obj.vector)
shared.objects.check(obj.vector)
def _process_msg_getdata(self, m): def _process_msg_getdata(self, m):
getdata = message.GetData.from_message(m) getdata = message.GetData.from_message(m)

View File

@ -8,7 +8,7 @@ import os
import signal import signal
import socket import socket
from . import i2p, shared from . import i2p, shared, sql
from .advertiser import Advertiser from .advertiser import Advertiser
from .manager import Manager from .manager import Manager
from .listener import Listener from .listener import Listener
@ -161,8 +161,7 @@ def start_ip_listener():
def start_i2p_listener(): def start_i2p_listener():
"""Starts I2P threads""" """Starts I2P threads"""
# Grab I2P destinations from old object file # Grab I2P destinations from old object file
for obj in shared.objects.values(): for obj in shared.objects.filter(object_type=shared.i2p_dest_obj_type):
if obj.object_type == shared.i2p_dest_obj_type:
shared.i2p_unchecked_node_pool.add(( shared.i2p_unchecked_node_pool.add((
base64.b64encode(obj.object_payload, altchars=b'-~'), 'i2p')) base64.b64encode(obj.object_payload, altchars=b'-~'), 'i2p'))
@ -245,6 +244,8 @@ def main():
if shared.ip_enabled and not shared.trusted_peer: if shared.ip_enabled and not shared.trusted_peer:
bootstrap_from_dns() bootstrap_from_dns()
shared.objects = sql.Inventory()
if shared.i2p_enabled: if shared.i2p_enabled:
# We are starting it before cleaning expired objects # We are starting it before cleaning expired objects
# so we can collect I2P destination objects # so we can collect I2P destination objects

View File

@ -23,7 +23,6 @@ class Manager(threading.Thread):
self.bootstrap_pool = [] self.bootstrap_pool = []
self.last_cleaned_objects = time.time() self.last_cleaned_objects = time.time()
self.last_cleaned_connections = time.time() self.last_cleaned_connections = time.time()
self.last_pickled_objects = time.time()
self.last_pickled_nodes = time.time() self.last_pickled_nodes = time.time()
# Publish destination 5-15 minutes after start # Publish destination 5-15 minutes after start
self.last_published_i2p_destination = \ self.last_published_i2p_destination = \
@ -36,23 +35,21 @@ class Manager(threading.Thread):
def run(self): def run(self):
self.load_data() self.load_data()
self.clean_objects() shared.objects.cleanup()
self.fill_bootstrap_pool() self.fill_bootstrap_pool()
while True: while True:
time.sleep(0.8) time.sleep(0.8)
now = time.time() now = time.time()
if shared.shutting_down: if shared.shutting_down:
logging.debug('Shutting down Manager') logging.debug('Shutting down Manager')
shared.objects.flush()
break break
if now - self.last_cleaned_objects > 90: if now - self.last_cleaned_objects > 90:
self.clean_objects() shared.objects.cleanup()
self.last_cleaned_objects = now self.last_cleaned_objects = now
if now - self.last_cleaned_connections > 2: if now - self.last_cleaned_connections > 2:
self.manage_connections() self.manage_connections()
self.last_cleaned_connections = now self.last_cleaned_connections = now
if now - self.last_pickled_objects > 100:
self.pickle_objects()
self.last_pickled_objects = now
if now - self.last_pickled_nodes > 60: if now - self.last_pickled_nodes > 60:
self.pickle_nodes() self.pickle_nodes()
self.last_pickled_nodes = now self.last_pickled_nodes = now
@ -60,17 +57,6 @@ class Manager(threading.Thread):
self.publish_i2p_destination() self.publish_i2p_destination()
self.last_published_i2p_destination = now self.last_published_i2p_destination = now
@staticmethod
def clean_objects():
for vector in set(shared.objects):
# FIXME: no need to check is_valid() here
if shared.objects[vector].is_expired():
logging.debug(
'Deleted expired object: %s',
base64.b16encode(vector).decode())
with shared.objects_lock:
del shared.objects[vector]
def manage_connections(self): def manage_connections(self):
"""Open new connections if needed, remove closed ones""" """Open new connections if needed, remove closed ones"""
hosts = set() hosts = set()
@ -100,6 +86,8 @@ class Manager(threading.Thread):
outgoing_connections = 0 outgoing_connections = 0
for c in shared.connections.copy(): for c in shared.connections.copy():
if not c.is_alive() or c.status == 'disconnected': if not c.is_alive() or c.status == 'disconnected':
shared.objects.check(
*(c.vectors_to_get | c.vectors_requested.keys()))
with shared.connections_lock: with shared.connections_lock:
shared.connections.remove(c) shared.connections.remove(c)
else: else:
@ -177,17 +165,6 @@ class Manager(threading.Thread):
@staticmethod @staticmethod
def load_data(): def load_data():
"""Loads initial nodes and data, stored in files between sessions""" """Loads initial nodes and data, stored in files between sessions"""
try:
with open(
os.path.join(shared.data_directory, 'objects.pickle'), 'br'
) as src:
shared.objects = pickle.load(src)
except FileNotFoundError:
pass # first start
except Exception:
logging.warning(
'Error while loading objects from disk.', exc_info=True)
try: try:
with open( with open(
os.path.join(shared.data_directory, 'nodes.pickle'), 'br' os.path.join(shared.data_directory, 'nodes.pickle'), 'br'
@ -227,18 +204,6 @@ class Manager(threading.Thread):
(row[0].encode(), 'i2p') for row in reader} (row[0].encode(), 'i2p') for row in reader}
shared.i2p_node_pool.update(shared.i2p_core_nodes) shared.i2p_node_pool.update(shared.i2p_core_nodes)
@staticmethod
def pickle_objects():
try:
with open(
os.path.join(shared.data_directory, 'objects.pickle'), 'bw'
) as dst:
with shared.objects_lock:
pickle.dump(shared.objects, dst, protocol=3)
logging.debug('Saved objects')
except Exception:
logging.warning('Error while saving objects', exc_info=True)
@staticmethod @staticmethod
def pickle_nodes(): def pickle_nodes():
if len(shared.node_pool) > 10000: if len(shared.node_pool) > 10000:
@ -275,7 +240,7 @@ class Manager(threading.Thread):
dest_pub_raw = base64.b64decode( dest_pub_raw = base64.b64decode(
shared.i2p_dest_pub, altchars=b'-~') shared.i2p_dest_pub, altchars=b'-~')
obj = structure.Object( obj = structure.Object(
b'\x00' * 8, int(time.time() + 2 * 3600), int(time.time() + 2 * 3600),
shared.i2p_dest_obj_type, shared.i2p_dest_obj_version, shared.i2p_dest_obj_type, shared.i2p_dest_obj_version,
shared.stream, dest_pub_raw) shared.stream, object_payload=dest_pub_raw)
proofofwork.do_pow_and_publish(obj) proofofwork.do_pow_and_publish(obj)

View File

@ -39,12 +39,12 @@ def _worker(obj):
logging.debug( logging.debug(
'Finished doing POW, nonce: %s, time: %ss', nonce, time.time() - t) 'Finished doing POW, nonce: %s, time: %ss', nonce, time.time() - t)
obj = structure.Object( obj = structure.Object(
nonce, obj.expires_time, obj.object_type, obj.version, obj.expires_time, obj.object_type, obj.version, obj.stream_number,
obj.stream_number, obj.object_payload) object_payload=obj.object_payload, nonce=nonce
)
logging.debug( logging.debug(
'Object vector is %s', base64.b16encode(obj.vector).decode()) 'Object vector is %s', base64.b16encode(obj.vector).decode())
with shared.objects_lock:
shared.objects[obj.vector] = obj shared.objects[obj.vector] = obj
shared.vector_advertise_queue.put(obj.vector) shared.vector_advertise_queue.put(obj.vector)

View File

@ -21,7 +21,7 @@ protocol_version = 3
services = 3 # NODE_NETWORK, NODE_SSL services = 3 # NODE_NETWORK, NODE_SSL
stream = 1 stream = 1
nonce = os.urandom(8) nonce = os.urandom(8)
user_agent = b'/MiNode:0.3.3/' user_agent = b'/MiNode:0.3.5/'
timeout = 600 timeout = 600
header_length = 24 header_length = 24
i2p_dest_obj_type = 0x493250 i2p_dest_obj_type = 0x493250
@ -63,4 +63,3 @@ outgoing_connections = 8
connection_limit = 250 connection_limit = 250
objects = {} objects = {}
objects_lock = threading.Lock()

257
minode/sql.py Normal file
View File

@ -0,0 +1,257 @@
"""Inventory implementation using sqlite"""
import base64
import logging
import os
import sqlite3
import threading
import time
from . import shared, structure
sqlite3.threadsafety = 3
class Inventory():
"""sqlite inventory"""
def __init__(self):
self._lock = threading.Lock()
self._deleted = 0
self._last = {}
self._pending = set()
self._db = sqlite3.connect(
os.path.join(shared.data_directory, 'objects.dat'),
check_same_thread=False
)
self._db.executescript("""
BEGIN;
CREATE TABLE IF NOT EXISTS status
(key text, value integer, UNIQUE(key) ON CONFLICT REPLACE);
INSERT INTO status VALUES ('version', 1);
CREATE TABLE IF NOT EXISTS objects
(vector unique, expires integer, type integer, version integer,
stream integer, tag, data, offset integer);
COMMIT;
""")
self.rowid = len(self) or None
try:
self.lastvacuumtime = self._db.execute(
"SELECT value FROM status WHERE key='lastvacuumtime'"
).fetchone()[0]
except TypeError:
self.lastvacuumtime = int(time.time())
self._db.execute(
"INSERT INTO status VALUES ('lastvacuumtime', ?)",
(self.lastvacuumtime,)
)
self._db.commit()
self._db.row_factory = self.__object
@staticmethod
def __object(cursor, row):
if len(cursor.description) != 8:
return row
vector, expires, obj_type, version, stream, tag, data, offset = row
return structure.Object(
expires, obj_type, version, stream, data, offset,
tag=tag, vector=vector)
def check(self, *vectors):
"""Remove given vectors from pending"""
with self._lock:
for vector in vectors:
self._pending.discard(vector)
def cleanup(self):
"""Remove expired objects"""
if len(self._pending) > 100:
logging.info(
'Not cleaning up, %s objects pending', len(self._pending))
return
for vector in set(self._last):
if self._last[vector].is_expired():
logging.debug(
'Deleted expired object: %s',
base64.b16encode(vector).decode())
with self._lock:
del self._last[vector]
if len(self._last) > 1000:
self.flush()
return
with self._lock:
now = int(time.time())
cur = self._db.execute(
'DELETE FROM objects WHERE expires < ?', (now - 3 * 3600,))
self._db.commit()
self._deleted += cur.rowcount
(logging.info if self._pending else logging.debug)(
'Deleted %s expired objects, %s pending',
cur.rowcount, len(self._pending))
# conditional vacuum and validity check (TODO)
# every 24 hours or after deleting a lot of items
if self._deleted > 10000 or self.lastvacuumtime < now - 86400:
logging.info('Doing VACUUM for objects')
cur.execute('VACUUM')
cur.execute(
"INSERT INTO status VALUES ('lastvacuumtime', ?)", (now,))
self._db.commit()
self._deleted = 0
self.lastvacuumtime = now
def flush(self):
"""Write cached objects to the database"""
with self._lock:
cur = self._db.executemany(
'INSERT INTO objects VALUES (?,?,?,?,?,?,?,?)',
((obj.vector, obj.expires_time, obj.object_type,
obj.version, obj.stream_number, obj.tag, obj.data,
obj.offset) for obj in self._last.values()))
self._db.commit()
self.rowid = cur.lastrowid
self._last = {}
def filter(self, stream=None, object_type=None, tag=None):
"""Generator of objects with the given parameters"""
def fits(obj):
if stream and obj.stream_number != stream:
return False
if object_type is not None and obj.object_type != object_type:
return False
if tag and obj.tag != tag:
return False
return True
yield from filter(fits, self._last.values())
clauses = []
if stream:
clauses.append(('stream = ?', stream))
if object_type is not None:
clauses.append(('type = ?', object_type))
if tag:
clauses.append(('tag = ?', tag))
clauses, params = zip(*clauses)
yield from self._db.execute(
'SELECT * FROM objects WHERE ' # nosec B608
+ ' AND '.join(clauses), params)
def select(self, vectors):
"""Select new vectors from the given set"""
chunk_size = 999
with self._lock:
vectors.difference_update(self._last)
keys = tuple(vectors)
for i in range(0, len(vectors), chunk_size):
chunk = keys[i:i+chunk_size]
cur = self._db.execute(
'SELECT vector FROM objects WHERE vector IN' # nosec B608
' ({})'.format(','.join('?' * len(chunk))),
chunk)
for v, in cur:
vectors.remove(v)
self._pending.update(vectors)
return vectors
def biginv_chunks(self, chunk_size=10000, stream=None):
"""Generator of vector lists for making the biginv"""
if stream is None:
stream = shared.stream
now = int(time.time())
cur = self._db.execute(
'SELECT vector FROM objects WHERE expires > ? AND stream = ?'
' ORDER BY random()', (now, stream)
)
cur.arraysize = chunk_size
while True:
vectors = cur.fetchmany()
if not vectors:
# TODO: append to the last short result,
# check that _last is shorter than the chunk_size
# (should be < 1000)
if self._last:
yield [
obj.vector for obj in self._last.values()
if obj.stream_number == stream
and obj.expires_time > now]
return
yield [v for v, in vectors]
def get(self, vector, default=None):
try:
return self[vector]
except KeyError:
return default
def keys(self):
yield from self._last
for vector, in self._db.execute('SELECT vector FROM objects'):
yield vector
def values(self):
yield from self._last.values()
yield from self._db.execute('SELECT * FROM objects')
def popitem(self):
try:
return self._last.popitem()
except KeyError:
pass
if not self.rowid:
raise KeyError('empty')
cur = self._db.execute(
'SELECT vector FROM objects WHERE ROWID = ?', (self.rowid,))
vector = cur.fetchone()[0]
obj = self.get(vector)
del self[vector]
return (vector, obj)
def __contains__(self, vector):
if vector in self._last:
return True
return self._db.execute(
'SELECT vector FROM objects WHERE vector = ?', (vector,)
).fetchone() is not None
def __getitem__(self, vector):
try:
return self._last[vector]
except KeyError:
pass
item = self._db.execute(
'SELECT * FROM objects WHERE vector = ?', (vector,)).fetchone()
if item is None:
raise KeyError(vector)
return item
def __delitem__(self, vector):
try:
del self._last[vector]
return
except KeyError:
pass
with self._lock: # KeyError
self._db.execute('DELETE FROM objects WHERE vector = ?', (vector,))
self._db.commit()
self.rowid = len(self)
def __setitem__(self, vector, obj):
if vector in self:
return
with self._lock:
self._last[vector] = obj
def __bool__(self):
if self._last:
return True
return self._db.execute(
'SELECT vector from objects LIMIT 1').fetchone() is not None
def __len__(self):
cur = self._db.execute('SELECT count(*) FROM objects')
return cur.fetchone()[0] + len(self._last)
def __del__(self):
self.flush()
self._db.close()

View File

@ -48,21 +48,26 @@ class VarInt():
class Object(): class Object():
"""The 'object' message payload""" """The 'object' message payload"""
def __init__( def __init__(
self, nonce, expires_time, object_type, version, self, expires_time, object_type, version, stream_number,
stream_number, object_payload data=None, offset=None, object_payload=None,
tag=None, nonce=b'\x00' * 8, vector=None
): ):
self.nonce = nonce self.nonce = nonce
self.expires_time = expires_time self.expires_time = expires_time
self.object_type = object_type self.object_type = object_type
self.version = version self.version = version
self.stream_number = stream_number self.stream_number = stream_number
self.object_payload = object_payload if not data:
self.vector = hashlib.sha512(hashlib.sha512( data, offset = self.to_bytes(object_payload)
self.to_bytes()).digest()).digest()[:32] self.data = data
self.offset = offset
self.vector = vector or hashlib.sha512(hashlib.sha512(
self.data).digest()).digest()[:32]
self.tag = ( self.tag = tag or (
# broadcast from version 5 and pubkey/getpukey from version 4 # broadcast from version 5 and pubkey/getpukey from version 4
self.object_payload[:32] if object_type == 3 and version == 5 (object_payload or self.object_payload)[:32]
if object_type == 3 and version == 5
or (object_type in (0, 1) and version == 4) or (object_type in (0, 1) and version == 4)
else None) else None)
@ -73,20 +78,25 @@ class Object():
@classmethod @classmethod
def from_message(cls, m): def from_message(cls, m):
"""Decode message payload""" """Decode message payload"""
payload = m.payload data = m.payload
nonce, expires_time, object_type = struct.unpack('>8sQL', payload[:20]) nonce, expires_time, object_type = struct.unpack('>8sQL', data[:20])
payload = payload[20:] version_varint_length = VarInt.length(data[20])
version_varint_length = VarInt.length(payload[0]) offset = 20 + version_varint_length
version = VarInt.from_bytes(payload[:version_varint_length]).n version = VarInt.from_bytes(data[20:offset]).n
payload = payload[version_varint_length:] stream_number_varint_length = VarInt.length(data[offset])
stream_number_varint_length = VarInt.length(payload[0])
stream_number = VarInt.from_bytes( stream_number = VarInt.from_bytes(
payload[:stream_number_varint_length]).n data[offset:offset+stream_number_varint_length]).n
payload = payload[stream_number_varint_length:] offset += stream_number_varint_length
return cls( return cls(
nonce, expires_time, object_type, version, stream_number, payload) expires_time, object_type, version, stream_number,
data, offset, nonce=nonce
)
def to_bytes(self): @property
def object_payload(self):
return self.data[self.offset:]
def to_bytes(self, object_payload):
"""Serialize to bytes""" """Serialize to bytes"""
payload = b'' payload = b''
payload += self.nonce payload += self.nonce
@ -94,8 +104,7 @@ class Object():
payload += ( payload += (
VarInt(self.version).to_bytes() VarInt(self.version).to_bytes()
+ VarInt(self.stream_number).to_bytes()) + VarInt(self.stream_number).to_bytes())
payload += self.object_payload return payload + object_payload, len(payload)
return payload
def is_expired(self): def is_expired(self):
"""Check if object's TTL is expired""" """Check if object's TTL is expired"""
@ -138,7 +147,7 @@ class Object():
def pow_target(self): def pow_target(self):
"""Compute PoW target""" """Compute PoW target"""
data = self.to_bytes()[8:] data = self.data[8:]
length = len(data) + 8 + shared.payload_length_extra_bytes length = len(data) + 8 + shared.payload_length_extra_bytes
dt = max(self.expires_time - time.time(), 0) dt = max(self.expires_time - time.time(), 0)
return int( return int(
@ -148,7 +157,7 @@ class Object():
def pow_initial_hash(self): def pow_initial_hash(self):
"""Compute the initial hash for PoW""" """Compute the initial hash for PoW"""
return hashlib.sha512(self.to_bytes()[8:]).digest() return hashlib.sha512(self.data[8:]).digest()
class NetAddrNoPrefix(): class NetAddrNoPrefix():

View File

@ -8,7 +8,7 @@ import tempfile
import time import time
from contextlib import contextmanager from contextlib import contextmanager
from minode import connection, main, shared from minode import connection, main, shared, sql
from minode.listener import Listener from minode.listener import Listener
from minode.manager import Manager from minode.manager import Manager
@ -66,11 +66,14 @@ class TestNetwork(unittest.TestCase):
def setUp(self): def setUp(self):
shared.core_nodes.clear() shared.core_nodes.clear()
shared.unchecked_node_pool.clear() shared.unchecked_node_pool.clear()
shared.objects = {}
try: try:
os.remove(os.path.join(shared.data_directory, 'objects.pickle')) os.remove(os.path.join(shared.data_directory, 'objects.dat'))
except FileNotFoundError: except FileNotFoundError:
pass pass
shared.objects = sql.Inventory()
def tearDown(self):
shared.objects.flush()
def _make_initial_nodes(self): def _make_initial_nodes(self):
Manager.load_data() Manager.load_data()

View File

@ -0,0 +1,165 @@
"""Tests for the Inventory implementation"""
import os
import random
import tempfile
import time
import unittest
from minode import sql, shared, structure
# + __bool__
# + __contains__
# + __getitem__
# + __setitem__
# = cleanup
# + get
# + filter
# = select
# + biginv_chunks
class TestObjects():
"""
A base class for the test case for shared.objects,
containing tests for all the methods directly used in code.
"""
# pylint: disable=no-member
# A possibility of abstract test cases was rejected:
# https://bugs.python.org/issue17519
def test_set_get(self):
"""Put some objects and check presence and getting"""
obj = structure.Object(
int(time.time()), 42, 1, 1, object_payload=b'HELLO')
self.assertFalse(obj.vector in self.objects)
with self.assertRaises(KeyError):
self.objects[obj.vector] # pylint: disable=pointless-statement
self.assertIsNone(self.objects.get(obj.vector))
prev_len = len(self.objects)
self.objects[obj.vector] = obj
self.objects[obj.vector] = obj
self.assertTrue(self.objects)
self.assertEqual(len(self.objects), prev_len + 1)
self.assertTrue(obj.vector in self.objects)
obj1 = self.objects[obj.vector]
self.assertEqual(obj.vector, obj1.vector)
self.assertEqual(obj.data, obj1.data)
def test_biginv_chunks(self):
"""Check vectors_to_send method"""
needed = set()
for _ in range(10):
# wrong stream
obj = structure.Object(
int(time.time()) + 10, 42, 1, random.randint(1, 3),
object_payload=os.urandom(32))
self.objects[obj.vector] = obj
# expired
obj = structure.Object(
int(time.time()) - 10, 42, 1, 4, object_payload=os.urandom(32))
self.objects[obj.vector] = obj
# interesting
obj = structure.Object(
int(time.time()) + 10, 42, 1, 4, object_payload=os.urandom(32))
self.objects[obj.vector] = obj
needed.add(obj.vector)
self.assertEqual(
set(next(self.objects.biginv_chunks(stream=4))), needed)
self.assertTrue(
set(next(self.objects.biginv_chunks())).difference(needed))
def test_filter(self):
"""Check the objects filtering"""
needed = set()
tagged = set()
tag = b'@' * 32
for _ in range(10):
# wrong type
obj = structure.Object(
int(time.time()), 0, 1, 5, object_payload=os.urandom(64))
self.objects[obj.vector] = obj
# wrong type, but the proper tag
obj = structure.Object(
int(time.time()) - 11000, 0, 4, random.choice([1, 2, 3, 5]),
object_payload=tag + os.urandom(32))
self.objects[obj.vector] = obj
tagged.add(obj.vector)
# wrong stream
obj = structure.Object(
int(time.time()), 33, 1, 1, object_payload=os.urandom(32))
self.objects[obj.vector] = obj
# interesting
obj = structure.Object(
int(time.time()) - 11000, 33, 1, 5,
object_payload=os.urandom(32))
self.objects[obj.vector] = obj
needed.add(obj.vector)
# stream and type
self.assertTrue(needed)
for obj in self.objects.filter(5, 33):
needed.remove(obj.vector)
self.assertFalse(needed)
# tag
self.assertTrue(tagged)
for obj in self.objects.filter(tag=tag):
tagged.remove(obj.vector)
self.assertFalse(tagged)
def test_cleanup(self):
"""Check cleaning up"""
for _ in range(10):
obj = structure.Object(
int(time.time()) - random.randint(4, 5) * 3600,
42, 1, 6, object_payload=os.urandom(32))
self.objects[obj.vector] = obj
obj = structure.Object(
int(time.time()) - 2 * 3600,
42, 1, 6, object_payload=os.urandom(32))
self.objects[obj.vector] = obj
for obj in self.objects.values():
if obj.is_expired():
break
else:
self.fail('No objects found to delete')
self.objects.cleanup()
self.assertTrue(self.objects)
for obj in self.objects.values():
self.assertFalse(obj.is_expired())
def test_select(self):
"""Check the select method"""
pending = set()
questionable = set()
for _ in range(5):
obj = structure.Object(
int(time.time()) - 10, 42, 1, 7, object_payload=os.urandom(32))
questionable.add(obj.vector)
self.objects[obj.vector] = obj
obj = structure.Object(
int(time.time()) + 10, 42, 1, 7, object_payload=os.urandom(32))
questionable.add(obj.vector)
pending.add(obj.vector)
self.assertEqual(self.objects.select(questionable), pending)
class TestObjectsSQL(TestObjects, unittest.TestCase):
"""A test case for the sqlite inventory"""
@classmethod
def setUpClass(cls):
shared.data_directory = tempfile.gettempdir()
cls.tearDownClass()
cls.objects = sql.Inventory()
@classmethod
def tearDownClass(cls):
cls.objects = None
os.remove(os.path.join(shared.data_directory, 'objects.dat'))

View File

@ -18,11 +18,12 @@ sample_addr_data = unhexlify(
'260753000201300000000000000057ae1f90') '260753000201300000000000000057ae1f90')
# data for an object with expires_time 1697063939 # data for an object with expires_time 1697063939
# structure.Object( # structure.Object(expires_time, 42, 1, 2, object_payload=b'HELLO').data
# b'\x00' * 8, expires_time, 42, 1, 2, b'HELLO').to_bytes()
sample_object_data = unhexlify( sample_object_data = unhexlify(
'000000000000000000000000652724030000002a010248454c4c4f') '000000000000000000000000652724030000002a010248454c4c4f')
sample_object_expires = 1697063939
logging.basicConfig( logging.basicConfig(
level=shared.log_level, level=shared.log_level,
format='[%(asctime)s] [%(levelname)s] %(message)s') format='[%(asctime)s] [%(levelname)s] %(message)s')
@ -131,17 +132,24 @@ class TestStructure(unittest.TestCase):
message.Message(b'object', sample_object_data)) message.Message(b'object', sample_object_data))
self.assertEqual(obj.object_type, 42) self.assertEqual(obj.object_type, 42)
self.assertEqual(obj.stream_number, 2) self.assertEqual(obj.stream_number, 2)
self.assertEqual(obj.expires_time, 1697063939) self.assertEqual(obj.expires_time, sample_object_expires)
self.assertEqual(obj.object_payload, b'HELLO') self.assertEqual(obj.object_payload, b'HELLO')
obj = structure.Object( obj = structure.Object(
b'\x00' * 8, int(time.time() + 3000000), 42, 1, 1, b'HELLO') sample_object_expires, 42, 1, 2, object_payload=b'HELLO')
self.assertEqual(obj.data, sample_object_data)
self.assertEqual(obj.offset, 22)
self.assertEqual(obj.nonce, b'\x00' * 8)
self.assertTrue(obj.is_expired())
obj = structure.Object(
int(time.time() + 3000000), 42, 1, 1, object_payload=b'HELLO')
self.assertFalse(obj.is_valid()) self.assertFalse(obj.is_valid())
obj.expires_time = int(time.time() - 11000) obj.expires_time = int(time.time() - 11000)
self.assertFalse(obj.is_valid()) self.assertFalse(obj.is_valid())
obj = structure.Object( obj = structure.Object(
b'\x00' * 8, int(time.time() + 300), 42, 1, 2, b'HELLO') int(time.time() + 300), 42, 1, 2, object_payload=b'HELLO')
vector = obj.vector vector = obj.vector
proofofwork._worker(obj) # pylint: disable=protected-access proofofwork._worker(obj) # pylint: disable=protected-access
obj = shared.objects.popitem()[1] obj = shared.objects.popitem()[1]
@ -151,16 +159,17 @@ class TestStructure(unittest.TestCase):
shared.stream = 2 shared.stream = 2
self.assertTrue(obj.is_valid()) self.assertTrue(obj.is_valid())
obj.object_payload = \ # obj.data = struct.pack(...
b'TIGER, tiger, burning bright. In the forests of the night' # obj.object_payload = \
self.assertFalse(obj.is_valid()) # b'TIGER, tiger, burning bright. In the forests of the night'
# self.assertFalse(obj.is_valid())
def test_proofofwork(self): def test_proofofwork(self):
"""Check the main proofofwork call and worker""" """Check the main proofofwork call and worker"""
shared.vector_advertise_queue = queue.Queue() shared.vector_advertise_queue = queue.Queue()
obj = structure.Object( obj = structure.Object(
b'\x00' * 8, int(time.time() + 300), 42, 1, int(time.time() + 300), 42, 1, shared.stream,
shared.stream, b'HELLO') object_payload=b'HELLO')
start_time = time.time() start_time = time.time()
proofofwork.do_pow_and_publish(obj) proofofwork.do_pow_and_publish(obj)
try: try:
@ -189,6 +198,6 @@ class TestStructure(unittest.TestCase):
self.fail("No nonce found in the queue") self.fail("No nonce found in the queue")
obj = structure.Object( obj = structure.Object(
nonce, obj.expires_time, obj.object_type, obj.version, obj.expires_time, obj.object_type, obj.version, obj.stream_number,
obj.stream_number, obj.object_payload) object_payload=obj.object_payload, nonce=nonce)
self.assertTrue(obj.is_valid()) self.assertTrue(obj.is_valid())