diff --git a/src/protocol.py b/src/protocol.py index 4f2d0856..e108f400 100644 --- a/src/protocol.py +++ b/src/protocol.py @@ -91,16 +91,16 @@ def isBitSetWithinBitfield(fourByteString, n): return x & 2**n != 0 -# ip addresses +# IP addresses def encodeHost(host): """Encode a given host to be used in low-level socket operations""" if host.find('.onion') > -1: - return '\xfd\x87\xd8\x7e\xeb\x43' + base64.b32decode( + return b'\xfd\x87\xd8\x7e\xeb\x43' + base64.b32decode( host.split(".")[0], True) elif host.find(':') == -1: - return '\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xFF\xFF' + \ + return b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xFF\xFF' + \ socket.inet_aton(host) return socket.inet_pton(socket.AF_INET6, host) @@ -147,10 +147,10 @@ def checkIPAddress(host, private=False): Returns hostStandardFormat if it is a valid IP address, otherwise returns False """ - if host[0:12] == '\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xFF\xFF': + if host[0:12] == b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xFF\xFF': hostStandardFormat = socket.inet_ntop(socket.AF_INET, host[12:]) return checkIPv4Address(host[12:], hostStandardFormat, private) - elif host[0:6] == '\xfd\x87\xd8\x7e\xeb\x43': + elif host[0:6] == b'\xfd\x87\xd8\x7e\xeb\x43': # Onion, based on BMD/bitcoind hostStandardFormat = base64.b32encode(host[6:]).lower() + ".onion" if private: @@ -161,7 +161,7 @@ def checkIPAddress(host, private=False): hostStandardFormat = socket.inet_ntop(socket.AF_INET6, host) except ValueError: return False - if hostStandardFormat == "": + if hostStandardFormat: # This can happen on Windows systems which are # not 64-bit compatible so let us drop the IPv6 address. return False @@ -173,23 +173,23 @@ def checkIPv4Address(host, hostStandardFormat, private=False): Returns hostStandardFormat if it is an IPv4 address, otherwise returns False """ - if host[0] == '\x7F': # 127/8 + if host[0:1] == b'\x7F': # 127/8 if not private: logger.debug( 'Ignoring IP address in loopback range: %s', hostStandardFormat) return hostStandardFormat if private else False - if host[0] == '\x0A': # 10/8 + if host[0:1] == b'\x0A': # 10/8 if not private: logger.debug( 'Ignoring IP address in private range: %s', hostStandardFormat) return hostStandardFormat if private else False - if host[0:2] == '\xC0\xA8': # 192.168/16 + if host[0:2] == b'\xC0\xA8': # 192.168/16 if not private: logger.debug( 'Ignoring IP address in private range: %s', hostStandardFormat) return hostStandardFormat if private else False - if host[0:2] >= '\xAC\x10' and host[0:2] < '\xAC\x20': # 172.16/12 + if host[0:2] >= b'\xAC\x10' and host[0:2] < b'\xAC\x20': # 172.16/12 if not private: logger.debug( 'Ignoring IP address in private range: %s', hostStandardFormat) @@ -202,15 +202,19 @@ def checkIPv6Address(host, hostStandardFormat, private=False): Returns hostStandardFormat if it is an IPv6 address, otherwise returns False """ - if host == ('\x00' * 15) + '\x01': + if host == b'\x00' * 15 + b'\x01': if not private: logger.debug('Ignoring loopback address: %s', hostStandardFormat) return False - if host[0] == '\xFE' and (ord(host[1]) & 0xc0) == 0x80: + try: + host = [ord(c) for c in host[:2]] + except TypeError: # python3 has ints already + pass + if host[0:1] == b'\xFE' and host[1] & 0xc0 == 0x80: if not private: logger.debug('Ignoring local address: %s', hostStandardFormat) return hostStandardFormat if private else False - if (ord(host[0]) & 0xfe) == 0xfc: + if host[0] & 0xfe == 0xfc: if not private: logger.debug( 'Ignoring unique local address: %s', hostStandardFormat) @@ -280,7 +284,7 @@ def isProofOfWorkSufficient( # Packet creation -def CreatePacket(command, payload=''): +def CreatePacket(command, payload=b''): """Construct and return a packet""" payload_length = len(payload) checksum = hashlib.sha512(payload).digest()[0:4] @@ -298,14 +302,14 @@ def assembleVersionMessage( Construct the payload of a version message, return the resulting bytes of running `CreatePacket` on it """ - payload = '' + payload = b'' payload += pack('>L', 3) # protocol version. # bitflags of the services I offer. payload += pack( '>q', - NODE_NETWORK | - (NODE_SSL if haveSSL(server) else 0) | - (NODE_DANDELION if state.dandelion else 0) + NODE_NETWORK + | (NODE_SSL if haveSSL(server) else 0) + | (NODE_DANDELION if state.dandelion else 0) ) payload += pack('>q', int(time.time())) @@ -327,13 +331,13 @@ def assembleVersionMessage( # bitflags of the services I offer. payload += pack( '>q', - NODE_NETWORK | - (NODE_SSL if haveSSL(server) else 0) | - (NODE_DANDELION if state.dandelion else 0) + NODE_NETWORK + | (NODE_SSL if haveSSL(server) else 0) + | (NODE_DANDELION if state.dandelion else 0) ) # = 127.0.0.1. This will be ignored by the remote host. # The actual remote connected IP will be used. - payload += '\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xFF\xFF' + pack( + payload += b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xFF\xFF' + pack( '>L', 2130706433) # we have a separate extPort and incoming over clearnet # or outgoing through clearnet @@ -355,7 +359,7 @@ def assembleVersionMessage( payload += nodeid[0:8] else: payload += eightBytesOfRandomDataUsedToDetectConnectionsToSelf - userAgent = '/PyBitmessage:' + softwareVersion + '/' + userAgent = ('/PyBitmessage:%s/' % softwareVersion).encode('utf-8') payload += encodeVarint(len(userAgent)) payload += userAgent diff --git a/src/tests/test_networkgroup.py b/src/tests/test_networkgroup.py deleted file mode 100644 index 79163402..00000000 --- a/src/tests/test_networkgroup.py +++ /dev/null @@ -1,43 +0,0 @@ -""" -Test for network group -""" -import unittest - -from .common import skip_python3 - -skip_python3() - - -class TestNetworkGroup(unittest.TestCase): - """ - Test case for network group - """ - def test_network_group(self): - """Test various types of network groups""" - from pybitmessage.protocol import network_group - - test_ip = '1.2.3.4' - self.assertEqual('\x01\x02', network_group(test_ip)) - - test_ip = '127.0.0.1' - self.assertEqual('IPv4', network_group(test_ip)) - - test_ip = '0102:0304:0506:0708:090A:0B0C:0D0E:0F10' - self.assertEqual( - '\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0A\x0B\x0C', - network_group(test_ip)) - - test_ip = 'bootstrap8444.bitmessage.org' - self.assertEqual( - 'bootstrap8444.bitmessage.org', - network_group(test_ip)) - - test_ip = 'quzwelsuziwqgpt2.onion' - self.assertEqual( - test_ip, - network_group(test_ip)) - - test_ip = None - self.assertEqual( - None, - network_group(test_ip)) diff --git a/src/tests/test_protocol.py b/src/tests/test_protocol.py index a3c73a73..521c90c7 100644 --- a/src/tests/test_protocol.py +++ b/src/tests/test_protocol.py @@ -2,25 +2,97 @@ Tests for common protocol functions """ +import sys import unittest -from .common import skip_python3 - -skip_python3() +from pybitmessage import protocol, state class TestProtocol(unittest.TestCase): """Main protocol test case""" + def test_checkIPv4Address(self): + """Check the results of protocol.checkIPv4Address()""" + token = 'HELLO' + # checking protocol.encodeHost()[12:] + self.assertEqual( # 127.0.0.1 + token, protocol.checkIPv4Address(b'\x7f\x00\x00\x01', token, True)) + self.assertEqual( # 10.42.43.1 + token, protocol.checkIPv4Address(b'\n*+\x01', token, True)) + self.assertEqual( # 192.168.0.254 + token, protocol.checkIPv4Address(b'\xc0\xa8\x00\xfe', token, True)) + self.assertEqual( # 172.31.255.254 + token, protocol.checkIPv4Address(b'\xac\x1f\xff\xfe', token, True)) + self.assertFalse( # 8.8.8.8 + protocol.checkIPv4Address(b'\x08\x08\x08\x08', token, True)) + + def test_checkIPv6Address(self): + """Check the results of protocol.checkIPv6Address()""" + test_ip = '2001:db8::ff00:42:8329' + self.assertEqual( + 'test', protocol.checkIPv6Address( + protocol.encodeHost(test_ip), 'test')) + self.assertFalse( + protocol.checkIPv6Address( + protocol.encodeHost(test_ip), 'test', True)) + def test_check_local(self): """Check the logic of TCPConnection.local""" - from pybitmessage import protocol, state - self.assertTrue( protocol.checkIPAddress(protocol.encodeHost('127.0.0.1'), True)) self.assertTrue( protocol.checkIPAddress(protocol.encodeHost('192.168.0.1'), True)) + self.assertTrue( + protocol.checkIPAddress(protocol.encodeHost('10.42.43.1'), True)) + self.assertTrue( + protocol.checkIPAddress(protocol.encodeHost('172.31.255.2'), True)) + self.assertFalse(protocol.checkIPAddress( + protocol.encodeHost('2001:db8::ff00:42:8329'), True)) + globalhost = protocol.encodeHost('8.8.8.8') + self.assertFalse(protocol.checkIPAddress(globalhost, True)) + self.assertEqual(protocol.checkIPAddress(globalhost), '8.8.8.8') + + @unittest.skipIf( + sys.hexversion >= 0x3000000, 'this is still not working with python3') + def test_check_local_socks(self): + """The SOCKS part of the local check""" self.assertTrue( not protocol.checkSocksIP('127.0.0.1') or state.socksIP) + + def test_network_group(self): + """Test various types of network groups""" + + test_ip = '1.2.3.4' + self.assertEqual(b'\x01\x02', protocol.network_group(test_ip)) + + test_ip = '127.0.0.1' + self.assertEqual('IPv4', protocol.network_group(test_ip)) + + self.assertEqual( + protocol.network_group('8.8.8.8'), + protocol.network_group('8.8.4.4')) + self.assertNotEqual( + protocol.network_group('1.1.1.1'), + protocol.network_group('8.8.8.8')) + + test_ip = '0102:0304:0506:0708:090A:0B0C:0D0E:0F10' + self.assertEqual( + b'\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0A\x0B\x0C', + protocol.network_group(test_ip)) + + test_ip = 'bootstrap8444.bitmessage.org' + self.assertEqual( + 'bootstrap8444.bitmessage.org', + protocol.network_group(test_ip)) + + test_ip = 'quzwelsuziwqgpt2.onion' + self.assertEqual( + test_ip, + protocol.network_group(test_ip)) + + test_ip = None + self.assertEqual( + None, + protocol.network_group(test_ip))