Make protocol decoder less recursive

- apparently, recursion has bad performance in Python, so the decoder is
now flat, except when parsing "version" command
This commit is contained in:
Peter Šurda 2017-07-10 07:15:27 +02:00
parent 853c8561ec
commit db2d78c9b6
Signed by untrusted user: PeterSurda
GPG Key ID: 0C5F50C0B5F37D87

View File

@ -152,8 +152,8 @@ class BMProto(AdvancedDispatcher, ObjectTracker):
return Node(services, host, port) return Node(services, host, port)
def decode_payload_content(self, pattern = "v"): def decode_payload_content(self, pattern = "v"):
# l = varint indicating the length of the next array # L = varint indicating the length of the next array
# L = varint indicating the length of the next item # l = varint indicating the length of the next item
# v = varint (or array) # v = varint (or array)
# H = uint16 # H = uint16
# I = uint32 # I = uint32
@ -163,68 +163,101 @@ class BMProto(AdvancedDispatcher, ObjectTracker):
# 0-9 = length of the next item # 0-9 = length of the next item
# , = end of array # , = end of array
retval = [] def decode_simple(self, char="v"):
if char == "v":
return self.decode_payload_varint()
if char == "i":
return self.decode_payload_node()
if char == "H":
self.payloadOffset += 2
return struct.unpack(">H", self.payload[self.payloadOffset-2:self.payloadOffset])[0]
if char == "I":
self.payloadOffset += 4
return struct.unpack(">I", self.payload[self.payloadOffset-4:self.payloadOffset])[0]
if char == "Q":
self.payloadOffset += 8
return struct.unpack(">Q", self.payload[self.payloadOffset-8:self.payloadOffset])[0]
size = None size = None
i = 0 isArray = False
try:
sys._getframe(200)
logger.error("Stack depth warning, pattern: %s", pattern)
return
except ValueError:
pass
while i < len(pattern): # size
if pattern[i] in "0123456789" and (i == 0 or pattern[i-1] not in "lL"): # iterator starting from size counting to 0
if size is None: # isArray?
size = 0 # subpattern
size = size * 10 + int(pattern[i]) # position of parser in subpattern
i += 1 # retval (array)
continue parserStack = [[1, 1, False, pattern, 0, []]]
elif pattern[i] == "l" and size is None:
#try:
# sys._getframe(200)
# logger.error("Stack depth warning, pattern: %s", pattern)
# return
#except ValueError:
# pass
while True:
i = parserStack[-1][3][parserStack[-1][4]]
if i in "0123456789" and (size is None or parserStack[-1][3][parserStack[-1][4]-1] not in "lL"):
try:
size = size * 10 + int(i)
except TypeError:
size = int(i)
isArray = False
elif i in "Ll" and size is None:
size = self.decode_payload_varint() size = self.decode_payload_varint()
i += 1 if i == "L":
continue isArray = True
elif pattern[i] == "L" and size is None:
size = self.decode_payload_varint()
i += 1
continue
if size is not None:
if pattern[i] == "s":
retval.append(self.payload[self.payloadOffset:self.payloadOffset + size])
self.payloadOffset += size
i += 1
else: else:
if "," in pattern[i:]: isArray = False
subpattern = pattern[i:pattern.index(",")] elif size is not None:
else: if isArray:
subpattern = pattern[i:] parserStack.append([size, size, isArray, parserStack[-1][3][parserStack[-1][4]:], 0, []])
parserStack[-2][4] = len(parserStack[-2][3])
for j in range(size): else:
if pattern[i-1:i] == "L": for j in range(parserStack[-1][4], len(parserStack[-1][3])):
retval.extend(self.decode_payload_content(subpattern)) if parserStack[-1][3][j] not in "lL0123456789":
else: break
retval.append(self.decode_payload_content(subpattern)) parserStack.append([size, size, isArray, parserStack[-1][3][parserStack[-1][4]:j+1], 0, []])
i += len(subpattern) size = None
continue
elif i == "s":
#if parserStack[-2][2]:
# parserStack[-1][5].append(self.payload[self.payloadOffset:self.payloadOffset + parserStack[-1][0]])
#else:
parserStack[-1][5] = self.payload[self.payloadOffset:self.payloadOffset + parserStack[-1][0]]
self.payloadOffset += parserStack[-1][0]
parserStack[-1][1] = 0
parserStack[-1][2] = True
#del parserStack[-1]
size = None
elif i in "viHIQ":
parserStack[-1][5].append(decode_simple(self, parserStack[-1][3][parserStack[-1][4]]))
size = None size = None
else: else:
if pattern[i] == "v": size = None
retval.append(self.decode_payload_varint()) for depth in range(len(parserStack) - 1, -1, -1):
if pattern[i] == "i": parserStack[depth][4] += 1
retval.append(self.decode_payload_node()) if parserStack[depth][4] >= len(parserStack[depth][3]):
if pattern[i] == "H": parserStack[depth][1] -= 1
retval.append(struct.unpack(">H", self.payload[self.payloadOffset:self.payloadOffset+2])[0]) parserStack[depth][4] = 0
self.payloadOffset += 2 if depth > 0:
if pattern[i] == "I": if parserStack[depth][2]:
retval.append(struct.unpack(">I", self.payload[self.payloadOffset:self.payloadOffset+4])[0]) parserStack[depth - 1][5].append(parserStack[depth][5])
self.payloadOffset += 4 else:
if pattern[i] == "Q": parserStack[depth - 1][5].extend(parserStack[depth][5])
retval.append(struct.unpack(">Q", self.payload[self.payloadOffset:self.payloadOffset+8])[0]) parserStack[depth][5] = []
self.payloadOffset += 8 if parserStack[depth][1] <= 0:
i += 1 if depth == 0:
# we're done, at depth 0 counter is at 0 and pattern is done parsing
return parserStack[depth][5]
del parserStack[-1]
continue
break
break
if self.payloadOffset > self.payloadLength: if self.payloadOffset > self.payloadLength:
logger.debug("Insufficient data %i/%i", self.payloadOffset, self.payloadLength) logger.debug("Insufficient data %i/%i", self.payloadOffset, self.payloadLength)
raise BMProtoInsufficientDataError() raise BMProtoInsufficientDataError()
return retval
def bm_command_error(self): def bm_command_error(self):
fatalStatus, banTime, inventoryVector, errorText = self.decode_payload_content("vvlsls") fatalStatus, banTime, inventoryVector, errorText = self.decode_payload_content("vvlsls")
@ -232,7 +265,7 @@ class BMProto(AdvancedDispatcher, ObjectTracker):
return True return True
def bm_command_getdata(self): def bm_command_getdata(self):
items = self.decode_payload_content("L32s") items = self.decode_payload_content("l32s")
# skip? # skip?
if time.time() < self.skipUntil: if time.time() < self.skipUntil:
return True return True
@ -246,7 +279,7 @@ class BMProto(AdvancedDispatcher, ObjectTracker):
return True return True
def bm_command_inv(self): def bm_command_inv(self):
items = self.decode_payload_content("L32s") items = self.decode_payload_content("l32s")
if len(items) >= BMProto.maxObjectCount: if len(items) >= BMProto.maxObjectCount:
logger.error("Too many items in inv message!") logger.error("Too many items in inv message!")
@ -294,7 +327,7 @@ class BMProto(AdvancedDispatcher, ObjectTracker):
return True return True
def _decode_addr(self): def _decode_addr(self):
return self.decode_payload_content("lQIQ16sH") return self.decode_payload_content("LQIQ16sH")
def bm_command_addr(self): def bm_command_addr(self):
addresses = self._decode_addr() addresses = self._decode_addr()
@ -344,7 +377,7 @@ class BMProto(AdvancedDispatcher, ObjectTracker):
def bm_command_version(self): def bm_command_version(self):
self.remoteProtocolVersion, self.services, self.timestamp, self.sockNode, self.peerNode, self.nonce, \ self.remoteProtocolVersion, self.services, self.timestamp, self.sockNode, self.peerNode, self.nonce, \
self.userAgent, self.streams = self.decode_payload_content("IQQiiQlslv") self.userAgent, self.streams = self.decode_payload_content("IQQiiQlsLv")
self.nonce = struct.pack('>Q', self.nonce) self.nonce = struct.pack('>Q', self.nonce)
self.timeOffset = self.timestamp - int(time.time()) self.timeOffset = self.timestamp - int(time.time())
logger.debug("remoteProtocolVersion: %i", self.remoteProtocolVersion) logger.debug("remoteProtocolVersion: %i", self.remoteProtocolVersion)