Define a helper function to read a varint and trim payload

This commit is contained in:
Lee Miller 2023-08-14 03:23:01 +03:00
parent 7812e4bbc2
commit 3f61bd694b
Signed by: lee.miller
GPG Key ID: 4F97A5EA88F4AB63

View File

@ -91,6 +91,13 @@ class Message():
return cls(h.command, payload) return cls(h.command, payload)
def _payload_read_int(data):
varint_length = structure.VarInt.length(data[0])
return (
structure.VarInt.from_bytes(data[:varint_length]).n,
data[varint_length:])
class Version(): class Version():
"""The version message payload""" """The version message payload"""
def __init__( def __init__(
@ -149,29 +156,18 @@ class Version():
payload = payload[80:] payload = payload[80:]
user_agent_varint_length = structure.VarInt.length(payload[0]) user_agent_length, payload = _payload_read_int(payload)
user_agent_length = structure.VarInt.from_bytes(
payload[:user_agent_varint_length]).n
payload = payload[user_agent_varint_length:]
user_agent = payload[:user_agent_length] user_agent = payload[:user_agent_length]
payload = payload[user_agent_length:] payload = payload[user_agent_length:]
streams_varint_length = structure.VarInt.length(payload[0]) streams_count, payload = _payload_read_int(payload)
streams_count = structure.VarInt.from_bytes(
payload[:streams_varint_length]).n
payload = payload[streams_varint_length:]
if streams_count > 160000: if streams_count > 160000:
raise ValueError('malformed Version message, to many streams') raise ValueError('malformed Version message, to many streams')
streams = [] streams = []
while payload: while payload:
stream_length = structure.VarInt.length(payload[0]) stream, payload = _payload_read_int(payload)
streams.append( streams.append(stream)
structure.VarInt.from_bytes(payload[:stream_length]).n)
payload = payload[stream_length:]
if streams_count != len(streams): if streams_count != len(streams):
raise ValueError('malformed Version message, wrong streams_count') raise ValueError('malformed Version message, wrong streams_count')
@ -198,11 +194,7 @@ class Inv():
def from_message(cls, m): def from_message(cls, m):
payload = m.payload payload = m.payload
vector_count_varint_length = structure.VarInt.length(payload[0]) vector_count, payload = _payload_read_int(payload)
vector_count = structure.VarInt.from_bytes(
payload[:vector_count_varint_length]).n
payload = payload[vector_count_varint_length:]
vectors = set() vectors = set()
@ -234,11 +226,7 @@ class GetData():
def from_message(cls, m): def from_message(cls, m):
payload = m.payload payload = m.payload
vector_count_varint_length = structure.VarInt.length(payload[0]) vector_count, payload = _payload_read_int(payload)
vector_count = structure.VarInt.from_bytes(
payload[:vector_count_varint_length]).n
payload = payload[vector_count_varint_length:]
vectors = set() vectors = set()
@ -270,11 +258,8 @@ class Addr():
def from_message(cls, m): def from_message(cls, m):
payload = m.payload payload = m.payload
addr_count_varint_length = structure.VarInt.length(payload[0]) # not validating addr_count
# addr_count = structure.VarInt.from_bytes( _, payload = _payload_read_int(payload)
# payload[:addr_count_varint_length]).n
payload = payload[addr_count_varint_length:]
addresses = set() addresses = set()