Asyncore update

- get rid of per-connection writeQueue/receiveQueue, and instead use
strings and locking
- minor code cleanup
- all state handlers now should set expectBytes
- almost all data processing happens in ReceiveDataThread, and
AsyncoreThread is almost only I/O (plus TLS). AsyncoreThread simply
puts the connection object into the queue when it has some data for
processing
- allow poll, epoll and kqueue handlers. kqueue is untested and
unoptimised, poll and epoll seem to work ok (linux)
- stack depth threshold handler  in decode_payload_content, this is
recursive and I think was causing occasional RuntimeErrors. Fixes #964
- longer asyncore loops, as now data is handled in ReceiveDataThread
- randomise node order when deciding what to download. Should prevent
retries being stuck to the same node
- socks cleanup (socks5 works ok, socks4a untested but should work too)
This commit is contained in:
Peter Šurda 2017-07-06 19:45:36 +02:00
parent a2b8867c1a
commit ba4162d7fe
Signed by untrusted user: PeterSurda
GPG Key ID: 0C5F50C0B5F37D87
14 changed files with 257 additions and 274 deletions

View File

@ -1,6 +1,7 @@
import Queue import Queue
import socket import socket
import sys import sys
import threading
import time import time
import asyncore_pollchoose as asyncore import asyncore_pollchoose as asyncore
@ -14,41 +15,43 @@ class AdvancedDispatcher(asyncore.dispatcher):
asyncore.dispatcher.__init__(self, sock) asyncore.dispatcher.__init__(self, sock)
self.read_buf = b"" self.read_buf = b""
self.write_buf = b"" self.write_buf = b""
self.writeQueue = Queue.Queue()
self.receiveQueue = Queue.Queue()
self.state = "init" self.state = "init"
self.lastTx = time.time() self.lastTx = time.time()
self.sentBytes = 0 self.sentBytes = 0
self.receivedBytes = 0 self.receivedBytes = 0
self.expectBytes = 0 self.expectBytes = 0
self.readLock = threading.RLock()
self.writeLock = threading.RLock()
def append_write_buf(self, data):
if data:
with self.writeLock:
self.write_buf += data
def slice_write_buf(self, length=0): def slice_write_buf(self, length=0):
if length > 0: if length > 0:
with self.writeLock:
self.write_buf = self.write_buf[length:] self.write_buf = self.write_buf[length:]
def slice_read_buf(self, length=0): def slice_read_buf(self, length=0):
if length > 0: if length > 0:
with self.readLock:
self.read_buf = self.read_buf[length:] self.read_buf = self.read_buf[length:]
def read_buf_sufficient(self, length=0):
if len(self.read_buf) < length:
return False
return True
def process(self): def process(self):
if self.state != "tls_handshake" and not self.read_buf:
return
if not self.connected: if not self.connected:
return return
maxLoop = 20 loop = 0
while maxLoop > 0: while len(self.read_buf) >= self.expectBytes:
loop += 1
if loop > 1000:
logger.error("Stuck at state %s, report this bug please", self.state)
break
try: try:
if getattr(self, "state_" + str(self.state))() is False: if getattr(self, "state_" + str(self.state))() is False:
break break
except AttributeError: except AttributeError:
# missing state
raise raise
maxLoop -= 1
def set_state(self, state, length=0, expectBytes=0): def set_state(self, state, length=0, expectBytes=0):
self.expectBytes = expectBytes self.expectBytes = expectBytes
@ -57,7 +60,7 @@ class AdvancedDispatcher(asyncore.dispatcher):
def writable(self): def writable(self):
return asyncore.dispatcher.writable(self) and \ return asyncore.dispatcher.writable(self) and \
(self.connecting or len(self.write_buf) > 0 or not self.writeQueue.empty()) (self.connecting or self.write_buf)
def readable(self): def readable(self):
return asyncore.dispatcher.readable(self) and \ return asyncore.dispatcher.readable(self) and \
@ -68,28 +71,20 @@ class AdvancedDispatcher(asyncore.dispatcher):
downloadBytes = AdvancedDispatcher._buf_len downloadBytes = AdvancedDispatcher._buf_len
if asyncore.maxDownloadRate > 0: if asyncore.maxDownloadRate > 0:
downloadBytes = asyncore.downloadBucket downloadBytes = asyncore.downloadBucket
if self.expectBytes > 0 and downloadBytes > self.expectBytes: if self.expectBytes > 0 and downloadBytes > self.expectBytes - len(self.read_buf):
downloadBytes = self.expectBytes downloadBytes = self.expectBytes - len(self.read_buf)
if downloadBytes > 0: if downloadBytes > 0:
newData = self.recv(downloadBytes) newData = self.recv(downloadBytes)
self.receivedBytes += len(newData) self.receivedBytes += len(newData)
if self.expectBytes > 0:
self.expectBytes -= len(newData)
asyncore.update_received(len(newData)) asyncore.update_received(len(newData))
with self.readLock:
self.read_buf += newData self.read_buf += newData
self.process()
def handle_write(self): def handle_write(self):
self.lastTx = time.time() self.lastTx = time.time()
bufSize = AdvancedDispatcher._buf_len bufSize = AdvancedDispatcher._buf_len
if asyncore.maxUploadRate > 0: if asyncore.maxUploadRate > 0:
bufSize = asyncore.uploadBucket bufSize = asyncore.uploadBucket
while len(self.write_buf) < bufSize:
try:
self.write_buf += self.writeQueue.get(False)
self.writeQueue.task_done()
except Queue.Empty:
break
if bufSize <= 0: if bufSize <= 0:
return return
if self.write_buf: if self.write_buf:
@ -107,25 +102,12 @@ class AdvancedDispatcher(asyncore.dispatcher):
def handle_connect(self): def handle_connect(self):
self.lastTx = time.time() self.lastTx = time.time()
self.process()
def state_close(self): def state_close(self):
pass return False
def handle_close(self): def handle_close(self):
self.read_buf = b"" self.read_buf = b""
self.write_buf = b"" self.write_buf = b""
self.state = "close" self.state = "close"
while True:
try:
self.writeQueue.get(False)
self.writeQueue.task_done()
except Queue.Empty:
break
while True:
try:
self.receiveQueue.get(False)
self.receiveQueue.task_done()
except Queue.Empty:
break
asyncore.dispatcher.close(self) asyncore.dispatcher.close(self)

View File

@ -30,4 +30,4 @@ class AnnounceThread(threading.Thread, StoppableThread):
for connection in BMConnectionPool().udpSockets.values(): for connection in BMConnectionPool().udpSockets.values():
for stream in state.streamsInWhichIAmParticipating: for stream in state.streamsInWhichIAmParticipating:
addr = (stream, state.Peer('127.0.0.1', BMConfigParser().safeGetInt("bitmessagesettings", "port")), time.time()) addr = (stream, state.Peer('127.0.0.1', BMConfigParser().safeGetInt("bitmessagesettings", "port")), time.time())
connection.writeQueue.put(BMProto.assembleAddr([addr])) connection.append_write_buf(BMProto.assembleAddr([addr]))

View File

@ -249,13 +249,16 @@ def poll_poller(timeout=0.0, map=None):
newflags |= OP_WRITE newflags |= OP_WRITE
else: else:
newflags &= ~ OP_WRITE newflags &= ~ OP_WRITE
if newflags != obj.flags: if newflags != obj.poller_flags:
obj.flags = newflags obj.poller_flags = newflags
try:
if obj.poller_registered: if obj.poller_registered:
poll_poller.pollster.modify(fd, flags) poll_poller.pollster.modify(fd, flags)
else: else:
poll_poller.pollster.register(fd, flags) poll_poller.pollster.register(fd, flags)
obj.poller_registered = True obj.poller_registered = True
except IOError:
pass
try: try:
r = poll_poller.pollster.poll(timeout) r = poll_poller.pollster.poll(timeout)
except KeyboardInterrupt: except KeyboardInterrupt:
@ -292,16 +295,19 @@ def epoll_poller(timeout=0.0, map=None):
newflags |= OP_WRITE newflags |= OP_WRITE
else: else:
newflags &= ~ OP_WRITE newflags &= ~ OP_WRITE
if newflags != obj.flags: if newflags != obj.poller_flags:
obj.flags = newflags obj.poller_flags = newflags
# Only check for exceptions if object was either readable # Only check for exceptions if object was either readable
# or writable. # or writable.
flags |= select.POLLERR | select.POLLHUP | select.POLLNVAL flags |= select.POLLERR | select.POLLHUP | select.POLLNVAL
try:
if obj.poller_registered: if obj.poller_registered:
epoll_poller.pollster.modify(fd, flags) epoll_poller.pollster.modify(fd, flags)
else: else:
epoll_poller.pollster.register(fd, flags) epoll_poller.pollster.register(fd, flags)
obj.poller_registered = True obj.poller_registered = True
except IOError:
pass
try: try:
r = epoll_poller.pollster.poll(timeout) r = epoll_poller.pollster.poll(timeout)
except select.error, err: except select.error, err:
@ -329,9 +335,12 @@ def kqueue_poller(timeout=0.0, map=None):
if obj.writable(): if obj.writable():
filter |= select.KQ_FILTER_WRITE filter |= select.KQ_FILTER_WRITE
if filter: if filter:
try:
ev = select.kevent(fd, filter=filter, flags=flags) ev = select.kevent(fd, filter=filter, flags=flags)
kqueue.control([ev], 0) kqueue.control([ev], 0)
selectables += 1 selectables += 1
except IOError:
pass
events = kqueue.control(None, selectables, timeout) events = kqueue.control(None, selectables, timeout)
for event in random.sample(events, len(events)): for event in random.sample(events, len(events)):
@ -347,13 +356,15 @@ def kqueue_poller(timeout=0.0, map=None):
def loop(timeout=30.0, use_poll=False, map=None, count=None, def loop(timeout=30.0, use_poll=False, map=None, count=None,
poller=select_poller): poller=None):
if map is None: if map is None:
map = socket_map map = socket_map
# code which grants backward compatibility with "use_poll" # code which grants backward compatibility with "use_poll"
# argument which should no longer be used in favor of # argument which should no longer be used in favor of
# "poller" # "poller"
if poller is None:
if hasattr(select, 'epoll'): if hasattr(select, 'epoll'):
poller = epoll_poller poller = epoll_poller
elif hasattr(select, 'kqueue'): elif hasattr(select, 'kqueue'):
@ -363,10 +374,6 @@ def loop(timeout=30.0, use_poll=False, map=None, count=None,
elif hasattr(select, 'select'): elif hasattr(select, 'select'):
poller = select_poller poller = select_poller
poller = select_poller
# print "Poll loop using %s" % (poller.__name__)
if count is None: if count is None:
while map: while map:
# fill buckets first # fill buckets first
@ -400,7 +407,7 @@ class dispatcher:
addr = None addr = None
ignore_log_types = frozenset(['warning']) ignore_log_types = frozenset(['warning'])
poller_registered = False poller_registered = False
flags = 0 poller_flags = 0
# don't do network IO with a smaller bucket than this # don't do network IO with a smaller bucket than this
minTx = 1500 minTx = 1500
@ -456,23 +463,26 @@ class dispatcher:
if map is None: if map is None:
map = self._map map = self._map
map[self._fileno] = self map[self._fileno] = self
self.poller_flags = 0
def del_channel(self, map=None): def del_channel(self, map=None):
fd = self._fileno fd = self._fileno
if map is None: if map is None:
map = self._map map = self._map
self.poller_flags = 0
self.poller_registered = False
if fd in map: if fd in map:
#self.log_info('closing channel %d:%s' % (fd, self)) #self.log_info('closing channel %d:%s' % (fd, self))
del map[fd] del map[fd]
self._fileno = None self._fileno = None
try: try:
epoll_poller.pollster.unregister(fd) epoll_poller.pollster.unregister(fd)
except (AttributeError, KeyError, TypeError): except (AttributeError, KeyError, TypeError, IOError):
# no epoll used, or not registered # no epoll used, or not registered
pass pass
try: try:
poll_poller.pollster.unregister(fd) poll_poller.pollster.unregister(fd)
except (AttributeError, KeyError, TypeError): except (AttributeError, KeyError, TypeError, IOError):
# no poll used, or not registered # no poll used, or not registered
pass pass

View File

@ -5,6 +5,7 @@ import math
import time import time
import socket import socket
import struct import struct
import sys
from addresses import calculateInventoryHash from addresses import calculateInventoryHash
from bmconfigparser import BMConfigParser from bmconfigparser import BMConfigParser
@ -67,15 +68,12 @@ class BMProto(AdvancedDispatcher, ObjectTracker):
self.object = None self.object = None
def state_bm_header(self): def state_bm_header(self):
if len(self.read_buf) < protocol.Header.size:
#print "Length below header size"
return False
self.magic, self.command, self.payloadLength, self.checksum = protocol.Header.unpack(self.read_buf[:protocol.Header.size]) self.magic, self.command, self.payloadLength, self.checksum = protocol.Header.unpack(self.read_buf[:protocol.Header.size])
self.command = self.command.rstrip('\x00') self.command = self.command.rstrip('\x00')
if self.magic != 0xE9BEB4D9: if self.magic != 0xE9BEB4D9:
# skip 1 byte in order to sync # skip 1 byte in order to sync
self.set_state("bm_header", length=1)
self.bm_proto_reset() self.bm_proto_reset()
self.set_state("bm_header", length=1, expectBytes=protocol.Header.size)
logger.debug("Bad magic") logger.debug("Bad magic")
self.handle_close("Bad magic") self.handle_close("Bad magic")
return False return False
@ -85,10 +83,6 @@ class BMProto(AdvancedDispatcher, ObjectTracker):
return True return True
def state_bm_command(self): def state_bm_command(self):
if len(self.read_buf) < self.payloadLength:
#print "Length below announced object length"
return False
#logger.debug("%s:%i: command %s (%ib)", self.destination.host, self.destination.port, self.command, self.payloadLength)
self.payload = self.read_buf[:self.payloadLength] self.payload = self.read_buf[:self.payloadLength]
if self.checksum != hashlib.sha512(self.payload).digest()[0:4]: if self.checksum != hashlib.sha512(self.payload).digest()[0:4]:
logger.debug("Bad checksum, ignoring") logger.debug("Bad checksum, ignoring")
@ -127,7 +121,7 @@ class BMProto(AdvancedDispatcher, ObjectTracker):
self.handle_close("Invalid command %s" % (self.command)) self.handle_close("Invalid command %s" % (self.command))
return False return False
if retval: if retval:
self.set_state("bm_header", length=self.payloadLength, expectBytes=protocol.Header.size) self.set_state("bm_header", length=self.payloadLength)
self.bm_proto_reset() self.bm_proto_reset()
# else assume the command requires a different state to follow # else assume the command requires a different state to follow
return True return True
@ -173,6 +167,12 @@ class BMProto(AdvancedDispatcher, ObjectTracker):
retval = [] retval = []
size = None size = None
i = 0 i = 0
try:
sys._getframe(200)
logger.error("Stack depth warning, pattern: %s", pattern)
return
except ValueError:
pass
while i < len(pattern): while i < len(pattern):
if pattern[i] in "0123456789" and (i == 0 or pattern[i-1] not in "lL"): if pattern[i] in "0123456789" and (i == 0 or pattern[i-1] not in "lL"):
@ -237,8 +237,13 @@ class BMProto(AdvancedDispatcher, ObjectTracker):
# skip? # skip?
if time.time() < self.skipUntil: if time.time() < self.skipUntil:
return True return True
#TODO make this more asynchronous and allow reordering
for i in items: for i in items:
self.receiveQueue.put(("object", i)) try:
self.append_write_buf(protocol.CreatePacket('object', Inventory()[i].payload))
except KeyError:
self.antiIntersectionDelay()
logger.info('%s asked for an object we don\'t have.', self.destination)
return True return True
def bm_command_inv(self): def bm_command_inv(self):
@ -251,7 +256,7 @@ class BMProto(AdvancedDispatcher, ObjectTracker):
pass pass
for i in items: for i in items:
self.receiveQueue.put(("inv", i)) self.handleReceivedInventory(i)
return True return True
@ -321,7 +326,7 @@ class BMProto(AdvancedDispatcher, ObjectTracker):
return True return True
def bm_command_ping(self): def bm_command_ping(self):
self.writeQueue.put(protocol.CreatePacket('pong')) self.append_write_buf(protocol.CreatePacket('pong'))
return True return True
def bm_command_pong(self): def bm_command_pong(self):
@ -332,11 +337,10 @@ class BMProto(AdvancedDispatcher, ObjectTracker):
self.verackReceived = True self.verackReceived = True
if self.verackSent: if self.verackSent:
if self.isSSL: if self.isSSL:
self.set_state("tls_init", self.payloadLength) self.set_state("tls_init", length=self.payloadLength, expectBytes=0)
self.bm_proto_reset() return False
self.set_state("connection_fully_established", length=self.payloadLength, expectBytes=0)
return False return False
self.set_connection_fully_established()
return True
return True return True
def bm_command_version(self): def bm_command_version(self):
@ -345,20 +349,20 @@ class BMProto(AdvancedDispatcher, ObjectTracker):
self.nonce = struct.pack('>Q', self.nonce) self.nonce = struct.pack('>Q', self.nonce)
self.timeOffset = self.timestamp - int(time.time()) self.timeOffset = self.timestamp - int(time.time())
logger.debug("remoteProtocolVersion: %i", self.remoteProtocolVersion) logger.debug("remoteProtocolVersion: %i", self.remoteProtocolVersion)
logger.debug("services: %08X", self.services) logger.debug("services: 0x%08X", self.services)
logger.debug("time offset: %i", self.timestamp - int(time.time())) logger.debug("time offset: %i", self.timestamp - int(time.time()))
logger.debug("my external IP: %s", self.sockNode.host) logger.debug("my external IP: %s", self.sockNode.host)
logger.debug("remote node incoming port: %i", self.peerNode.port) logger.debug("remote node incoming address: %s:%i", self.destination.host, self.peerNode.port)
logger.debug("user agent: %s", self.userAgent) logger.debug("user agent: %s", self.userAgent)
logger.debug("streams: [%s]", ",".join(map(str,self.streams))) logger.debug("streams: [%s]", ",".join(map(str,self.streams)))
if not self.peerValidityChecks(): if not self.peerValidityChecks():
# TODO ABORT # TODO ABORT
return True return True
#shared.connectedHostsList[self.destination] = self.streams[0] #shared.connectedHostsList[self.destination] = self.streams[0]
self.writeQueue.put(protocol.CreatePacket('verack')) self.append_write_buf(protocol.CreatePacket('verack'))
self.verackSent = True self.verackSent = True
if not self.isOutbound: if not self.isOutbound:
self.writeQueue.put(protocol.assembleVersionMessage(self.destination.host, self.destination.port, \ self.append_write_buf(protocol.assembleVersionMessage(self.destination.host, self.destination.port, \
network.connectionpool.BMConnectionPool().streams, True)) network.connectionpool.BMConnectionPool().streams, True))
#print "%s:%i: Sending version" % (self.destination.host, self.destination.port) #print "%s:%i: Sending version" % (self.destination.host, self.destination.port)
if ((self.services & protocol.NODE_SSL == protocol.NODE_SSL) and if ((self.services & protocol.NODE_SSL == protocol.NODE_SSL) and
@ -366,29 +370,28 @@ class BMProto(AdvancedDispatcher, ObjectTracker):
self.isSSL = True self.isSSL = True
if self.verackReceived: if self.verackReceived:
if self.isSSL: if self.isSSL:
self.set_state("tls_init", self.payloadLength) self.set_state("tls_init", length=self.payloadLength, expectBytes=0)
self.bm_proto_reset() return False
self.set_state("connection_fully_established", length=self.payloadLength, expectBytes=0)
return False return False
self.set_connection_fully_established()
return True
return True return True
def peerValidityChecks(self): def peerValidityChecks(self):
if self.remoteProtocolVersion < 3: if self.remoteProtocolVersion < 3:
self.writeQueue.put(protocol.assembleErrorMessage(fatal=2, self.append_write_buf(protocol.assembleErrorMessage(fatal=2,
errorText="Your is using an old protocol. Closing connection.")) errorText="Your is using an old protocol. Closing connection."))
logger.debug ('Closing connection to old protocol version %s, node: %s', logger.debug ('Closing connection to old protocol version %s, node: %s',
str(self.remoteProtocolVersion), str(self.destination)) str(self.remoteProtocolVersion), str(self.destination))
return False return False
if self.timeOffset > BMProto.maxTimeOffset: if self.timeOffset > BMProto.maxTimeOffset:
self.writeQueue.put(protocol.assembleErrorMessage(fatal=2, self.append_write_buf(protocol.assembleErrorMessage(fatal=2,
errorText="Your time is too far in the future compared to mine. Closing connection.")) errorText="Your time is too far in the future compared to mine. Closing connection."))
logger.info("%s's time is too far in the future (%s seconds). Closing connection to it.", logger.info("%s's time is too far in the future (%s seconds). Closing connection to it.",
self.destination, self.timeOffset) self.destination, self.timeOffset)
shared.timeOffsetWrongCount += 1 shared.timeOffsetWrongCount += 1
return False return False
elif self.timeOffset < -BMProto.maxTimeOffset: elif self.timeOffset < -BMProto.maxTimeOffset:
self.writeQueue.put(protocol.assembleErrorMessage(fatal=2, self.append_write_buf(protocol.assembleErrorMessage(fatal=2,
errorText="Your time is too far in the past compared to mine. Closing connection.")) errorText="Your time is too far in the past compared to mine. Closing connection."))
logger.info("%s's time is too far in the past (timeOffset %s seconds). Closing connection to it.", logger.info("%s's time is too far in the past (timeOffset %s seconds). Closing connection to it.",
self.destination, self.timeOffset) self.destination, self.timeOffset)
@ -397,7 +400,7 @@ class BMProto(AdvancedDispatcher, ObjectTracker):
else: else:
shared.timeOffsetWrongCount = 0 shared.timeOffsetWrongCount = 0
if not self.streams: if not self.streams:
self.writeQueue.put(protocol.assembleErrorMessage(fatal=2, self.append_write_buf(protocol.assembleErrorMessage(fatal=2,
errorText="We don't have shared stream interests. Closing connection.")) errorText="We don't have shared stream interests. Closing connection."))
logger.debug ('Closed connection to %s because there is no overlapping interest in streams.', logger.debug ('Closed connection to %s because there is no overlapping interest in streams.',
str(self.destination)) str(self.destination))
@ -405,7 +408,7 @@ class BMProto(AdvancedDispatcher, ObjectTracker):
if self.destination in network.connectionpool.BMConnectionPool().inboundConnections: if self.destination in network.connectionpool.BMConnectionPool().inboundConnections:
try: try:
if not protocol.checkSocksIP(self.destination.host): if not protocol.checkSocksIP(self.destination.host):
self.writeQueue.put(protocol.assembleErrorMessage(fatal=2, self.append_write_buf(protocol.assembleErrorMessage(fatal=2,
errorText="Too many connections from your IP. Closing connection.")) errorText="Too many connections from your IP. Closing connection."))
logger.debug ('Closed connection to %s because we are already connected to that IP.', logger.debug ('Closed connection to %s because we are already connected to that IP.',
str(self.destination)) str(self.destination))
@ -413,7 +416,7 @@ class BMProto(AdvancedDispatcher, ObjectTracker):
except: except:
pass pass
if self.nonce == protocol.eightBytesOfRandomDataUsedToDetectConnectionsToSelf: if self.nonce == protocol.eightBytesOfRandomDataUsedToDetectConnectionsToSelf:
self.writeQueue.put(protocol.assembleErrorMessage(fatal=2, self.append_write_buf(protocol.assembleErrorMessage(fatal=2,
errorText="I'm connected to myself. Closing connection.")) errorText="I'm connected to myself. Closing connection."))
logger.debug ("Closed connection to %s because I'm connected to myself.", logger.debug ("Closed connection to %s because I'm connected to myself.",
str(self.destination)) str(self.destination))

View File

@ -19,7 +19,6 @@ import state
@Singleton @Singleton
class BMConnectionPool(object): class BMConnectionPool(object):
def __init__(self): def __init__(self):
asyncore.set_rates( asyncore.set_rates(
BMConfigParser().safeGetInt("bitmessagesettings", "maxdownloadrate") * 1024, BMConfigParser().safeGetInt("bitmessagesettings", "maxdownloadrate") * 1024,
@ -30,7 +29,7 @@ class BMConnectionPool(object):
self.udpSockets = {} self.udpSockets = {}
self.streams = [] self.streams = []
self.lastSpawned = 0 self.lastSpawned = 0
self.spawnWait = 0.3 self.spawnWait = 2
self.bootstrapped = False self.bootstrapped = False
def handleReceivedObject(self, streamNumber, hashid, connection = None): def handleReceivedObject(self, streamNumber, hashid, connection = None):
@ -187,12 +186,10 @@ class BMConnectionPool(object):
i.handle_close() i.handle_close()
logger.info('Stopped udp sockets.') logger.info('Stopped udp sockets.')
# while len(asyncore.socket_map) > 0 and state.shutdown == 0:
# print "loop, state = %s" % (proxy.state)
loopTime = float(self.spawnWait) loopTime = float(self.spawnWait)
if self.lastSpawned < time.time() - self.spawnWait: if self.lastSpawned < time.time() - self.spawnWait:
loopTime = 1.0 loopTime = 2.0
asyncore.loop(timeout=loopTime, count=10) asyncore.loop(timeout=loopTime, count=1000)
reaper = [] reaper = []
for i in self.inboundConnections.values() + self.outboundConnections.values(): for i in self.inboundConnections.values() + self.outboundConnections.values():
@ -201,7 +198,7 @@ class BMConnectionPool(object):
minTx -= 300 - 20 minTx -= 300 - 20
if i.lastTx < minTx: if i.lastTx < minTx:
if i.fullyEstablished: if i.fullyEstablished:
i.writeQueue.put(protocol.CreatePacket('ping')) i.append_write_buf(protocol.CreatePacket('ping'))
else: else:
i.handle_close("Timeout (%is)" % (time.time() - i.lastTx)) i.handle_close("Timeout (%is)" % (time.time() - i.lastTx))
for i in self.inboundConnections.values() + self.outboundConnections.values() + self.listeningSockets.values() + self.udpSockets.values(): for i in self.inboundConnections.values() + self.outboundConnections.values() + self.listeningSockets.values() + self.udpSockets.values():

View File

@ -1,3 +1,4 @@
import random
import threading import threading
import time import time
@ -32,7 +33,10 @@ class DownloadThread(threading.Thread, StoppableThread):
def run(self): def run(self):
while not self._stopped: while not self._stopped:
requested = 0 requested = 0
for i in BMConnectionPool().inboundConnections.values() + BMConnectionPool().outboundConnections.values(): # Choose downloading peers randomly
connections = BMConnectionPool().inboundConnections.values() + BMConnectionPool().outboundConnections.values()
random.shuffle(connections)
for i in connections:
now = time.time() now = time.time()
timedOut = now - DownloadThread.requestTimeout timedOut = now - DownloadThread.requestTimeout
# this may take a while, but it needs a consistency so I think it's better to lock a bigger chunk # this may take a while, but it needs a consistency so I think it's better to lock a bigger chunk
@ -52,7 +56,7 @@ class DownloadThread(threading.Thread, StoppableThread):
self.pending[k] = now self.pending[k] = now
payload = addresses.encodeVarint(len(request)) + ''.join(request) payload = addresses.encodeVarint(len(request)) + ''.join(request)
i.writeQueue.put(protocol.CreatePacket('getdata', payload)) i.append_write_buf(protocol.CreatePacket('getdata', payload))
logger.debug("%s:%i Requesting %i objects", i.destination.host, i.destination.port, len(request)) logger.debug("%s:%i Requesting %i objects", i.destination.host, i.destination.port, len(request))
requested += len(request) requested += len(request)
if time.time() >= self.lastCleaned + DownloadThread.cleanInterval: if time.time() >= self.lastCleaned + DownloadThread.cleanInterval:

View File

@ -42,7 +42,7 @@ class InvThread(threading.Thread, StoppableThread):
except KeyError: except KeyError:
continue continue
if hashes: if hashes:
connection.writeQueue.put(protocol.CreatePacket('inv', \ connection.append_write_buf(protocol.CreatePacket('inv', \
addresses.encodeVarint(len(hashes)) + "".join(hashes))) addresses.encodeVarint(len(hashes)) + "".join(hashes)))
invQueue.iterate() invQueue.iterate()
self.stop.wait(1) self.stop.wait(1)

View File

@ -10,6 +10,7 @@ from helper_threading import StoppableThread
from inventory import Inventory from inventory import Inventory
from network.connectionpool import BMConnectionPool from network.connectionpool import BMConnectionPool
from network.bmproto import BMProto from network.bmproto import BMProto
from queues import receiveDataQueue
import protocol import protocol
import state import state
@ -21,73 +22,23 @@ class ReceiveQueueThread(threading.Thread, StoppableThread):
logger.info("init receive queue thread") logger.info("init receive queue thread")
def run(self): def run(self):
lastprinted = int(time.time())
while not self._stopped and state.shutdown == 0: while not self._stopped and state.shutdown == 0:
if lastprinted < int(time.time()):
lastprinted = int(time.time())
# try:
# sys._getframe(200)
# logger.error("Stack depth warning")
# except ValueError:
# pass
processed = 0
for i in BMConnectionPool().inboundConnections.values() + BMConnectionPool().outboundConnections.values():
if self._stopped:
break
try: try:
command, args = i.receiveQueue.get(False) connection = receiveDataQueue.get(block=True, timeout=1)
receiveDataQueue.task_done()
except Queue.Empty: except Queue.Empty:
continue continue
processed += 1
if self._stopped:
break
# cycle as long as there is data
# methods should return False if there isn't enough data, or the connection is to be aborted
try: try:
getattr(self, "command_" + str(command))(i, args) while connection.process():
i.receiveQueue.task_done() pass
except AttributeError: except AttributeError:
i.receiveQueue.task_done()
# missing command # missing command
raise logger.error("Unknown state %s, ignoring", connection.state)
if processed == 0:
self.stop.wait(2)
def command_object(self, connection, objHash):
try:
connection.writeQueue.put(protocol.CreatePacket('object', Inventory()[objHash].payload))
except KeyError:
connection.antiIntersectionDelay()
logger.info('%s asked for an object we don\'t have.', connection.destination)
def command_biginv(self, connection, dummy):
def sendChunk():
if objectCount == 0:
return
logger.debug('Sending huge inv message with %i objects to just this one peer', objectCount)
connection.writeQueue.put(protocol.CreatePacket('inv', addresses.encodeVarint(objectCount) + payload))
# Select all hashes for objects in this stream.
bigInvList = {}
for stream in connection.streams:
# may lock for a long time, but I think it's better than thousands of small locks
with connection.objectsNewToThemLock:
for objHash in Inventory().unexpired_hashes_by_stream(stream):
bigInvList[objHash] = 0
connection.objectsNewToThem[objHash] = time.time()
objectCount = 0
payload = b''
# Now let us start appending all of these hashes together. They will be
# sent out in a big inv message to our new peer.
for hash, storedValue in bigInvList.items():
payload += hash
objectCount += 1
if objectCount >= BMProto.maxObjectCount:
self.sendChunk()
payload = b''
objectCount = 0
# flush
sendChunk()
def command_inv(self, connection, hashId):
connection.handleReceivedInventory(hashId)
def stopThread(self): def stopThread(self):
super(ReceiveQueueThread, self).stopThread() super(ReceiveQueueThread, self).stopThread()

View File

@ -19,10 +19,9 @@ class Socks4a(Proxy):
def state_init(self): def state_init(self):
self.set_state("auth_done", 0) self.set_state("auth_done", 0)
return True
def state_pre_connect(self): def state_pre_connect(self):
if not self.read_buf_sufficient(8):
return False
# Get the response # Get the response
if self.read_buf[0:1] != chr(0x00).encode(): if self.read_buf[0:1] != chr(0x00).encode():
# bad data # bad data
@ -44,14 +43,12 @@ class Socks4a(Proxy):
self.__proxypeername = (socket.inet_ntoa(self.ipaddr), self.destination[1]) self.__proxypeername = (socket.inet_ntoa(self.ipaddr), self.destination[1])
else: else:
self.__proxypeername = (self.destination[0], self.destport) self.__proxypeername = (self.destination[0], self.destport)
self.set_state("proxy_handshake_done", 8) self.set_state("proxy_handshake_done", length=8)
return True
def proxy_sock_name(self): def proxy_sock_name(self):
return socket.inet_ntoa(self.__proxysockname[0]) return socket.inet_ntoa(self.__proxysockname[0])
def state_socks_handshake_done(self):
return False
class Socks4aConnection(Socks4a): class Socks4aConnection(Socks4a):
def __init__(self, address): def __init__(self, address):
@ -60,33 +57,34 @@ class Socks4aConnection(Socks4a):
def state_auth_done(self): def state_auth_done(self):
# Now we can request the actual connection # Now we can request the actual connection
rmtrslv = False rmtrslv = False
self.writeQueue.put(struct.pack('>BBH', 0x04, 0x01, self.destination[1])) self.append_write_buf(struct.pack('>BBH', 0x04, 0x01, self.destination[1]))
# If the given destination address is an IP address, we'll # If the given destination address is an IP address, we'll
# use the IPv4 address request even if remote resolving was specified. # use the IPv4 address request even if remote resolving was specified.
try: try:
self.ipaddr = socket.inet_aton(self.destination[0]) self.ipaddr = socket.inet_aton(self.destination[0])
self.writeQueue.put(self.ipaddr) self.append_write_buf(self.ipaddr)
except socket.error: except socket.error:
# Well it's not an IP number, so it's probably a DNS name. # Well it's not an IP number, so it's probably a DNS name.
if Proxy._remote_dns: if Proxy._remote_dns:
# Resolve remotely # Resolve remotely
rmtrslv = True rmtrslv = True
self.ipaddr = None self.ipaddr = None
self.writeQueue.put(struct.pack("BBBB", 0x00, 0x00, 0x00, 0x01)) self.append_write_buf(struct.pack("BBBB", 0x00, 0x00, 0x00, 0x01))
else: else:
# Resolve locally # Resolve locally
self.ipaddr = socket.inet_aton(socket.gethostbyname(self.destination[0])) self.ipaddr = socket.inet_aton(socket.gethostbyname(self.destination[0]))
self.writeQueue.put(self.ipaddr) self.append_write_buf(self.ipaddr)
if self._auth: if self._auth:
self.writeQueue.put(self._auth[0]) self.append_write_buf(self._auth[0])
self.writeQueue.put(chr(0x00).encode()) self.append_write_buf(chr(0x00).encode())
if rmtrslv: if rmtrslv:
self.writeQueue.put(self.destination[0] + chr(0x00).encode()) self.append_write_buf(self.destination[0] + chr(0x00).encode())
self.set_state("pre_connect", 0) self.set_state("pre_connect", length=0, expectBytes=8)
return True
def state_pre_connect(self): def state_pre_connect(self):
try: try:
Socks4a.state_pre_connect(self) return Socks4a.state_pre_connect(self)
except Socks4aError as e: except Socks4aError as e:
self.handle_close(e.message) self.handle_close(e.message)
@ -99,13 +97,14 @@ class Socks4aResolver(Socks4a):
def state_auth_done(self): def state_auth_done(self):
# Now we can request the actual connection # Now we can request the actual connection
self.writeQueue.put(struct.pack('>BBH', 0x04, 0xF0, self.destination[1])) self.append_write_buf(struct.pack('>BBH', 0x04, 0xF0, self.destination[1]))
self.writeQueue.put(struct.pack("BBBB", 0x00, 0x00, 0x00, 0x01)) self.append_write_buf(struct.pack("BBBB", 0x00, 0x00, 0x00, 0x01))
if self._auth: if self._auth:
self.writeQueue.put(self._auth[0]) self.append_write_buf(self._auth[0])
self.writeQueue.put(chr(0x00).encode()) self.append_write_buf(chr(0x00).encode())
self.writeQueue.put(self.host + chr(0x00).encode()) self.append_write_buf(self.host + chr(0x00).encode())
self.set_state("pre_connect", 0) self.set_state("pre_connect", length=0, expectBytes=8)
return True
def resolved(self): def resolved(self):
print "Resolved %s as %s" % (self.host, self.proxy_sock_name()) print "Resolved %s as %s" % (self.host, self.proxy_sock_name())

View File

@ -32,28 +32,26 @@ class Socks5(Proxy):
def state_init(self): def state_init(self):
if self._auth: if self._auth:
self.writeQueue.put(struct.pack('BBBB', 0x05, 0x02, 0x00, 0x02)) self.append_write_buf(struct.pack('BBBB', 0x05, 0x02, 0x00, 0x02))
else: else:
self.writeQueue.put(struct.pack('BBB', 0x05, 0x01, 0x00)) self.append_write_buf(struct.pack('BBB', 0x05, 0x01, 0x00))
self.set_state("auth_1", 0) self.set_state("auth_1", length=0, expectBytes=2)
return True
def state_auth_1(self): def state_auth_1(self):
if not self.read_buf_sufficient(2):
return False
ret = struct.unpack('BB', self.read_buf) ret = struct.unpack('BB', self.read_buf)
self.read_buf = self.read_buf[2:]
if ret[0] != 5: if ret[0] != 5:
# general error # general error
raise GeneralProxyError(1) raise GeneralProxyError(1)
elif ret[1] == 0: elif ret[1] == 0:
# no auth required # no auth required
self.set_state("auth_done", 2) self.set_state("auth_done", length=2)
elif ret[1] == 2: elif ret[1] == 2:
# username/password # username/password
self.writeQueue.put(struct.pack('BB', 1, len(self._auth[0])) + \ self.append_write_buf(struct.pack('BB', 1, len(self._auth[0])) + \
self._auth[0] + struct.pack('B', len(self._auth[1])) + \ self._auth[0] + struct.pack('B', len(self._auth[1])) + \
self._auth[1]) self._auth[1])
self.set_state("auth_needed", 2) self.set_state("auth_needed", length=2, expectBytes=2)
else: else:
if ret[1] == 0xff: if ret[1] == 0xff:
# auth error # auth error
@ -61,11 +59,10 @@ class Socks5(Proxy):
else: else:
# other error # other error
raise GeneralProxyError(1) raise GeneralProxyError(1)
return True
def state_auth_needed(self): def state_auth_needed(self):
if not self.read_buf_sufficient(2): ret = struct.unpack('BB', self.read_buf[0:2])
return False
ret = struct.unpack('BB', self.read_buf)
if ret[0] != 1: if ret[0] != 1:
# general error # general error
raise GeneralProxyError(1) raise GeneralProxyError(1)
@ -73,11 +70,10 @@ class Socks5(Proxy):
# auth error # auth error
raise Socks5AuthError(3) raise Socks5AuthError(3)
# all ok # all ok
self.set_state = ("auth_done", 2) self.set_state("auth_done", length=2)
return True
def state_pre_connect(self): def state_pre_connect(self):
if not self.read_buf_sufficient(4):
return False
# Get the response # Get the response
if self.read_buf[0:1] != chr(0x05).encode(): if self.read_buf[0:1] != chr(0x05).encode():
self.close() self.close()
@ -91,41 +87,38 @@ class Socks5(Proxy):
raise Socks5Error(9) raise Socks5Error(9)
# Get the bound address/port # Get the bound address/port
elif self.read_buf[3:4] == chr(0x01).encode(): elif self.read_buf[3:4] == chr(0x01).encode():
self.set_state("proxy_addr_1", 4) self.set_state("proxy_addr_1", length=4, expectBytes=4)
elif self.read_buf[3:4] == chr(0x03).encode(): elif self.read_buf[3:4] == chr(0x03).encode():
self.set_state("proxy_addr_2_1", 4) self.set_state("proxy_addr_2_1", length=4, expectBytes=1)
else: else:
self.close() self.close()
raise GeneralProxyError(1) raise GeneralProxyError(1)
return True
def state_proxy_addr_1(self): def state_proxy_addr_1(self):
if not self.read_buf_sufficient(4):
return False
self.boundaddr = self.read_buf[0:4] self.boundaddr = self.read_buf[0:4]
self.set_state("proxy_port", 4) self.set_state("proxy_port", length=4, expectBytes=2)
return True
def state_proxy_addr_2_1(self): def state_proxy_addr_2_1(self):
if not self.read_buf_sufficient(1):
return False
self.address_length = ord(self.read_buf[0:1]) self.address_length = ord(self.read_buf[0:1])
self.set_state("proxy_addr_2_2", 1) self.set_state("proxy_addr_2_2", length=1, expectBytes=self.address_length)
return True
def state_proxy_addr_2_2(self): def state_proxy_addr_2_2(self):
if not self.read_buf_sufficient(self.address_length): self.boundaddr = self.read_buf[0:self.address_length]
return False self.set_state("proxy_port", length=self.address_length, expectBytes=2)
self.boundaddr = self.read_buf return True
self.set_state("proxy_port", self.address_length)
def state_proxy_port(self): def state_proxy_port(self):
if not self.read_buf_sufficient(2):
return False
self.boundport = struct.unpack(">H", self.read_buf[0:2])[0] self.boundport = struct.unpack(">H", self.read_buf[0:2])[0]
self.__proxysockname = (self.boundaddr, self.boundport) self.__proxysockname = (self.boundaddr, self.boundport)
if self.ipaddr is not None: if self.ipaddr is not None:
self.__proxypeername = (socket.inet_ntoa(self.ipaddr), self.destination[1]) self.__proxypeername = (socket.inet_ntoa(self.ipaddr), self.destination[1])
else: else:
self.__proxypeername = (self.destination[0], self.destport) self.__proxypeername = (self.destination[0], self.destport)
self.set_state("proxy_handshake_done", 2) self.set_state("proxy_handshake_done", length=2)
return True
def proxy_sock_name(self): def proxy_sock_name(self):
return socket.inet_ntoa(self.__proxysockname[0]) return socket.inet_ntoa(self.__proxysockname[0])
@ -137,28 +130,29 @@ class Socks5Connection(Socks5):
def state_auth_done(self): def state_auth_done(self):
# Now we can request the actual connection # Now we can request the actual connection
self.writeQueue.put(struct.pack('BBB', 0x05, 0x01, 0x00)) self.append_write_buf(struct.pack('BBB', 0x05, 0x01, 0x00))
# If the given destination address is an IP address, we'll # If the given destination address is an IP address, we'll
# use the IPv4 address request even if remote resolving was specified. # use the IPv4 address request even if remote resolving was specified.
try: try:
self.ipaddr = socket.inet_aton(self.destination[0]) self.ipaddr = socket.inet_aton(self.destination[0])
self.writeQueue.put(chr(0x01).encode() + self.ipaddr) self.append_write_buf(chr(0x01).encode() + self.ipaddr)
except socket.error: except socket.error:
# Well it's not an IP number, so it's probably a DNS name. # Well it's not an IP number, so it's probably a DNS name.
if Proxy._remote_dns: if Proxy._remote_dns:
# Resolve remotely # Resolve remotely
self.ipaddr = None self.ipaddr = None
self.writeQueue.put(chr(0x03).encode() + chr(len(self.destination[0])).encode() + self.destination[0]) self.append_write_buf(chr(0x03).encode() + chr(len(self.destination[0])).encode() + self.destination[0])
else: else:
# Resolve locally # Resolve locally
self.ipaddr = socket.inet_aton(socket.gethostbyname(self.destination[0])) self.ipaddr = socket.inet_aton(socket.gethostbyname(self.destination[0]))
self.writeQueue.put(chr(0x01).encode() + self.ipaddr) self.append_write_buf(chr(0x01).encode() + self.ipaddr)
self.writeQueue.put(struct.pack(">H", self.destination[1])) self.append_write_buf(struct.pack(">H", self.destination[1]))
self.set_state("pre_connect", 0) self.set_state("pre_connect", length=0, expectBytes=4)
return True
def state_pre_connect(self): def state_pre_connect(self):
try: try:
Socks5.state_pre_connect(self) return Socks5.state_pre_connect(self)
except Socks5Error as e: except Socks5Error as e:
self.handle_close(e.message) self.handle_close(e.message)
@ -171,10 +165,11 @@ class Socks5Resolver(Socks5):
def state_auth_done(self): def state_auth_done(self):
# Now we can request the actual connection # Now we can request the actual connection
self.writeQueue.put(struct.pack('BBB', 0x05, 0xF0, 0x00)) self.append_write_buf(struct.pack('BBB', 0x05, 0xF0, 0x00))
self.writeQueue.put(chr(0x03).encode() + chr(len(self.host)).encode() + str(self.host)) self.append_write_buf(chr(0x03).encode() + chr(len(self.host)).encode() + str(self.host))
self.writeQueue.put(struct.pack(">H", self.port)) self.append_write_buf(struct.pack(">H", self.port))
self.set_state("pre_connect", 0) self.set_state("pre_connect", length=0, expectBytes=4)
return True
def resolved(self): def resolved(self):
print "Resolved %s as %s" % (self.host, self.proxy_sock_name()) print "Resolved %s as %s" % (self.host, self.proxy_sock_name())

View File

@ -29,7 +29,7 @@ from network.tls import TLSDispatcher
import addresses import addresses
from bmconfigparser import BMConfigParser from bmconfigparser import BMConfigParser
from queues import invQueue, objectProcessorQueue, portCheckerQueue, UISignalQueue from queues import invQueue, objectProcessorQueue, portCheckerQueue, UISignalQueue, receiveDataQueue
import shared import shared
import state import state
import protocol import protocol
@ -91,6 +91,12 @@ class TCPConnection(BMProto, TLSDispatcher):
logger.debug("Skipping processing getdata due to missing object for %.2fs", self.skipUntil - time.time()) logger.debug("Skipping processing getdata due to missing object for %.2fs", self.skipUntil - time.time())
self.skipUntil = time.time() + delay self.skipUntil = time.time() + delay
def state_connection_fully_established(self):
self.set_connection_fully_established()
self.set_state("bm_header")
self.bm_proto_reset()
return True
def set_connection_fully_established(self): def set_connection_fully_established(self):
if not self.isOutbound and not self.local: if not self.isOutbound and not self.local:
shared.clientHasReceivedIncomingConnections = True shared.clientHasReceivedIncomingConnections = True
@ -144,10 +150,37 @@ class TCPConnection(BMProto, TLSDispatcher):
for peer, params in addrs[substream]: for peer, params in addrs[substream]:
templist.append((substream, peer, params["lastseen"])) templist.append((substream, peer, params["lastseen"]))
if len(templist) > 0: if len(templist) > 0:
self.writeQueue.put(BMProto.assembleAddr(templist)) self.append_write_buf(BMProto.assembleAddr(templist))
def sendBigInv(self): def sendBigInv(self):
self.receiveQueue.put(("biginv", None)) def sendChunk():
if objectCount == 0:
return
logger.debug('Sending huge inv message with %i objects to just this one peer', objectCount)
self.append_write_buf(protocol.CreatePacket('inv', addresses.encodeVarint(objectCount) + payload))
# Select all hashes for objects in this stream.
bigInvList = {}
for stream in self.streams:
# may lock for a long time, but I think it's better than thousands of small locks
with self.objectsNewToThemLock:
for objHash in Inventory().unexpired_hashes_by_stream(stream):
bigInvList[objHash] = 0
self.objectsNewToThem[objHash] = time.time()
objectCount = 0
payload = b''
# Now let us start appending all of these hashes together. They will be
# sent out in a big inv message to our new peer.
for hash, storedValue in bigInvList.items():
payload += hash
objectCount += 1
if objectCount >= BMProto.maxObjectCount:
self.sendChunk()
payload = b''
objectCount = 0
# flush
sendChunk()
def handle_connect(self): def handle_connect(self):
try: try:
@ -156,9 +189,10 @@ class TCPConnection(BMProto, TLSDispatcher):
if e.errno in asyncore._DISCONNECTED: if e.errno in asyncore._DISCONNECTED:
logger.debug("%s:%i: Connection failed: %s" % (self.destination.host, self.destination.port, str(e))) logger.debug("%s:%i: Connection failed: %s" % (self.destination.host, self.destination.port, str(e)))
return return
self.writeQueue.put(protocol.assembleVersionMessage(self.destination.host, self.destination.port, network.connectionpool.BMConnectionPool().streams, False)) self.append_write_buf(protocol.assembleVersionMessage(self.destination.host, self.destination.port, network.connectionpool.BMConnectionPool().streams, False))
#print "%s:%i: Sending version" % (self.destination.host, self.destination.port) #print "%s:%i: Sending version" % (self.destination.host, self.destination.port)
self.connectedAt = time.time() self.connectedAt = time.time()
receiveDataQueue.put(self)
def handle_read(self): def handle_read(self):
TLSDispatcher.handle_read(self) TLSDispatcher.handle_read(self)
@ -169,6 +203,7 @@ class TCPConnection(BMProto, TLSDispatcher):
knownnodes.knownNodes[s][self.destination]["lastseen"] = time.time() knownnodes.knownNodes[s][self.destination]["lastseen"] = time.time()
except KeyError: except KeyError:
pass pass
receiveDataQueue.put(self)
def handle_write(self): def handle_write(self):
TLSDispatcher.handle_write(self) TLSDispatcher.handle_write(self)
@ -187,10 +222,10 @@ class Socks5BMConnection(Socks5Connection, TCPConnection):
def state_proxy_handshake_done(self): def state_proxy_handshake_done(self):
Socks5Connection.state_proxy_handshake_done(self) Socks5Connection.state_proxy_handshake_done(self)
self.writeQueue.put(protocol.assembleVersionMessage(self.destination.host, self.destination.port, \ self.append_write_buf(protocol.assembleVersionMessage(self.destination.host, self.destination.port, \
network.connectionpool.BMConnectionPool().streams, False)) network.connectionpool.BMConnectionPool().streams, False))
self.set_state("bm_header", expectBytes=protocol.Header.size) self.set_state("bm_header", expectBytes=protocol.Header.size)
return False return True
class Socks4aBMConnection(Socks4aConnection, TCPConnection): class Socks4aBMConnection(Socks4aConnection, TCPConnection):
@ -201,10 +236,10 @@ class Socks4aBMConnection(Socks4aConnection, TCPConnection):
def state_proxy_handshake_done(self): def state_proxy_handshake_done(self):
Socks4aConnection.state_proxy_handshake_done(self) Socks4aConnection.state_proxy_handshake_done(self)
self.writeQueue.put(protocol.assembleVersionMessage(self.destination.host, self.destination.port, \ self.append_write_buf(protocol.assembleVersionMessage(self.destination.host, self.destination.port, \
network.connectionpool.BMConnectionPool().streams, False)) network.connectionpool.BMConnectionPool().streams, False))
self.set_state("bm_header", expectBytes=protocol.Header.size) self.set_state("bm_header", expectBytes=protocol.Header.size)
return False return True
class TCPServer(AdvancedDispatcher): class TCPServer(AdvancedDispatcher):

View File

@ -10,6 +10,7 @@ import sys
from debug import logger from debug import logger
from network.advanceddispatcher import AdvancedDispatcher from network.advanceddispatcher import AdvancedDispatcher
import network.asyncore_pollchoose as asyncore import network.asyncore_pollchoose as asyncore
from queues import receiveDataQueue
import paths import paths
import protocol import protocol
@ -58,14 +59,17 @@ class TLSDispatcher(AdvancedDispatcher):
do_handshake_on_connect=False) do_handshake_on_connect=False)
self.sslSocket.setblocking(0) self.sslSocket.setblocking(0)
self.want_read = self.want_write = True self.want_read = self.want_write = True
self.set_state("bm_header") self.set_state("tls_handshake")
return False
# if hasattr(self.socket, "context"): # if hasattr(self.socket, "context"):
# self.socket.context.set_ecdh_curve("secp256k1") # self.socket.context.set_ecdh_curve("secp256k1")
def state_tls_handshake(self):
return False
def writable(self): def writable(self):
try: try:
if self.tlsStarted and not self.tlsDone and not self.write_buf and self.writeQueue.empty(): if self.tlsStarted and not self.tlsDone and not self.write_buf:
#print "tls writable, %r" % (self.want_write)
return self.want_write return self.want_write
return AdvancedDispatcher.writable(self) return AdvancedDispatcher.writable(self)
except AttributeError: except AttributeError:
@ -73,9 +77,13 @@ class TLSDispatcher(AdvancedDispatcher):
def readable(self): def readable(self):
try: try:
if self.tlsStarted and not self.tlsDone and not self.write_buf and self.writeQueue.empty(): # during TLS handshake, and after flushing write buffer, return status of last handshake attempt
if self.tlsStarted and not self.tlsDone and not self.write_buf:
#print "tls readable, %r" % (self.want_read) #print "tls readable, %r" % (self.want_read)
return self.want_read return self.want_read
# prior to TLS handshake, receiveDataThread should emulate synchronous behaviour
elif not self.fullyEstablished and (self.expectBytes == 0 or not self.write_buf_empty()):
return False
return AdvancedDispatcher.readable(self) return AdvancedDispatcher.readable(self)
except AttributeError: except AttributeError:
return AdvancedDispatcher.readable(self) return AdvancedDispatcher.readable(self)
@ -83,11 +91,11 @@ class TLSDispatcher(AdvancedDispatcher):
def handle_read(self): def handle_read(self):
try: try:
# wait for write buffer flush # wait for write buffer flush
if self.tlsStarted and not self.tlsDone and not self.write_buf and self.writeQueue.empty(): if self.tlsStarted and not self.tlsDone and not self.write_buf:
#print "handshaking (read)" #logger.debug("%s:%i TLS handshaking (read)", self.destination.host, self.destination.port)
self.tls_handshake() self.tls_handshake()
else: else:
#print "not handshaking (read)" #logger.debug("%s:%i Not TLS handshaking (read)", self.destination.host, self.destination.port)
return AdvancedDispatcher.handle_read(self) return AdvancedDispatcher.handle_read(self)
except AttributeError: except AttributeError:
return AdvancedDispatcher.handle_read(self) return AdvancedDispatcher.handle_read(self)
@ -104,14 +112,14 @@ class TLSDispatcher(AdvancedDispatcher):
def handle_write(self): def handle_write(self):
try: try:
# wait for write buffer flush # wait for write buffer flush
if self.tlsStarted and not self.tlsDone and not self.write_buf and self.writeQueue.empty(): if self.tlsStarted and not self.tlsDone and not self.write_buf:
#print "handshaking (write)" #logger.debug("%s:%i TLS handshaking (write)", self.destination.host, self.destination.port)
self.tls_handshake() self.tls_handshake()
else: else:
#print "not handshaking (write)" #logger.debug("%s:%i Not TLS handshaking (write)", self.destination.host, self.destination.port)
return AdvancedDispatcher.handle_write(self) return AdvancedDispatcher.handle_write(self)
except AttributeError: except AttributeError:
return AdvancedDispatcher.handle_read(self) return AdvancedDispatcher.handle_write(self)
except ssl.SSLError as err: except ssl.SSLError as err:
if err.errno == ssl.SSL_ERROR_WANT_WRITE: if err.errno == ssl.SSL_ERROR_WANT_WRITE:
return 0 return 0
@ -158,6 +166,8 @@ class TLSDispatcher(AdvancedDispatcher):
self.del_channel() self.del_channel()
self.set_socket(self.sslSocket) self.set_socket(self.sslSocket)
self.tlsDone = True self.tlsDone = True
self.set_state("bm_header")
self.set_connection_fully_established() self.bm_proto_reset()
self.set_state("connection_fully_established")
receiveDataQueue.put(self)
return False return False

View File

@ -9,7 +9,7 @@ from network.bmobject import BMObject, BMObjectInsufficientPOWError, BMObjectInv
import network.asyncore_pollchoose as asyncore import network.asyncore_pollchoose as asyncore
from network.objectracker import ObjectTracker from network.objectracker import ObjectTracker
from queues import objectProcessorQueue, peerDiscoveryQueue, UISignalQueue from queues import objectProcessorQueue, peerDiscoveryQueue, UISignalQueue, receiveDataQueue
import state import state
import protocol import protocol
@ -80,7 +80,7 @@ class UDPSocket(BMProto):
addresses = self._decode_addr() addresses = self._decode_addr()
# only allow peer discovery from private IPs in order to avoid attacks from random IPs on the internet # only allow peer discovery from private IPs in order to avoid attacks from random IPs on the internet
if not self.local: if not self.local:
return return True
remoteport = False remoteport = False
for i in addresses: for i in addresses:
seenTime, stream, services, ip, port = i seenTime, stream, services, ip, port = i
@ -93,7 +93,7 @@ class UDPSocket(BMProto):
# if the address isn't local, interpret it as the hosts' own announcement # if the address isn't local, interpret it as the hosts' own announcement
remoteport = port remoteport = port
if remoteport is False: if remoteport is False:
return return True
logger.debug("received peer discovery from %s:%i (port %i):", self.destination.host, self.destination.port, remoteport) logger.debug("received peer discovery from %s:%i (port %i):", self.destination.host, self.destination.port, remoteport)
if self.local: if self.local:
peerDiscoveryQueue.put(state.Peer(self.destination.host, remoteport)) peerDiscoveryQueue.put(state.Peer(self.destination.host, remoteport))
@ -118,7 +118,7 @@ class UDPSocket(BMProto):
return return
def writable(self): def writable(self):
return not self.writeQueue.empty() return self.write_buf
def readable(self): def readable(self):
return len(self.read_buf) < AdvancedDispatcher._buf_len return len(self.read_buf) < AdvancedDispatcher._buf_len
@ -139,18 +139,14 @@ class UDPSocket(BMProto):
# overwrite the old buffer to avoid mixing data and so that self.local works correctly # overwrite the old buffer to avoid mixing data and so that self.local works correctly
self.read_buf = recdata self.read_buf = recdata
self.bm_proto_reset() self.bm_proto_reset()
self.process() receiveDataQueue.put(self)
def handle_write(self): def handle_write(self):
try: try:
data = self.writeQueue.get(False) retval = self.socket.sendto(self.write_buf, ('<broadcast>', UDPSocket.port))
except Queue.Empty:
return
try:
retval = self.socket.sendto(data, ('<broadcast>', UDPSocket.port))
except socket.error as e: except socket.error as e:
logger.error("socket error on sendato: %s", str(e)) logger.error("socket error on sendato: %s", str(e))
self.writeQueue.task_done() self.slice_write_buf(retval)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -12,5 +12,6 @@ invQueue = MultiQueue()
addrQueue = MultiQueue() addrQueue = MultiQueue()
portCheckerQueue = Queue.Queue() portCheckerQueue = Queue.Queue()
peerDiscoveryQueue = Queue.Queue() peerDiscoveryQueue = Queue.Queue()
receiveDataQueue = Queue.Queue()
apiAddressGeneratorReturnQueue = Queue.Queue( apiAddressGeneratorReturnQueue = Queue.Queue(
) # The address generator thread uses this queue to get information back to the API thread. ) # The address generator thread uses this queue to get information back to the API thread.