WIP: Implementing sqlite objects storage #13
|
@ -277,22 +277,11 @@ class ConnectionBase(threading.Thread):
|
|||
if len(addr) != 0:
|
||||
self.send_queue.put(message.Addr(addr))
|
||||
|
||||
with shared.objects_lock:
|
||||
if len(shared.objects) > 0:
|
||||
to_send = {
|
||||
vector for vector in shared.objects.keys()
|
||||
if shared.objects[vector].expires_time > time.time()}
|
||||
while len(to_send) > 0:
|
||||
if len(to_send) > 10000:
|
||||
# We limit size of inv messaged to 10000 entries
|
||||
# because they might time out
|
||||
# in very slow networks (I2P)
|
||||
pack = random.sample(tuple(to_send), 10000)
|
||||
self.send_queue.put(message.Inv(pack))
|
||||
to_send.difference_update(pack)
|
||||
else:
|
||||
self.send_queue.put(message.Inv(to_send))
|
||||
to_send.clear()
|
||||
if shared.objects:
|
||||
for chunk in shared.objects.biginv_chunks(10000):
|
||||
# We limit size of inv messages to 10000 entries
|
||||
# because they might time out in very slow networks (I2P)
|
||||
self.send_queue.put(message.Inv(chunk))
|
||||
self.status = 'fully_established'
|
||||
|
||||
def _process_queue(self):
|
||||
|
@ -416,7 +405,7 @@ class ConnectionBase(threading.Thread):
|
|||
|
||||
def _request_objects(self):
|
||||
if self.vectors_to_get and len(self.vectors_requested) < 100:
|
||||
self.vectors_to_get.difference_update(shared.objects.keys())
|
||||
self.vectors_to_get = shared.objects.select(self.vectors_to_get)
|
||||
if not self.wait_until:
|
||||
nodes_count = (
|
||||
len(shared.node_pool) + len(shared.unchecked_node_pool))
|
||||
|
@ -463,12 +452,10 @@ class ConnectionBase(threading.Thread):
|
|||
else:
|
||||
to_send = self.vectors_to_send.copy()
|
||||
self.vectors_to_send.clear()
|
||||
with shared.objects_lock:
|
||||
for vector in to_send:
|
||||
obj = shared.objects.get(vector, None)
|
||||
obj = shared.objects.get(vector)
|
||||
if obj:
|
||||
self.send_queue.put(
|
||||
message.Message(b'object', obj.to_bytes()))
|
||||
self.send_queue.put(message.Message(b'object', obj.data))
|
||||
|
||||
|
||||
class Connection(ConnectionBase):
|
||||
|
@ -476,9 +463,7 @@ class Connection(ConnectionBase):
|
|||
def _process_msg_inv(self, m):
|
||||
inv = message.Inv.from_message(m)
|
||||
logging.debug('%s:%s -> %s', self.host_print, self.port, inv)
|
||||
to_get = inv.vectors.copy()
|
||||
to_get.difference_update(shared.objects.keys())
|
||||
self.vectors_to_get.update(to_get)
|
||||
self.vectors_to_get.update(shared.objects.select(inv.vectors))
|
||||
# Do not send objects they already have.
|
||||
self.vectors_to_send.difference_update(inv.vectors)
|
||||
|
||||
|
@ -487,8 +472,7 @@ class Connection(ConnectionBase):
|
|||
logging.debug('%s:%s -> %s', self.host_print, self.port, obj)
|
||||
self.vectors_requested.pop(obj.vector, None)
|
||||
self.vectors_to_get.discard(obj.vector)
|
||||
if obj.is_valid() and obj.vector not in shared.objects:
|
||||
with shared.objects_lock:
|
||||
if obj.is_valid():
|
||||
shared.objects[obj.vector] = obj
|
||||
if (
|
||||
obj.object_type == shared.i2p_dest_obj_type
|
||||
|
@ -501,6 +485,7 @@ class Connection(ConnectionBase):
|
|||
logging.debug(dest)
|
||||
shared.i2p_unchecked_node_pool.add((dest, 'i2p'))
|
||||
shared.vector_advertise_queue.put(obj.vector)
|
||||
shared.objects.check(obj.vector)
|
||||
|
||||
def _process_msg_getdata(self, m):
|
||||
getdata = message.GetData.from_message(m)
|
||||
|
|
|
@ -8,7 +8,7 @@ import os
|
|||
import signal
|
||||
import socket
|
||||
|
||||
from . import i2p, shared
|
||||
from . import i2p, shared, sql
|
||||
from .advertiser import Advertiser
|
||||
from .manager import Manager
|
||||
from .listener import Listener
|
||||
|
@ -161,8 +161,7 @@ def start_ip_listener():
|
|||
def start_i2p_listener():
|
||||
"""Starts I2P threads"""
|
||||
# Grab I2P destinations from old object file
|
||||
for obj in shared.objects.values():
|
||||
if obj.object_type == shared.i2p_dest_obj_type:
|
||||
for obj in shared.objects.filter(object_type=shared.i2p_dest_obj_type):
|
||||
shared.i2p_unchecked_node_pool.add((
|
||||
base64.b64encode(obj.object_payload, altchars=b'-~'), 'i2p'))
|
||||
|
||||
|
@ -245,6 +244,8 @@ def main():
|
|||
if shared.ip_enabled and not shared.trusted_peer:
|
||||
bootstrap_from_dns()
|
||||
|
||||
shared.objects = sql.Inventory()
|
||||
|
||||
if shared.i2p_enabled:
|
||||
# We are starting it before cleaning expired objects
|
||||
# so we can collect I2P destination objects
|
||||
|
|
|
@ -23,7 +23,6 @@ class Manager(threading.Thread):
|
|||
self.bootstrap_pool = []
|
||||
self.last_cleaned_objects = time.time()
|
||||
self.last_cleaned_connections = time.time()
|
||||
self.last_pickled_objects = time.time()
|
||||
self.last_pickled_nodes = time.time()
|
||||
# Publish destination 5-15 minutes after start
|
||||
self.last_published_i2p_destination = \
|
||||
|
@ -36,23 +35,21 @@ class Manager(threading.Thread):
|
|||
|
||||
def run(self):
|
||||
self.load_data()
|
||||
self.clean_objects()
|
||||
shared.objects.cleanup()
|
||||
self.fill_bootstrap_pool()
|
||||
while True:
|
||||
time.sleep(0.8)
|
||||
now = time.time()
|
||||
if shared.shutting_down:
|
||||
logging.debug('Shutting down Manager')
|
||||
shared.objects.flush()
|
||||
break
|
||||
if now - self.last_cleaned_objects > 90:
|
||||
self.clean_objects()
|
||||
shared.objects.cleanup()
|
||||
self.last_cleaned_objects = now
|
||||
if now - self.last_cleaned_connections > 2:
|
||||
self.manage_connections()
|
||||
self.last_cleaned_connections = now
|
||||
if now - self.last_pickled_objects > 100:
|
||||
self.pickle_objects()
|
||||
self.last_pickled_objects = now
|
||||
if now - self.last_pickled_nodes > 60:
|
||||
self.pickle_nodes()
|
||||
self.last_pickled_nodes = now
|
||||
|
@ -60,17 +57,6 @@ class Manager(threading.Thread):
|
|||
self.publish_i2p_destination()
|
||||
self.last_published_i2p_destination = now
|
||||
|
||||
@staticmethod
|
||||
def clean_objects():
|
||||
for vector in set(shared.objects):
|
||||
# FIXME: no need to check is_valid() here
|
||||
if shared.objects[vector].is_expired():
|
||||
logging.debug(
|
||||
'Deleted expired object: %s',
|
||||
base64.b16encode(vector).decode())
|
||||
with shared.objects_lock:
|
||||
del shared.objects[vector]
|
||||
|
||||
def manage_connections(self):
|
||||
"""Open new connections if needed, remove closed ones"""
|
||||
hosts = set()
|
||||
|
@ -100,6 +86,8 @@ class Manager(threading.Thread):
|
|||
outgoing_connections = 0
|
||||
for c in shared.connections.copy():
|
||||
if not c.is_alive() or c.status == 'disconnected':
|
||||
shared.objects.check(
|
||||
*(c.vectors_to_get | c.vectors_requested.keys()))
|
||||
with shared.connections_lock:
|
||||
shared.connections.remove(c)
|
||||
else:
|
||||
|
@ -177,17 +165,6 @@ class Manager(threading.Thread):
|
|||
@staticmethod
|
||||
def load_data():
|
||||
"""Loads initial nodes and data, stored in files between sessions"""
|
||||
try:
|
||||
with open(
|
||||
os.path.join(shared.data_directory, 'objects.pickle'), 'br'
|
||||
) as src:
|
||||
shared.objects = pickle.load(src)
|
||||
except FileNotFoundError:
|
||||
pass # first start
|
||||
except Exception:
|
||||
logging.warning(
|
||||
'Error while loading objects from disk.', exc_info=True)
|
||||
|
||||
try:
|
||||
with open(
|
||||
os.path.join(shared.data_directory, 'nodes.pickle'), 'br'
|
||||
|
@ -227,18 +204,6 @@ class Manager(threading.Thread):
|
|||
(row[0].encode(), 'i2p') for row in reader}
|
||||
shared.i2p_node_pool.update(shared.i2p_core_nodes)
|
||||
|
||||
@staticmethod
|
||||
def pickle_objects():
|
||||
try:
|
||||
with open(
|
||||
os.path.join(shared.data_directory, 'objects.pickle'), 'bw'
|
||||
) as dst:
|
||||
with shared.objects_lock:
|
||||
pickle.dump(shared.objects, dst, protocol=3)
|
||||
logging.debug('Saved objects')
|
||||
except Exception:
|
||||
logging.warning('Error while saving objects', exc_info=True)
|
||||
|
||||
@staticmethod
|
||||
def pickle_nodes():
|
||||
if len(shared.node_pool) > 10000:
|
||||
|
@ -275,7 +240,7 @@ class Manager(threading.Thread):
|
|||
dest_pub_raw = base64.b64decode(
|
||||
shared.i2p_dest_pub, altchars=b'-~')
|
||||
obj = structure.Object(
|
||||
b'\x00' * 8, int(time.time() + 2 * 3600),
|
||||
int(time.time() + 2 * 3600),
|
||||
shared.i2p_dest_obj_type, shared.i2p_dest_obj_version,
|
||||
shared.stream, dest_pub_raw)
|
||||
shared.stream, object_payload=dest_pub_raw)
|
||||
proofofwork.do_pow_and_publish(obj)
|
||||
|
|
|
@ -39,12 +39,12 @@ def _worker(obj):
|
|||
logging.debug(
|
||||
'Finished doing POW, nonce: %s, time: %ss', nonce, time.time() - t)
|
||||
obj = structure.Object(
|
||||
nonce, obj.expires_time, obj.object_type, obj.version,
|
||||
obj.stream_number, obj.object_payload)
|
||||
obj.expires_time, obj.object_type, obj.version, obj.stream_number,
|
||||
object_payload=obj.object_payload, nonce=nonce
|
||||
)
|
||||
logging.debug(
|
||||
'Object vector is %s', base64.b16encode(obj.vector).decode())
|
||||
|
||||
with shared.objects_lock:
|
||||
shared.objects[obj.vector] = obj
|
||||
shared.vector_advertise_queue.put(obj.vector)
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ protocol_version = 3
|
|||
services = 3 # NODE_NETWORK, NODE_SSL
|
||||
stream = 1
|
||||
nonce = os.urandom(8)
|
||||
user_agent = b'/MiNode:0.3.3/'
|
||||
user_agent = b'/MiNode:0.3.5/'
|
||||
timeout = 600
|
||||
header_length = 24
|
||||
i2p_dest_obj_type = 0x493250
|
||||
|
@ -63,4 +63,3 @@ outgoing_connections = 8
|
|||
connection_limit = 250
|
||||
|
||||
objects = {}
|
||||
objects_lock = threading.Lock()
|
||||
|
|
257
minode/sql.py
Normal file
257
minode/sql.py
Normal file
|
@ -0,0 +1,257 @@
|
|||
"""Inventory implementation using sqlite"""
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
import sqlite3
|
||||
import threading
|
||||
import time
|
||||
|
||||
from . import shared, structure
|
||||
|
||||
sqlite3.threadsafety = 3
|
||||
|
||||
|
||||
class Inventory():
|
||||
"""sqlite inventory"""
|
||||
def __init__(self):
|
||||
self._lock = threading.Lock()
|
||||
self._deleted = 0
|
||||
self._last = {}
|
||||
self._pending = set()
|
||||
self._db = sqlite3.connect(
|
||||
os.path.join(shared.data_directory, 'objects.dat'),
|
||||
check_same_thread=False
|
||||
)
|
||||
self._db.executescript("""
|
||||
BEGIN;
|
||||
CREATE TABLE IF NOT EXISTS status
|
||||
(key text, value integer, UNIQUE(key) ON CONFLICT REPLACE);
|
||||
INSERT INTO status VALUES ('version', 1);
|
||||
CREATE TABLE IF NOT EXISTS objects
|
||||
(vector unique, expires integer, type integer, version integer,
|
||||
stream integer, tag, data, offset integer);
|
||||
COMMIT;
|
||||
""")
|
||||
self.rowid = len(self) or None
|
||||
try:
|
||||
self.lastvacuumtime = self._db.execute(
|
||||
"SELECT value FROM status WHERE key='lastvacuumtime'"
|
||||
).fetchone()[0]
|
||||
except TypeError:
|
||||
self.lastvacuumtime = int(time.time())
|
||||
self._db.execute(
|
||||
"INSERT INTO status VALUES ('lastvacuumtime', ?)",
|
||||
(self.lastvacuumtime,)
|
||||
)
|
||||
self._db.commit()
|
||||
self._db.row_factory = self.__object
|
||||
|
||||
@staticmethod
|
||||
def __object(cursor, row):
|
||||
if len(cursor.description) != 8:
|
||||
return row
|
||||
vector, expires, obj_type, version, stream, tag, data, offset = row
|
||||
return structure.Object(
|
||||
expires, obj_type, version, stream, data, offset,
|
||||
tag=tag, vector=vector)
|
||||
|
||||
def check(self, *vectors):
|
||||
"""Remove given vectors from pending"""
|
||||
with self._lock:
|
||||
for vector in vectors:
|
||||
self._pending.discard(vector)
|
||||
|
||||
def cleanup(self):
|
||||
"""Remove expired objects"""
|
||||
if len(self._pending) > 100:
|
||||
logging.info(
|
||||
'Not cleaning up, %s objects pending', len(self._pending))
|
||||
return
|
||||
for vector in set(self._last):
|
||||
if self._last[vector].is_expired():
|
||||
logging.debug(
|
||||
'Deleted expired object: %s',
|
||||
base64.b16encode(vector).decode())
|
||||
with self._lock:
|
||||
del self._last[vector]
|
||||
if len(self._last) > 1000:
|
||||
self.flush()
|
||||
return
|
||||
with self._lock:
|
||||
now = int(time.time())
|
||||
cur = self._db.execute(
|
||||
'DELETE FROM objects WHERE expires < ?', (now - 3 * 3600,))
|
||||
self._db.commit()
|
||||
self._deleted += cur.rowcount
|
||||
(logging.info if self._pending else logging.debug)(
|
||||
'Deleted %s expired objects, %s pending',
|
||||
cur.rowcount, len(self._pending))
|
||||
# conditional vacuum and validity check (TODO)
|
||||
# every 24 hours or after deleting a lot of items
|
||||
if self._deleted > 10000 or self.lastvacuumtime < now - 86400:
|
||||
logging.info('Doing VACUUM for objects')
|
||||
cur.execute('VACUUM')
|
||||
cur.execute(
|
||||
"INSERT INTO status VALUES ('lastvacuumtime', ?)", (now,))
|
||||
self._db.commit()
|
||||
self._deleted = 0
|
||||
self.lastvacuumtime = now
|
||||
|
||||
def flush(self):
|
||||
"""Write cached objects to the database"""
|
||||
with self._lock:
|
||||
cur = self._db.executemany(
|
||||
'INSERT INTO objects VALUES (?,?,?,?,?,?,?,?)',
|
||||
((obj.vector, obj.expires_time, obj.object_type,
|
||||
obj.version, obj.stream_number, obj.tag, obj.data,
|
||||
obj.offset) for obj in self._last.values()))
|
||||
self._db.commit()
|
||||
self.rowid = cur.lastrowid
|
||||
self._last = {}
|
||||
|
||||
def filter(self, stream=None, object_type=None, tag=None):
|
||||
"""Generator of objects with the given parameters"""
|
||||
def fits(obj):
|
||||
if stream and obj.stream_number != stream:
|
||||
return False
|
||||
if object_type is not None and obj.object_type != object_type:
|
||||
return False
|
||||
if tag and obj.tag != tag:
|
||||
return False
|
||||
return True
|
||||
|
||||
yield from filter(fits, self._last.values())
|
||||
|
||||
clauses = []
|
||||
if stream:
|
||||
clauses.append(('stream = ?', stream))
|
||||
if object_type is not None:
|
||||
clauses.append(('type = ?', object_type))
|
||||
if tag:
|
||||
clauses.append(('tag = ?', tag))
|
||||
|
||||
clauses, params = zip(*clauses)
|
||||
|
||||
yield from self._db.execute(
|
||||
'SELECT * FROM objects WHERE ' # nosec B608
|
||||
+ ' AND '.join(clauses), params)
|
||||
|
||||
def select(self, vectors):
|
||||
"""Select new vectors from the given set"""
|
||||
chunk_size = 999
|
||||
with self._lock:
|
||||
vectors.difference_update(self._last)
|
||||
keys = tuple(vectors)
|
||||
for i in range(0, len(vectors), chunk_size):
|
||||
chunk = keys[i:i+chunk_size]
|
||||
cur = self._db.execute(
|
||||
'SELECT vector FROM objects WHERE vector IN' # nosec B608
|
||||
' ({})'.format(','.join('?' * len(chunk))),
|
||||
chunk)
|
||||
for v, in cur:
|
||||
vectors.remove(v)
|
||||
self._pending.update(vectors)
|
||||
return vectors
|
||||
|
||||
def biginv_chunks(self, chunk_size=10000, stream=None):
|
||||
"""Generator of vector lists for making the biginv"""
|
||||
if stream is None:
|
||||
stream = shared.stream
|
||||
now = int(time.time())
|
||||
cur = self._db.execute(
|
||||
'SELECT vector FROM objects WHERE expires > ? AND stream = ?'
|
||||
' ORDER BY random()', (now, stream)
|
||||
)
|
||||
cur.arraysize = chunk_size
|
||||
while True:
|
||||
vectors = cur.fetchmany()
|
||||
if not vectors:
|
||||
# TODO: append to the last short result,
|
||||
# check that _last is shorter than the chunk_size
|
||||
# (should be < 1000)
|
||||
if self._last:
|
||||
yield [
|
||||
obj.vector for obj in self._last.values()
|
||||
if obj.stream_number == stream
|
||||
and obj.expires_time > now]
|
||||
return
|
||||
yield [v for v, in vectors]
|
||||
|
||||
def get(self, vector, default=None):
|
||||
try:
|
||||
return self[vector]
|
||||
except KeyError:
|
||||
return default
|
||||
|
||||
def keys(self):
|
||||
yield from self._last
|
||||
for vector, in self._db.execute('SELECT vector FROM objects'):
|
||||
yield vector
|
||||
|
||||
def values(self):
|
||||
yield from self._last.values()
|
||||
yield from self._db.execute('SELECT * FROM objects')
|
||||
|
||||
def popitem(self):
|
||||
try:
|
||||
return self._last.popitem()
|
||||
except KeyError:
|
||||
pass
|
||||
if not self.rowid:
|
||||
raise KeyError('empty')
|
||||
cur = self._db.execute(
|
||||
'SELECT vector FROM objects WHERE ROWID = ?', (self.rowid,))
|
||||
vector = cur.fetchone()[0]
|
||||
obj = self.get(vector)
|
||||
del self[vector]
|
||||
return (vector, obj)
|
||||
|
||||
def __contains__(self, vector):
|
||||
if vector in self._last:
|
||||
return True
|
||||
return self._db.execute(
|
||||
'SELECT vector FROM objects WHERE vector = ?', (vector,)
|
||||
).fetchone() is not None
|
||||
|
||||
def __getitem__(self, vector):
|
||||
try:
|
||||
return self._last[vector]
|
||||
except KeyError:
|
||||
pass
|
||||
item = self._db.execute(
|
||||
'SELECT * FROM objects WHERE vector = ?', (vector,)).fetchone()
|
||||
if item is None:
|
||||
raise KeyError(vector)
|
||||
return item
|
||||
|
||||
def __delitem__(self, vector):
|
||||
try:
|
||||
del self._last[vector]
|
||||
return
|
||||
except KeyError:
|
||||
pass
|
||||
with self._lock: # KeyError
|
||||
self._db.execute('DELETE FROM objects WHERE vector = ?', (vector,))
|
||||
self._db.commit()
|
||||
self.rowid = len(self)
|
||||
|
||||
def __setitem__(self, vector, obj):
|
||||
if vector in self:
|
||||
return
|
||||
with self._lock:
|
||||
self._last[vector] = obj
|
||||
|
||||
def __bool__(self):
|
||||
if self._last:
|
||||
return True
|
||||
return self._db.execute(
|
||||
'SELECT vector from objects LIMIT 1').fetchone() is not None
|
||||
|
||||
def __len__(self):
|
||||
cur = self._db.execute('SELECT count(*) FROM objects')
|
||||
return cur.fetchone()[0] + len(self._last)
|
||||
|
||||
def __del__(self):
|
||||
self.flush()
|
||||
self._db.close()
|
|
@ -48,21 +48,26 @@ class VarInt():
|
|||
class Object():
|
||||
"""The 'object' message payload"""
|
||||
def __init__(
|
||||
self, nonce, expires_time, object_type, version,
|
||||
stream_number, object_payload
|
||||
self, expires_time, object_type, version, stream_number,
|
||||
data=None, offset=None, object_payload=None,
|
||||
tag=None, nonce=b'\x00' * 8, vector=None
|
||||
):
|
||||
self.nonce = nonce
|
||||
self.expires_time = expires_time
|
||||
self.object_type = object_type
|
||||
self.version = version
|
||||
self.stream_number = stream_number
|
||||
self.object_payload = object_payload
|
||||
self.vector = hashlib.sha512(hashlib.sha512(
|
||||
self.to_bytes()).digest()).digest()[:32]
|
||||
if not data:
|
||||
data, offset = self.to_bytes(object_payload)
|
||||
self.data = data
|
||||
self.offset = offset
|
||||
self.vector = vector or hashlib.sha512(hashlib.sha512(
|
||||
self.data).digest()).digest()[:32]
|
||||
|
||||
self.tag = (
|
||||
self.tag = tag or (
|
||||
# broadcast from version 5 and pubkey/getpukey from version 4
|
||||
self.object_payload[:32] if object_type == 3 and version == 5
|
||||
(object_payload or self.object_payload)[:32]
|
||||
if object_type == 3 and version == 5
|
||||
or (object_type in (0, 1) and version == 4)
|
||||
else None)
|
||||
|
||||
|
@ -73,20 +78,25 @@ class Object():
|
|||
@classmethod
|
||||
def from_message(cls, m):
|
||||
"""Decode message payload"""
|
||||
payload = m.payload
|
||||
nonce, expires_time, object_type = struct.unpack('>8sQL', payload[:20])
|
||||
payload = payload[20:]
|
||||
version_varint_length = VarInt.length(payload[0])
|
||||
version = VarInt.from_bytes(payload[:version_varint_length]).n
|
||||
payload = payload[version_varint_length:]
|
||||
stream_number_varint_length = VarInt.length(payload[0])
|
||||
data = m.payload
|
||||
nonce, expires_time, object_type = struct.unpack('>8sQL', data[:20])
|
||||
version_varint_length = VarInt.length(data[20])
|
||||
offset = 20 + version_varint_length
|
||||
version = VarInt.from_bytes(data[20:offset]).n
|
||||
stream_number_varint_length = VarInt.length(data[offset])
|
||||
stream_number = VarInt.from_bytes(
|
||||
payload[:stream_number_varint_length]).n
|
||||
payload = payload[stream_number_varint_length:]
|
||||
data[offset:offset+stream_number_varint_length]).n
|
||||
offset += stream_number_varint_length
|
||||
return cls(
|
||||
nonce, expires_time, object_type, version, stream_number, payload)
|
||||
expires_time, object_type, version, stream_number,
|
||||
data, offset, nonce=nonce
|
||||
)
|
||||
|
||||
def to_bytes(self):
|
||||
@property
|
||||
def object_payload(self):
|
||||
return self.data[self.offset:]
|
||||
|
||||
def to_bytes(self, object_payload):
|
||||
"""Serialize to bytes"""
|
||||
payload = b''
|
||||
payload += self.nonce
|
||||
|
@ -94,8 +104,7 @@ class Object():
|
|||
payload += (
|
||||
VarInt(self.version).to_bytes()
|
||||
+ VarInt(self.stream_number).to_bytes())
|
||||
payload += self.object_payload
|
||||
return payload
|
||||
return payload + object_payload, len(payload)
|
||||
|
||||
def is_expired(self):
|
||||
"""Check if object's TTL is expired"""
|
||||
|
@ -138,7 +147,7 @@ class Object():
|
|||
|
||||
def pow_target(self):
|
||||
"""Compute PoW target"""
|
||||
data = self.to_bytes()[8:]
|
||||
data = self.data[8:]
|
||||
length = len(data) + 8 + shared.payload_length_extra_bytes
|
||||
dt = max(self.expires_time - time.time(), 0)
|
||||
return int(
|
||||
|
@ -148,7 +157,7 @@ class Object():
|
|||
|
||||
def pow_initial_hash(self):
|
||||
"""Compute the initial hash for PoW"""
|
||||
return hashlib.sha512(self.to_bytes()[8:]).digest()
|
||||
return hashlib.sha512(self.data[8:]).digest()
|
||||
|
||||
|
||||
class NetAddrNoPrefix():
|
||||
|
|
|
@ -8,7 +8,7 @@ import tempfile
|
|||
import time
|
||||
from contextlib import contextmanager
|
||||
|
||||
from minode import connection, main, shared
|
||||
from minode import connection, main, shared, sql
|
||||
from minode.listener import Listener
|
||||
from minode.manager import Manager
|
||||
|
||||
|
@ -66,11 +66,14 @@ class TestNetwork(unittest.TestCase):
|
|||
def setUp(self):
|
||||
shared.core_nodes.clear()
|
||||
shared.unchecked_node_pool.clear()
|
||||
shared.objects = {}
|
||||
try:
|
||||
os.remove(os.path.join(shared.data_directory, 'objects.pickle'))
|
||||
os.remove(os.path.join(shared.data_directory, 'objects.dat'))
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
shared.objects = sql.Inventory()
|
||||
|
||||
def tearDown(self):
|
||||
shared.objects.flush()
|
||||
|
||||
def _make_initial_nodes(self):
|
||||
Manager.load_data()
|
||||
|
|
165
minode/tests/test_objects.py
Normal file
165
minode/tests/test_objects.py
Normal file
|
@ -0,0 +1,165 @@
|
|||
"""Tests for the Inventory implementation"""
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
import time
|
||||
import unittest
|
||||
|
||||
from minode import sql, shared, structure
|
||||
|
||||
|
||||
# + __bool__
|
||||
# + __contains__
|
||||
# + __getitem__
|
||||
# + __setitem__
|
||||
# = cleanup
|
||||
# + get
|
||||
# + filter
|
||||
# = select
|
||||
# + biginv_chunks
|
||||
|
||||
|
||||
class TestObjects():
|
||||
"""
|
||||
A base class for the test case for shared.objects,
|
||||
containing tests for all the methods directly used in code.
|
||||
"""
|
||||
# pylint: disable=no-member
|
||||
# A possibility of abstract test cases was rejected:
|
||||
# https://bugs.python.org/issue17519
|
||||
|
||||
def test_set_get(self):
|
||||
"""Put some objects and check presence and getting"""
|
||||
obj = structure.Object(
|
||||
int(time.time()), 42, 1, 1, object_payload=b'HELLO')
|
||||
self.assertFalse(obj.vector in self.objects)
|
||||
with self.assertRaises(KeyError):
|
||||
self.objects[obj.vector] # pylint: disable=pointless-statement
|
||||
self.assertIsNone(self.objects.get(obj.vector))
|
||||
prev_len = len(self.objects)
|
||||
self.objects[obj.vector] = obj
|
||||
self.objects[obj.vector] = obj
|
||||
self.assertTrue(self.objects)
|
||||
self.assertEqual(len(self.objects), prev_len + 1)
|
||||
self.assertTrue(obj.vector in self.objects)
|
||||
obj1 = self.objects[obj.vector]
|
||||
self.assertEqual(obj.vector, obj1.vector)
|
||||
self.assertEqual(obj.data, obj1.data)
|
||||
|
||||
def test_biginv_chunks(self):
|
||||
"""Check vectors_to_send method"""
|
||||
needed = set()
|
||||
for _ in range(10):
|
||||
# wrong stream
|
||||
obj = structure.Object(
|
||||
int(time.time()) + 10, 42, 1, random.randint(1, 3),
|
||||
object_payload=os.urandom(32))
|
||||
self.objects[obj.vector] = obj
|
||||
# expired
|
||||
obj = structure.Object(
|
||||
int(time.time()) - 10, 42, 1, 4, object_payload=os.urandom(32))
|
||||
self.objects[obj.vector] = obj
|
||||
# interesting
|
||||
obj = structure.Object(
|
||||
int(time.time()) + 10, 42, 1, 4, object_payload=os.urandom(32))
|
||||
self.objects[obj.vector] = obj
|
||||
needed.add(obj.vector)
|
||||
|
||||
self.assertEqual(
|
||||
set(next(self.objects.biginv_chunks(stream=4))), needed)
|
||||
self.assertTrue(
|
||||
set(next(self.objects.biginv_chunks())).difference(needed))
|
||||
|
||||
def test_filter(self):
|
||||
"""Check the objects filtering"""
|
||||
needed = set()
|
||||
tagged = set()
|
||||
tag = b'@' * 32
|
||||
for _ in range(10):
|
||||
# wrong type
|
||||
obj = structure.Object(
|
||||
int(time.time()), 0, 1, 5, object_payload=os.urandom(64))
|
||||
self.objects[obj.vector] = obj
|
||||
# wrong type, but the proper tag
|
||||
obj = structure.Object(
|
||||
int(time.time()) - 11000, 0, 4, random.choice([1, 2, 3, 5]),
|
||||
object_payload=tag + os.urandom(32))
|
||||
self.objects[obj.vector] = obj
|
||||
tagged.add(obj.vector)
|
||||
# wrong stream
|
||||
obj = structure.Object(
|
||||
int(time.time()), 33, 1, 1, object_payload=os.urandom(32))
|
||||
self.objects[obj.vector] = obj
|
||||
# interesting
|
||||
obj = structure.Object(
|
||||
int(time.time()) - 11000, 33, 1, 5,
|
||||
object_payload=os.urandom(32))
|
||||
self.objects[obj.vector] = obj
|
||||
needed.add(obj.vector)
|
||||
|
||||
# stream and type
|
||||
self.assertTrue(needed)
|
||||
for obj in self.objects.filter(5, 33):
|
||||
needed.remove(obj.vector)
|
||||
self.assertFalse(needed)
|
||||
|
||||
# tag
|
||||
self.assertTrue(tagged)
|
||||
for obj in self.objects.filter(tag=tag):
|
||||
tagged.remove(obj.vector)
|
||||
self.assertFalse(tagged)
|
||||
|
||||
def test_cleanup(self):
|
||||
"""Check cleaning up"""
|
||||
for _ in range(10):
|
||||
obj = structure.Object(
|
||||
int(time.time()) - random.randint(4, 5) * 3600,
|
||||
42, 1, 6, object_payload=os.urandom(32))
|
||||
self.objects[obj.vector] = obj
|
||||
obj = structure.Object(
|
||||
int(time.time()) - 2 * 3600,
|
||||
42, 1, 6, object_payload=os.urandom(32))
|
||||
self.objects[obj.vector] = obj
|
||||
|
||||
for obj in self.objects.values():
|
||||
if obj.is_expired():
|
||||
break
|
||||
else:
|
||||
self.fail('No objects found to delete')
|
||||
|
||||
self.objects.cleanup()
|
||||
self.assertTrue(self.objects)
|
||||
for obj in self.objects.values():
|
||||
self.assertFalse(obj.is_expired())
|
||||
|
||||
def test_select(self):
|
||||
"""Check the select method"""
|
||||
pending = set()
|
||||
questionable = set()
|
||||
|
||||
for _ in range(5):
|
||||
obj = structure.Object(
|
||||
int(time.time()) - 10, 42, 1, 7, object_payload=os.urandom(32))
|
||||
questionable.add(obj.vector)
|
||||
self.objects[obj.vector] = obj
|
||||
obj = structure.Object(
|
||||
int(time.time()) + 10, 42, 1, 7, object_payload=os.urandom(32))
|
||||
questionable.add(obj.vector)
|
||||
pending.add(obj.vector)
|
||||
|
||||
self.assertEqual(self.objects.select(questionable), pending)
|
||||
|
||||
|
||||
class TestObjectsSQL(TestObjects, unittest.TestCase):
|
||||
"""A test case for the sqlite inventory"""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
shared.data_directory = tempfile.gettempdir()
|
||||
cls.tearDownClass()
|
||||
cls.objects = sql.Inventory()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.objects = None
|
||||
os.remove(os.path.join(shared.data_directory, 'objects.dat'))
|
|
@ -18,11 +18,12 @@ sample_addr_data = unhexlify(
|
|||
'260753000201300000000000000057ae1f90')
|
||||
|
||||
# data for an object with expires_time 1697063939
|
||||
# structure.Object(
|
||||
# b'\x00' * 8, expires_time, 42, 1, 2, b'HELLO').to_bytes()
|
||||
# structure.Object(expires_time, 42, 1, 2, object_payload=b'HELLO').data
|
||||
sample_object_data = unhexlify(
|
||||
'000000000000000000000000652724030000002a010248454c4c4f')
|
||||
|
||||
sample_object_expires = 1697063939
|
||||
|
||||
logging.basicConfig(
|
||||
level=shared.log_level,
|
||||
format='[%(asctime)s] [%(levelname)s] %(message)s')
|
||||
|
@ -131,17 +132,24 @@ class TestStructure(unittest.TestCase):
|
|||
message.Message(b'object', sample_object_data))
|
||||
self.assertEqual(obj.object_type, 42)
|
||||
self.assertEqual(obj.stream_number, 2)
|
||||
self.assertEqual(obj.expires_time, 1697063939)
|
||||
self.assertEqual(obj.expires_time, sample_object_expires)
|
||||
self.assertEqual(obj.object_payload, b'HELLO')
|
||||
|
||||
obj = structure.Object(
|
||||
b'\x00' * 8, int(time.time() + 3000000), 42, 1, 1, b'HELLO')
|
||||
sample_object_expires, 42, 1, 2, object_payload=b'HELLO')
|
||||
self.assertEqual(obj.data, sample_object_data)
|
||||
self.assertEqual(obj.offset, 22)
|
||||
self.assertEqual(obj.nonce, b'\x00' * 8)
|
||||
self.assertTrue(obj.is_expired())
|
||||
|
||||
obj = structure.Object(
|
||||
int(time.time() + 3000000), 42, 1, 1, object_payload=b'HELLO')
|
||||
self.assertFalse(obj.is_valid())
|
||||
obj.expires_time = int(time.time() - 11000)
|
||||
self.assertFalse(obj.is_valid())
|
||||
|
||||
obj = structure.Object(
|
||||
b'\x00' * 8, int(time.time() + 300), 42, 1, 2, b'HELLO')
|
||||
int(time.time() + 300), 42, 1, 2, object_payload=b'HELLO')
|
||||
vector = obj.vector
|
||||
proofofwork._worker(obj) # pylint: disable=protected-access
|
||||
obj = shared.objects.popitem()[1]
|
||||
|
@ -151,16 +159,17 @@ class TestStructure(unittest.TestCase):
|
|||
shared.stream = 2
|
||||
self.assertTrue(obj.is_valid())
|
||||
|
||||
obj.object_payload = \
|
||||
b'TIGER, tiger, burning bright. In the forests of the night'
|
||||
self.assertFalse(obj.is_valid())
|
||||
# obj.data = struct.pack(...
|
||||
# obj.object_payload = \
|
||||
# b'TIGER, tiger, burning bright. In the forests of the night'
|
||||
# self.assertFalse(obj.is_valid())
|
||||
|
||||
def test_proofofwork(self):
|
||||
"""Check the main proofofwork call and worker"""
|
||||
shared.vector_advertise_queue = queue.Queue()
|
||||
obj = structure.Object(
|
||||
b'\x00' * 8, int(time.time() + 300), 42, 1,
|
||||
shared.stream, b'HELLO')
|
||||
int(time.time() + 300), 42, 1, shared.stream,
|
||||
object_payload=b'HELLO')
|
||||
start_time = time.time()
|
||||
proofofwork.do_pow_and_publish(obj)
|
||||
try:
|
||||
|
@ -189,6 +198,6 @@ class TestStructure(unittest.TestCase):
|
|||
self.fail("No nonce found in the queue")
|
||||
|
||||
obj = structure.Object(
|
||||
nonce, obj.expires_time, obj.object_type, obj.version,
|
||||
obj.stream_number, obj.object_payload)
|
||||
obj.expires_time, obj.object_type, obj.version, obj.stream_number,
|
||||
object_payload=obj.object_payload, nonce=nonce)
|
||||
self.assertTrue(obj.is_valid())
|
||||
|
|
Loading…
Reference in New Issue
Block a user