diff --git a/src/bmproto.py b/src/bmproto.py index 0a147d3e..5cd08779 100644 --- a/src/bmproto.py +++ b/src/bmproto.py @@ -9,14 +9,19 @@ from network.socks5 import Socks5Connection, Socks5Resolver, Socks5AuthError, So from network.socks4a import Socks4aConnection, Socks4aResolver, Socks4aError import addresses +from bmconfigparser import BMConfigParser import protocol class BMProtoError(ProxyError): pass -class BMConnection(AdvancedDispatcher): +class BMConnection(TLSDispatcher): # ~1.6 MB which is the maximum possible size of an inv message. maxMessageSize = 1600100 + # protocol specification says max 1000 addresses in one addr command + maxAddrCount = 1000 + # protocol specification says max 50000 objects in one inv command + maxObjectCount = 50000 def __init__(self, address=None, sock=None): AdvancedDispatcher.__init__(self, sock) @@ -25,12 +30,14 @@ class BMConnection(AdvancedDispatcher): if address is None and sock is not None: self.destination = self.addr() self.isOutbound = False + TLSHandshake.__init__(self, sock, server_side=True) print "received connection in background from %s:%i" % (self.destination[0], self.destination[1]) else: self.destination = address self.isOutbound = True self.create_socket(socket.AF_INET, socket.SOCK_STREAM) self.connect(self.destination) + TLSHandshake.__init__(self, sock, server_side=False) print "connecting in background to %s:%i" % (self.destination[0], self.destination[1]) def bm_proto_reset(self): @@ -49,6 +56,11 @@ class BMConnection(AdvancedDispatcher): self.set_state("bm_header", 0) return False + def state_bm_ready(self): + self.sendAddr() + self.sendBigInv() + return True + def state_bm_header(self): if len(self.read_buf) < protocol.Header.size: print "Length below header size" @@ -74,18 +86,27 @@ class BMConnection(AdvancedDispatcher): if self.checksum != hashlib.sha512(self.payload).digest()[0:4]: print "Bad checksum, ignoring" self.invalid = True + retval = True if not self.invalid: try: - getattr(self, "bm_command_" + str(self.command))() + retval = getattr(self, "bm_command_" + str(self.command).lower())() except AttributeError: # unimplemented command print "unimplemented command %s" % (self.command) else: print "Skipping command %s due to invalid data" % (self.command) - self.set_state("bm_header", self.payloadLength) - self.bm_proto_reset() + if retval: + self.set_state("bm_header", self.payloadLength) + self.bm_proto_reset() + # else assume the command requires a different state to follow return True + def bm_command_error(self): + def bm_command_getdata(self): + def bm_command_object(self): + def bm_command_ping(self): + def bm_command_pong(self): + def bm_command_verack(self): self.verackReceived = True return True @@ -102,8 +123,171 @@ class BMConnection(AdvancedDispatcher): self.userAgent = self.payload[readPosition:readPosition + useragentLength] readPosition += useragentLength print "user agent: %s" % (self.userAgent) + if not self.peerValidityChecks(): + # TODO ABORT + return True + self.write_buf += protocol.CreatePacket('verack') + self.verackSent = True + if ((self.services & protocol.NODE_SSL == protocol.NODE_SSL) and + protocol.haveSSL(not self.isOutbound)): + self.isSSL = True + if self.verackReceived: + if self.isSSL: + self.set_state("tls_init", self.payloadLength) + else: + self.set_state("bm_ready", self.payloadLength) + self.bm_proto_reset() + return False + + def peerValidityChecks(self): + if self.remoteProtocolVersion < 3: + self.write_buf += protocol.assembleErrorMessage(fatal=2, + errorText="Your is using an old protocol. Closing connection.") + logger.debug ('Closing connection to old protocol version %s, node: %s', + str(self.remoteProtocolVersion), str(self.peer)) + return False + if self.timeOffset > 3600: + self.write_buf += protocol.assembleErrorMessage(fatal=2, + errorText="Your time is too far in the future compared to mine. Closing connection.") + logger.info("%s's time is too far in the future (%s seconds). Closing connection to it.", + self.peer, self.timeOffset) + shared.timeOffsetWrongCount += 1 + return False + elif self.timeOffset < -3600: + self.write_buf += protocol.assembleErrorMessage(fatal=2, + errorText="Your time is too far in the past compared to mine. Closing connection.") + logger.info("%s's time is too far in the past (timeOffset %s seconds). Closing connection to it.", + self.peer, self.timeOffset) + shared.timeOffsetWrongCount += 1 + return False + else: + shared.timeOffsetWrongCount = 0 + if len(self.streams) == 0: + self.write_buf += protocol.assembleErrorMessage(fatal=2, + errorText="We don't have shared stream interests. Closing connection."))) + logger.debug ('Closed connection to %s because there is no overlapping interest in streams.', + str(self.peer)) + return False return True + def sendAddr(self): + def sendChunk(): + if numberOfAddressesInAddrMessage == 0: + return + self.write_buf += protocol.CreatePacket('addr', \ + addresses.encodeVarint(numberOfAddressesInAddrMessage) + payload))) + + # We are going to share a maximum number of 1000 addrs (per overlapping + # stream) with our peer. 500 from overlapping streams, 250 from the + # left child stream, and 250 from the right child stream. + maxAddrCount = BMConfigParser().safeGetInt("bitmessagesettings", "maxaddrperstreamsend", 500) + + # init + addressCount = 0 + payload = '' + + for stream in self.streams: + addrsInMyStream = {} + addrsInChildStreamLeft = {} + addrsInChildStreamRight = {} + + with knownnodes.knownNodesLock: + if len(knownnodes.knownNodes[stream]) > 0: + filtered = {k: v for k, v in knownnodes.knownNodes[stream].items() + if v > (int(time.time()) - shared.maximumAgeOfNodesThatIAdvertiseToOthers)} + elemCount = len(filtered) + if elemCount > maxAddrCount: + elemCount = maxAddrCount + # only if more recent than 3 hours + addrsInMyStream = random.sample(filtered.items(), elemCount) + # sent 250 only if the remote isn't interested in it + if len(knownnodes.knownNodes[stream * 2]) > 0 and stream not in self.streams: + filtered = {k: v for k, v in knownnodes.knownNodes[stream*2].items() + if v > (int(time.time()) - shared.maximumAgeOfNodesThatIAdvertiseToOthers)} + elemCount = len(filtered) + if elemCount > maxAddrCount / 2: + elemCount = int(maxAddrCount / 2) + addrsInChildStreamLeft = random.sample(filtered.items(), elemCount) + if len(knownnodes.knownNodes[(stream * 2) + 1]) > 0 and stream not in self.streams: + filtered = {k: v for k, v in knownnodes.knownNodes[stream*2+1].items() + if v > (int(time.time()) - shared.maximumAgeOfNodesThatIAdvertiseToOthers)} + elemCount = len(filtered) + if elemCount > maxAddrCount / 2: + elemCount = int(maxAddrCount / 2) + addrsInChildStreamRight = random.sample(filtered.items(), elemCount) + for (HOST, PORT), timeLastReceivedMessageFromThisNode in addrsInMyStream: + addressCount += 1 + payload += pack( + '>Q', timeLastReceivedMessageFromThisNode) # 64-bit time + payload += pack('>I', stream) + payload += pack( + '>q', 1) # service bit flags offered by this node + payload += protocol.encodeHost(HOST) + payload += pack('>H', PORT) # remote port + if addressCount >= BMConnection.maxAddrCount: + sendChunk() + payload = '' + addressCount = 0 + for (HOST, PORT), timeLastReceivedMessageFromThisNode in addrsInChildStreamLeft: + addressCount += 1 + payload += pack( + '>Q', timeLastReceivedMessageFromThisNode) # 64-bit time + payload += pack('>I', stream * 2) + payload += pack( + '>q', 1) # service bit flags offered by this node + payload += protocol.encodeHost(HOST) + payload += pack('>H', PORT) # remote port + if addressCount >= BMConnection.maxAddrCount: + sendChunk() + payload = '' + addressCount = 0 + for (HOST, PORT), timeLastReceivedMessageFromThisNode in addrsInChildStreamRight: + addressCount += 1 + payload += pack( + '>Q', timeLastReceivedMessageFromThisNode) # 64-bit time + payload += pack('>I', (stream * 2) + 1) + payload += pack( + '>q', 1) # service bit flags offered by this node + payload += protocol.encodeHost(HOST) + payload += pack('>H', PORT) # remote port + if addressCount >= BMConnection.maxAddrCount: + sendChunk() + payload = '' + addressCount = 0 + + # flush + sendChunk() + + def sendBigInv(self): + def sendChunk(): + if objectCount == 0: + return + payload = encodeVarint(objectCount) + payload + logger.debug('Sending huge inv message with %i objects to just this one peer', + str(numberOfObjects)) + self.write_buf += protocol.CreatePacket('inv', payload) + + # Select all hashes for objects in this stream. + bigInvList = {} + for stream in self.streams: + for hash in Inventory().unexpired_hashes_by_stream(stream): + if not self.objectHashHolderInstance.hasHash(hash): + bigInvList[hash] = 0 + objectCount = 0 + payload = '' + # Now let us start appending all of these hashes together. They will be + # sent out in a big inv message to our new peer. + for hash, storedValue in bigInvList.items(): + payload += hash + objectCount += 1 + if objectCount >= BMConnection.maxObjectCount: + self.sendChunk() + payload = '' + objectCount = 0 + + # flush + sendChunk() + class Socks5BMConnection(Socks5Connection, BMConnection): def __init__(self, address): diff --git a/src/network/tls.py b/src/network/tls.py index 8f104c55..023f6cac 100644 --- a/src/network/tls.py +++ b/src/network/tls.py @@ -2,57 +2,45 @@ SSL/TLS negotiation. """ -import asyncore +from network.advanceddispatcher import AdvancedDispatcher +import network.asyncore_pollchoose as asyncore import socket import ssl import sys import protocol -class TLSHandshake(asyncore.dispatcher): - """ - Negotiates a SSL/TLS connection before handing itself spawning a - dispatcher that can deal with the overlying protocol as soon as the - handshake has been completed. - - `handoff` is a function/method called when the handshake has completed. - `address` is a tuple consisting of hostname/address and port to connect to - if nothing is passed in `sock`, which can take an already-connected socket. - `certfile` can take a path to a certificate bundle, and `server_side` - indicates whether the socket is intended to be a server-side or client-side - socket. - """ - +class TLSDispatcher(AdvancedDispatcher): def __init__(self, address=None, sock=None, - certfile=None, keyfile=None, server_side=False, ciphers=None, init_parent=True): - if not hasattr(self, '_map'): - asyncore.dispatcher.__init__(self, sock) + certfile=None, keyfile=None, server_side=False, ciphers=protocol.sslProtocolCiphers): self.want_read = self.want_write = True - self.certfile = certfile - self.keyfile = keyfile + if certfile is None: + self.certfile = os.path.join(paths.codePath(), 'sslkeys', 'cert.pem') + else: + self.certfile = certfile + if keyfile is None: + self.keyfile = os.path.join(paths.codePath(), 'sslkeys', 'key.pem') + else: + self.keyfile = keyfile self.server_side = server_side self.ciphers = ciphers + self.tlsStarted = False self.tlsDone = False - if sock is None: - self.create_socket(socket.AF_INET, socket.SOCK_STREAM) -# logger.info('Connecting to %s%d', address[0], address[1]) - self.connect(address) - elif self.connected: - # Initiate the handshake for an already-connected socket. - self.handle_connect() + self.isSSL = False - def handle_connect(self): + def state_tls_init(self): + self.isSSL = True # Once the connection has been established, it's safe to wrap the # socket. if sys.version_info >= (2,7,9): context = ssl.create_default_context(purpose = ssl.Purpose.SERVER_AUTH if self.server_side else ssl.Purpose.CLIENT_AUTH) context.set_ciphers(self.ciphers) - # context.set_ecdh_curve("secp256k1") + context.set_ecdh_curve("secp256k1") context.check_hostname = False context.verify_mode = ssl.CERT_NONE # also exclude TLSv1 and TLSv1.1 in the future - context.options |= ssl.OP_NOSSLv2 | ssl.OP_NOSSLv3 - self.sslSock = context.wrap_socket(self.sock, server_side = self.server_side, do_handshake_on_connect=False) + context.options = ssl.OP_ALL | ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | ssl.OP_SINGLE_ECDH_USE | ssl.OP_CIPHER_SERVER_PREFERENCE + self.sslSocket = context.wrap_socket(self.sock, server_side = self.server_side, do_handshake_on_connect=False) else: self.sslSocket = ssl.wrap_socket(self.socket, server_side=self.server_side, @@ -67,20 +55,30 @@ class TLSHandshake(asyncore.dispatcher): # self.socket.context.set_ecdh_curve("secp256k1") def writable(self): - return self.want_write + if self.tlsStarted and not self.tlsDone: + return self.want_write + else: + return AdvancedDispacher.writable(self) def readable(self): - return self.want_read + if self.tlsStarted and not self.tlsDone: + return self.want_read + else: + return AdvancedDispacher.readable(self) def handle_read(self): - if not self.tlsDone: + if self.tlsStarted and not self.tlsDone: self._handshake() + else: + return AdvancedDispacher.handle_read(self) def handle_write(self): - if not self.tlsDone: + if self.tlsStarted and not not self.tlsDone: self._handshake() + else: + return AdvancedDispacher.handle_write(self) - def _handshake(self): + def state_tls_handshake(self): """ Perform the handshake. """