Improve docstrings in message and structure and add more

This commit is contained in:
Lee Miller 2023-08-12 03:34:20 +03:00
parent 3788b12a28
commit dd2b0b89af
Signed by: lee.miller
GPG Key ID: 4F97A5EA88F4AB63
2 changed files with 16 additions and 6 deletions

View File

@ -9,7 +9,7 @@ from . import shared, structure
class Header(): class Header():
"""Common message header""" """Message header structure"""
def __init__(self, command, payload_length, payload_checksum): def __init__(self, command, payload_length, payload_checksum):
self.command = command self.command = command
self.payload_length = payload_length self.payload_length = payload_length
@ -24,6 +24,7 @@ class Header():
base64.b16encode(self.payload_checksum).decode()) base64.b16encode(self.payload_checksum).decode())
def to_bytes(self): def to_bytes(self):
"""Serialize to bytes"""
b = b'' b = b''
b += shared.magic_bytes b += shared.magic_bytes
b += self.command.ljust(12, b'\x00') b += self.command.ljust(12, b'\x00')
@ -33,6 +34,7 @@ class Header():
@classmethod @classmethod
def from_bytes(cls, b): def from_bytes(cls, b):
"""Parse from bytes"""
magic_bytes, command, payload_length, payload_checksum = struct.unpack( magic_bytes, command, payload_length, payload_checksum = struct.unpack(
'>4s12sL4s', b) '>4s12sL4s', b)
@ -59,6 +61,7 @@ class Message():
base64.b16encode(self.payload_checksum).decode()) base64.b16encode(self.payload_checksum).decode())
def to_bytes(self): def to_bytes(self):
"""Serialize to bytes"""
b = Header( b = Header(
self.command, self.payload_length, self.payload_checksum self.command, self.payload_length, self.payload_checksum
).to_bytes() ).to_bytes()
@ -67,6 +70,7 @@ class Message():
@classmethod @classmethod
def from_bytes(cls, b): def from_bytes(cls, b):
"""Parse from bytes"""
h = Header.from_bytes(b[:24]) h = Header.from_bytes(b[:24])
payload = b[24:] payload = b[24:]
@ -88,7 +92,7 @@ class Message():
class Version(): class Version():
"""The version message""" """The version message payload"""
def __init__( def __init__(
self, host, port, protocol_version=shared.protocol_version, self, host, port, protocol_version=shared.protocol_version,
services=shared.services, nonce=shared.nonce, services=shared.services, nonce=shared.nonce,
@ -159,7 +163,7 @@ class Version():
class Inv(): class Inv():
"""The inv message""" """The inv message payload"""
def __init__(self, vectors): def __init__(self, vectors):
self.vectors = set(vectors) self.vectors = set(vectors)
@ -195,7 +199,7 @@ class Inv():
class GetData(): class GetData():
"""The getdata message""" """The getdata message payload"""
def __init__(self, vectors): def __init__(self, vectors):
self.vectors = set(vectors) self.vectors = set(vectors)
@ -231,7 +235,7 @@ class GetData():
class Addr(): class Addr():
"""The addr message""" """The addr message payload"""
def __init__(self, addresses): def __init__(self, addresses):
self.addresses = addresses self.addresses = addresses

View File

@ -46,7 +46,7 @@ class VarInt():
class Object(): class Object():
"""object message""" """The 'object' message payload"""
def __init__( def __init__(
self, nonce, expires_time, object_type, version, self, nonce, expires_time, object_type, version,
stream_number, object_payload stream_number, object_payload
@ -66,6 +66,7 @@ class Object():
@classmethod @classmethod
def from_message(cls, m): def from_message(cls, m):
"""Decode message payload"""
payload = m.payload payload = m.payload
nonce, expires_time, object_type = struct.unpack('>8sQL', payload[:20]) nonce, expires_time, object_type = struct.unpack('>8sQL', payload[:20])
payload = payload[20:] payload = payload[20:]
@ -80,6 +81,7 @@ class Object():
nonce, expires_time, object_type, version, stream_number, payload) nonce, expires_time, object_type, version, stream_number, payload)
def to_bytes(self): def to_bytes(self):
"""Serialize to bytes"""
payload = b'' payload = b''
payload += self.nonce payload += self.nonce
payload += struct.pack('>QL', self.expires_time, self.object_type) payload += struct.pack('>QL', self.expires_time, self.object_type)
@ -90,9 +92,11 @@ class Object():
return payload return payload
def is_expired(self): def is_expired(self):
"""Check if object's TTL is expired"""
return self.expires_time + 3 * 3600 < time.time() return self.expires_time + 3 * 3600 < time.time()
def is_valid(self): def is_valid(self):
"""Checks the object validity"""
if self.is_expired(): if self.is_expired():
logging.debug( logging.debug(
'Invalid object %s, reason: expired', 'Invalid object %s, reason: expired',
@ -129,6 +133,7 @@ class Object():
return True return True
def pow_target(self): def pow_target(self):
"""Compute PoW target"""
data = self.to_bytes()[8:] data = self.to_bytes()[8:]
length = len(data) + 8 + shared.payload_length_extra_bytes length = len(data) + 8 + shared.payload_length_extra_bytes
dt = max(self.expires_time - time.time(), 0) dt = max(self.expires_time - time.time(), 0)
@ -138,6 +143,7 @@ class Object():
length + (dt * length) / (2 ** 16)))) length + (dt * length) / (2 ** 16))))
def pow_initial_hash(self): def pow_initial_hash(self):
"""Compute the initial hash for PoW"""
return hashlib.sha512(self.to_bytes()[8:]).digest() return hashlib.sha512(self.to_bytes()[8:]).digest()