Asyncore updates

- should prevent the same object being re-requested indefinitely
- locking for object tracking
- move SSL-specific error handling to TLSDispatcher
- observe maximum connection limit when accepting a new connection
- stack depth test (for debugging purposes)
- separate download thread
- connection pool init moved to main thread
This commit is contained in:
Peter Šurda 2017-06-21 12:16:33 +02:00
parent 76fed78211
commit 0cc8589b27
Signed by untrusted user: PeterSurda
GPG Key ID: 0C5F50C0B5F37D87
12 changed files with 125 additions and 46 deletions

View File

@ -58,7 +58,7 @@ from network.networkthread import BMNetworkThread
from network.receivequeuethread import ReceiveQueueThread from network.receivequeuethread import ReceiveQueueThread
from network.announcethread import AnnounceThread from network.announcethread import AnnounceThread
from network.invthread import InvThread from network.invthread import InvThread
#from network.downloadthread import DownloadThread from network.downloadthread import DownloadThread
# Helper Functions # Helper Functions
import helper_bootstrap import helper_bootstrap
@ -261,6 +261,7 @@ class Main:
singleAPIThread.start() singleAPIThread.start()
if BMConfigParser().get("network", "asyncore"): if BMConfigParser().get("network", "asyncore"):
BMConnectionPool()
asyncoreThread = BMNetworkThread() asyncoreThread = BMNetworkThread()
asyncoreThread.daemon = True asyncoreThread.daemon = True
asyncoreThread.start() asyncoreThread.start()
@ -273,6 +274,9 @@ class Main:
state.invThread = InvThread() state.invThread = InvThread()
state.invThread.daemon = True state.invThread.daemon = True
state.invThread.start() state.invThread.start()
downloadThread = DownloadThread()
downloadThread.daemon = True
downloadThread.start()
connectToStream(1) connectToStream(1)

View File

@ -1,5 +1,6 @@
import socket
import Queue import Queue
import socket
import sys
import time import time
import asyncore_pollchoose as asyncore import asyncore_pollchoose as asyncore
@ -42,6 +43,11 @@ class AdvancedDispatcher(asyncore.dispatcher):
if not self.connected: if not self.connected:
return return
maxLoop = 20 maxLoop = 20
try:
sys._getframe(200)
logger.error("Stack depth warning")
except ValueError:
pass
while maxLoop > 0: while maxLoop > 0:
try: try:
# print "Trying to handle state \"%s\"" % (self.state) # print "Trying to handle state \"%s\"" % (self.state)

View File

@ -16,7 +16,6 @@ class AnnounceThread(threading.Thread, StoppableThread):
threading.Thread.__init__(self, name="AnnounceThread") threading.Thread.__init__(self, name="AnnounceThread")
self.initStop() self.initStop()
self.name = "AnnounceThread" self.name = "AnnounceThread"
BMConnectionPool()
logger.info("init announce thread") logger.info("init announce thread")
def run(self): def run(self):

View File

@ -57,7 +57,7 @@ import warnings
import os import os
from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, ECONNRESET, EINVAL, \ from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, ECONNRESET, EINVAL, \
ENOTCONN, ESHUTDOWN, EISCONN, EBADF, ECONNABORTED, EPIPE, EAGAIN, \ ENOTCONN, ESHUTDOWN, EISCONN, EBADF, ECONNABORTED, EPIPE, EAGAIN, \
ECONNREFUSED, EHOSTUNREACH, ENETUNREACH, ENOTSOCK, EINTR, \ ECONNREFUSED, EHOSTUNREACH, ENETUNREACH, ENOTSOCK, EINTR, ETIMEDOUT, \
errorcode errorcode
try: try:
from errno import WSAEWOULDBLOCK from errno import WSAEWOULDBLOCK
@ -68,10 +68,8 @@ try:
except (ImportError, AttributeError): except (ImportError, AttributeError):
WSAENOTSOCK = ENOTSOCK WSAENOTSOCK = ENOTSOCK
from ssl import SSLError, SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE
_DISCONNECTED = frozenset((ECONNRESET, ENOTCONN, ESHUTDOWN, ECONNABORTED, EPIPE, _DISCONNECTED = frozenset((ECONNRESET, ENOTCONN, ESHUTDOWN, ECONNABORTED, EPIPE,
EBADF, ECONNREFUSED, EHOSTUNREACH, ENETUNREACH)) EBADF, ECONNREFUSED, EHOSTUNREACH, ENETUNREACH, ETIMEDOUT))
OP_READ = 1 OP_READ = 1
OP_WRITE = 2 OP_WRITE = 2
@ -563,11 +561,6 @@ class dispatcher:
try: try:
result = self.socket.send(data) result = self.socket.send(data)
return result return result
except SSLError as err:
if err.errno == SSL_ERROR_WANT_WRITE:
return 0
else:
raise
except socket.error as why: except socket.error as why:
if why.args[0] in (EAGAIN, EWOULDBLOCK, WSAEWOULDBLOCK): if why.args[0] in (EAGAIN, EWOULDBLOCK, WSAEWOULDBLOCK):
return 0 return 0
@ -587,11 +580,6 @@ class dispatcher:
return b'' return b''
else: else:
return data return data
except SSLError as err:
if err.errno == SSL_ERROR_WANT_READ:
return b''
else:
raise
except socket.error as why: except socket.error as why:
# winsock sometimes raises ENOTCONN # winsock sometimes raises ENOTCONN
if why.args[0] in (EAGAIN, EWOULDBLOCK, WSAEWOULDBLOCK): if why.args[0] in (EAGAIN, EWOULDBLOCK, WSAEWOULDBLOCK):
@ -616,6 +604,11 @@ class dispatcher:
# cheap inheritance, used to pass all other attribute # cheap inheritance, used to pass all other attribute
# references to the underlying socket object. # references to the underlying socket object.
def __getattr__(self, attr): def __getattr__(self, attr):
try:
sys._getframe(200)
logger.error("Stack depth warning")
except ValueError:
pass
try: try:
retattr = getattr(self.socket, attr) retattr = getattr(self.socket, attr)
except AttributeError: except AttributeError:

View File

@ -119,7 +119,7 @@ class BMProto(AdvancedDispatcher, ObjectTracker):
except BMObjectInvalidError: except BMObjectInvalidError:
logger.debug("object invalid, skipping") logger.debug("object invalid, skipping")
except BMObjectAlreadyHaveError: except BMObjectAlreadyHaveError:
logger.debug("already got object, skipping") logger.debug("%s:%i already got object, skipping", self.destination.host, self.destination.port)
except struct.error: except struct.error:
logger.debug("decoding error, skipping") logger.debug("decoding error, skipping")
else: else:
@ -260,10 +260,7 @@ class BMProto(AdvancedDispatcher, ObjectTracker):
for i in items: for i in items:
self.receiveQueue.put(("inv", i)) self.receiveQueue.put(("inv", i))
self.handleReceivedInventory(i)
payload = addresses.encodeVarint(len(self.objectsNewToMe)) + ''.join(self.objectsNewToMe.keys())
self.writeQueue.put(protocol.CreatePacket('getdata', payload))
return True return True
def bm_command_object(self): def bm_command_object(self):
@ -279,19 +276,23 @@ class BMProto(AdvancedDispatcher, ObjectTracker):
try: try:
self.object.checkEOLSanity() self.object.checkEOLSanity()
self.object.checkStream() self.object.checkStream()
except (BMObjectExpiredError, BMObjectUnwantedStreamError): self.object.checkAlreadyHave()
except (BMObjectExpiredError, BMObjectUnwantedStreamError, BMObjectAlreadyHaveError) as e:
for connection in network.connectionpool.BMConnectionPool().inboundConnections.values() + network.connectionpool.BMConnectionPool().outboundConnections.values(): for connection in network.connectionpool.BMConnectionPool().inboundConnections.values() + network.connectionpool.BMConnectionPool().outboundConnections.values():
try: try:
del connection.objectsNewtoThem[self.object.inventoryHash] with connection.objectsNewToThemLock:
del connection.objectsNewToThem[self.object.inventoryHash]
except KeyError: except KeyError:
pass pass
try: try:
del connection.objectsNewToMe[self.object.inventoryHash] with connection.objectsNewToMeLock:
del connection.objectsNewToMe[self.object.inventoryHash]
except KeyError: except KeyError:
pass pass
if not BMConfigParser().get("inventory", "acceptmismatch"): if not BMConfigParser().get("inventory", "acceptmismatch") or \
raise isinstance(e, BMObjectAlreadyHaveError) or \
self.object.checkAlreadyHave() isinstance(e, BMObjectExpiredError):
raise e
if self.object.objectType == protocol.OBJECT_GETPUBKEY: if self.object.objectType == protocol.OBJECT_GETPUBKEY:
self.object.checkGetpubkey() self.object.checkGetpubkey()

View File

@ -0,0 +1,48 @@
import Queue
import threading
import addresses
#from bmconfigparser import BMConfigParser
from debug import logger
from helper_threading import StoppableThread
#from inventory import Inventory
from network.connectionpool import BMConnectionPool
import protocol
class DownloadThread(threading.Thread, StoppableThread):
maxPending = 500
requestChunk = 1000
def __init__(self):
threading.Thread.__init__(self, name="DownloadThread")
self.initStop()
self.name = "DownloadThread"
logger.info("init download thread")
def run(self):
while not self._stopped:
requested = 0
for i in BMConnectionPool().inboundConnections.values() + BMConnectionPool().outboundConnections.values():
# this may take a while, but it needs a consistency so I think it's better
with i.objectsNewToMeLock:
downloadPending = len(list((k for k, v in i.objectsNewToMe.iteritems() if not v)))
if downloadPending >= DownloadThread.maxPending:
continue
# keys with True values in the dict
request = list((k for k, v in i.objectsNewToMe.iteritems() if v))
if len(request) == 0:
continue
if len(request) > DownloadThread.requestChunk - downloadPending:
request = request[:DownloadThread.requestChunk - downloadPending]
# mark them as pending
for k in request:
i.objectsNewToMe[k] = False
payload = addresses.encodeVarint(len(request)) + ''.join(request)
i.writeQueue.put(protocol.CreatePacket('getdata', payload))
logger.debug("%s:%i Requesting %i objects", i.destination.host, i.destination.port, len(request))
requested += len(request)
self.stop.wait(1)
def stopThread(self):
super(DownloadThread, self).stopThread()

View File

@ -40,7 +40,6 @@ class InvThread(threading.Thread, StoppableThread):
else: else:
BMConnectionPool().handleReceivedObject(data[0], data[1], data[2]) BMConnectionPool().handleReceivedObject(data[0], data[1], data[2])
self.holdHash (data[0], data[1]) self.holdHash (data[0], data[1])
#print "Holding hash %i, %s" % (stream, hexlify(hash))
except Queue.Empty: except Queue.Empty:
break break
@ -50,13 +49,15 @@ class InvThread(threading.Thread, StoppableThread):
for stream in connection.streams: for stream in connection.streams:
try: try:
for hashId in self.collectionOfInvs[iterator][stream]: for hashId in self.collectionOfInvs[iterator][stream]:
if hashId in connection.objectsNewToThem: try:
with connection.objectsNewToThemLock:
del connection.objectsNewToThem[hashId]
hashes.append(hashId) hashes.append(hashId)
del connection.objectsNewToThem[hashId] except KeyError:
pass
except KeyError: except KeyError:
continue continue
if len(hashes) > 0: if len(hashes) > 0:
#print "sending inv of %i" % (len(hashes))
connection.writeQueue.put(protocol.CreatePacket('inv', addresses.encodeVarint(len(hashes)) + "".join(hashes))) connection.writeQueue.put(protocol.CreatePacket('inv', addresses.encodeVarint(len(hashes)) + "".join(hashes)))
self.collectionOfInvs[iterator] = {} self.collectionOfInvs[iterator] = {}
iterator += 1 iterator += 1

View File

@ -12,7 +12,6 @@ class BMNetworkThread(threading.Thread, StoppableThread):
threading.Thread.__init__(self, name="AsyncoreThread") threading.Thread.__init__(self, name="AsyncoreThread")
self.initStop() self.initStop()
self.name = "AsyncoreThread" self.name = "AsyncoreThread"
BMConnectionPool()
logger.info("init asyncore thread") logger.info("init asyncore thread")
def run(self): def run(self):

View File

@ -2,6 +2,7 @@ from Queue import Queue
import time import time
from threading import RLock from threading import RLock
from debug import logger
from inventory import Inventory from inventory import Inventory
from network.downloadqueue import DownloadQueue from network.downloadqueue import DownloadQueue
from network.uploadqueue import UploadQueue from network.uploadqueue import UploadQueue
@ -33,8 +34,6 @@ class ObjectTracker(object):
self.objectsNewToMeLock = RLock() self.objectsNewToMeLock = RLock()
self.objectsNewToThem = {} self.objectsNewToThem = {}
self.objectsNewToThemLock = RLock() self.objectsNewToThemLock = RLock()
self.downloadPending = 0
self.downloadQueue = Queue()
self.initInvBloom() self.initInvBloom()
self.initAddrBloom() self.initAddrBloom()
self.lastCleaned = time.time() self.lastCleaned = time.time()
@ -77,16 +76,14 @@ class ObjectTracker(object):
def handleReceivedInventory(self, hashId): def handleReceivedInventory(self, hashId):
if haveBloom: if haveBloom:
self.invBloom.add(hashId) self.invBloom.add(hashId)
elif hashId in Inventory(): try:
try: with self.objectsNewToThemLock:
with self.objectsNewToThemLock: del self.objectsNewToThem[hashId]
del self.objectsNewToThem[hashId] except KeyError:
except KeyError: pass
pass if hashId not in Inventory():
else:
with self.objectsNewToMeLock: with self.objectsNewToMeLock:
self.objectsNewToMe[hashId] = True self.objectsNewToMe[hashId] = True
# self.DownloadQueue.put(hashId)
def hasAddr(self, addr): def hasAddr(self, addr):
if haveBloom: if haveBloom:

View File

@ -1,4 +1,5 @@
import Queue import Queue
import sys
import threading import threading
import time import time
@ -17,7 +18,6 @@ class ReceiveQueueThread(threading.Thread, StoppableThread):
threading.Thread.__init__(self, name="ReceiveQueueThread") threading.Thread.__init__(self, name="ReceiveQueueThread")
self.initStop() self.initStop()
self.name = "ReceiveQueueThread" self.name = "ReceiveQueueThread"
BMConnectionPool()
logger.info("init receive queue thread") logger.info("init receive queue thread")
def run(self): def run(self):
@ -25,6 +25,11 @@ class ReceiveQueueThread(threading.Thread, StoppableThread):
while not self._stopped and state.shutdown == 0: while not self._stopped and state.shutdown == 0:
if lastprinted < int(time.time()): if lastprinted < int(time.time()):
lastprinted = int(time.time()) lastprinted = int(time.time())
try:
sys._getframe(200)
logger.error("Stack depth warning")
except ValueError:
pass
processed = 0 processed = 0
for i in BMConnectionPool().inboundConnections.values() + BMConnectionPool().outboundConnections.values(): for i in BMConnectionPool().inboundConnections.values() + BMConnectionPool().outboundConnections.values():
if self._stopped: if self._stopped:
@ -61,9 +66,11 @@ class ReceiveQueueThread(threading.Thread, StoppableThread):
# Select all hashes for objects in this stream. # Select all hashes for objects in this stream.
bigInvList = {} bigInvList = {}
for stream in connection.streams: for stream in connection.streams:
for objHash in Inventory().unexpired_hashes_by_stream(stream): # may lock for a long time, but I think it's better than thousands of small locks
bigInvList[objHash] = 0 with connection.objectsNewToThemLock:
connection.objectsNewToThem[objHash] = True for objHash in Inventory().unexpired_hashes_by_stream(stream):
bigInvList[objHash] = 0
connection.objectsNewToThem[objHash] = True
objectCount = 0 objectCount = 0
payload = b'' payload = b''
# Now let us start appending all of these hashes together. They will be # Now let us start appending all of these hashes together. They will be

View File

@ -212,6 +212,12 @@ class TCPServer(AdvancedDispatcher):
if pair is not None: if pair is not None:
sock, addr = pair sock, addr = pair
state.ownAddresses[state.Peer(sock.getsockname()[0], sock.getsockname()[1])] = True state.ownAddresses[state.Peer(sock.getsockname()[0], sock.getsockname()[1])] = True
if len(network.connectionpool.BMConnectionPool().inboundConnections) + \
len(network.connectionpool.BMConnectionPool().outboundConnections) > \
BMConfigParser().safeGetInt("bitmessagesettings", "maxtotalconnections") + \
BMConfigParser().safeGetInt("bitmessagesettings", "maxbootstrapconnections"):
close(sock)
return
try: try:
network.connectionpool.BMConnectionPool().addConnection(TCPConnection(sock=sock)) network.connectionpool.BMConnectionPool().addConnection(TCPConnection(sock=sock))
except socket.error: except socket.error:

View File

@ -13,6 +13,8 @@ import network.asyncore_pollchoose as asyncore
import paths import paths
import protocol import protocol
_DISCONNECTED_SSL = frozenset((ssl.SSL_ERROR_EOF,))
class TLSDispatcher(AdvancedDispatcher): class TLSDispatcher(AdvancedDispatcher):
def __init__(self, address=None, sock=None, def __init__(self, address=None, sock=None,
certfile=None, keyfile=None, server_side=False, ciphers=protocol.sslProtocolCiphers): certfile=None, keyfile=None, server_side=False, ciphers=protocol.sslProtocolCiphers):
@ -90,6 +92,14 @@ class TLSDispatcher(AdvancedDispatcher):
return AdvancedDispatcher.handle_read(self) return AdvancedDispatcher.handle_read(self)
except AttributeError: except AttributeError:
return AdvancedDispatcher.handle_read(self) return AdvancedDispatcher.handle_read(self)
except ssl.SSLError as err:
if err.errno == ssl.SSL_ERROR_WANT_READ:
return
elif err.errno in _DISCONNECTED_SSL:
self.handle_close()
return
else:
raise
def handle_write(self): def handle_write(self):
try: try:
@ -102,6 +112,14 @@ class TLSDispatcher(AdvancedDispatcher):
return AdvancedDispatcher.handle_write(self) return AdvancedDispatcher.handle_write(self)
except AttributeError: except AttributeError:
return AdvancedDispatcher.handle_read(self) return AdvancedDispatcher.handle_read(self)
except ssl.SSLError as err:
if err.errno == ssl.SSL_ERROR_WANT_WRITE:
return 0
elif err.errno in _DISCONNECTED_SSL:
self.handle_close()
return 0
else:
raise
def tls_handshake(self): def tls_handshake(self):
# wait for flush # wait for flush