diff --git a/src/connection.py b/src/connection.py index 591aad1..3c13142 100644 --- a/src/connection.py +++ b/src/connection.py @@ -197,11 +197,20 @@ class Connection(threading.Thread): while len(self.buffer_receive) >= self.next_message_size: if self.next_header: self.next_header = False - h = message.Header.from_bytes(self.buffer_receive[:shared.header_length]) + try: + h = message.Header.from_bytes(self.buffer_receive[:shared.header_length]) + except ValueError as e: + self.status = 'disconnecting' + logging.warning('Received malformed message from {}:{}: {}'.format(self.host, self.port, e)) + break self.next_message_size += h.payload_length else: - m = message.Message.from_bytes(self.buffer_receive[:self.next_message_size]) - + try: + m = message.Message.from_bytes(self.buffer_receive[:self.next_message_size]) + except ValueError as e: + self.status = 'disconnecting' + logging.warning('Received malformed message from {}:{}, {}'.format(self.host, self.port, e)) + break self.next_header = True self.buffer_receive = self.buffer_receive[self.next_message_size:] self.next_message_size = shared.header_length diff --git a/src/message.py b/src/message.py index 17ae576..2030b4e 100644 --- a/src/message.py +++ b/src/message.py @@ -31,7 +31,7 @@ class Header(object): magic_bytes, command, payload_length, payload_checksum = struct.unpack('>4s12sL4s', b) if magic_bytes != shared.magic_bytes: - raise IOError('magic_bytes do not match') + raise ValueError('magic_bytes do not match') command = command.rstrip(b'\x00') @@ -63,12 +63,12 @@ class Message(object): payload_length = len(payload) if payload_length != h.payload_length: - raise Exception('wrong payload length, expected {}, got {}'.format(h.payload_length, payload_length)) + raise ValueError('wrong payload length, expected {}, got {}'.format(h.payload_length, payload_length)) payload_checksum = hashlib.sha512(payload).digest()[:4] if payload_checksum != h.payload_checksum: - raise Exception('wrong payload checksum, expected {}, got {}'.format(h.payload_checksum, payload_checksum)) + raise ValueError('wrong payload checksum, expected {}, got {}'.format(h.payload_checksum, payload_checksum)) return cls(h.command, payload)