From 0cc8589b27e1b91c0c8eb0356efe70d1f3ada575 Mon Sep 17 00:00:00 2001 From: Peter Surda Date: Wed, 21 Jun 2017 12:16:33 +0200 Subject: [PATCH] Asyncore updates - should prevent the same object being re-requested indefinitely - locking for object tracking - move SSL-specific error handling to TLSDispatcher - observe maximum connection limit when accepting a new connection - stack depth test (for debugging purposes) - separate download thread - connection pool init moved to main thread --- src/bitmessagemain.py | 6 +++- src/network/advanceddispatcher.py | 8 ++++- src/network/announcethread.py | 1 - src/network/asyncore_pollchoose.py | 21 +++++-------- src/network/bmproto.py | 21 ++++++------- src/network/downloadthread.py | 48 ++++++++++++++++++++++++++++++ src/network/invthread.py | 9 +++--- src/network/networkthread.py | 1 - src/network/objectracker.py | 17 +++++------ src/network/receivequeuethread.py | 15 +++++++--- src/network/tcp.py | 6 ++++ src/network/tls.py | 18 +++++++++++ 12 files changed, 125 insertions(+), 46 deletions(-) create mode 100644 src/network/downloadthread.py diff --git a/src/bitmessagemain.py b/src/bitmessagemain.py index 57eab27a..d9578748 100755 --- a/src/bitmessagemain.py +++ b/src/bitmessagemain.py @@ -58,7 +58,7 @@ from network.networkthread import BMNetworkThread from network.receivequeuethread import ReceiveQueueThread from network.announcethread import AnnounceThread from network.invthread import InvThread -#from network.downloadthread import DownloadThread +from network.downloadthread import DownloadThread # Helper Functions import helper_bootstrap @@ -261,6 +261,7 @@ class Main: singleAPIThread.start() if BMConfigParser().get("network", "asyncore"): + BMConnectionPool() asyncoreThread = BMNetworkThread() asyncoreThread.daemon = True asyncoreThread.start() @@ -273,6 +274,9 @@ class Main: state.invThread = InvThread() state.invThread.daemon = True state.invThread.start() + downloadThread = DownloadThread() + downloadThread.daemon = True + downloadThread.start() connectToStream(1) diff --git a/src/network/advanceddispatcher.py b/src/network/advanceddispatcher.py index dadb625b..80762aa5 100644 --- a/src/network/advanceddispatcher.py +++ b/src/network/advanceddispatcher.py @@ -1,5 +1,6 @@ -import socket import Queue +import socket +import sys import time import asyncore_pollchoose as asyncore @@ -42,6 +43,11 @@ class AdvancedDispatcher(asyncore.dispatcher): if not self.connected: return maxLoop = 20 + try: + sys._getframe(200) + logger.error("Stack depth warning") + except ValueError: + pass while maxLoop > 0: try: # print "Trying to handle state \"%s\"" % (self.state) diff --git a/src/network/announcethread.py b/src/network/announcethread.py index 85e69877..37b7e628 100644 --- a/src/network/announcethread.py +++ b/src/network/announcethread.py @@ -16,7 +16,6 @@ class AnnounceThread(threading.Thread, StoppableThread): threading.Thread.__init__(self, name="AnnounceThread") self.initStop() self.name = "AnnounceThread" - BMConnectionPool() logger.info("init announce thread") def run(self): diff --git a/src/network/asyncore_pollchoose.py b/src/network/asyncore_pollchoose.py index c212effe..94cbfaf3 100644 --- a/src/network/asyncore_pollchoose.py +++ b/src/network/asyncore_pollchoose.py @@ -57,7 +57,7 @@ import warnings import os from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, ECONNRESET, EINVAL, \ ENOTCONN, ESHUTDOWN, EISCONN, EBADF, ECONNABORTED, EPIPE, EAGAIN, \ - ECONNREFUSED, EHOSTUNREACH, ENETUNREACH, ENOTSOCK, EINTR, \ + ECONNREFUSED, EHOSTUNREACH, ENETUNREACH, ENOTSOCK, EINTR, ETIMEDOUT, \ errorcode try: from errno import WSAEWOULDBLOCK @@ -68,10 +68,8 @@ try: except (ImportError, AttributeError): WSAENOTSOCK = ENOTSOCK -from ssl import SSLError, SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE - _DISCONNECTED = frozenset((ECONNRESET, ENOTCONN, ESHUTDOWN, ECONNABORTED, EPIPE, - EBADF, ECONNREFUSED, EHOSTUNREACH, ENETUNREACH)) + EBADF, ECONNREFUSED, EHOSTUNREACH, ENETUNREACH, ETIMEDOUT)) OP_READ = 1 OP_WRITE = 2 @@ -563,11 +561,6 @@ class dispatcher: try: result = self.socket.send(data) return result - except SSLError as err: - if err.errno == SSL_ERROR_WANT_WRITE: - return 0 - else: - raise except socket.error as why: if why.args[0] in (EAGAIN, EWOULDBLOCK, WSAEWOULDBLOCK): return 0 @@ -587,11 +580,6 @@ class dispatcher: return b'' else: return data - except SSLError as err: - if err.errno == SSL_ERROR_WANT_READ: - return b'' - else: - raise except socket.error as why: # winsock sometimes raises ENOTCONN if why.args[0] in (EAGAIN, EWOULDBLOCK, WSAEWOULDBLOCK): @@ -616,6 +604,11 @@ class dispatcher: # cheap inheritance, used to pass all other attribute # references to the underlying socket object. def __getattr__(self, attr): + try: + sys._getframe(200) + logger.error("Stack depth warning") + except ValueError: + pass try: retattr = getattr(self.socket, attr) except AttributeError: diff --git a/src/network/bmproto.py b/src/network/bmproto.py index cd36726e..39464845 100644 --- a/src/network/bmproto.py +++ b/src/network/bmproto.py @@ -119,7 +119,7 @@ class BMProto(AdvancedDispatcher, ObjectTracker): except BMObjectInvalidError: logger.debug("object invalid, skipping") except BMObjectAlreadyHaveError: - logger.debug("already got object, skipping") + logger.debug("%s:%i already got object, skipping", self.destination.host, self.destination.port) except struct.error: logger.debug("decoding error, skipping") else: @@ -260,10 +260,7 @@ class BMProto(AdvancedDispatcher, ObjectTracker): for i in items: self.receiveQueue.put(("inv", i)) - self.handleReceivedInventory(i) - payload = addresses.encodeVarint(len(self.objectsNewToMe)) + ''.join(self.objectsNewToMe.keys()) - self.writeQueue.put(protocol.CreatePacket('getdata', payload)) return True def bm_command_object(self): @@ -279,19 +276,23 @@ class BMProto(AdvancedDispatcher, ObjectTracker): try: self.object.checkEOLSanity() self.object.checkStream() - except (BMObjectExpiredError, BMObjectUnwantedStreamError): + self.object.checkAlreadyHave() + except (BMObjectExpiredError, BMObjectUnwantedStreamError, BMObjectAlreadyHaveError) as e: for connection in network.connectionpool.BMConnectionPool().inboundConnections.values() + network.connectionpool.BMConnectionPool().outboundConnections.values(): try: - del connection.objectsNewtoThem[self.object.inventoryHash] + with connection.objectsNewToThemLock: + del connection.objectsNewToThem[self.object.inventoryHash] except KeyError: pass try: - del connection.objectsNewToMe[self.object.inventoryHash] + with connection.objectsNewToMeLock: + del connection.objectsNewToMe[self.object.inventoryHash] except KeyError: pass - if not BMConfigParser().get("inventory", "acceptmismatch"): - raise - self.object.checkAlreadyHave() + if not BMConfigParser().get("inventory", "acceptmismatch") or \ + isinstance(e, BMObjectAlreadyHaveError) or \ + isinstance(e, BMObjectExpiredError): + raise e if self.object.objectType == protocol.OBJECT_GETPUBKEY: self.object.checkGetpubkey() diff --git a/src/network/downloadthread.py b/src/network/downloadthread.py new file mode 100644 index 00000000..e8bd44a7 --- /dev/null +++ b/src/network/downloadthread.py @@ -0,0 +1,48 @@ +import Queue +import threading + +import addresses +#from bmconfigparser import BMConfigParser +from debug import logger +from helper_threading import StoppableThread +#from inventory import Inventory +from network.connectionpool import BMConnectionPool +import protocol + +class DownloadThread(threading.Thread, StoppableThread): + maxPending = 500 + requestChunk = 1000 + + def __init__(self): + threading.Thread.__init__(self, name="DownloadThread") + self.initStop() + self.name = "DownloadThread" + logger.info("init download thread") + + def run(self): + while not self._stopped: + requested = 0 + for i in BMConnectionPool().inboundConnections.values() + BMConnectionPool().outboundConnections.values(): + # this may take a while, but it needs a consistency so I think it's better + with i.objectsNewToMeLock: + downloadPending = len(list((k for k, v in i.objectsNewToMe.iteritems() if not v))) + if downloadPending >= DownloadThread.maxPending: + continue + # keys with True values in the dict + request = list((k for k, v in i.objectsNewToMe.iteritems() if v)) + if len(request) == 0: + continue + if len(request) > DownloadThread.requestChunk - downloadPending: + request = request[:DownloadThread.requestChunk - downloadPending] + # mark them as pending + for k in request: + i.objectsNewToMe[k] = False + + payload = addresses.encodeVarint(len(request)) + ''.join(request) + i.writeQueue.put(protocol.CreatePacket('getdata', payload)) + logger.debug("%s:%i Requesting %i objects", i.destination.host, i.destination.port, len(request)) + requested += len(request) + self.stop.wait(1) + + def stopThread(self): + super(DownloadThread, self).stopThread() diff --git a/src/network/invthread.py b/src/network/invthread.py index 6d1828e1..63107a1f 100644 --- a/src/network/invthread.py +++ b/src/network/invthread.py @@ -40,7 +40,6 @@ class InvThread(threading.Thread, StoppableThread): else: BMConnectionPool().handleReceivedObject(data[0], data[1], data[2]) self.holdHash (data[0], data[1]) - #print "Holding hash %i, %s" % (stream, hexlify(hash)) except Queue.Empty: break @@ -50,13 +49,15 @@ class InvThread(threading.Thread, StoppableThread): for stream in connection.streams: try: for hashId in self.collectionOfInvs[iterator][stream]: - if hashId in connection.objectsNewToThem: + try: + with connection.objectsNewToThemLock: + del connection.objectsNewToThem[hashId] hashes.append(hashId) - del connection.objectsNewToThem[hashId] + except KeyError: + pass except KeyError: continue if len(hashes) > 0: - #print "sending inv of %i" % (len(hashes)) connection.writeQueue.put(protocol.CreatePacket('inv', addresses.encodeVarint(len(hashes)) + "".join(hashes))) self.collectionOfInvs[iterator] = {} iterator += 1 diff --git a/src/network/networkthread.py b/src/network/networkthread.py index 7e98bcc0..a4a23103 100644 --- a/src/network/networkthread.py +++ b/src/network/networkthread.py @@ -12,7 +12,6 @@ class BMNetworkThread(threading.Thread, StoppableThread): threading.Thread.__init__(self, name="AsyncoreThread") self.initStop() self.name = "AsyncoreThread" - BMConnectionPool() logger.info("init asyncore thread") def run(self): diff --git a/src/network/objectracker.py b/src/network/objectracker.py index 0c4a8d56..5c0ad147 100644 --- a/src/network/objectracker.py +++ b/src/network/objectracker.py @@ -2,6 +2,7 @@ from Queue import Queue import time from threading import RLock +from debug import logger from inventory import Inventory from network.downloadqueue import DownloadQueue from network.uploadqueue import UploadQueue @@ -33,8 +34,6 @@ class ObjectTracker(object): self.objectsNewToMeLock = RLock() self.objectsNewToThem = {} self.objectsNewToThemLock = RLock() - self.downloadPending = 0 - self.downloadQueue = Queue() self.initInvBloom() self.initAddrBloom() self.lastCleaned = time.time() @@ -77,16 +76,14 @@ class ObjectTracker(object): def handleReceivedInventory(self, hashId): if haveBloom: self.invBloom.add(hashId) - elif hashId in Inventory(): - try: - with self.objectsNewToThemLock: - del self.objectsNewToThem[hashId] - except KeyError: - pass - else: + try: + with self.objectsNewToThemLock: + del self.objectsNewToThem[hashId] + except KeyError: + pass + if hashId not in Inventory(): with self.objectsNewToMeLock: self.objectsNewToMe[hashId] = True -# self.DownloadQueue.put(hashId) def hasAddr(self, addr): if haveBloom: diff --git a/src/network/receivequeuethread.py b/src/network/receivequeuethread.py index b31b82b4..6662a078 100644 --- a/src/network/receivequeuethread.py +++ b/src/network/receivequeuethread.py @@ -1,4 +1,5 @@ import Queue +import sys import threading import time @@ -17,7 +18,6 @@ class ReceiveQueueThread(threading.Thread, StoppableThread): threading.Thread.__init__(self, name="ReceiveQueueThread") self.initStop() self.name = "ReceiveQueueThread" - BMConnectionPool() logger.info("init receive queue thread") def run(self): @@ -25,6 +25,11 @@ class ReceiveQueueThread(threading.Thread, StoppableThread): while not self._stopped and state.shutdown == 0: if lastprinted < int(time.time()): lastprinted = int(time.time()) + try: + sys._getframe(200) + logger.error("Stack depth warning") + except ValueError: + pass processed = 0 for i in BMConnectionPool().inboundConnections.values() + BMConnectionPool().outboundConnections.values(): if self._stopped: @@ -61,9 +66,11 @@ class ReceiveQueueThread(threading.Thread, StoppableThread): # Select all hashes for objects in this stream. bigInvList = {} for stream in connection.streams: - for objHash in Inventory().unexpired_hashes_by_stream(stream): - bigInvList[objHash] = 0 - connection.objectsNewToThem[objHash] = True + # may lock for a long time, but I think it's better than thousands of small locks + with connection.objectsNewToThemLock: + for objHash in Inventory().unexpired_hashes_by_stream(stream): + bigInvList[objHash] = 0 + connection.objectsNewToThem[objHash] = True objectCount = 0 payload = b'' # Now let us start appending all of these hashes together. They will be diff --git a/src/network/tcp.py b/src/network/tcp.py index badb6774..8a6b5705 100644 --- a/src/network/tcp.py +++ b/src/network/tcp.py @@ -212,6 +212,12 @@ class TCPServer(AdvancedDispatcher): if pair is not None: sock, addr = pair state.ownAddresses[state.Peer(sock.getsockname()[0], sock.getsockname()[1])] = True + if len(network.connectionpool.BMConnectionPool().inboundConnections) + \ + len(network.connectionpool.BMConnectionPool().outboundConnections) > \ + BMConfigParser().safeGetInt("bitmessagesettings", "maxtotalconnections") + \ + BMConfigParser().safeGetInt("bitmessagesettings", "maxbootstrapconnections"): + close(sock) + return try: network.connectionpool.BMConnectionPool().addConnection(TCPConnection(sock=sock)) except socket.error: diff --git a/src/network/tls.py b/src/network/tls.py index 115f3faa..9694b4b9 100644 --- a/src/network/tls.py +++ b/src/network/tls.py @@ -13,6 +13,8 @@ import network.asyncore_pollchoose as asyncore import paths import protocol +_DISCONNECTED_SSL = frozenset((ssl.SSL_ERROR_EOF,)) + class TLSDispatcher(AdvancedDispatcher): def __init__(self, address=None, sock=None, certfile=None, keyfile=None, server_side=False, ciphers=protocol.sslProtocolCiphers): @@ -90,6 +92,14 @@ class TLSDispatcher(AdvancedDispatcher): return AdvancedDispatcher.handle_read(self) except AttributeError: return AdvancedDispatcher.handle_read(self) + except ssl.SSLError as err: + if err.errno == ssl.SSL_ERROR_WANT_READ: + return + elif err.errno in _DISCONNECTED_SSL: + self.handle_close() + return + else: + raise def handle_write(self): try: @@ -102,6 +112,14 @@ class TLSDispatcher(AdvancedDispatcher): return AdvancedDispatcher.handle_write(self) except AttributeError: return AdvancedDispatcher.handle_read(self) + except ssl.SSLError as err: + if err.errno == ssl.SSL_ERROR_WANT_WRITE: + return 0 + elif err.errno in _DISCONNECTED_SSL: + self.handle_close() + return 0 + else: + raise def tls_handshake(self): # wait for flush