Use bytes in all the comparisons in the protocol.checkIP*Address(),

concatenate bytes in protocol.encodeHost().
This commit is contained in:
Dmitri Bogomolov 2021-07-21 22:27:30 +03:00
parent 8a5d8ecd4c
commit 1e7be609ed
Signed by untrusted user: g1itch
GPG Key ID: 720A756F18DEED13

View File

@ -91,16 +91,16 @@ def isBitSetWithinBitfield(fourByteString, n):
return x & 2**n != 0
# ip addresses
# IP addresses
def encodeHost(host):
"""Encode a given host to be used in low-level socket operations"""
if host.find('.onion') > -1:
return '\xfd\x87\xd8\x7e\xeb\x43' + base64.b32decode(
return b'\xfd\x87\xd8\x7e\xeb\x43' + base64.b32decode(
host.split(".")[0], True)
elif host.find(':') == -1:
return '\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xFF\xFF' + \
return b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xFF\xFF' + \
socket.inet_aton(host)
return socket.inet_pton(socket.AF_INET6, host)
@ -147,10 +147,10 @@ def checkIPAddress(host, private=False):
Returns hostStandardFormat if it is a valid IP address,
otherwise returns False
"""
if host[0:12] == '\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xFF\xFF':
if host[0:12] == b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xFF\xFF':
hostStandardFormat = socket.inet_ntop(socket.AF_INET, host[12:])
return checkIPv4Address(host[12:], hostStandardFormat, private)
elif host[0:6] == '\xfd\x87\xd8\x7e\xeb\x43':
elif host[0:6] == b'\xfd\x87\xd8\x7e\xeb\x43':
# Onion, based on BMD/bitcoind
hostStandardFormat = base64.b32encode(host[6:]).lower() + ".onion"
if private:
@ -161,7 +161,7 @@ def checkIPAddress(host, private=False):
hostStandardFormat = socket.inet_ntop(socket.AF_INET6, host)
except ValueError:
return False
if hostStandardFormat == "":
if len(hostStandardFormat) == 0:
# This can happen on Windows systems which are
# not 64-bit compatible so let us drop the IPv6 address.
return False
@ -173,23 +173,23 @@ def checkIPv4Address(host, hostStandardFormat, private=False):
Returns hostStandardFormat if it is an IPv4 address,
otherwise returns False
"""
if host[0] == '\x7F': # 127/8
if host[0] == b'\x7F': # 127/8
if not private:
logger.debug(
'Ignoring IP address in loopback range: %s',
hostStandardFormat)
return hostStandardFormat if private else False
if host[0] == '\x0A': # 10/8
if host[0] == b'\x0A': # 10/8
if not private:
logger.debug(
'Ignoring IP address in private range: %s', hostStandardFormat)
return hostStandardFormat if private else False
if host[0:2] == '\xC0\xA8': # 192.168/16
if host[0:2] == b'\xC0\xA8': # 192.168/16
if not private:
logger.debug(
'Ignoring IP address in private range: %s', hostStandardFormat)
return hostStandardFormat if private else False
if host[0:2] >= '\xAC\x10' and host[0:2] < '\xAC\x20': # 172.16/12
if host[0:2] >= b'\xAC\x10' and host[0:2] < b'\xAC\x20': # 172.16/12
if not private:
logger.debug(
'Ignoring IP address in private range: %s', hostStandardFormat)
@ -202,11 +202,11 @@ def checkIPv6Address(host, hostStandardFormat, private=False):
Returns hostStandardFormat if it is an IPv6 address,
otherwise returns False
"""
if host == ('\x00' * 15) + '\x01':
if host == (b'\x00' * 15) + b'\x01':
if not private:
logger.debug('Ignoring loopback address: %s', hostStandardFormat)
return False
if host[0] == '\xFE' and (ord(host[1]) & 0xc0) == 0x80:
if host[0] == b'\xFE' and (ord(host[1]) & 0xc0) == 0x80:
if not private:
logger.debug('Ignoring local address: %s', hostStandardFormat)
return hostStandardFormat if private else False