From 3f61bd694b92629b434e9eaf9e09b6542bf5eba4 Mon Sep 17 00:00:00 2001 From: Lee Miller Date: Mon, 14 Aug 2023 03:23:01 +0300 Subject: [PATCH] Define a helper function to read a varint and trim payload --- minode/message.py | 45 +++++++++++++++------------------------------ 1 file changed, 15 insertions(+), 30 deletions(-) diff --git a/minode/message.py b/minode/message.py index a8fc204..546e1f3 100644 --- a/minode/message.py +++ b/minode/message.py @@ -91,6 +91,13 @@ class Message(): return cls(h.command, payload) +def _payload_read_int(data): + varint_length = structure.VarInt.length(data[0]) + return ( + structure.VarInt.from_bytes(data[:varint_length]).n, + data[varint_length:]) + + class Version(): """The version message payload""" def __init__( @@ -149,29 +156,18 @@ class Version(): payload = payload[80:] - user_agent_varint_length = structure.VarInt.length(payload[0]) - user_agent_length = structure.VarInt.from_bytes( - payload[:user_agent_varint_length]).n - - payload = payload[user_agent_varint_length:] - + user_agent_length, payload = _payload_read_int(payload) user_agent = payload[:user_agent_length] - payload = payload[user_agent_length:] - streams_varint_length = structure.VarInt.length(payload[0]) - streams_count = structure.VarInt.from_bytes( - payload[:streams_varint_length]).n - payload = payload[streams_varint_length:] + streams_count, payload = _payload_read_int(payload) if streams_count > 160000: raise ValueError('malformed Version message, to many streams') streams = [] while payload: - stream_length = structure.VarInt.length(payload[0]) - streams.append( - structure.VarInt.from_bytes(payload[:stream_length]).n) - payload = payload[stream_length:] + stream, payload = _payload_read_int(payload) + streams.append(stream) if streams_count != len(streams): raise ValueError('malformed Version message, wrong streams_count') @@ -198,11 +194,7 @@ class Inv(): def from_message(cls, m): payload = m.payload - vector_count_varint_length = structure.VarInt.length(payload[0]) - vector_count = structure.VarInt.from_bytes( - payload[:vector_count_varint_length]).n - - payload = payload[vector_count_varint_length:] + vector_count, payload = _payload_read_int(payload) vectors = set() @@ -234,11 +226,7 @@ class GetData(): def from_message(cls, m): payload = m.payload - vector_count_varint_length = structure.VarInt.length(payload[0]) - vector_count = structure.VarInt.from_bytes( - payload[:vector_count_varint_length]).n - - payload = payload[vector_count_varint_length:] + vector_count, payload = _payload_read_int(payload) vectors = set() @@ -270,11 +258,8 @@ class Addr(): def from_message(cls, m): payload = m.payload - addr_count_varint_length = structure.VarInt.length(payload[0]) - # addr_count = structure.VarInt.from_bytes( - # payload[:addr_count_varint_length]).n - - payload = payload[addr_count_varint_length:] + # not validating addr_count + _, payload = _payload_read_int(payload) addresses = set()