asyncore fixes

- bm headers and commands are only read up to expected length.
On a very fast connection (e.g. local VM), reading verack
also read a part of the TLS handshake
- some debugging info moved from print to logger.debug
- tls handshake cleanup
This commit is contained in:
Peter Šurda 2017-05-27 22:30:30 +02:00
parent 21f6d38ec2
commit 5d4e1e2007
Signed by untrusted user: PeterSurda
GPG Key ID: 0C5F50C0B5F37D87
4 changed files with 31 additions and 23 deletions

View File

@ -18,6 +18,7 @@ class AdvancedDispatcher(asyncore.dispatcher):
self.lastTx = time.time() self.lastTx = time.time()
self.sentBytes = 0 self.sentBytes = 0
self.receivedBytes = 0 self.receivedBytes = 0
self.expectBytes = 0
def slice_write_buf(self, length=0): def slice_write_buf(self, length=0):
if length > 0: if length > 0:
@ -49,24 +50,30 @@ class AdvancedDispatcher(asyncore.dispatcher):
raise raise
maxLoop -= 1 maxLoop -= 1
def set_state(self, state, length=0): def set_state(self, state, length=0, expectBytes=0):
self.expectBytes = expectBytes
self.slice_read_buf(length) self.slice_read_buf(length)
self.state = state self.state = state
def writable(self): def writable(self):
return self.connected and (len(self.write_buf) > 0 or not self.writeQueue.empty()) return self.connecting or len(self.write_buf) > 0 or not self.writeQueue.empty()
def readable(self): def readable(self):
return self.connecting or len(self.read_buf) < AdvancedDispatcher._buf_len return self.connecting or len(self.read_buf) < AdvancedDispatcher._buf_len
def handle_read(self): def handle_read(self):
self.lastTx = time.time() self.lastTx = time.time()
downloadBytes = AdvancedDispatcher._buf_len
if asyncore.maxDownloadRate > 0:
downloadBytes = asyncore.downloadChunk
if self.expectBytes > 0 and downloadBytes > self.expectBytes:
downloadBytes = self.expectBytes
newData = self.recv(downloadBytes)
if asyncore.maxDownloadRate > 0: if asyncore.maxDownloadRate > 0:
newData = self.recv(asyncore.downloadChunk)
asyncore.downloadBucket -= len(newData) asyncore.downloadBucket -= len(newData)
else:
newData = self.recv(AdvancedDispatcher._buf_len)
self.receivedBytes += len(newData) self.receivedBytes += len(newData)
if self.expectBytes > 0:
self.expectBytes -= len(newData)
asyncore.updateReceived(len(newData)) asyncore.updateReceived(len(newData))
self.read_buf += newData self.read_buf += newData
self.process() self.process()

View File

@ -91,6 +91,7 @@ class BMProto(AdvancedDispatcher, ObjectTracker):
self.payload = None self.payload = None
self.invalid = False self.invalid = False
self.payloadOffset = 0 self.payloadOffset = 0
self.expectBytes = protocol.Header.size
self.object = None self.object = None
def state_bm_header(self): def state_bm_header(self):
@ -109,7 +110,7 @@ class BMProto(AdvancedDispatcher, ObjectTracker):
return False return False
if self.payloadLength > BMProto.maxMessageSize: if self.payloadLength > BMProto.maxMessageSize:
self.invalid = True self.invalid = True
self.set_state("bm_command", protocol.Header.size) self.set_state("bm_command", protocol.Header.size, expectBytes=self.payloadLength)
return True return True
def state_bm_command(self): def state_bm_command(self):
@ -376,13 +377,13 @@ class BMProto(AdvancedDispatcher, ObjectTracker):
self.remoteProtocolVersion, self.services, self.timestamp, self.sockNode, self.peerNode, self.nonce, self.userAgent, self.streams = self.decode_payload_content("IQQiiQlslv") self.remoteProtocolVersion, self.services, self.timestamp, self.sockNode, self.peerNode, self.nonce, self.userAgent, self.streams = self.decode_payload_content("IQQiiQlslv")
self.nonce = struct.pack('>Q', self.nonce) self.nonce = struct.pack('>Q', self.nonce)
self.timeOffset = self.timestamp - int(time.time()) self.timeOffset = self.timestamp - int(time.time())
#print "remoteProtocolVersion: %i" % (self.remoteProtocolVersion) logger.debug("remoteProtocolVersion: %i", self.remoteProtocolVersion)
#print "services: %08X" % (self.services) logger.debug("services: %08X", self.services)
#print "time offset: %i" % (self.timestamp - int(time.time())) logger.debug("time offset: %i", self.timestamp - int(time.time()))
#print "my external IP: %s" % (self.sockNode.host) logger.debug("my external IP: %s", self.sockNode.host)
#print "remote node incoming port: %i" % (self.peerNode.port) logger.debug("remote node incoming port: %i", self.peerNode.port)
#print "user agent: %s" % (self.userAgent) logger.debug("user agent: %s", self.userAgent)
#print "streams: [%s]" % (",".join(map(str,self.streams))) logger.debug("streams: [%s]", ",".join(map(str,self.streams)))
if not self.peerValidityChecks(): if not self.peerValidityChecks():
# TODO ABORT # TODO ABORT
return True return True

View File

@ -48,7 +48,7 @@ class TCPConnection(BMProto, TLSDispatcher):
self.isOutbound = False self.isOutbound = False
TLSDispatcher.__init__(self, sock, server_side=True) TLSDispatcher.__init__(self, sock, server_side=True)
self.connectedAt = time.time() self.connectedAt = time.time()
#print "received connection in background from %s:%i" % (self.destination.host, self.destination.port) logger.debug("Received connection from %s:%i", self.destination.host, self.destination.port)
else: else:
self.destination = address self.destination = address
self.isOutbound = True self.isOutbound = True
@ -59,7 +59,7 @@ class TCPConnection(BMProto, TLSDispatcher):
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
TLSDispatcher.__init__(self, sock, server_side=False) TLSDispatcher.__init__(self, sock, server_side=False)
self.connect(self.destination) self.connect(self.destination)
#print "connecting in background to %s:%i" % (self.destination.host, self.destination.port) logger.debug("Connecting to %s:%i", self.destination.host, self.destination.port)
shared.connectedHostsList[self.destination] = 0 shared.connectedHostsList[self.destination] = 0
ObjectTracker.__init__(self) ObjectTracker.__init__(self)
UISignalQueue.put(('updateNetworkStatusTab', 'no data')) UISignalQueue.put(('updateNetworkStatusTab', 'no data'))
@ -152,14 +152,14 @@ class TCPConnection(BMProto, TLSDispatcher):
def handle_read(self): def handle_read(self):
try: try:
AdvancedDispatcher.handle_read(self) TLSDispatcher.handle_read(self)
except socket.error as e: except socket.error as e:
#print "%s:%i: socket error: %s" % (self.destination.host, self.destination.port, str(e)) #print "%s:%i: socket error: %s" % (self.destination.host, self.destination.port, str(e))
self.close() self.close()
def handle_write(self): def handle_write(self):
try: try:
AdvancedDispatcher.handle_write(self) TLSDispatcher.handle_write(self)
except socket.error as e: except socket.error as e:
#print "%s:%i: socket error: %s" % (self.destination.host, self.destination.port, str(e)) #print "%s:%i: socket error: %s" % (self.destination.host, self.destination.port, str(e))
self.close() self.close()

View File

@ -54,7 +54,7 @@ class TLSDispatcher(AdvancedDispatcher):
do_handshake_on_connect=False) do_handshake_on_connect=False)
self.sslSocket.setblocking(0) self.sslSocket.setblocking(0)
self.want_read = self.want_write = True self.want_read = self.want_write = True
self.set_state("tls_handshake") self.set_state("bm_header")
# if hasattr(self.socket, "context"): # if hasattr(self.socket, "context"):
# self.socket.context.set_ecdh_curve("secp256k1") # self.socket.context.set_ecdh_curve("secp256k1")
@ -83,7 +83,7 @@ class TLSDispatcher(AdvancedDispatcher):
# wait for write buffer flush # wait for write buffer flush
if self.tlsStarted and not self.tlsDone and len(self.write_buf) == 0 and self.writeQueue.empty(): if self.tlsStarted and not self.tlsDone and len(self.write_buf) == 0 and self.writeQueue.empty():
#print "handshaking (read)" #print "handshaking (read)"
self.state_tls_handshake() self.tls_handshake()
else: else:
#print "not handshaking (read)" #print "not handshaking (read)"
return AdvancedDispatcher.handle_read(self) return AdvancedDispatcher.handle_read(self)
@ -95,23 +95,23 @@ class TLSDispatcher(AdvancedDispatcher):
# wait for write buffer flush # wait for write buffer flush
if self.tlsStarted and not self.tlsDone and len(self.write_buf) == 0 and self.writeQueue.empty(): if self.tlsStarted and not self.tlsDone and len(self.write_buf) == 0 and self.writeQueue.empty():
#print "handshaking (write)" #print "handshaking (write)"
self.state_tls_handshake() self.tls_handshake()
else: else:
#print "not handshaking (write)" #print "not handshaking (write)"
return AdvancedDispatcher.handle_write(self) return AdvancedDispatcher.handle_write(self)
except AttributeError: except AttributeError:
return AdvancedDispatcher.handle_read(self) return AdvancedDispatcher.handle_read(self)
def state_tls_handshake(self): def tls_handshake(self):
# wait for flush # wait for flush
if len(self.write_buf) > 0: if len(self.write_buf) > 0:
return False return False
# Perform the handshake. # Perform the handshake.
try: try:
#print "handshaking (internal)" print "handshaking (internal)"
self.sslSocket.do_handshake() self.sslSocket.do_handshake()
except ssl.SSLError, err: except ssl.SSLError, err:
#print "%s:%i: handshake fail" % (self.destination.host, self.destination.port) print "%s:%i: handshake fail" % (self.destination.host, self.destination.port)
self.want_read = self.want_write = False self.want_read = self.want_write = False
if err.args[0] == ssl.SSL_ERROR_WANT_READ: if err.args[0] == ssl.SSL_ERROR_WANT_READ:
#print "want read" #print "want read"