From ba4162d7fe1a81952ce105431a63d9ede84021ba Mon Sep 17 00:00:00 2001 From: Peter Surda Date: Thu, 6 Jul 2017 19:45:36 +0200 Subject: [PATCH] Asyncore update - get rid of per-connection writeQueue/receiveQueue, and instead use strings and locking - minor code cleanup - all state handlers now should set expectBytes - almost all data processing happens in ReceiveDataThread, and AsyncoreThread is almost only I/O (plus TLS). AsyncoreThread simply puts the connection object into the queue when it has some data for processing - allow poll, epoll and kqueue handlers. kqueue is untested and unoptimised, poll and epoll seem to work ok (linux) - stack depth threshold handler in decode_payload_content, this is recursive and I think was causing occasional RuntimeErrors. Fixes #964 - longer asyncore loops, as now data is handled in ReceiveDataThread - randomise node order when deciding what to download. Should prevent retries being stuck to the same node - socks cleanup (socks5 works ok, socks4a untested but should work too) --- src/network/advanceddispatcher.py | 66 +++++++++--------------- src/network/announcethread.py | 2 +- src/network/asyncore_pollchoose.py | 74 +++++++++++++++------------ src/network/bmproto.py | 63 ++++++++++++----------- src/network/connectionpool.py | 11 ++-- src/network/downloadthread.py | 8 ++- src/network/invthread.py | 2 +- src/network/receivequeuethread.py | 81 ++++++------------------------ src/network/socks4a.py | 41 ++++++++------- src/network/socks5.py | 77 +++++++++++++--------------- src/network/tcp.py | 51 ++++++++++++++++--- src/network/tls.py | 36 ++++++++----- src/network/udp.py | 18 +++---- src/queues.py | 1 + 14 files changed, 257 insertions(+), 274 deletions(-) diff --git a/src/network/advanceddispatcher.py b/src/network/advanceddispatcher.py index 0945d764..57bd4f41 100644 --- a/src/network/advanceddispatcher.py +++ b/src/network/advanceddispatcher.py @@ -1,6 +1,7 @@ import Queue import socket import sys +import threading import time import asyncore_pollchoose as asyncore @@ -14,41 +15,43 @@ class AdvancedDispatcher(asyncore.dispatcher): asyncore.dispatcher.__init__(self, sock) self.read_buf = b"" self.write_buf = b"" - self.writeQueue = Queue.Queue() - self.receiveQueue = Queue.Queue() self.state = "init" self.lastTx = time.time() self.sentBytes = 0 self.receivedBytes = 0 self.expectBytes = 0 + self.readLock = threading.RLock() + self.writeLock = threading.RLock() + + def append_write_buf(self, data): + if data: + with self.writeLock: + self.write_buf += data def slice_write_buf(self, length=0): if length > 0: - self.write_buf = self.write_buf[length:] + with self.writeLock: + self.write_buf = self.write_buf[length:] def slice_read_buf(self, length=0): if length > 0: - self.read_buf = self.read_buf[length:] - - def read_buf_sufficient(self, length=0): - if len(self.read_buf) < length: - return False - return True + with self.readLock: + self.read_buf = self.read_buf[length:] def process(self): - if self.state != "tls_handshake" and not self.read_buf: - return if not self.connected: return - maxLoop = 20 - while maxLoop > 0: + loop = 0 + while len(self.read_buf) >= self.expectBytes: + loop += 1 + if loop > 1000: + logger.error("Stuck at state %s, report this bug please", self.state) + break try: if getattr(self, "state_" + str(self.state))() is False: break except AttributeError: - # missing state raise - maxLoop -= 1 def set_state(self, state, length=0, expectBytes=0): self.expectBytes = expectBytes @@ -57,7 +60,7 @@ class AdvancedDispatcher(asyncore.dispatcher): def writable(self): return asyncore.dispatcher.writable(self) and \ - (self.connecting or len(self.write_buf) > 0 or not self.writeQueue.empty()) + (self.connecting or self.write_buf) def readable(self): return asyncore.dispatcher.readable(self) and \ @@ -68,28 +71,20 @@ class AdvancedDispatcher(asyncore.dispatcher): downloadBytes = AdvancedDispatcher._buf_len if asyncore.maxDownloadRate > 0: downloadBytes = asyncore.downloadBucket - if self.expectBytes > 0 and downloadBytes > self.expectBytes: - downloadBytes = self.expectBytes + if self.expectBytes > 0 and downloadBytes > self.expectBytes - len(self.read_buf): + downloadBytes = self.expectBytes - len(self.read_buf) if downloadBytes > 0: newData = self.recv(downloadBytes) self.receivedBytes += len(newData) - if self.expectBytes > 0: - self.expectBytes -= len(newData) asyncore.update_received(len(newData)) - self.read_buf += newData - self.process() + with self.readLock: + self.read_buf += newData def handle_write(self): self.lastTx = time.time() bufSize = AdvancedDispatcher._buf_len if asyncore.maxUploadRate > 0: bufSize = asyncore.uploadBucket - while len(self.write_buf) < bufSize: - try: - self.write_buf += self.writeQueue.get(False) - self.writeQueue.task_done() - except Queue.Empty: - break if bufSize <= 0: return if self.write_buf: @@ -107,25 +102,12 @@ class AdvancedDispatcher(asyncore.dispatcher): def handle_connect(self): self.lastTx = time.time() - self.process() def state_close(self): - pass + return False def handle_close(self): self.read_buf = b"" self.write_buf = b"" self.state = "close" - while True: - try: - self.writeQueue.get(False) - self.writeQueue.task_done() - except Queue.Empty: - break - while True: - try: - self.receiveQueue.get(False) - self.receiveQueue.task_done() - except Queue.Empty: - break asyncore.dispatcher.close(self) diff --git a/src/network/announcethread.py b/src/network/announcethread.py index 3adcae48..354cfaa8 100644 --- a/src/network/announcethread.py +++ b/src/network/announcethread.py @@ -30,4 +30,4 @@ class AnnounceThread(threading.Thread, StoppableThread): for connection in BMConnectionPool().udpSockets.values(): for stream in state.streamsInWhichIAmParticipating: addr = (stream, state.Peer('127.0.0.1', BMConfigParser().safeGetInt("bitmessagesettings", "port")), time.time()) - connection.writeQueue.put(BMProto.assembleAddr([addr])) + connection.append_write_buf(BMProto.assembleAddr([addr])) diff --git a/src/network/asyncore_pollchoose.py b/src/network/asyncore_pollchoose.py index 3f188812..07b2c120 100644 --- a/src/network/asyncore_pollchoose.py +++ b/src/network/asyncore_pollchoose.py @@ -249,13 +249,16 @@ def poll_poller(timeout=0.0, map=None): newflags |= OP_WRITE else: newflags &= ~ OP_WRITE - if newflags != obj.flags: - obj.flags = newflags - if obj.poller_registered: - poll_poller.pollster.modify(fd, flags) - else: - poll_poller.pollster.register(fd, flags) - obj.poller_registered = True + if newflags != obj.poller_flags: + obj.poller_flags = newflags + try: + if obj.poller_registered: + poll_poller.pollster.modify(fd, flags) + else: + poll_poller.pollster.register(fd, flags) + obj.poller_registered = True + except IOError: + pass try: r = poll_poller.pollster.poll(timeout) except KeyboardInterrupt: @@ -292,16 +295,19 @@ def epoll_poller(timeout=0.0, map=None): newflags |= OP_WRITE else: newflags &= ~ OP_WRITE - if newflags != obj.flags: - obj.flags = newflags + if newflags != obj.poller_flags: + obj.poller_flags = newflags # Only check for exceptions if object was either readable # or writable. flags |= select.POLLERR | select.POLLHUP | select.POLLNVAL - if obj.poller_registered: - epoll_poller.pollster.modify(fd, flags) - else: - epoll_poller.pollster.register(fd, flags) - obj.poller_registered = True + try: + if obj.poller_registered: + epoll_poller.pollster.modify(fd, flags) + else: + epoll_poller.pollster.register(fd, flags) + obj.poller_registered = True + except IOError: + pass try: r = epoll_poller.pollster.poll(timeout) except select.error, err: @@ -329,9 +335,12 @@ def kqueue_poller(timeout=0.0, map=None): if obj.writable(): filter |= select.KQ_FILTER_WRITE if filter: - ev = select.kevent(fd, filter=filter, flags=flags) - kqueue.control([ev], 0) - selectables += 1 + try: + ev = select.kevent(fd, filter=filter, flags=flags) + kqueue.control([ev], 0) + selectables += 1 + except IOError: + pass events = kqueue.control(None, selectables, timeout) for event in random.sample(events, len(events)): @@ -347,25 +356,23 @@ def kqueue_poller(timeout=0.0, map=None): def loop(timeout=30.0, use_poll=False, map=None, count=None, - poller=select_poller): + poller=None): if map is None: map = socket_map # code which grants backward compatibility with "use_poll" # argument which should no longer be used in favor of # "poller" - if hasattr(select, 'epoll'): - poller = epoll_poller - elif hasattr(select, 'kqueue'): - poller = kqueue_poller - elif hasattr(select, 'poll'): - poller = poll_poller - elif hasattr(select, 'select'): - poller = select_poller - poller = select_poller - -# print "Poll loop using %s" % (poller.__name__) + if poller is None: + if hasattr(select, 'epoll'): + poller = epoll_poller + elif hasattr(select, 'kqueue'): + poller = kqueue_poller + elif hasattr(select, 'poll'): + poller = poll_poller + elif hasattr(select, 'select'): + poller = select_poller if count is None: while map: @@ -400,7 +407,7 @@ class dispatcher: addr = None ignore_log_types = frozenset(['warning']) poller_registered = False - flags = 0 + poller_flags = 0 # don't do network IO with a smaller bucket than this minTx = 1500 @@ -456,23 +463,26 @@ class dispatcher: if map is None: map = self._map map[self._fileno] = self + self.poller_flags = 0 def del_channel(self, map=None): fd = self._fileno if map is None: map = self._map + self.poller_flags = 0 + self.poller_registered = False if fd in map: #self.log_info('closing channel %d:%s' % (fd, self)) del map[fd] self._fileno = None try: epoll_poller.pollster.unregister(fd) - except (AttributeError, KeyError, TypeError): + except (AttributeError, KeyError, TypeError, IOError): # no epoll used, or not registered pass try: poll_poller.pollster.unregister(fd) - except (AttributeError, KeyError, TypeError): + except (AttributeError, KeyError, TypeError, IOError): # no poll used, or not registered pass diff --git a/src/network/bmproto.py b/src/network/bmproto.py index abf2f0ec..e00d6d0a 100644 --- a/src/network/bmproto.py +++ b/src/network/bmproto.py @@ -5,6 +5,7 @@ import math import time import socket import struct +import sys from addresses import calculateInventoryHash from bmconfigparser import BMConfigParser @@ -67,15 +68,12 @@ class BMProto(AdvancedDispatcher, ObjectTracker): self.object = None def state_bm_header(self): - if len(self.read_buf) < protocol.Header.size: - #print "Length below header size" - return False self.magic, self.command, self.payloadLength, self.checksum = protocol.Header.unpack(self.read_buf[:protocol.Header.size]) self.command = self.command.rstrip('\x00') if self.magic != 0xE9BEB4D9: # skip 1 byte in order to sync + self.set_state("bm_header", length=1) self.bm_proto_reset() - self.set_state("bm_header", length=1, expectBytes=protocol.Header.size) logger.debug("Bad magic") self.handle_close("Bad magic") return False @@ -85,10 +83,6 @@ class BMProto(AdvancedDispatcher, ObjectTracker): return True def state_bm_command(self): - if len(self.read_buf) < self.payloadLength: - #print "Length below announced object length" - return False - #logger.debug("%s:%i: command %s (%ib)", self.destination.host, self.destination.port, self.command, self.payloadLength) self.payload = self.read_buf[:self.payloadLength] if self.checksum != hashlib.sha512(self.payload).digest()[0:4]: logger.debug("Bad checksum, ignoring") @@ -127,7 +121,7 @@ class BMProto(AdvancedDispatcher, ObjectTracker): self.handle_close("Invalid command %s" % (self.command)) return False if retval: - self.set_state("bm_header", length=self.payloadLength, expectBytes=protocol.Header.size) + self.set_state("bm_header", length=self.payloadLength) self.bm_proto_reset() # else assume the command requires a different state to follow return True @@ -173,6 +167,12 @@ class BMProto(AdvancedDispatcher, ObjectTracker): retval = [] size = None i = 0 + try: + sys._getframe(200) + logger.error("Stack depth warning, pattern: %s", pattern) + return + except ValueError: + pass while i < len(pattern): if pattern[i] in "0123456789" and (i == 0 or pattern[i-1] not in "lL"): @@ -237,8 +237,13 @@ class BMProto(AdvancedDispatcher, ObjectTracker): # skip? if time.time() < self.skipUntil: return True + #TODO make this more asynchronous and allow reordering for i in items: - self.receiveQueue.put(("object", i)) + try: + self.append_write_buf(protocol.CreatePacket('object', Inventory()[i].payload)) + except KeyError: + self.antiIntersectionDelay() + logger.info('%s asked for an object we don\'t have.', self.destination) return True def bm_command_inv(self): @@ -251,7 +256,7 @@ class BMProto(AdvancedDispatcher, ObjectTracker): pass for i in items: - self.receiveQueue.put(("inv", i)) + self.handleReceivedInventory(i) return True @@ -321,7 +326,7 @@ class BMProto(AdvancedDispatcher, ObjectTracker): return True def bm_command_ping(self): - self.writeQueue.put(protocol.CreatePacket('pong')) + self.append_write_buf(protocol.CreatePacket('pong')) return True def bm_command_pong(self): @@ -332,11 +337,10 @@ class BMProto(AdvancedDispatcher, ObjectTracker): self.verackReceived = True if self.verackSent: if self.isSSL: - self.set_state("tls_init", self.payloadLength) - self.bm_proto_reset() + self.set_state("tls_init", length=self.payloadLength, expectBytes=0) return False - self.set_connection_fully_established() - return True + self.set_state("connection_fully_established", length=self.payloadLength, expectBytes=0) + return False return True def bm_command_version(self): @@ -345,20 +349,20 @@ class BMProto(AdvancedDispatcher, ObjectTracker): self.nonce = struct.pack('>Q', self.nonce) self.timeOffset = self.timestamp - int(time.time()) logger.debug("remoteProtocolVersion: %i", self.remoteProtocolVersion) - logger.debug("services: %08X", self.services) + logger.debug("services: 0x%08X", self.services) logger.debug("time offset: %i", self.timestamp - int(time.time())) logger.debug("my external IP: %s", self.sockNode.host) - logger.debug("remote node incoming port: %i", self.peerNode.port) + logger.debug("remote node incoming address: %s:%i", self.destination.host, self.peerNode.port) logger.debug("user agent: %s", self.userAgent) logger.debug("streams: [%s]", ",".join(map(str,self.streams))) if not self.peerValidityChecks(): # TODO ABORT return True #shared.connectedHostsList[self.destination] = self.streams[0] - self.writeQueue.put(protocol.CreatePacket('verack')) + self.append_write_buf(protocol.CreatePacket('verack')) self.verackSent = True if not self.isOutbound: - self.writeQueue.put(protocol.assembleVersionMessage(self.destination.host, self.destination.port, \ + self.append_write_buf(protocol.assembleVersionMessage(self.destination.host, self.destination.port, \ network.connectionpool.BMConnectionPool().streams, True)) #print "%s:%i: Sending version" % (self.destination.host, self.destination.port) if ((self.services & protocol.NODE_SSL == protocol.NODE_SSL) and @@ -366,29 +370,28 @@ class BMProto(AdvancedDispatcher, ObjectTracker): self.isSSL = True if self.verackReceived: if self.isSSL: - self.set_state("tls_init", self.payloadLength) - self.bm_proto_reset() + self.set_state("tls_init", length=self.payloadLength, expectBytes=0) return False - self.set_connection_fully_established() - return True + self.set_state("connection_fully_established", length=self.payloadLength, expectBytes=0) + return False return True def peerValidityChecks(self): if self.remoteProtocolVersion < 3: - self.writeQueue.put(protocol.assembleErrorMessage(fatal=2, + self.append_write_buf(protocol.assembleErrorMessage(fatal=2, errorText="Your is using an old protocol. Closing connection.")) logger.debug ('Closing connection to old protocol version %s, node: %s', str(self.remoteProtocolVersion), str(self.destination)) return False if self.timeOffset > BMProto.maxTimeOffset: - self.writeQueue.put(protocol.assembleErrorMessage(fatal=2, + self.append_write_buf(protocol.assembleErrorMessage(fatal=2, errorText="Your time is too far in the future compared to mine. Closing connection.")) logger.info("%s's time is too far in the future (%s seconds). Closing connection to it.", self.destination, self.timeOffset) shared.timeOffsetWrongCount += 1 return False elif self.timeOffset < -BMProto.maxTimeOffset: - self.writeQueue.put(protocol.assembleErrorMessage(fatal=2, + self.append_write_buf(protocol.assembleErrorMessage(fatal=2, errorText="Your time is too far in the past compared to mine. Closing connection.")) logger.info("%s's time is too far in the past (timeOffset %s seconds). Closing connection to it.", self.destination, self.timeOffset) @@ -397,7 +400,7 @@ class BMProto(AdvancedDispatcher, ObjectTracker): else: shared.timeOffsetWrongCount = 0 if not self.streams: - self.writeQueue.put(protocol.assembleErrorMessage(fatal=2, + self.append_write_buf(protocol.assembleErrorMessage(fatal=2, errorText="We don't have shared stream interests. Closing connection.")) logger.debug ('Closed connection to %s because there is no overlapping interest in streams.', str(self.destination)) @@ -405,7 +408,7 @@ class BMProto(AdvancedDispatcher, ObjectTracker): if self.destination in network.connectionpool.BMConnectionPool().inboundConnections: try: if not protocol.checkSocksIP(self.destination.host): - self.writeQueue.put(protocol.assembleErrorMessage(fatal=2, + self.append_write_buf(protocol.assembleErrorMessage(fatal=2, errorText="Too many connections from your IP. Closing connection.")) logger.debug ('Closed connection to %s because we are already connected to that IP.', str(self.destination)) @@ -413,7 +416,7 @@ class BMProto(AdvancedDispatcher, ObjectTracker): except: pass if self.nonce == protocol.eightBytesOfRandomDataUsedToDetectConnectionsToSelf: - self.writeQueue.put(protocol.assembleErrorMessage(fatal=2, + self.append_write_buf(protocol.assembleErrorMessage(fatal=2, errorText="I'm connected to myself. Closing connection.")) logger.debug ("Closed connection to %s because I'm connected to myself.", str(self.destination)) diff --git a/src/network/connectionpool.py b/src/network/connectionpool.py index c4aa53d8..4c2b4c6c 100644 --- a/src/network/connectionpool.py +++ b/src/network/connectionpool.py @@ -19,7 +19,6 @@ import state @Singleton class BMConnectionPool(object): - def __init__(self): asyncore.set_rates( BMConfigParser().safeGetInt("bitmessagesettings", "maxdownloadrate") * 1024, @@ -30,7 +29,7 @@ class BMConnectionPool(object): self.udpSockets = {} self.streams = [] self.lastSpawned = 0 - self.spawnWait = 0.3 + self.spawnWait = 2 self.bootstrapped = False def handleReceivedObject(self, streamNumber, hashid, connection = None): @@ -187,12 +186,10 @@ class BMConnectionPool(object): i.handle_close() logger.info('Stopped udp sockets.') -# while len(asyncore.socket_map) > 0 and state.shutdown == 0: -# print "loop, state = %s" % (proxy.state) loopTime = float(self.spawnWait) if self.lastSpawned < time.time() - self.spawnWait: - loopTime = 1.0 - asyncore.loop(timeout=loopTime, count=10) + loopTime = 2.0 + asyncore.loop(timeout=loopTime, count=1000) reaper = [] for i in self.inboundConnections.values() + self.outboundConnections.values(): @@ -201,7 +198,7 @@ class BMConnectionPool(object): minTx -= 300 - 20 if i.lastTx < minTx: if i.fullyEstablished: - i.writeQueue.put(protocol.CreatePacket('ping')) + i.append_write_buf(protocol.CreatePacket('ping')) else: i.handle_close("Timeout (%is)" % (time.time() - i.lastTx)) for i in self.inboundConnections.values() + self.outboundConnections.values() + self.listeningSockets.values() + self.udpSockets.values(): diff --git a/src/network/downloadthread.py b/src/network/downloadthread.py index 9c7e92da..c42d7e1c 100644 --- a/src/network/downloadthread.py +++ b/src/network/downloadthread.py @@ -1,3 +1,4 @@ +import random import threading import time @@ -32,7 +33,10 @@ class DownloadThread(threading.Thread, StoppableThread): def run(self): while not self._stopped: requested = 0 - for i in BMConnectionPool().inboundConnections.values() + BMConnectionPool().outboundConnections.values(): + # Choose downloading peers randomly + connections = BMConnectionPool().inboundConnections.values() + BMConnectionPool().outboundConnections.values() + random.shuffle(connections) + for i in connections: now = time.time() timedOut = now - DownloadThread.requestTimeout # this may take a while, but it needs a consistency so I think it's better to lock a bigger chunk @@ -52,7 +56,7 @@ class DownloadThread(threading.Thread, StoppableThread): self.pending[k] = now payload = addresses.encodeVarint(len(request)) + ''.join(request) - i.writeQueue.put(protocol.CreatePacket('getdata', payload)) + i.append_write_buf(protocol.CreatePacket('getdata', payload)) logger.debug("%s:%i Requesting %i objects", i.destination.host, i.destination.port, len(request)) requested += len(request) if time.time() >= self.lastCleaned + DownloadThread.cleanInterval: diff --git a/src/network/invthread.py b/src/network/invthread.py index 398fecf0..e5ab890a 100644 --- a/src/network/invthread.py +++ b/src/network/invthread.py @@ -42,7 +42,7 @@ class InvThread(threading.Thread, StoppableThread): except KeyError: continue if hashes: - connection.writeQueue.put(protocol.CreatePacket('inv', \ + connection.append_write_buf(protocol.CreatePacket('inv', \ addresses.encodeVarint(len(hashes)) + "".join(hashes))) invQueue.iterate() self.stop.wait(1) diff --git a/src/network/receivequeuethread.py b/src/network/receivequeuethread.py index 442c755a..120d15e2 100644 --- a/src/network/receivequeuethread.py +++ b/src/network/receivequeuethread.py @@ -10,6 +10,7 @@ from helper_threading import StoppableThread from inventory import Inventory from network.connectionpool import BMConnectionPool from network.bmproto import BMProto +from queues import receiveDataQueue import protocol import state @@ -21,73 +22,23 @@ class ReceiveQueueThread(threading.Thread, StoppableThread): logger.info("init receive queue thread") def run(self): - lastprinted = int(time.time()) while not self._stopped and state.shutdown == 0: - if lastprinted < int(time.time()): - lastprinted = int(time.time()) -# try: -# sys._getframe(200) -# logger.error("Stack depth warning") -# except ValueError: -# pass - processed = 0 - for i in BMConnectionPool().inboundConnections.values() + BMConnectionPool().outboundConnections.values(): - if self._stopped: - break - try: - command, args = i.receiveQueue.get(False) - except Queue.Empty: - continue - processed += 1 - try: - getattr(self, "command_" + str(command))(i, args) - i.receiveQueue.task_done() - except AttributeError: - i.receiveQueue.task_done() - # missing command - raise - if processed == 0: - self.stop.wait(2) + try: + connection = receiveDataQueue.get(block=True, timeout=1) + receiveDataQueue.task_done() + except Queue.Empty: + continue - def command_object(self, connection, objHash): - try: - connection.writeQueue.put(protocol.CreatePacket('object', Inventory()[objHash].payload)) - except KeyError: - connection.antiIntersectionDelay() - logger.info('%s asked for an object we don\'t have.', connection.destination) - - def command_biginv(self, connection, dummy): - def sendChunk(): - if objectCount == 0: - return - logger.debug('Sending huge inv message with %i objects to just this one peer', objectCount) - connection.writeQueue.put(protocol.CreatePacket('inv', addresses.encodeVarint(objectCount) + payload)) - - # Select all hashes for objects in this stream. - bigInvList = {} - for stream in connection.streams: - # may lock for a long time, but I think it's better than thousands of small locks - with connection.objectsNewToThemLock: - for objHash in Inventory().unexpired_hashes_by_stream(stream): - bigInvList[objHash] = 0 - connection.objectsNewToThem[objHash] = time.time() - objectCount = 0 - payload = b'' - # Now let us start appending all of these hashes together. They will be - # sent out in a big inv message to our new peer. - for hash, storedValue in bigInvList.items(): - payload += hash - objectCount += 1 - if objectCount >= BMProto.maxObjectCount: - self.sendChunk() - payload = b'' - objectCount = 0 - - # flush - sendChunk() - - def command_inv(self, connection, hashId): - connection.handleReceivedInventory(hashId) + if self._stopped: + break + # cycle as long as there is data + # methods should return False if there isn't enough data, or the connection is to be aborted + try: + while connection.process(): + pass + except AttributeError: + # missing command + logger.error("Unknown state %s, ignoring", connection.state) def stopThread(self): super(ReceiveQueueThread, self).stopThread() diff --git a/src/network/socks4a.py b/src/network/socks4a.py index 350e163e..d6cf2ad8 100644 --- a/src/network/socks4a.py +++ b/src/network/socks4a.py @@ -19,10 +19,9 @@ class Socks4a(Proxy): def state_init(self): self.set_state("auth_done", 0) + return True def state_pre_connect(self): - if not self.read_buf_sufficient(8): - return False # Get the response if self.read_buf[0:1] != chr(0x00).encode(): # bad data @@ -44,14 +43,12 @@ class Socks4a(Proxy): self.__proxypeername = (socket.inet_ntoa(self.ipaddr), self.destination[1]) else: self.__proxypeername = (self.destination[0], self.destport) - self.set_state("proxy_handshake_done", 8) + self.set_state("proxy_handshake_done", length=8) + return True def proxy_sock_name(self): return socket.inet_ntoa(self.__proxysockname[0]) - def state_socks_handshake_done(self): - return False - class Socks4aConnection(Socks4a): def __init__(self, address): @@ -60,33 +57,34 @@ class Socks4aConnection(Socks4a): def state_auth_done(self): # Now we can request the actual connection rmtrslv = False - self.writeQueue.put(struct.pack('>BBH', 0x04, 0x01, self.destination[1])) + self.append_write_buf(struct.pack('>BBH', 0x04, 0x01, self.destination[1])) # If the given destination address is an IP address, we'll # use the IPv4 address request even if remote resolving was specified. try: self.ipaddr = socket.inet_aton(self.destination[0]) - self.writeQueue.put(self.ipaddr) + self.append_write_buf(self.ipaddr) except socket.error: # Well it's not an IP number, so it's probably a DNS name. if Proxy._remote_dns: # Resolve remotely rmtrslv = True self.ipaddr = None - self.writeQueue.put(struct.pack("BBBB", 0x00, 0x00, 0x00, 0x01)) + self.append_write_buf(struct.pack("BBBB", 0x00, 0x00, 0x00, 0x01)) else: # Resolve locally self.ipaddr = socket.inet_aton(socket.gethostbyname(self.destination[0])) - self.writeQueue.put(self.ipaddr) + self.append_write_buf(self.ipaddr) if self._auth: - self.writeQueue.put(self._auth[0]) - self.writeQueue.put(chr(0x00).encode()) + self.append_write_buf(self._auth[0]) + self.append_write_buf(chr(0x00).encode()) if rmtrslv: - self.writeQueue.put(self.destination[0] + chr(0x00).encode()) - self.set_state("pre_connect", 0) + self.append_write_buf(self.destination[0] + chr(0x00).encode()) + self.set_state("pre_connect", length=0, expectBytes=8) + return True def state_pre_connect(self): try: - Socks4a.state_pre_connect(self) + return Socks4a.state_pre_connect(self) except Socks4aError as e: self.handle_close(e.message) @@ -99,13 +97,14 @@ class Socks4aResolver(Socks4a): def state_auth_done(self): # Now we can request the actual connection - self.writeQueue.put(struct.pack('>BBH', 0x04, 0xF0, self.destination[1])) - self.writeQueue.put(struct.pack("BBBB", 0x00, 0x00, 0x00, 0x01)) + self.append_write_buf(struct.pack('>BBH', 0x04, 0xF0, self.destination[1])) + self.append_write_buf(struct.pack("BBBB", 0x00, 0x00, 0x00, 0x01)) if self._auth: - self.writeQueue.put(self._auth[0]) - self.writeQueue.put(chr(0x00).encode()) - self.writeQueue.put(self.host + chr(0x00).encode()) - self.set_state("pre_connect", 0) + self.append_write_buf(self._auth[0]) + self.append_write_buf(chr(0x00).encode()) + self.append_write_buf(self.host + chr(0x00).encode()) + self.set_state("pre_connect", length=0, expectBytes=8) + return True def resolved(self): print "Resolved %s as %s" % (self.host, self.proxy_sock_name()) diff --git a/src/network/socks5.py b/src/network/socks5.py index f2bc83e4..e57e7c6a 100644 --- a/src/network/socks5.py +++ b/src/network/socks5.py @@ -32,28 +32,26 @@ class Socks5(Proxy): def state_init(self): if self._auth: - self.writeQueue.put(struct.pack('BBBB', 0x05, 0x02, 0x00, 0x02)) + self.append_write_buf(struct.pack('BBBB', 0x05, 0x02, 0x00, 0x02)) else: - self.writeQueue.put(struct.pack('BBB', 0x05, 0x01, 0x00)) - self.set_state("auth_1", 0) + self.append_write_buf(struct.pack('BBB', 0x05, 0x01, 0x00)) + self.set_state("auth_1", length=0, expectBytes=2) + return True def state_auth_1(self): - if not self.read_buf_sufficient(2): - return False ret = struct.unpack('BB', self.read_buf) - self.read_buf = self.read_buf[2:] if ret[0] != 5: # general error raise GeneralProxyError(1) elif ret[1] == 0: # no auth required - self.set_state("auth_done", 2) + self.set_state("auth_done", length=2) elif ret[1] == 2: # username/password - self.writeQueue.put(struct.pack('BB', 1, len(self._auth[0])) + \ + self.append_write_buf(struct.pack('BB', 1, len(self._auth[0])) + \ self._auth[0] + struct.pack('B', len(self._auth[1])) + \ self._auth[1]) - self.set_state("auth_needed", 2) + self.set_state("auth_needed", length=2, expectBytes=2) else: if ret[1] == 0xff: # auth error @@ -61,11 +59,10 @@ class Socks5(Proxy): else: # other error raise GeneralProxyError(1) + return True def state_auth_needed(self): - if not self.read_buf_sufficient(2): - return False - ret = struct.unpack('BB', self.read_buf) + ret = struct.unpack('BB', self.read_buf[0:2]) if ret[0] != 1: # general error raise GeneralProxyError(1) @@ -73,11 +70,10 @@ class Socks5(Proxy): # auth error raise Socks5AuthError(3) # all ok - self.set_state = ("auth_done", 2) + self.set_state("auth_done", length=2) + return True def state_pre_connect(self): - if not self.read_buf_sufficient(4): - return False # Get the response if self.read_buf[0:1] != chr(0x05).encode(): self.close() @@ -91,41 +87,38 @@ class Socks5(Proxy): raise Socks5Error(9) # Get the bound address/port elif self.read_buf[3:4] == chr(0x01).encode(): - self.set_state("proxy_addr_1", 4) + self.set_state("proxy_addr_1", length=4, expectBytes=4) elif self.read_buf[3:4] == chr(0x03).encode(): - self.set_state("proxy_addr_2_1", 4) + self.set_state("proxy_addr_2_1", length=4, expectBytes=1) else: self.close() raise GeneralProxyError(1) + return True def state_proxy_addr_1(self): - if not self.read_buf_sufficient(4): - return False self.boundaddr = self.read_buf[0:4] - self.set_state("proxy_port", 4) + self.set_state("proxy_port", length=4, expectBytes=2) + return True def state_proxy_addr_2_1(self): - if not self.read_buf_sufficient(1): - return False self.address_length = ord(self.read_buf[0:1]) - self.set_state("proxy_addr_2_2", 1) + self.set_state("proxy_addr_2_2", length=1, expectBytes=self.address_length) + return True def state_proxy_addr_2_2(self): - if not self.read_buf_sufficient(self.address_length): - return False - self.boundaddr = self.read_buf - self.set_state("proxy_port", self.address_length) + self.boundaddr = self.read_buf[0:self.address_length] + self.set_state("proxy_port", length=self.address_length, expectBytes=2) + return True def state_proxy_port(self): - if not self.read_buf_sufficient(2): - return False self.boundport = struct.unpack(">H", self.read_buf[0:2])[0] self.__proxysockname = (self.boundaddr, self.boundport) if self.ipaddr is not None: self.__proxypeername = (socket.inet_ntoa(self.ipaddr), self.destination[1]) else: self.__proxypeername = (self.destination[0], self.destport) - self.set_state("proxy_handshake_done", 2) + self.set_state("proxy_handshake_done", length=2) + return True def proxy_sock_name(self): return socket.inet_ntoa(self.__proxysockname[0]) @@ -137,28 +130,29 @@ class Socks5Connection(Socks5): def state_auth_done(self): # Now we can request the actual connection - self.writeQueue.put(struct.pack('BBB', 0x05, 0x01, 0x00)) + self.append_write_buf(struct.pack('BBB', 0x05, 0x01, 0x00)) # If the given destination address is an IP address, we'll # use the IPv4 address request even if remote resolving was specified. try: self.ipaddr = socket.inet_aton(self.destination[0]) - self.writeQueue.put(chr(0x01).encode() + self.ipaddr) + self.append_write_buf(chr(0x01).encode() + self.ipaddr) except socket.error: # Well it's not an IP number, so it's probably a DNS name. if Proxy._remote_dns: # Resolve remotely self.ipaddr = None - self.writeQueue.put(chr(0x03).encode() + chr(len(self.destination[0])).encode() + self.destination[0]) + self.append_write_buf(chr(0x03).encode() + chr(len(self.destination[0])).encode() + self.destination[0]) else: # Resolve locally self.ipaddr = socket.inet_aton(socket.gethostbyname(self.destination[0])) - self.writeQueue.put(chr(0x01).encode() + self.ipaddr) - self.writeQueue.put(struct.pack(">H", self.destination[1])) - self.set_state("pre_connect", 0) + self.append_write_buf(chr(0x01).encode() + self.ipaddr) + self.append_write_buf(struct.pack(">H", self.destination[1])) + self.set_state("pre_connect", length=0, expectBytes=4) + return True def state_pre_connect(self): try: - Socks5.state_pre_connect(self) + return Socks5.state_pre_connect(self) except Socks5Error as e: self.handle_close(e.message) @@ -171,10 +165,11 @@ class Socks5Resolver(Socks5): def state_auth_done(self): # Now we can request the actual connection - self.writeQueue.put(struct.pack('BBB', 0x05, 0xF0, 0x00)) - self.writeQueue.put(chr(0x03).encode() + chr(len(self.host)).encode() + str(self.host)) - self.writeQueue.put(struct.pack(">H", self.port)) - self.set_state("pre_connect", 0) + self.append_write_buf(struct.pack('BBB', 0x05, 0xF0, 0x00)) + self.append_write_buf(chr(0x03).encode() + chr(len(self.host)).encode() + str(self.host)) + self.append_write_buf(struct.pack(">H", self.port)) + self.set_state("pre_connect", length=0, expectBytes=4) + return True def resolved(self): print "Resolved %s as %s" % (self.host, self.proxy_sock_name()) diff --git a/src/network/tcp.py b/src/network/tcp.py index c2052df1..8ef3d1e1 100644 --- a/src/network/tcp.py +++ b/src/network/tcp.py @@ -29,7 +29,7 @@ from network.tls import TLSDispatcher import addresses from bmconfigparser import BMConfigParser -from queues import invQueue, objectProcessorQueue, portCheckerQueue, UISignalQueue +from queues import invQueue, objectProcessorQueue, portCheckerQueue, UISignalQueue, receiveDataQueue import shared import state import protocol @@ -91,6 +91,12 @@ class TCPConnection(BMProto, TLSDispatcher): logger.debug("Skipping processing getdata due to missing object for %.2fs", self.skipUntil - time.time()) self.skipUntil = time.time() + delay + def state_connection_fully_established(self): + self.set_connection_fully_established() + self.set_state("bm_header") + self.bm_proto_reset() + return True + def set_connection_fully_established(self): if not self.isOutbound and not self.local: shared.clientHasReceivedIncomingConnections = True @@ -144,10 +150,37 @@ class TCPConnection(BMProto, TLSDispatcher): for peer, params in addrs[substream]: templist.append((substream, peer, params["lastseen"])) if len(templist) > 0: - self.writeQueue.put(BMProto.assembleAddr(templist)) + self.append_write_buf(BMProto.assembleAddr(templist)) def sendBigInv(self): - self.receiveQueue.put(("biginv", None)) + def sendChunk(): + if objectCount == 0: + return + logger.debug('Sending huge inv message with %i objects to just this one peer', objectCount) + self.append_write_buf(protocol.CreatePacket('inv', addresses.encodeVarint(objectCount) + payload)) + + # Select all hashes for objects in this stream. + bigInvList = {} + for stream in self.streams: + # may lock for a long time, but I think it's better than thousands of small locks + with self.objectsNewToThemLock: + for objHash in Inventory().unexpired_hashes_by_stream(stream): + bigInvList[objHash] = 0 + self.objectsNewToThem[objHash] = time.time() + objectCount = 0 + payload = b'' + # Now let us start appending all of these hashes together. They will be + # sent out in a big inv message to our new peer. + for hash, storedValue in bigInvList.items(): + payload += hash + objectCount += 1 + if objectCount >= BMProto.maxObjectCount: + self.sendChunk() + payload = b'' + objectCount = 0 + + # flush + sendChunk() def handle_connect(self): try: @@ -156,9 +189,10 @@ class TCPConnection(BMProto, TLSDispatcher): if e.errno in asyncore._DISCONNECTED: logger.debug("%s:%i: Connection failed: %s" % (self.destination.host, self.destination.port, str(e))) return - self.writeQueue.put(protocol.assembleVersionMessage(self.destination.host, self.destination.port, network.connectionpool.BMConnectionPool().streams, False)) + self.append_write_buf(protocol.assembleVersionMessage(self.destination.host, self.destination.port, network.connectionpool.BMConnectionPool().streams, False)) #print "%s:%i: Sending version" % (self.destination.host, self.destination.port) self.connectedAt = time.time() + receiveDataQueue.put(self) def handle_read(self): TLSDispatcher.handle_read(self) @@ -169,6 +203,7 @@ class TCPConnection(BMProto, TLSDispatcher): knownnodes.knownNodes[s][self.destination]["lastseen"] = time.time() except KeyError: pass + receiveDataQueue.put(self) def handle_write(self): TLSDispatcher.handle_write(self) @@ -187,10 +222,10 @@ class Socks5BMConnection(Socks5Connection, TCPConnection): def state_proxy_handshake_done(self): Socks5Connection.state_proxy_handshake_done(self) - self.writeQueue.put(protocol.assembleVersionMessage(self.destination.host, self.destination.port, \ + self.append_write_buf(protocol.assembleVersionMessage(self.destination.host, self.destination.port, \ network.connectionpool.BMConnectionPool().streams, False)) self.set_state("bm_header", expectBytes=protocol.Header.size) - return False + return True class Socks4aBMConnection(Socks4aConnection, TCPConnection): @@ -201,10 +236,10 @@ class Socks4aBMConnection(Socks4aConnection, TCPConnection): def state_proxy_handshake_done(self): Socks4aConnection.state_proxy_handshake_done(self) - self.writeQueue.put(protocol.assembleVersionMessage(self.destination.host, self.destination.port, \ + self.append_write_buf(protocol.assembleVersionMessage(self.destination.host, self.destination.port, \ network.connectionpool.BMConnectionPool().streams, False)) self.set_state("bm_header", expectBytes=protocol.Header.size) - return False + return True class TCPServer(AdvancedDispatcher): diff --git a/src/network/tls.py b/src/network/tls.py index 9dafaab2..69fc2c20 100644 --- a/src/network/tls.py +++ b/src/network/tls.py @@ -10,6 +10,7 @@ import sys from debug import logger from network.advanceddispatcher import AdvancedDispatcher import network.asyncore_pollchoose as asyncore +from queues import receiveDataQueue import paths import protocol @@ -58,14 +59,17 @@ class TLSDispatcher(AdvancedDispatcher): do_handshake_on_connect=False) self.sslSocket.setblocking(0) self.want_read = self.want_write = True - self.set_state("bm_header") + self.set_state("tls_handshake") + return False # if hasattr(self.socket, "context"): # self.socket.context.set_ecdh_curve("secp256k1") + def state_tls_handshake(self): + return False + def writable(self): try: - if self.tlsStarted and not self.tlsDone and not self.write_buf and self.writeQueue.empty(): - #print "tls writable, %r" % (self.want_write) + if self.tlsStarted and not self.tlsDone and not self.write_buf: return self.want_write return AdvancedDispatcher.writable(self) except AttributeError: @@ -73,9 +77,13 @@ class TLSDispatcher(AdvancedDispatcher): def readable(self): try: - if self.tlsStarted and not self.tlsDone and not self.write_buf and self.writeQueue.empty(): + # during TLS handshake, and after flushing write buffer, return status of last handshake attempt + if self.tlsStarted and not self.tlsDone and not self.write_buf: #print "tls readable, %r" % (self.want_read) return self.want_read + # prior to TLS handshake, receiveDataThread should emulate synchronous behaviour + elif not self.fullyEstablished and (self.expectBytes == 0 or not self.write_buf_empty()): + return False return AdvancedDispatcher.readable(self) except AttributeError: return AdvancedDispatcher.readable(self) @@ -83,11 +91,11 @@ class TLSDispatcher(AdvancedDispatcher): def handle_read(self): try: # wait for write buffer flush - if self.tlsStarted and not self.tlsDone and not self.write_buf and self.writeQueue.empty(): - #print "handshaking (read)" + if self.tlsStarted and not self.tlsDone and not self.write_buf: + #logger.debug("%s:%i TLS handshaking (read)", self.destination.host, self.destination.port) self.tls_handshake() else: - #print "not handshaking (read)" + #logger.debug("%s:%i Not TLS handshaking (read)", self.destination.host, self.destination.port) return AdvancedDispatcher.handle_read(self) except AttributeError: return AdvancedDispatcher.handle_read(self) @@ -104,14 +112,14 @@ class TLSDispatcher(AdvancedDispatcher): def handle_write(self): try: # wait for write buffer flush - if self.tlsStarted and not self.tlsDone and not self.write_buf and self.writeQueue.empty(): - #print "handshaking (write)" + if self.tlsStarted and not self.tlsDone and not self.write_buf: + #logger.debug("%s:%i TLS handshaking (write)", self.destination.host, self.destination.port) self.tls_handshake() else: - #print "not handshaking (write)" + #logger.debug("%s:%i Not TLS handshaking (write)", self.destination.host, self.destination.port) return AdvancedDispatcher.handle_write(self) except AttributeError: - return AdvancedDispatcher.handle_read(self) + return AdvancedDispatcher.handle_write(self) except ssl.SSLError as err: if err.errno == ssl.SSL_ERROR_WANT_WRITE: return 0 @@ -158,6 +166,8 @@ class TLSDispatcher(AdvancedDispatcher): self.del_channel() self.set_socket(self.sslSocket) self.tlsDone = True - self.set_state("bm_header") - self.set_connection_fully_established() + + self.bm_proto_reset() + self.set_state("connection_fully_established") + receiveDataQueue.put(self) return False diff --git a/src/network/udp.py b/src/network/udp.py index 6770e5a0..824c9bfa 100644 --- a/src/network/udp.py +++ b/src/network/udp.py @@ -9,7 +9,7 @@ from network.bmobject import BMObject, BMObjectInsufficientPOWError, BMObjectInv import network.asyncore_pollchoose as asyncore from network.objectracker import ObjectTracker -from queues import objectProcessorQueue, peerDiscoveryQueue, UISignalQueue +from queues import objectProcessorQueue, peerDiscoveryQueue, UISignalQueue, receiveDataQueue import state import protocol @@ -80,7 +80,7 @@ class UDPSocket(BMProto): addresses = self._decode_addr() # only allow peer discovery from private IPs in order to avoid attacks from random IPs on the internet if not self.local: - return + return True remoteport = False for i in addresses: seenTime, stream, services, ip, port = i @@ -93,7 +93,7 @@ class UDPSocket(BMProto): # if the address isn't local, interpret it as the hosts' own announcement remoteport = port if remoteport is False: - return + return True logger.debug("received peer discovery from %s:%i (port %i):", self.destination.host, self.destination.port, remoteport) if self.local: peerDiscoveryQueue.put(state.Peer(self.destination.host, remoteport)) @@ -118,7 +118,7 @@ class UDPSocket(BMProto): return def writable(self): - return not self.writeQueue.empty() + return self.write_buf def readable(self): return len(self.read_buf) < AdvancedDispatcher._buf_len @@ -139,18 +139,14 @@ class UDPSocket(BMProto): # overwrite the old buffer to avoid mixing data and so that self.local works correctly self.read_buf = recdata self.bm_proto_reset() - self.process() + receiveDataQueue.put(self) def handle_write(self): try: - data = self.writeQueue.get(False) - except Queue.Empty: - return - try: - retval = self.socket.sendto(data, ('', UDPSocket.port)) + retval = self.socket.sendto(self.write_buf, ('', UDPSocket.port)) except socket.error as e: logger.error("socket error on sendato: %s", str(e)) - self.writeQueue.task_done() + self.slice_write_buf(retval) if __name__ == "__main__": diff --git a/src/queues.py b/src/queues.py index 223c7c3b..f768c59f 100644 --- a/src/queues.py +++ b/src/queues.py @@ -12,5 +12,6 @@ invQueue = MultiQueue() addrQueue = MultiQueue() portCheckerQueue = Queue.Queue() peerDiscoveryQueue = Queue.Queue() +receiveDataQueue = Queue.Queue() apiAddressGeneratorReturnQueue = Queue.Queue( ) # The address generator thread uses this queue to get information back to the API thread.