Add Error message class, handle fatal
This commit is contained in:
parent
3f61bd694b
commit
b806906af4
|
@ -422,8 +422,12 @@ class Connection(threading.Thread):
|
|||
self.send_queue.put(message.Message(b'pong', b''))
|
||||
|
||||
elif m.command == b'error':
|
||||
error = message.Error.from_message(m)
|
||||
logging.warning(
|
||||
'%s:%s -> error: %s', self.host_print, self.port, m.payload)
|
||||
'%s:%s -> %s', self.host_print, self.port, error)
|
||||
if error.fatal == 2:
|
||||
# reduce probability to connect soon
|
||||
shared.unchecked_node_pool.discard((self.host, self.port))
|
||||
|
||||
else:
|
||||
logging.debug('%s:%s -> %s', self.host_print, self.port, m)
|
||||
|
|
|
@ -268,3 +268,37 @@ class Addr():
|
|||
payload = payload[38:]
|
||||
|
||||
return cls(addresses)
|
||||
|
||||
|
||||
class Error():
|
||||
"""The error message payload"""
|
||||
def __init__(self, error_text=b'', fatal=0, ban_time=0, vector=b''):
|
||||
self.error_text = error_text
|
||||
self.fatal = fatal
|
||||
self.ban_time = ban_time
|
||||
self.vector = vector
|
||||
|
||||
def __repr__(self):
|
||||
return 'error, text: {}'.format(self.error_text)
|
||||
|
||||
def to_bytes(self):
|
||||
return Message(
|
||||
b'error', structure.VarInt(self.fatal).to_bytes()
|
||||
+ structure.VarInt(self.ban_time).to_bytes()
|
||||
+ structure.VarInt(len(self.vector)).to_bytes() + self.vector
|
||||
+ structure.VarInt(len(self.error_text)).to_bytes()
|
||||
+ self.error_text
|
||||
).to_bytes()
|
||||
|
||||
@classmethod
|
||||
def from_message(cls, m):
|
||||
payload = m.payload
|
||||
fatal, payload = _payload_read_int(payload)
|
||||
ban_time, payload = _payload_read_int(payload)
|
||||
vector_length, payload = _payload_read_int(payload)
|
||||
vector = payload[:vector_length]
|
||||
payload = payload[vector_length:]
|
||||
error_text_length, payload = _payload_read_int(payload)
|
||||
error_text = payload[:error_text_length]
|
||||
|
||||
return cls(error_text, fatal, ban_time, vector)
|
||||
|
|
|
@ -35,6 +35,10 @@ sample_version_msg = unhexlify(
|
|||
'4269746d6573736167653a302e362e332e322f03010203'
|
||||
)
|
||||
|
||||
#
|
||||
sample_error_data = \
|
||||
b'\x02\x00\x006Too many connections from your IP. Closing connection.'
|
||||
|
||||
|
||||
class TestMessage(unittest.TestCase):
|
||||
"""Test assembling and disassembling of network mesages"""
|
||||
|
@ -83,3 +87,15 @@ class TestMessage(unittest.TestCase):
|
|||
# omit header and timestamp
|
||||
self.assertEqual(msg[24:36], sample_version_msg[24:36])
|
||||
self.assertEqual(msg[44:], sample_version_msg[44:])
|
||||
|
||||
def test_error(self):
|
||||
"""Test error message"""
|
||||
msg = message.Error.from_message(
|
||||
message.Message(b'error', sample_error_data))
|
||||
self.assertEqual(msg.fatal, 2)
|
||||
self.assertEqual(msg.ban_time, 0)
|
||||
self.assertEqual(msg.vector, b'')
|
||||
|
||||
msg = message.Error(
|
||||
b'Too many connections from your IP. Closing connection.', 2)
|
||||
self.assertEqual(msg.to_bytes()[24:], sample_error_data)
|
||||
|
|
Loading…
Reference in New Issue
Block a user