Async network IO updates
- WIP
This commit is contained in:
parent
fe93473fc5
commit
51aeb284ca
192
src/bmproto.py
192
src/bmproto.py
|
@ -9,14 +9,19 @@ from network.socks5 import Socks5Connection, Socks5Resolver, Socks5AuthError, So
|
|||
from network.socks4a import Socks4aConnection, Socks4aResolver, Socks4aError
|
||||
|
||||
import addresses
|
||||
from bmconfigparser import BMConfigParser
|
||||
import protocol
|
||||
|
||||
class BMProtoError(ProxyError): pass
|
||||
|
||||
|
||||
class BMConnection(AdvancedDispatcher):
|
||||
class BMConnection(TLSDispatcher):
|
||||
# ~1.6 MB which is the maximum possible size of an inv message.
|
||||
maxMessageSize = 1600100
|
||||
# protocol specification says max 1000 addresses in one addr command
|
||||
maxAddrCount = 1000
|
||||
# protocol specification says max 50000 objects in one inv command
|
||||
maxObjectCount = 50000
|
||||
|
||||
def __init__(self, address=None, sock=None):
|
||||
AdvancedDispatcher.__init__(self, sock)
|
||||
|
@ -25,12 +30,14 @@ class BMConnection(AdvancedDispatcher):
|
|||
if address is None and sock is not None:
|
||||
self.destination = self.addr()
|
||||
self.isOutbound = False
|
||||
TLSHandshake.__init__(self, sock, server_side=True)
|
||||
print "received connection in background from %s:%i" % (self.destination[0], self.destination[1])
|
||||
else:
|
||||
self.destination = address
|
||||
self.isOutbound = True
|
||||
self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self.connect(self.destination)
|
||||
TLSHandshake.__init__(self, sock, server_side=False)
|
||||
print "connecting in background to %s:%i" % (self.destination[0], self.destination[1])
|
||||
|
||||
def bm_proto_reset(self):
|
||||
|
@ -49,6 +56,11 @@ class BMConnection(AdvancedDispatcher):
|
|||
self.set_state("bm_header", 0)
|
||||
return False
|
||||
|
||||
def state_bm_ready(self):
|
||||
self.sendAddr()
|
||||
self.sendBigInv()
|
||||
return True
|
||||
|
||||
def state_bm_header(self):
|
||||
if len(self.read_buf) < protocol.Header.size:
|
||||
print "Length below header size"
|
||||
|
@ -74,18 +86,27 @@ class BMConnection(AdvancedDispatcher):
|
|||
if self.checksum != hashlib.sha512(self.payload).digest()[0:4]:
|
||||
print "Bad checksum, ignoring"
|
||||
self.invalid = True
|
||||
retval = True
|
||||
if not self.invalid:
|
||||
try:
|
||||
getattr(self, "bm_command_" + str(self.command))()
|
||||
retval = getattr(self, "bm_command_" + str(self.command).lower())()
|
||||
except AttributeError:
|
||||
# unimplemented command
|
||||
print "unimplemented command %s" % (self.command)
|
||||
else:
|
||||
print "Skipping command %s due to invalid data" % (self.command)
|
||||
self.set_state("bm_header", self.payloadLength)
|
||||
self.bm_proto_reset()
|
||||
if retval:
|
||||
self.set_state("bm_header", self.payloadLength)
|
||||
self.bm_proto_reset()
|
||||
# else assume the command requires a different state to follow
|
||||
return True
|
||||
|
||||
def bm_command_error(self):
|
||||
def bm_command_getdata(self):
|
||||
def bm_command_object(self):
|
||||
def bm_command_ping(self):
|
||||
def bm_command_pong(self):
|
||||
|
||||
def bm_command_verack(self):
|
||||
self.verackReceived = True
|
||||
return True
|
||||
|
@ -102,8 +123,171 @@ class BMConnection(AdvancedDispatcher):
|
|||
self.userAgent = self.payload[readPosition:readPosition + useragentLength]
|
||||
readPosition += useragentLength
|
||||
print "user agent: %s" % (self.userAgent)
|
||||
if not self.peerValidityChecks():
|
||||
# TODO ABORT
|
||||
return True
|
||||
self.write_buf += protocol.CreatePacket('verack')
|
||||
self.verackSent = True
|
||||
if ((self.services & protocol.NODE_SSL == protocol.NODE_SSL) and
|
||||
protocol.haveSSL(not self.isOutbound)):
|
||||
self.isSSL = True
|
||||
if self.verackReceived:
|
||||
if self.isSSL:
|
||||
self.set_state("tls_init", self.payloadLength)
|
||||
else:
|
||||
self.set_state("bm_ready", self.payloadLength)
|
||||
self.bm_proto_reset()
|
||||
return False
|
||||
|
||||
def peerValidityChecks(self):
|
||||
if self.remoteProtocolVersion < 3:
|
||||
self.write_buf += protocol.assembleErrorMessage(fatal=2,
|
||||
errorText="Your is using an old protocol. Closing connection.")
|
||||
logger.debug ('Closing connection to old protocol version %s, node: %s',
|
||||
str(self.remoteProtocolVersion), str(self.peer))
|
||||
return False
|
||||
if self.timeOffset > 3600:
|
||||
self.write_buf += protocol.assembleErrorMessage(fatal=2,
|
||||
errorText="Your time is too far in the future compared to mine. Closing connection.")
|
||||
logger.info("%s's time is too far in the future (%s seconds). Closing connection to it.",
|
||||
self.peer, self.timeOffset)
|
||||
shared.timeOffsetWrongCount += 1
|
||||
return False
|
||||
elif self.timeOffset < -3600:
|
||||
self.write_buf += protocol.assembleErrorMessage(fatal=2,
|
||||
errorText="Your time is too far in the past compared to mine. Closing connection.")
|
||||
logger.info("%s's time is too far in the past (timeOffset %s seconds). Closing connection to it.",
|
||||
self.peer, self.timeOffset)
|
||||
shared.timeOffsetWrongCount += 1
|
||||
return False
|
||||
else:
|
||||
shared.timeOffsetWrongCount = 0
|
||||
if len(self.streams) == 0:
|
||||
self.write_buf += protocol.assembleErrorMessage(fatal=2,
|
||||
errorText="We don't have shared stream interests. Closing connection.")))
|
||||
logger.debug ('Closed connection to %s because there is no overlapping interest in streams.',
|
||||
str(self.peer))
|
||||
return False
|
||||
return True
|
||||
|
||||
def sendAddr(self):
|
||||
def sendChunk():
|
||||
if numberOfAddressesInAddrMessage == 0:
|
||||
return
|
||||
self.write_buf += protocol.CreatePacket('addr', \
|
||||
addresses.encodeVarint(numberOfAddressesInAddrMessage) + payload)))
|
||||
|
||||
# We are going to share a maximum number of 1000 addrs (per overlapping
|
||||
# stream) with our peer. 500 from overlapping streams, 250 from the
|
||||
# left child stream, and 250 from the right child stream.
|
||||
maxAddrCount = BMConfigParser().safeGetInt("bitmessagesettings", "maxaddrperstreamsend", 500)
|
||||
|
||||
# init
|
||||
addressCount = 0
|
||||
payload = ''
|
||||
|
||||
for stream in self.streams:
|
||||
addrsInMyStream = {}
|
||||
addrsInChildStreamLeft = {}
|
||||
addrsInChildStreamRight = {}
|
||||
|
||||
with knownnodes.knownNodesLock:
|
||||
if len(knownnodes.knownNodes[stream]) > 0:
|
||||
filtered = {k: v for k, v in knownnodes.knownNodes[stream].items()
|
||||
if v > (int(time.time()) - shared.maximumAgeOfNodesThatIAdvertiseToOthers)}
|
||||
elemCount = len(filtered)
|
||||
if elemCount > maxAddrCount:
|
||||
elemCount = maxAddrCount
|
||||
# only if more recent than 3 hours
|
||||
addrsInMyStream = random.sample(filtered.items(), elemCount)
|
||||
# sent 250 only if the remote isn't interested in it
|
||||
if len(knownnodes.knownNodes[stream * 2]) > 0 and stream not in self.streams:
|
||||
filtered = {k: v for k, v in knownnodes.knownNodes[stream*2].items()
|
||||
if v > (int(time.time()) - shared.maximumAgeOfNodesThatIAdvertiseToOthers)}
|
||||
elemCount = len(filtered)
|
||||
if elemCount > maxAddrCount / 2:
|
||||
elemCount = int(maxAddrCount / 2)
|
||||
addrsInChildStreamLeft = random.sample(filtered.items(), elemCount)
|
||||
if len(knownnodes.knownNodes[(stream * 2) + 1]) > 0 and stream not in self.streams:
|
||||
filtered = {k: v for k, v in knownnodes.knownNodes[stream*2+1].items()
|
||||
if v > (int(time.time()) - shared.maximumAgeOfNodesThatIAdvertiseToOthers)}
|
||||
elemCount = len(filtered)
|
||||
if elemCount > maxAddrCount / 2:
|
||||
elemCount = int(maxAddrCount / 2)
|
||||
addrsInChildStreamRight = random.sample(filtered.items(), elemCount)
|
||||
for (HOST, PORT), timeLastReceivedMessageFromThisNode in addrsInMyStream:
|
||||
addressCount += 1
|
||||
payload += pack(
|
||||
'>Q', timeLastReceivedMessageFromThisNode) # 64-bit time
|
||||
payload += pack('>I', stream)
|
||||
payload += pack(
|
||||
'>q', 1) # service bit flags offered by this node
|
||||
payload += protocol.encodeHost(HOST)
|
||||
payload += pack('>H', PORT) # remote port
|
||||
if addressCount >= BMConnection.maxAddrCount:
|
||||
sendChunk()
|
||||
payload = ''
|
||||
addressCount = 0
|
||||
for (HOST, PORT), timeLastReceivedMessageFromThisNode in addrsInChildStreamLeft:
|
||||
addressCount += 1
|
||||
payload += pack(
|
||||
'>Q', timeLastReceivedMessageFromThisNode) # 64-bit time
|
||||
payload += pack('>I', stream * 2)
|
||||
payload += pack(
|
||||
'>q', 1) # service bit flags offered by this node
|
||||
payload += protocol.encodeHost(HOST)
|
||||
payload += pack('>H', PORT) # remote port
|
||||
if addressCount >= BMConnection.maxAddrCount:
|
||||
sendChunk()
|
||||
payload = ''
|
||||
addressCount = 0
|
||||
for (HOST, PORT), timeLastReceivedMessageFromThisNode in addrsInChildStreamRight:
|
||||
addressCount += 1
|
||||
payload += pack(
|
||||
'>Q', timeLastReceivedMessageFromThisNode) # 64-bit time
|
||||
payload += pack('>I', (stream * 2) + 1)
|
||||
payload += pack(
|
||||
'>q', 1) # service bit flags offered by this node
|
||||
payload += protocol.encodeHost(HOST)
|
||||
payload += pack('>H', PORT) # remote port
|
||||
if addressCount >= BMConnection.maxAddrCount:
|
||||
sendChunk()
|
||||
payload = ''
|
||||
addressCount = 0
|
||||
|
||||
# flush
|
||||
sendChunk()
|
||||
|
||||
def sendBigInv(self):
|
||||
def sendChunk():
|
||||
if objectCount == 0:
|
||||
return
|
||||
payload = encodeVarint(objectCount) + payload
|
||||
logger.debug('Sending huge inv message with %i objects to just this one peer',
|
||||
str(numberOfObjects))
|
||||
self.write_buf += protocol.CreatePacket('inv', payload)
|
||||
|
||||
# Select all hashes for objects in this stream.
|
||||
bigInvList = {}
|
||||
for stream in self.streams:
|
||||
for hash in Inventory().unexpired_hashes_by_stream(stream):
|
||||
if not self.objectHashHolderInstance.hasHash(hash):
|
||||
bigInvList[hash] = 0
|
||||
objectCount = 0
|
||||
payload = ''
|
||||
# Now let us start appending all of these hashes together. They will be
|
||||
# sent out in a big inv message to our new peer.
|
||||
for hash, storedValue in bigInvList.items():
|
||||
payload += hash
|
||||
objectCount += 1
|
||||
if objectCount >= BMConnection.maxObjectCount:
|
||||
self.sendChunk()
|
||||
payload = ''
|
||||
objectCount = 0
|
||||
|
||||
# flush
|
||||
sendChunk()
|
||||
|
||||
|
||||
class Socks5BMConnection(Socks5Connection, BMConnection):
|
||||
def __init__(self, address):
|
||||
|
|
|
@ -2,57 +2,45 @@
|
|||
SSL/TLS negotiation.
|
||||
"""
|
||||
|
||||
import asyncore
|
||||
from network.advanceddispatcher import AdvancedDispatcher
|
||||
import network.asyncore_pollchoose as asyncore
|
||||
import socket
|
||||
import ssl
|
||||
import sys
|
||||
|
||||
import protocol
|
||||
|
||||
class TLSHandshake(asyncore.dispatcher):
|
||||
"""
|
||||
Negotiates a SSL/TLS connection before handing itself spawning a
|
||||
dispatcher that can deal with the overlying protocol as soon as the
|
||||
handshake has been completed.
|
||||
|
||||
`handoff` is a function/method called when the handshake has completed.
|
||||
`address` is a tuple consisting of hostname/address and port to connect to
|
||||
if nothing is passed in `sock`, which can take an already-connected socket.
|
||||
`certfile` can take a path to a certificate bundle, and `server_side`
|
||||
indicates whether the socket is intended to be a server-side or client-side
|
||||
socket.
|
||||
"""
|
||||
|
||||
class TLSDispatcher(AdvancedDispatcher):
|
||||
def __init__(self, address=None, sock=None,
|
||||
certfile=None, keyfile=None, server_side=False, ciphers=None, init_parent=True):
|
||||
if not hasattr(self, '_map'):
|
||||
asyncore.dispatcher.__init__(self, sock)
|
||||
certfile=None, keyfile=None, server_side=False, ciphers=protocol.sslProtocolCiphers):
|
||||
self.want_read = self.want_write = True
|
||||
self.certfile = certfile
|
||||
self.keyfile = keyfile
|
||||
if certfile is None:
|
||||
self.certfile = os.path.join(paths.codePath(), 'sslkeys', 'cert.pem')
|
||||
else:
|
||||
self.certfile = certfile
|
||||
if keyfile is None:
|
||||
self.keyfile = os.path.join(paths.codePath(), 'sslkeys', 'key.pem')
|
||||
else:
|
||||
self.keyfile = keyfile
|
||||
self.server_side = server_side
|
||||
self.ciphers = ciphers
|
||||
self.tlsStarted = False
|
||||
self.tlsDone = False
|
||||
if sock is None:
|
||||
self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
# logger.info('Connecting to %s%d', address[0], address[1])
|
||||
self.connect(address)
|
||||
elif self.connected:
|
||||
# Initiate the handshake for an already-connected socket.
|
||||
self.handle_connect()
|
||||
self.isSSL = False
|
||||
|
||||
def handle_connect(self):
|
||||
def state_tls_init(self):
|
||||
self.isSSL = True
|
||||
# Once the connection has been established, it's safe to wrap the
|
||||
# socket.
|
||||
if sys.version_info >= (2,7,9):
|
||||
context = ssl.create_default_context(purpose = ssl.Purpose.SERVER_AUTH if self.server_side else ssl.Purpose.CLIENT_AUTH)
|
||||
context.set_ciphers(self.ciphers)
|
||||
# context.set_ecdh_curve("secp256k1")
|
||||
context.set_ecdh_curve("secp256k1")
|
||||
context.check_hostname = False
|
||||
context.verify_mode = ssl.CERT_NONE
|
||||
# also exclude TLSv1 and TLSv1.1 in the future
|
||||
context.options |= ssl.OP_NOSSLv2 | ssl.OP_NOSSLv3
|
||||
self.sslSock = context.wrap_socket(self.sock, server_side = self.server_side, do_handshake_on_connect=False)
|
||||
context.options = ssl.OP_ALL | ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | ssl.OP_SINGLE_ECDH_USE | ssl.OP_CIPHER_SERVER_PREFERENCE
|
||||
self.sslSocket = context.wrap_socket(self.sock, server_side = self.server_side, do_handshake_on_connect=False)
|
||||
else:
|
||||
self.sslSocket = ssl.wrap_socket(self.socket,
|
||||
server_side=self.server_side,
|
||||
|
@ -67,20 +55,30 @@ class TLSHandshake(asyncore.dispatcher):
|
|||
# self.socket.context.set_ecdh_curve("secp256k1")
|
||||
|
||||
def writable(self):
|
||||
return self.want_write
|
||||
if self.tlsStarted and not self.tlsDone:
|
||||
return self.want_write
|
||||
else:
|
||||
return AdvancedDispacher.writable(self)
|
||||
|
||||
def readable(self):
|
||||
return self.want_read
|
||||
if self.tlsStarted and not self.tlsDone:
|
||||
return self.want_read
|
||||
else:
|
||||
return AdvancedDispacher.readable(self)
|
||||
|
||||
def handle_read(self):
|
||||
if not self.tlsDone:
|
||||
if self.tlsStarted and not self.tlsDone:
|
||||
self._handshake()
|
||||
else:
|
||||
return AdvancedDispacher.handle_read(self)
|
||||
|
||||
def handle_write(self):
|
||||
if not self.tlsDone:
|
||||
if self.tlsStarted and not not self.tlsDone:
|
||||
self._handshake()
|
||||
else:
|
||||
return AdvancedDispacher.handle_write(self)
|
||||
|
||||
def _handshake(self):
|
||||
def state_tls_handshake(self):
|
||||
"""
|
||||
Perform the handshake.
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue
Block a user