Make protocol decoder less recursive
- apparently, recursion has bad performance in Python, so the decoder is now flat, except when parsing "version" command
This commit is contained in:
parent
853c8561ec
commit
db2d78c9b6
|
@ -152,8 +152,8 @@ class BMProto(AdvancedDispatcher, ObjectTracker):
|
||||||
return Node(services, host, port)
|
return Node(services, host, port)
|
||||||
|
|
||||||
def decode_payload_content(self, pattern = "v"):
|
def decode_payload_content(self, pattern = "v"):
|
||||||
# l = varint indicating the length of the next array
|
# L = varint indicating the length of the next array
|
||||||
# L = varint indicating the length of the next item
|
# l = varint indicating the length of the next item
|
||||||
# v = varint (or array)
|
# v = varint (or array)
|
||||||
# H = uint16
|
# H = uint16
|
||||||
# I = uint32
|
# I = uint32
|
||||||
|
@ -163,68 +163,101 @@ class BMProto(AdvancedDispatcher, ObjectTracker):
|
||||||
# 0-9 = length of the next item
|
# 0-9 = length of the next item
|
||||||
# , = end of array
|
# , = end of array
|
||||||
|
|
||||||
retval = []
|
def decode_simple(self, char="v"):
|
||||||
|
if char == "v":
|
||||||
|
return self.decode_payload_varint()
|
||||||
|
if char == "i":
|
||||||
|
return self.decode_payload_node()
|
||||||
|
if char == "H":
|
||||||
|
self.payloadOffset += 2
|
||||||
|
return struct.unpack(">H", self.payload[self.payloadOffset-2:self.payloadOffset])[0]
|
||||||
|
if char == "I":
|
||||||
|
self.payloadOffset += 4
|
||||||
|
return struct.unpack(">I", self.payload[self.payloadOffset-4:self.payloadOffset])[0]
|
||||||
|
if char == "Q":
|
||||||
|
self.payloadOffset += 8
|
||||||
|
return struct.unpack(">Q", self.payload[self.payloadOffset-8:self.payloadOffset])[0]
|
||||||
|
|
||||||
size = None
|
size = None
|
||||||
i = 0
|
isArray = False
|
||||||
try:
|
|
||||||
sys._getframe(200)
|
|
||||||
logger.error("Stack depth warning, pattern: %s", pattern)
|
|
||||||
return
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
while i < len(pattern):
|
# size
|
||||||
if pattern[i] in "0123456789" and (i == 0 or pattern[i-1] not in "lL"):
|
# iterator starting from size counting to 0
|
||||||
if size is None:
|
# isArray?
|
||||||
size = 0
|
# subpattern
|
||||||
size = size * 10 + int(pattern[i])
|
# position of parser in subpattern
|
||||||
i += 1
|
# retval (array)
|
||||||
continue
|
parserStack = [[1, 1, False, pattern, 0, []]]
|
||||||
elif pattern[i] == "l" and size is None:
|
|
||||||
|
#try:
|
||||||
|
# sys._getframe(200)
|
||||||
|
# logger.error("Stack depth warning, pattern: %s", pattern)
|
||||||
|
# return
|
||||||
|
#except ValueError:
|
||||||
|
# pass
|
||||||
|
|
||||||
|
while True:
|
||||||
|
i = parserStack[-1][3][parserStack[-1][4]]
|
||||||
|
if i in "0123456789" and (size is None or parserStack[-1][3][parserStack[-1][4]-1] not in "lL"):
|
||||||
|
try:
|
||||||
|
size = size * 10 + int(i)
|
||||||
|
except TypeError:
|
||||||
|
size = int(i)
|
||||||
|
isArray = False
|
||||||
|
elif i in "Ll" and size is None:
|
||||||
size = self.decode_payload_varint()
|
size = self.decode_payload_varint()
|
||||||
i += 1
|
if i == "L":
|
||||||
continue
|
isArray = True
|
||||||
elif pattern[i] == "L" and size is None:
|
|
||||||
size = self.decode_payload_varint()
|
|
||||||
i += 1
|
|
||||||
continue
|
|
||||||
if size is not None:
|
|
||||||
if pattern[i] == "s":
|
|
||||||
retval.append(self.payload[self.payloadOffset:self.payloadOffset + size])
|
|
||||||
self.payloadOffset += size
|
|
||||||
i += 1
|
|
||||||
else:
|
else:
|
||||||
if "," in pattern[i:]:
|
isArray = False
|
||||||
subpattern = pattern[i:pattern.index(",")]
|
elif size is not None:
|
||||||
else:
|
if isArray:
|
||||||
subpattern = pattern[i:]
|
parserStack.append([size, size, isArray, parserStack[-1][3][parserStack[-1][4]:], 0, []])
|
||||||
|
parserStack[-2][4] = len(parserStack[-2][3])
|
||||||
for j in range(size):
|
else:
|
||||||
if pattern[i-1:i] == "L":
|
for j in range(parserStack[-1][4], len(parserStack[-1][3])):
|
||||||
retval.extend(self.decode_payload_content(subpattern))
|
if parserStack[-1][3][j] not in "lL0123456789":
|
||||||
else:
|
break
|
||||||
retval.append(self.decode_payload_content(subpattern))
|
parserStack.append([size, size, isArray, parserStack[-1][3][parserStack[-1][4]:j+1], 0, []])
|
||||||
i += len(subpattern)
|
size = None
|
||||||
|
continue
|
||||||
|
elif i == "s":
|
||||||
|
#if parserStack[-2][2]:
|
||||||
|
# parserStack[-1][5].append(self.payload[self.payloadOffset:self.payloadOffset + parserStack[-1][0]])
|
||||||
|
#else:
|
||||||
|
parserStack[-1][5] = self.payload[self.payloadOffset:self.payloadOffset + parserStack[-1][0]]
|
||||||
|
self.payloadOffset += parserStack[-1][0]
|
||||||
|
parserStack[-1][1] = 0
|
||||||
|
parserStack[-1][2] = True
|
||||||
|
#del parserStack[-1]
|
||||||
|
size = None
|
||||||
|
elif i in "viHIQ":
|
||||||
|
parserStack[-1][5].append(decode_simple(self, parserStack[-1][3][parserStack[-1][4]]))
|
||||||
size = None
|
size = None
|
||||||
else:
|
else:
|
||||||
if pattern[i] == "v":
|
size = None
|
||||||
retval.append(self.decode_payload_varint())
|
for depth in range(len(parserStack) - 1, -1, -1):
|
||||||
if pattern[i] == "i":
|
parserStack[depth][4] += 1
|
||||||
retval.append(self.decode_payload_node())
|
if parserStack[depth][4] >= len(parserStack[depth][3]):
|
||||||
if pattern[i] == "H":
|
parserStack[depth][1] -= 1
|
||||||
retval.append(struct.unpack(">H", self.payload[self.payloadOffset:self.payloadOffset+2])[0])
|
parserStack[depth][4] = 0
|
||||||
self.payloadOffset += 2
|
if depth > 0:
|
||||||
if pattern[i] == "I":
|
if parserStack[depth][2]:
|
||||||
retval.append(struct.unpack(">I", self.payload[self.payloadOffset:self.payloadOffset+4])[0])
|
parserStack[depth - 1][5].append(parserStack[depth][5])
|
||||||
self.payloadOffset += 4
|
else:
|
||||||
if pattern[i] == "Q":
|
parserStack[depth - 1][5].extend(parserStack[depth][5])
|
||||||
retval.append(struct.unpack(">Q", self.payload[self.payloadOffset:self.payloadOffset+8])[0])
|
parserStack[depth][5] = []
|
||||||
self.payloadOffset += 8
|
if parserStack[depth][1] <= 0:
|
||||||
i += 1
|
if depth == 0:
|
||||||
|
# we're done, at depth 0 counter is at 0 and pattern is done parsing
|
||||||
|
return parserStack[depth][5]
|
||||||
|
del parserStack[-1]
|
||||||
|
continue
|
||||||
|
break
|
||||||
|
break
|
||||||
if self.payloadOffset > self.payloadLength:
|
if self.payloadOffset > self.payloadLength:
|
||||||
logger.debug("Insufficient data %i/%i", self.payloadOffset, self.payloadLength)
|
logger.debug("Insufficient data %i/%i", self.payloadOffset, self.payloadLength)
|
||||||
raise BMProtoInsufficientDataError()
|
raise BMProtoInsufficientDataError()
|
||||||
return retval
|
|
||||||
|
|
||||||
def bm_command_error(self):
|
def bm_command_error(self):
|
||||||
fatalStatus, banTime, inventoryVector, errorText = self.decode_payload_content("vvlsls")
|
fatalStatus, banTime, inventoryVector, errorText = self.decode_payload_content("vvlsls")
|
||||||
|
@ -232,7 +265,7 @@ class BMProto(AdvancedDispatcher, ObjectTracker):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def bm_command_getdata(self):
|
def bm_command_getdata(self):
|
||||||
items = self.decode_payload_content("L32s")
|
items = self.decode_payload_content("l32s")
|
||||||
# skip?
|
# skip?
|
||||||
if time.time() < self.skipUntil:
|
if time.time() < self.skipUntil:
|
||||||
return True
|
return True
|
||||||
|
@ -246,7 +279,7 @@ class BMProto(AdvancedDispatcher, ObjectTracker):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def bm_command_inv(self):
|
def bm_command_inv(self):
|
||||||
items = self.decode_payload_content("L32s")
|
items = self.decode_payload_content("l32s")
|
||||||
|
|
||||||
if len(items) >= BMProto.maxObjectCount:
|
if len(items) >= BMProto.maxObjectCount:
|
||||||
logger.error("Too many items in inv message!")
|
logger.error("Too many items in inv message!")
|
||||||
|
@ -294,7 +327,7 @@ class BMProto(AdvancedDispatcher, ObjectTracker):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _decode_addr(self):
|
def _decode_addr(self):
|
||||||
return self.decode_payload_content("lQIQ16sH")
|
return self.decode_payload_content("LQIQ16sH")
|
||||||
|
|
||||||
def bm_command_addr(self):
|
def bm_command_addr(self):
|
||||||
addresses = self._decode_addr()
|
addresses = self._decode_addr()
|
||||||
|
@ -344,7 +377,7 @@ class BMProto(AdvancedDispatcher, ObjectTracker):
|
||||||
|
|
||||||
def bm_command_version(self):
|
def bm_command_version(self):
|
||||||
self.remoteProtocolVersion, self.services, self.timestamp, self.sockNode, self.peerNode, self.nonce, \
|
self.remoteProtocolVersion, self.services, self.timestamp, self.sockNode, self.peerNode, self.nonce, \
|
||||||
self.userAgent, self.streams = self.decode_payload_content("IQQiiQlslv")
|
self.userAgent, self.streams = self.decode_payload_content("IQQiiQlsLv")
|
||||||
self.nonce = struct.pack('>Q', self.nonce)
|
self.nonce = struct.pack('>Q', self.nonce)
|
||||||
self.timeOffset = self.timestamp - int(time.time())
|
self.timeOffset = self.timestamp - int(time.time())
|
||||||
logger.debug("remoteProtocolVersion: %i", self.remoteProtocolVersion)
|
logger.debug("remoteProtocolVersion: %i", self.remoteProtocolVersion)
|
||||||
|
|
Reference in New Issue
Block a user