Improve malformed message handling

This commit is contained in:
TheKysek 2016-10-15 17:00:18 +02:00
parent 90a9cee834
commit c525d4f2e1
2 changed files with 15 additions and 6 deletions

View File

@ -197,11 +197,20 @@ class Connection(threading.Thread):
while len(self.buffer_receive) >= self.next_message_size: while len(self.buffer_receive) >= self.next_message_size:
if self.next_header: if self.next_header:
self.next_header = False self.next_header = False
try:
h = message.Header.from_bytes(self.buffer_receive[:shared.header_length]) h = message.Header.from_bytes(self.buffer_receive[:shared.header_length])
except ValueError as e:
self.status = 'disconnecting'
logging.warning('Received malformed message from {}:{}: {}'.format(self.host, self.port, e))
break
self.next_message_size += h.payload_length self.next_message_size += h.payload_length
else: else:
try:
m = message.Message.from_bytes(self.buffer_receive[:self.next_message_size]) m = message.Message.from_bytes(self.buffer_receive[:self.next_message_size])
except ValueError as e:
self.status = 'disconnecting'
logging.warning('Received malformed message from {}:{}, {}'.format(self.host, self.port, e))
break
self.next_header = True self.next_header = True
self.buffer_receive = self.buffer_receive[self.next_message_size:] self.buffer_receive = self.buffer_receive[self.next_message_size:]
self.next_message_size = shared.header_length self.next_message_size = shared.header_length

View File

@ -31,7 +31,7 @@ class Header(object):
magic_bytes, command, payload_length, payload_checksum = struct.unpack('>4s12sL4s', b) magic_bytes, command, payload_length, payload_checksum = struct.unpack('>4s12sL4s', b)
if magic_bytes != shared.magic_bytes: if magic_bytes != shared.magic_bytes:
raise IOError('magic_bytes do not match') raise ValueError('magic_bytes do not match')
command = command.rstrip(b'\x00') command = command.rstrip(b'\x00')
@ -63,12 +63,12 @@ class Message(object):
payload_length = len(payload) payload_length = len(payload)
if payload_length != h.payload_length: if payload_length != h.payload_length:
raise Exception('wrong payload length, expected {}, got {}'.format(h.payload_length, payload_length)) raise ValueError('wrong payload length, expected {}, got {}'.format(h.payload_length, payload_length))
payload_checksum = hashlib.sha512(payload).digest()[:4] payload_checksum = hashlib.sha512(payload).digest()[:4]
if payload_checksum != h.payload_checksum: if payload_checksum != h.payload_checksum:
raise Exception('wrong payload checksum, expected {}, got {}'.format(h.payload_checksum, payload_checksum)) raise ValueError('wrong payload checksum, expected {}, got {}'.format(h.payload_checksum, payload_checksum))
return cls(h.command, payload) return cls(h.command, payload)