From 5d4e1e2007c7cb40abdd16ffd89a12873282ccfd Mon Sep 17 00:00:00 2001 From: Peter Surda Date: Sat, 27 May 2017 22:30:30 +0200 Subject: [PATCH] asyncore fixes - bm headers and commands are only read up to expected length. On a very fast connection (e.g. local VM), reading verack also read a part of the TLS handshake - some debugging info moved from print to logger.debug - tls handshake cleanup --- src/network/advanceddispatcher.py | 17 ++++++++++++----- src/network/bmproto.py | 17 +++++++++-------- src/network/tcp.py | 8 ++++---- src/network/tls.py | 12 ++++++------ 4 files changed, 31 insertions(+), 23 deletions(-) diff --git a/src/network/advanceddispatcher.py b/src/network/advanceddispatcher.py index 8c976b36..96b206cf 100644 --- a/src/network/advanceddispatcher.py +++ b/src/network/advanceddispatcher.py @@ -18,6 +18,7 @@ class AdvancedDispatcher(asyncore.dispatcher): self.lastTx = time.time() self.sentBytes = 0 self.receivedBytes = 0 + self.expectBytes = 0 def slice_write_buf(self, length=0): if length > 0: @@ -49,24 +50,30 @@ class AdvancedDispatcher(asyncore.dispatcher): raise maxLoop -= 1 - def set_state(self, state, length=0): + def set_state(self, state, length=0, expectBytes=0): + self.expectBytes = expectBytes self.slice_read_buf(length) self.state = state def writable(self): - return self.connected and (len(self.write_buf) > 0 or not self.writeQueue.empty()) + return self.connecting or len(self.write_buf) > 0 or not self.writeQueue.empty() def readable(self): return self.connecting or len(self.read_buf) < AdvancedDispatcher._buf_len def handle_read(self): self.lastTx = time.time() + downloadBytes = AdvancedDispatcher._buf_len + if asyncore.maxDownloadRate > 0: + downloadBytes = asyncore.downloadChunk + if self.expectBytes > 0 and downloadBytes > self.expectBytes: + downloadBytes = self.expectBytes + newData = self.recv(downloadBytes) if asyncore.maxDownloadRate > 0: - newData = self.recv(asyncore.downloadChunk) asyncore.downloadBucket -= len(newData) - else: - newData = self.recv(AdvancedDispatcher._buf_len) self.receivedBytes += len(newData) + if self.expectBytes > 0: + self.expectBytes -= len(newData) asyncore.updateReceived(len(newData)) self.read_buf += newData self.process() diff --git a/src/network/bmproto.py b/src/network/bmproto.py index c706b81a..d3f02568 100644 --- a/src/network/bmproto.py +++ b/src/network/bmproto.py @@ -91,6 +91,7 @@ class BMProto(AdvancedDispatcher, ObjectTracker): self.payload = None self.invalid = False self.payloadOffset = 0 + self.expectBytes = protocol.Header.size self.object = None def state_bm_header(self): @@ -109,7 +110,7 @@ class BMProto(AdvancedDispatcher, ObjectTracker): return False if self.payloadLength > BMProto.maxMessageSize: self.invalid = True - self.set_state("bm_command", protocol.Header.size) + self.set_state("bm_command", protocol.Header.size, expectBytes=self.payloadLength) return True def state_bm_command(self): @@ -376,13 +377,13 @@ class BMProto(AdvancedDispatcher, ObjectTracker): self.remoteProtocolVersion, self.services, self.timestamp, self.sockNode, self.peerNode, self.nonce, self.userAgent, self.streams = self.decode_payload_content("IQQiiQlslv") self.nonce = struct.pack('>Q', self.nonce) self.timeOffset = self.timestamp - int(time.time()) - #print "remoteProtocolVersion: %i" % (self.remoteProtocolVersion) - #print "services: %08X" % (self.services) - #print "time offset: %i" % (self.timestamp - int(time.time())) - #print "my external IP: %s" % (self.sockNode.host) - #print "remote node incoming port: %i" % (self.peerNode.port) - #print "user agent: %s" % (self.userAgent) - #print "streams: [%s]" % (",".join(map(str,self.streams))) + logger.debug("remoteProtocolVersion: %i", self.remoteProtocolVersion) + logger.debug("services: %08X", self.services) + logger.debug("time offset: %i", self.timestamp - int(time.time())) + logger.debug("my external IP: %s", self.sockNode.host) + logger.debug("remote node incoming port: %i", self.peerNode.port) + logger.debug("user agent: %s", self.userAgent) + logger.debug("streams: [%s]", ",".join(map(str,self.streams))) if not self.peerValidityChecks(): # TODO ABORT return True diff --git a/src/network/tcp.py b/src/network/tcp.py index 8f7e60d7..8c5fb968 100644 --- a/src/network/tcp.py +++ b/src/network/tcp.py @@ -48,7 +48,7 @@ class TCPConnection(BMProto, TLSDispatcher): self.isOutbound = False TLSDispatcher.__init__(self, sock, server_side=True) self.connectedAt = time.time() - #print "received connection in background from %s:%i" % (self.destination.host, self.destination.port) + logger.debug("Received connection from %s:%i", self.destination.host, self.destination.port) else: self.destination = address self.isOutbound = True @@ -59,7 +59,7 @@ class TCPConnection(BMProto, TLSDispatcher): self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) TLSDispatcher.__init__(self, sock, server_side=False) self.connect(self.destination) - #print "connecting in background to %s:%i" % (self.destination.host, self.destination.port) + logger.debug("Connecting to %s:%i", self.destination.host, self.destination.port) shared.connectedHostsList[self.destination] = 0 ObjectTracker.__init__(self) UISignalQueue.put(('updateNetworkStatusTab', 'no data')) @@ -152,14 +152,14 @@ class TCPConnection(BMProto, TLSDispatcher): def handle_read(self): try: - AdvancedDispatcher.handle_read(self) + TLSDispatcher.handle_read(self) except socket.error as e: #print "%s:%i: socket error: %s" % (self.destination.host, self.destination.port, str(e)) self.close() def handle_write(self): try: - AdvancedDispatcher.handle_write(self) + TLSDispatcher.handle_write(self) except socket.error as e: #print "%s:%i: socket error: %s" % (self.destination.host, self.destination.port, str(e)) self.close() diff --git a/src/network/tls.py b/src/network/tls.py index f79f0650..f813e3be 100644 --- a/src/network/tls.py +++ b/src/network/tls.py @@ -54,7 +54,7 @@ class TLSDispatcher(AdvancedDispatcher): do_handshake_on_connect=False) self.sslSocket.setblocking(0) self.want_read = self.want_write = True - self.set_state("tls_handshake") + self.set_state("bm_header") # if hasattr(self.socket, "context"): # self.socket.context.set_ecdh_curve("secp256k1") @@ -83,7 +83,7 @@ class TLSDispatcher(AdvancedDispatcher): # wait for write buffer flush if self.tlsStarted and not self.tlsDone and len(self.write_buf) == 0 and self.writeQueue.empty(): #print "handshaking (read)" - self.state_tls_handshake() + self.tls_handshake() else: #print "not handshaking (read)" return AdvancedDispatcher.handle_read(self) @@ -95,23 +95,23 @@ class TLSDispatcher(AdvancedDispatcher): # wait for write buffer flush if self.tlsStarted and not self.tlsDone and len(self.write_buf) == 0 and self.writeQueue.empty(): #print "handshaking (write)" - self.state_tls_handshake() + self.tls_handshake() else: #print "not handshaking (write)" return AdvancedDispatcher.handle_write(self) except AttributeError: return AdvancedDispatcher.handle_read(self) - def state_tls_handshake(self): + def tls_handshake(self): # wait for flush if len(self.write_buf) > 0: return False # Perform the handshake. try: - #print "handshaking (internal)" + print "handshaking (internal)" self.sslSocket.do_handshake() except ssl.SSLError, err: - #print "%s:%i: handshake fail" % (self.destination.host, self.destination.port) + print "%s:%i: handshake fail" % (self.destination.host, self.destination.port) self.want_read = self.want_write = False if err.args[0] == ssl.SSL_ERROR_WANT_READ: #print "want read"