Make protocol decoder less recursive

- apparently, recursion has bad performance in Python, so the decoder is
now flat, except when parsing "version" command
This commit is contained in:
Peter Šurda 2017-07-10 07:15:27 +02:00
parent 853c8561ec
commit db2d78c9b6
Signed by untrusted user: PeterSurda
GPG Key ID: 0C5F50C0B5F37D87
1 changed files with 92 additions and 59 deletions

View File

@ -152,8 +152,8 @@ class BMProto(AdvancedDispatcher, ObjectTracker):
return Node(services, host, port)
def decode_payload_content(self, pattern = "v"):
# l = varint indicating the length of the next array
# L = varint indicating the length of the next item
# L = varint indicating the length of the next array
# l = varint indicating the length of the next item
# v = varint (or array)
# H = uint16
# I = uint32
@ -163,68 +163,101 @@ class BMProto(AdvancedDispatcher, ObjectTracker):
# 0-9 = length of the next item
# , = end of array
retval = []
def decode_simple(self, char="v"):
if char == "v":
return self.decode_payload_varint()
if char == "i":
return self.decode_payload_node()
if char == "H":
self.payloadOffset += 2
return struct.unpack(">H", self.payload[self.payloadOffset-2:self.payloadOffset])[0]
if char == "I":
self.payloadOffset += 4
return struct.unpack(">I", self.payload[self.payloadOffset-4:self.payloadOffset])[0]
if char == "Q":
self.payloadOffset += 8
return struct.unpack(">Q", self.payload[self.payloadOffset-8:self.payloadOffset])[0]
size = None
i = 0
try:
sys._getframe(200)
logger.error("Stack depth warning, pattern: %s", pattern)
return
except ValueError:
pass
isArray = False
while i < len(pattern):
if pattern[i] in "0123456789" and (i == 0 or pattern[i-1] not in "lL"):
if size is None:
size = 0
size = size * 10 + int(pattern[i])
i += 1
continue
elif pattern[i] == "l" and size is None:
# size
# iterator starting from size counting to 0
# isArray?
# subpattern
# position of parser in subpattern
# retval (array)
parserStack = [[1, 1, False, pattern, 0, []]]
#try:
# sys._getframe(200)
# logger.error("Stack depth warning, pattern: %s", pattern)
# return
#except ValueError:
# pass
while True:
i = parserStack[-1][3][parserStack[-1][4]]
if i in "0123456789" and (size is None or parserStack[-1][3][parserStack[-1][4]-1] not in "lL"):
try:
size = size * 10 + int(i)
except TypeError:
size = int(i)
isArray = False
elif i in "Ll" and size is None:
size = self.decode_payload_varint()
i += 1
continue
elif pattern[i] == "L" and size is None:
size = self.decode_payload_varint()
i += 1
continue
if size is not None:
if pattern[i] == "s":
retval.append(self.payload[self.payloadOffset:self.payloadOffset + size])
self.payloadOffset += size
i += 1
if i == "L":
isArray = True
else:
if "," in pattern[i:]:
subpattern = pattern[i:pattern.index(",")]
else:
subpattern = pattern[i:]
for j in range(size):
if pattern[i-1:i] == "L":
retval.extend(self.decode_payload_content(subpattern))
else:
retval.append(self.decode_payload_content(subpattern))
i += len(subpattern)
isArray = False
elif size is not None:
if isArray:
parserStack.append([size, size, isArray, parserStack[-1][3][parserStack[-1][4]:], 0, []])
parserStack[-2][4] = len(parserStack[-2][3])
else:
for j in range(parserStack[-1][4], len(parserStack[-1][3])):
if parserStack[-1][3][j] not in "lL0123456789":
break
parserStack.append([size, size, isArray, parserStack[-1][3][parserStack[-1][4]:j+1], 0, []])
size = None
continue
elif i == "s":
#if parserStack[-2][2]:
# parserStack[-1][5].append(self.payload[self.payloadOffset:self.payloadOffset + parserStack[-1][0]])
#else:
parserStack[-1][5] = self.payload[self.payloadOffset:self.payloadOffset + parserStack[-1][0]]
self.payloadOffset += parserStack[-1][0]
parserStack[-1][1] = 0
parserStack[-1][2] = True
#del parserStack[-1]
size = None
elif i in "viHIQ":
parserStack[-1][5].append(decode_simple(self, parserStack[-1][3][parserStack[-1][4]]))
size = None
else:
if pattern[i] == "v":
retval.append(self.decode_payload_varint())
if pattern[i] == "i":
retval.append(self.decode_payload_node())
if pattern[i] == "H":
retval.append(struct.unpack(">H", self.payload[self.payloadOffset:self.payloadOffset+2])[0])
self.payloadOffset += 2
if pattern[i] == "I":
retval.append(struct.unpack(">I", self.payload[self.payloadOffset:self.payloadOffset+4])[0])
self.payloadOffset += 4
if pattern[i] == "Q":
retval.append(struct.unpack(">Q", self.payload[self.payloadOffset:self.payloadOffset+8])[0])
self.payloadOffset += 8
i += 1
size = None
for depth in range(len(parserStack) - 1, -1, -1):
parserStack[depth][4] += 1
if parserStack[depth][4] >= len(parserStack[depth][3]):
parserStack[depth][1] -= 1
parserStack[depth][4] = 0
if depth > 0:
if parserStack[depth][2]:
parserStack[depth - 1][5].append(parserStack[depth][5])
else:
parserStack[depth - 1][5].extend(parserStack[depth][5])
parserStack[depth][5] = []
if parserStack[depth][1] <= 0:
if depth == 0:
# we're done, at depth 0 counter is at 0 and pattern is done parsing
return parserStack[depth][5]
del parserStack[-1]
continue
break
break
if self.payloadOffset > self.payloadLength:
logger.debug("Insufficient data %i/%i", self.payloadOffset, self.payloadLength)
raise BMProtoInsufficientDataError()
return retval
def bm_command_error(self):
fatalStatus, banTime, inventoryVector, errorText = self.decode_payload_content("vvlsls")
@ -232,7 +265,7 @@ class BMProto(AdvancedDispatcher, ObjectTracker):
return True
def bm_command_getdata(self):
items = self.decode_payload_content("L32s")
items = self.decode_payload_content("l32s")
# skip?
if time.time() < self.skipUntil:
return True
@ -246,7 +279,7 @@ class BMProto(AdvancedDispatcher, ObjectTracker):
return True
def bm_command_inv(self):
items = self.decode_payload_content("L32s")
items = self.decode_payload_content("l32s")
if len(items) >= BMProto.maxObjectCount:
logger.error("Too many items in inv message!")
@ -294,7 +327,7 @@ class BMProto(AdvancedDispatcher, ObjectTracker):
return True
def _decode_addr(self):
return self.decode_payload_content("lQIQ16sH")
return self.decode_payload_content("LQIQ16sH")
def bm_command_addr(self):
addresses = self._decode_addr()
@ -344,7 +377,7 @@ class BMProto(AdvancedDispatcher, ObjectTracker):
def bm_command_version(self):
self.remoteProtocolVersion, self.services, self.timestamp, self.sockNode, self.peerNode, self.nonce, \
self.userAgent, self.streams = self.decode_payload_content("IQQiiQlslv")
self.userAgent, self.streams = self.decode_payload_content("IQQiiQlsLv")
self.nonce = struct.pack('>Q', self.nonce)
self.timeOffset = self.timestamp - int(time.time())
logger.debug("remoteProtocolVersion: %i", self.remoteProtocolVersion)