Refactor packet header extraction and generation
Demote payloadLength from class instance variable to processData local variable as no other function was using it Improve processData: -Utilise shared.Header -Use a memoryview to reduce memory overhead -Clean up variables before a recursive call -Strip null bytes from command Refactor sendData Various functions: -Use shared.CreatePacket to generate packets Fix typo in _checkIPv4Address
This commit is contained in:
parent
ccfbe8da47
commit
0f9625aac7
|
@ -47,7 +47,6 @@ class receiveDataThread(threading.Thread):
|
|||
self.sock = sock
|
||||
self.peer = shared.Peer(HOST, port)
|
||||
self.streamNumber = streamNumber
|
||||
self.payloadLength = 0 # This is the protocol payload length thus it doesn't include the 24 byte message header
|
||||
self.objectsThatWeHaveYetToGetFromThisPeer = {}
|
||||
self.selfInitiatedConnections = selfInitiatedConnections
|
||||
self.sendDataThreadQueue = sendDataThreadQueue # used to send commands and data to the sendDataThread
|
||||
|
@ -113,28 +112,39 @@ class receiveDataThread(threading.Thread):
|
|||
# with shared.printLock:
|
||||
# print 'self.data is currently ', repr(self.data)
|
||||
#
|
||||
if len(self.data) < 24: # if so little of the data has arrived that we can't even read the checksum then wait for more data.
|
||||
if len(self.data) < shared.Header.size: # if so little of the data has arrived that we can't even read the checksum then wait for more data.
|
||||
return
|
||||
if self.data[0:4] != '\xe9\xbe\xb4\xd9':
|
||||
#Use a memoryview so we don't copy data unnecessarily
|
||||
view = memoryview(self.data)
|
||||
magic,command,payloadLength,checksum = shared.Header.unpack(view[:shared.Header.size])
|
||||
view = view[shared.Header.size:]
|
||||
if magic != 0xE9BEB4D9:
|
||||
#if shared.verbose >= 1:
|
||||
# with shared.printLock:
|
||||
# print 'The magic bytes were not correct. First 40 bytes of data: ' + repr(self.data[0:40])
|
||||
|
||||
self.data = ""
|
||||
return
|
||||
self.payloadLength, = unpack('>L', self.data[16:20])
|
||||
if self.payloadLength > 20000000:
|
||||
logger.info('The incoming message, which we have not yet download, is too large. Ignoring it. (unfortunately there is no way to tell the other node to stop sending it except to disconnect.) Message size: %s' % self.payloadLength)
|
||||
self.data = self.data[self.payloadLength + 24:]
|
||||
if payloadLength > 20000000:
|
||||
logger.info('The incoming message, which we have not yet download, is too large. Ignoring it. (unfortunately there is no way to tell the other node to stop sending it except to disconnect.) Message size: %s' % payloadLength)
|
||||
self.data = view[payloadLength:].tobytes()
|
||||
del view,magic,command,payloadLength,checksum #we don't need these anymore and better to clean them now before the recursive call rather than after
|
||||
self.processData()
|
||||
return
|
||||
if len(self.data) < self.payloadLength + 24: # check if the whole message has arrived yet.
|
||||
if len(view) < payloadLength: # check if the whole message has arrived yet.
|
||||
return
|
||||
if self.data[20:24] != hashlib.sha512(self.data[24:self.payloadLength + 24]).digest()[0:4]: # test the checksum in the message. If it is correct...
|
||||
payload = view[:payloadLength]
|
||||
if checksum != hashlib.sha512(payload).digest()[0:4]: # test the checksum in the message. If it is correct...
|
||||
print 'Checksum incorrect. Clearing this message.'
|
||||
self.data = self.data[self.payloadLength + 24:]
|
||||
self.data = view[payloadLength:].tobytes()
|
||||
del view,magic,command,payloadLength,checksum,payload #again better to clean up before the recursive call
|
||||
self.processData()
|
||||
return
|
||||
|
||||
#We can now revert back to bytestrings and take this message out
|
||||
payload = payload.tobytes()
|
||||
self.data = view[payloadLength:].tobytes()
|
||||
del view,magic,payloadLength,checksum
|
||||
# The time we've last seen this node is obviously right now since we
|
||||
# just received valid data from it. So update the knownNodes list so
|
||||
# that other peers can be made aware of its existance.
|
||||
|
@ -143,37 +153,39 @@ class receiveDataThread(threading.Thread):
|
|||
shared.knownNodes[self.streamNumber][self.peer] = int(time.time())
|
||||
shared.knownNodesLock.release()
|
||||
|
||||
remoteCommand = self.data[4:16]
|
||||
#Strip the nulls
|
||||
command = command.rstrip('\x00')
|
||||
with shared.printLock:
|
||||
print 'remoteCommand', repr(remoteCommand.replace('\x00', '')), ' from', self.peer
|
||||
print 'remoteCommand', repr(command), ' from', self.peer
|
||||
|
||||
if remoteCommand == 'version\x00\x00\x00\x00\x00' and not self.connectionIsOrWasFullyEstablished:
|
||||
self.recversion(self.data[24:self.payloadLength + 24])
|
||||
elif remoteCommand == 'verack\x00\x00\x00\x00\x00\x00' and not self.connectionIsOrWasFullyEstablished:
|
||||
#TODO: Use a dispatcher here
|
||||
if not self.connectionIsOrWasFullyEstablished:
|
||||
if command == 'version':
|
||||
self.recversion(payload)
|
||||
elif command == 'verack':
|
||||
self.recverack()
|
||||
elif remoteCommand == 'addr\x00\x00\x00\x00\x00\x00\x00\x00' and self.connectionIsOrWasFullyEstablished:
|
||||
self.recaddr(self.data[24:self.payloadLength + 24])
|
||||
elif remoteCommand == 'getpubkey\x00\x00\x00' and self.connectionIsOrWasFullyEstablished:
|
||||
shared.checkAndSharegetpubkeyWithPeers(self.data[24:self.payloadLength + 24])
|
||||
elif remoteCommand == 'pubkey\x00\x00\x00\x00\x00\x00' and self.connectionIsOrWasFullyEstablished:
|
||||
self.recpubkey(self.data[24:self.payloadLength + 24])
|
||||
elif remoteCommand == 'inv\x00\x00\x00\x00\x00\x00\x00\x00\x00' and self.connectionIsOrWasFullyEstablished:
|
||||
self.recinv(self.data[24:self.payloadLength + 24])
|
||||
elif remoteCommand == 'getdata\x00\x00\x00\x00\x00' and self.connectionIsOrWasFullyEstablished:
|
||||
self.recgetdata(self.data[24:self.payloadLength + 24])
|
||||
elif remoteCommand == 'msg\x00\x00\x00\x00\x00\x00\x00\x00\x00' and self.connectionIsOrWasFullyEstablished:
|
||||
self.recmsg(self.data[24:self.payloadLength + 24])
|
||||
elif remoteCommand == 'broadcast\x00\x00\x00' and self.connectionIsOrWasFullyEstablished:
|
||||
self.recbroadcast(self.data[24:self.payloadLength + 24])
|
||||
elif remoteCommand == 'ping\x00\x00\x00\x00\x00\x00\x00\x00' and self.connectionIsOrWasFullyEstablished:
|
||||
self.sendpong()
|
||||
elif remoteCommand == 'pong\x00\x00\x00\x00\x00\x00\x00\x00' and self.connectionIsOrWasFullyEstablished:
|
||||
pass
|
||||
elif remoteCommand == 'alert\x00\x00\x00\x00\x00\x00\x00' and self.connectionIsOrWasFullyEstablished:
|
||||
pass
|
||||
else:
|
||||
if command == 'addr':
|
||||
self.recaddr(payload)
|
||||
elif command == 'getpubkey':
|
||||
shared.checkAndSharegetpubkeyWithPeers(payload)
|
||||
elif command == 'pubkey':
|
||||
self.recpubkey(payload)
|
||||
elif command == 'inv':
|
||||
self.recinv(payload)
|
||||
elif command == 'getdata':
|
||||
self.recgetdata(payload)
|
||||
elif command == 'msg':
|
||||
self.recmsg(payload)
|
||||
elif command == 'broadcast':
|
||||
self.recbroadcast(payload)
|
||||
elif command == 'ping':
|
||||
self.sendpong(payload)
|
||||
#elif command == 'pong':
|
||||
# pass
|
||||
#elif command == 'alert':
|
||||
# pass
|
||||
|
||||
self.data = self.data[
|
||||
self.payloadLength + 24:] # take this message out and then process the next message
|
||||
if self.data == '':
|
||||
while len(self.objectsThatWeHaveYetToGetFromThisPeer) > 0:
|
||||
shared.numberOfInventoryLookupsPerformed += 1
|
||||
|
@ -226,7 +238,7 @@ class receiveDataThread(threading.Thread):
|
|||
|
||||
def sendpong(self):
|
||||
print 'Sending pong'
|
||||
self.sendDataThreadQueue.put((0, 'sendRawData', '\xE9\xBE\xB4\xD9\x70\x6F\x6E\x67\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xcf\x83\xe1\x35'))
|
||||
self.sendDataThreadQueue.put((0, 'sendRawData', shared.CreatePacket('pong')))
|
||||
|
||||
|
||||
def recverack(self):
|
||||
|
@ -297,7 +309,7 @@ class receiveDataThread(threading.Thread):
|
|||
for hash, storedValue in bigInvList.items():
|
||||
payload += hash
|
||||
numberOfObjectsInInvMessage += 1
|
||||
if numberOfObjectsInInvMessage >= 50000: # We can only send a max of 50000 items per inv message but we may have more objects to advertise. They must be split up into multiple inv messages.
|
||||
if numberOfObjectsInInvMessage == 50000: # We can only send a max of 50000 items per inv message but we may have more objects to advertise. They must be split up into multiple inv messages.
|
||||
self.sendinvMessageToJustThisOnePeer(
|
||||
numberOfObjectsInInvMessage, payload)
|
||||
payload = ''
|
||||
|
@ -311,13 +323,9 @@ class receiveDataThread(threading.Thread):
|
|||
# function for broadcasting invs to everyone in our stream.
|
||||
def sendinvMessageToJustThisOnePeer(self, numberOfObjects, payload):
|
||||
payload = encodeVarint(numberOfObjects) + payload
|
||||
headerData = '\xe9\xbe\xb4\xd9' # magic bits, slighly different from Bitcoin's magic bits.
|
||||
headerData += 'inv\x00\x00\x00\x00\x00\x00\x00\x00\x00'
|
||||
headerData += pack('>L', len(payload))
|
||||
headerData += hashlib.sha512(payload).digest()[:4]
|
||||
with shared.printLock:
|
||||
print 'Sending huge inv message with', numberOfObjects, 'objects to just this one peer'
|
||||
self.sendDataThreadQueue.put((0, 'sendRawData', headerData + payload))
|
||||
self.sendDataThreadQueue.put((0, 'sendRawData', shared.CreatePacket('inv', payload)))
|
||||
|
||||
def _sleepForTimingAttackMitigation(self, sleepTime):
|
||||
# We don't need to do the timing attack mitigation if we are
|
||||
|
@ -457,12 +465,7 @@ class receiveDataThread(threading.Thread):
|
|||
print 'sending getdata to retrieve object with hash:', hash.encode('hex')
|
||||
|
||||
payload = '\x01' + hash
|
||||
headerData = '\xe9\xbe\xb4\xd9' # magic bits, slighly different from Bitcoin's magic bits.
|
||||
headerData += 'getdata\x00\x00\x00\x00\x00'
|
||||
headerData += pack('>L', len(
|
||||
payload)) # payload length. Note that we add an extra 8 for the nonce.
|
||||
headerData += hashlib.sha512(payload).digest()[:4]
|
||||
self.sendDataThreadQueue.put((0, 'sendRawData', headerData + payload))
|
||||
self.sendDataThreadQueue.put((0, 'sendRawData', shared.CreatePacket('getdata', payload)))
|
||||
|
||||
|
||||
# We have received a getdata request from our peer
|
||||
|
@ -499,34 +502,17 @@ class receiveDataThread(threading.Thread):
|
|||
|
||||
# Our peer has requested (in a getdata message) that we send an object.
|
||||
def sendData(self, objectType, payload):
|
||||
headerData = '\xe9\xbe\xb4\xd9' # magic bits, slighly different from Bitcoin's magic bits.
|
||||
if objectType == 'pubkey':
|
||||
with shared.printLock:
|
||||
print 'sending pubkey'
|
||||
|
||||
headerData += 'pubkey\x00\x00\x00\x00\x00\x00'
|
||||
elif objectType == 'getpubkey' or objectType == 'pubkeyrequest':
|
||||
with shared.printLock:
|
||||
print 'sending getpubkey'
|
||||
|
||||
headerData += 'getpubkey\x00\x00\x00'
|
||||
elif objectType == 'msg':
|
||||
with shared.printLock:
|
||||
print 'sending msg'
|
||||
|
||||
headerData += 'msg\x00\x00\x00\x00\x00\x00\x00\x00\x00'
|
||||
elif objectType == 'broadcast':
|
||||
with shared.printLock:
|
||||
print 'sending broadcast'
|
||||
|
||||
headerData += 'broadcast\x00\x00\x00'
|
||||
else:
|
||||
# pubkeyrequest??
|
||||
if objectType == 'pubkeyrequest':
|
||||
objectType = 'getpubkey'
|
||||
elif objectType != 'pubkey' and objectType != 'getpubkey' and
|
||||
objectType != 'msg' and objectType != 'broadcast':
|
||||
sys.stderr.write(
|
||||
'Error: sendData has been asked to send a strange objectType: %s\n' % str(objectType))
|
||||
return
|
||||
headerData += pack('>L', len(payload)) # payload length.
|
||||
headerData += hashlib.sha512(payload).digest()[:4]
|
||||
self.sendDataThreadQueue.put((0, 'sendRawData', headerData + payload))
|
||||
with shared.printLock:
|
||||
print 'sending', objectType
|
||||
self.sendDataThreadQueue.put((0, 'sendRawData', shared.CreatePacket(objectType, payload)))
|
||||
|
||||
|
||||
# Advertise this object to all of our peers
|
||||
|
@ -545,7 +531,7 @@ class receiveDataThread(threading.Thread):
|
|||
if host[0] == '\x0A':
|
||||
print 'Ignoring IP address in private range:', hostFromAddrMessage
|
||||
return False
|
||||
if host[0:2] == '\xC0A8':
|
||||
if host[0:2] == '\xC0\xA8':
|
||||
print 'Ignoring IP address in private range:', hostFromAddrMessage
|
||||
return False
|
||||
return True
|
||||
|
@ -745,11 +731,7 @@ class receiveDataThread(threading.Thread):
|
|||
payload += pack('>H', PORT) # remote port
|
||||
|
||||
payload = encodeVarint(numberOfAddressesInAddrMessage) + payload
|
||||
datatosend = '\xE9\xBE\xB4\xD9addr\x00\x00\x00\x00\x00\x00\x00\x00'
|
||||
datatosend = datatosend + pack('>L', len(payload)) # payload length
|
||||
datatosend = datatosend + hashlib.sha512(payload).digest()[0:4]
|
||||
datatosend = datatosend + payload
|
||||
self.sendDataThreadQueue.put((0, 'sendRawData', datatosend))
|
||||
self.sendDataThreadQueue.put((0, 'sendRawData', shared.CreatePacket('addr', payload)))
|
||||
|
||||
|
||||
# We have received a version message
|
||||
|
@ -832,7 +814,7 @@ class receiveDataThread(threading.Thread):
|
|||
def sendverack(self):
|
||||
with shared.printLock:
|
||||
print 'Sending verack'
|
||||
self.sendDataThreadQueue.put((0, 'sendRawData', '\xE9\xBE\xB4\xD9\x76\x65\x72\x61\x63\x6B\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xcf\x83\xe1\x35'))
|
||||
self.sendDataThreadQueue.put((0, 'sendRawData', shared.CreatePacket('verack')))
|
||||
self.verackSent = True
|
||||
if self.verackReceived:
|
||||
self.connectionFullyEstablished()
|
||||
|
|
Reference in New Issue
Block a user