Asyncore update (WIP)

This commit is contained in:
Peter Šurda 2017-04-16 18:27:15 +02:00
parent e6f0b34f9b
commit 96d58f3c11
Signed by: PeterSurda
GPG Key ID: 0C5F50C0B5F37D87
9 changed files with 306 additions and 84 deletions

View File

@ -1,15 +1,20 @@
import hashlib
import time
from pprint import pprint
import socket
from struct import unpack
from network.advanceddispatcher import AdvancedDispatcher
from network.node import Node
import network.asyncore_pollchoose as asyncore
from network.proxy import Proxy, ProxyError, GeneralProxyError
from network.socks5 import Socks5Connection, Socks5Resolver, Socks5AuthError, Socks5Error
from network.socks4a import Socks4aConnection, Socks4aResolver, Socks4aError
from network.tls import TLSDispatcher
import addresses
from bmconfigparser import BMConfigParser
import shared
import protocol
class BMProtoError(ProxyError): pass
@ -30,14 +35,14 @@ class BMConnection(TLSDispatcher):
if address is None and sock is not None:
self.destination = self.addr()
self.isOutbound = False
TLSHandshake.__init__(self, sock, server_side=True)
TLSDispatcher.__init__(self, sock, server_side=True)
print "received connection in background from %s:%i" % (self.destination[0], self.destination[1])
else:
self.destination = address
self.isOutbound = True
self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
self.connect(self.destination)
TLSHandshake.__init__(self, sock, server_side=False)
TLSDispatcher.__init__(self, sock, server_side=False)
print "connecting in background to %s:%i" % (self.destination[0], self.destination[1])
def bm_proto_reset(self):
@ -47,19 +52,22 @@ class BMConnection(TLSDispatcher):
self.checksum = None
self.payload = None
self.invalid = False
self.payloadOffset = 0
def state_init(self):
self.bm_proto_reset()
self.write_buf += protocol.assembleVersionMessage(self.destination[0], self.destination[1], (1,), False)
self.append_write_buf(protocol.assembleVersionMessage(self.destination[0], self.destination[1], (1,), False))
if True:
print "Sending version (%ib)" % len(self.write_buf)
self.set_state("bm_header", 0)
self.set_state("bm_header")
return False
def state_bm_ready(self):
print "doing bm ready"
self.sendAddr()
self.sendBigInv()
return True
self.set_state("bm_header")
return False
def state_bm_header(self):
if len(self.read_buf) < protocol.Header.size:
@ -101,32 +109,127 @@ class BMConnection(TLSDispatcher):
# else assume the command requires a different state to follow
return True
def decode_payload_string(self, length):
value = self.payload[self.payloadOffset:self.payloadOffset+length]
self.payloadOffset += length
return value
def decode_payload_varint(self):
value, offset = addresses.decodeVarint(self.payload[self.payloadOffset:])
self.payloadOffset += offset
return value
def decode_payload_node(self):
services, address, port = self.decode_payload_content("Q16sH")
return Node(services, address, port)
def decode_payload_content(self, pattern = "v"):
# l = varint indicating the length of the next item
# v = varint (or array)
# H = uint16
# I = uint32
# Q = uint64
# i = net_addr (without time and stream number)
# s = string
# 0-9 = length of the next item
# , = end of array
retval = []
size = 0
insideDigit = False
for i in range(len(pattern)):
if pattern[i] in "0123456789":
size = size * 10 + int(pattern[i])
continue
elif pattern[i] == "l":
size = self.decode_payload_varint()
continue
if size > 0:
innerval = []
if pattern[i] == "s":
retval.append(self.payload[self.payloadOffset:self.payloadOffset + size])
self.payloadOffset += size
else:
for j in range(size):
if "," in pattern[i:]:
retval.append(self.decode_payload_content(pattern[i:pattern.index(",")]))
else:
retval.append(self.decode_payload_content(pattern[i:]))
size = 0
else:
if pattern[i] == "v":
retval.append(self.decode_payload_varint())
if pattern[i] == "i":
retval.append(self.decode_payload_node())
if pattern[i] == "H":
retval.append(unpack(">H", self.payload[self.payloadOffset:self.payloadOffset+2])[0])
self.payloadOffset += 2
if pattern[i] == "I":
retval.append(unpack(">I", self.payload[self.payloadOffset:self.payloadOffset+4])[0])
self.payloadOffset += 4
if pattern[i] == "Q":
retval.append(unpack(">Q", self.payload[self.payloadOffset:self.payloadOffset+8])[0])
self.payloadOffset += 8
return retval
def bm_command_error(self):
fatalStatus, banTime, inventoryVector, errorText = self.decode_payload_content("vvlsls")
def bm_command_getdata(self):
items = self.decode_payload_content("l32s")
#self.antiIntersectionDelay(True) # only handle getdata requests if we have been connected long enough
for i in items:
logger.debug('received getdata request for item:' + hexlify(i))
if self.objectHashHolderInstance.hasHash(i):
self.antiIntersectionDelay()
else:
if i in Inventory():
self.append_write_buf(protocol.CreatePacket('object', Inventory()[i].payload))
else:
#self.antiIntersectionDelay()
logger.warning('%s asked for an object with a getdata which is not in either our memory inventory or our SQL inventory. We probably cleaned it out after advertising it but before they got around to asking for it.' % (self.peer,))
def bm_command_object(self):
lengthOfTimeWeShouldUseToProcessThisMessage = shared.checkAndShareObjectWithPeers(self.payload)
self.downloadQueue.task_done(calculateInventoryHash(self.payload))
def bm_command_addr(self):
addresses = self.decode_payload_content("lQbQ16sH")
def bm_command_ping(self):
self.append_write_buf(protocol.CreatePacket('pong'))
def bm_command_pong(self):
# nothing really
pass
def bm_command_verack(self):
self.verackReceived = True
return True
if self.verackSent:
if self.isSSL:
self.set_state("tls_init", self.payloadLength)
else:
self.set_state("bm_ready", self.payloadLength)
else:
self.set_state("bm_header", self.payloadLength)
self.bm_proto_reset()
return False
def bm_command_version(self):
self.remoteProtocolVersion, self.services, self.timestamp, padding1, self.myExternalIP, padding2, self.remoteNodeIncomingPort = protocol.VersionPacket.unpack(self.payload[:protocol.VersionPacket.size])
#self.remoteProtocolVersion, self.services, self.timestamp, padding1, self.myExternalIP, padding2, self.remoteNodeIncomingPort = protocol.VersionPacket.unpack(self.payload[:protocol.VersionPacket.size])
self.remoteProtocolVersion, self.services, self.timestamp, self.sockNode, self.peerNode, self.nonce, self.userAgent, self.streams = self.decode_payload_content("IQQiiQlslv")
self.timeOffset = self.timestamp - int(time.time())
print "remoteProtocolVersion: %i" % (self.remoteProtocolVersion)
print "services: %08X" % (self.services)
print "time offset: %i" % (self.timestamp - int(time.time()))
print "my external IP: %s" % (socket.inet_ntoa(self.myExternalIP))
print "remote node incoming port: %i" % (self.remoteNodeIncomingPort)
useragentLength, lengthOfUseragentVarint = addresses.decodeVarint(self.payload[80:84])
readPosition = 80 + lengthOfUseragentVarint
self.userAgent = self.payload[readPosition:readPosition + useragentLength]
readPosition += useragentLength
print "my external IP: %s" % (self.sockNode.address)
print "remote node incoming port: %i" % (self.peerNode.port)
print "user agent: %s" % (self.userAgent)
if not self.peerValidityChecks():
# TODO ABORT
return True
self.write_buf += protocol.CreatePacket('verack')
self.append_write_buf(protocol.CreatePacket('verack'))
self.verackSent = True
if ((self.services & protocol.NODE_SSL == protocol.NODE_SSL) and
protocol.haveSSL(not self.isOutbound)):
@ -141,21 +244,21 @@ class BMConnection(TLSDispatcher):
def peerValidityChecks(self):
if self.remoteProtocolVersion < 3:
self.write_buf += protocol.assembleErrorMessage(fatal=2,
errorText="Your is using an old protocol. Closing connection.")
self.append_write_buf(protocol.assembleErrorMessage(fatal=2,
errorText="Your is using an old protocol. Closing connection."))
logger.debug ('Closing connection to old protocol version %s, node: %s',
str(self.remoteProtocolVersion), str(self.peer))
return False
if self.timeOffset > 3600:
self.write_buf += protocol.assembleErrorMessage(fatal=2,
errorText="Your time is too far in the future compared to mine. Closing connection.")
self.append_write_buf(protocol.assembleErrorMessage(fatal=2,
errorText="Your time is too far in the future compared to mine. Closing connection."))
logger.info("%s's time is too far in the future (%s seconds). Closing connection to it.",
self.peer, self.timeOffset)
shared.timeOffsetWrongCount += 1
return False
elif self.timeOffset < -3600:
self.write_buf += protocol.assembleErrorMessage(fatal=2,
errorText="Your time is too far in the past compared to mine. Closing connection.")
self.append_write_buf(protocol.assembleErrorMessage(fatal=2,
errorText="Your time is too far in the past compared to mine. Closing connection."))
logger.info("%s's time is too far in the past (timeOffset %s seconds). Closing connection to it.",
self.peer, self.timeOffset)
shared.timeOffsetWrongCount += 1
@ -163,8 +266,8 @@ class BMConnection(TLSDispatcher):
else:
shared.timeOffsetWrongCount = 0
if len(self.streams) == 0:
self.write_buf += protocol.assembleErrorMessage(fatal=2,
errorText="We don't have shared stream interests. Closing connection.")))
self.append_write_buf(protocol.assembleErrorMessage(fatal=2,
errorText="We don't have shared stream interests. Closing connection."))
logger.debug ('Closed connection to %s because there is no overlapping interest in streams.',
str(self.peer))
return False
@ -174,8 +277,8 @@ class BMConnection(TLSDispatcher):
def sendChunk():
if numberOfAddressesInAddrMessage == 0:
return
self.write_buf += protocol.CreatePacket('addr', \
addresses.encodeVarint(numberOfAddressesInAddrMessage) + payload)))
self.append_write_buf(protocol.CreatePacket('addr', \
addresses.encodeVarint(numberOfAddressesInAddrMessage) + payload))
# We are going to share a maximum number of 1000 addrs (per overlapping
# stream) with our peer. 500 from overlapping streams, 250 from the
@ -265,7 +368,7 @@ class BMConnection(TLSDispatcher):
payload = encodeVarint(objectCount) + payload
logger.debug('Sending huge inv message with %i objects to just this one peer',
str(numberOfObjects))
self.write_buf += protocol.CreatePacket('inv', payload)
self.append_write_buf(protocol.CreatePacket('inv', payload))
# Select all hashes for objects in this stream.
bigInvList = {}
@ -335,15 +438,15 @@ if __name__ == "__main__":
direct = BMConnection(host)
while len(asyncore.socket_map) > 0:
print "loop, state = %s" % (direct.state)
asyncore.loop(timeout=1, count=1)
asyncore.loop(timeout=10, count=1)
continue
proxy = Socks5BMConnection(host)
while len(asyncore.socket_map) > 0:
# print "loop, state = %s" % (proxy.state)
asyncore.loop(timeout=1, count=1)
asyncore.loop(timeout=10, count=1)
proxy = Socks4aBMConnection(host)
while len(asyncore.socket_map) > 0:
# print "loop, state = %s" % (proxy.state)
asyncore.loop(timeout=1, count=1)
asyncore.loop(timeout=10, count=1)

View File

@ -10,11 +10,16 @@ class AdvancedDispatcher(asyncore.dispatcher):
self.write_buf = b""
self.state = "init"
def slice_read_buf(self, length=0):
self.read_buf = self.read_buf[length:]
def append_write_buf(self, string = None):
self.write_buf += string
def slice_write_buf(self, length=0):
self.write_buf = self.read_buf[length:]
if length > 0:
self.write_buf = self.write_buf[length:]
def slice_read_buf(self, length=0):
if length > 0:
self.read_buf = self.read_buf[length:]
def read_buf_sufficient(self, length=0):
if len(self.read_buf) < length:
@ -23,7 +28,7 @@ class AdvancedDispatcher(asyncore.dispatcher):
return True
def process(self):
if self.state != "init" and len(self.read_buf) == 0:
if self.state not in ["init", "tls_handshake"] and len(self.read_buf) == 0:
return
while True:
try:
@ -34,7 +39,7 @@ class AdvancedDispatcher(asyncore.dispatcher):
# missing state
raise
def set_state(self, state, length):
def set_state(self, state, length=0):
self.slice_read_buf(length)
self.state = state
@ -45,6 +50,7 @@ class AdvancedDispatcher(asyncore.dispatcher):
return self.connecting or len(self.read_buf) < AdvancedDispatcher._buf_len
def handle_read(self):
print "handle_read"
self.read_buf += self.recv(AdvancedDispatcher._buf_len)
self.process()

View File

@ -60,6 +60,9 @@ from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, ECONNRESET, EINVAL, \
_DISCONNECTED = frozenset((ECONNRESET, ENOTCONN, ESHUTDOWN, ECONNABORTED, EPIPE,
EBADF))
OP_READ = 1
OP_WRITE = 2
try:
socket_map
except NameError:
@ -178,17 +181,25 @@ def poll_poller(timeout=0.0, map=None):
poll_poller.pollster = select.poll()
if map:
for fd, obj in list(map.items()):
flags = 0
flags = newflags = 0
if obj.readable():
flags |= select.POLLIN | select.POLLPRI
newflags |= OP_READ
else:
newflags &= ~ OP_READ
# accepting sockets should not be writable
if obj.writable() and not obj.accepting:
flags |= select.POLLOUT
if flags:
try:
newflags |= OP_WRITE
else:
newflags &= ~ OP_WRITE
if newflags != obj.flags:
obj.flags = newflags
if obj.poller_registered:
poll_poller.pollster.modify(fd, flags)
except IOError:
else:
poll_poller.pollster.register(fd, flags)
obj.poller_registered = True
try:
r = poll_poller.pollster.poll(timeout)
except KeyboardInterrupt:
@ -213,19 +224,28 @@ def epoll_poller(timeout=0.0, map=None):
epoll_poller.pollster = select.epoll()
if map:
for fd, obj in map.items():
flags = 0
flags = newflags = 0
if obj.readable():
flags |= select.POLLIN | select.POLLPRI
if obj.writable():
newflags |= OP_READ
else:
newflags &= ~ OP_READ
# accepting sockets should not be writable
if obj.writable() and not obj.accepting:
flags |= select.POLLOUT
if flags:
newflags |= OP_WRITE
else:
newflags &= ~ OP_WRITE
if newflags != obj.flags:
obj.flags = newflags
# Only check for exceptions if object was either readable
# or writable.
flags |= select.POLLERR | select.POLLHUP | select.POLLNVAL
try:
epoll_poller.pollster.register(fd, flags)
except IOError:
if obj.poller_registered:
epoll_poller.pollster.modify(fd, flags)
else:
epoll_poller.pollster.register(fd, flags)
obj.poller_registered = True
try:
r = epoll_poller.pollster.poll(timeout)
except select.error, err:
@ -306,6 +326,8 @@ class dispatcher:
closing = False
addr = None
ignore_log_types = frozenset(['warning'])
poller_registered = False
flags = 0
def __init__(self, sock=None, map=None):
if map is None:

View File

@ -19,7 +19,7 @@ class HttpConnection(AdvancedDispatcher):
print "connecting in background to %s:%i" % (self.destination[0], self.destination[1])
def state_init(self):
self.write_buf += "GET %s HTTP/1.1\r\nHost: %s\r\nConnection: close\r\n\r\n" % (self.path, self.destination[0])
self.append_write_buf("GET %s HTTP/1.1\r\nHost: %s\r\nConnection: close\r\n\r\n" % (self.path, self.destination[0]))
print "Sending %ib" % (len(self.write_buf))
self.set_state("http_request_sent", 0)
return False

66
src/network/node.py Normal file
View File

@ -0,0 +1,66 @@
import socket
import protocol
class Node (object):
TYPE_IPV4 = 1
TYPE_IPV6 = 2
TYPE_ONION = 3
TYPE_LOCAL = 4
TYPE_LOOPBACK = 8
TYPE_UNDEF = 12
def __init__(self, services, address, port):
self.services = services
self.address, self.addressType = Node.decodeIPAddress(address)
self.port = port
def isLocal(self):
return self.addressType | Node.TYPE_LOCAL > 0
def isGlobal(self):
return self.addressType <= Node.TYPE_ONION
def isOnion(self):
return self.addressType | Node.TYPE_ONION > 0
def isLoopback(self):
return self.addressType | Node.TYPE_LOOPBACK > 0
@staticmethod
def decodeIPAddress(host):
if host[0:12] == '\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xFF\xFF':
hostStandardFormat = socket.inet_ntop(socket.AF_INET, host[12:])
return Node.decodeIPv4Address(host[12:], hostStandardFormat)
elif host[0:6] == '\xfd\x87\xd8\x7e\xeb\x43':
# Onion, based on BMD/bitcoind
hostStandardFormat = base64.b32encode(host[6:]).lower() + ".onion"
return hostStandardFormat, Node.TYPE_ONION
else:
hostStandardFormat = socket.inet_ntop(socket.AF_INET6, host)
if hostStandardFormat == "":
# This can happen on Windows systems which are not 64-bit compatible
# so let us drop the IPv6 address.
return hostStandardFormat, Node.TYPE_IPV6|Node.TYPE_UNDEF
return Node.decodeIPv6Address(host, hostStandardFormat)
@staticmethod
def decodeIPv4Address(host, hostStandardFormat):
if host[0] == '\x7F': # 127/8
return hostStandardFormat, Node.TYPE_IPV4|Node.TYPE_LOOPBACK
if host[0] == '\x0A': # 10/8
return hostStandardFormat, Node.TYPE_IPV4|Node.TYPE_LOCAL
if host[0:2] == '\xC0\xA8': # 192.168/16
return hostStandardFormat, Node.TYPE_IPV4|Node.TYPE_LOCAL
if host[0:2] >= '\xAC\x10' and host[0:2] < '\xAC\x20': # 172.16/12
return hostStandardFormat, Node.TYPE_IPV4|Node.TYPE_LOCAL
return hostStandardFormat, Node.TYPE_IPV4
@staticmethod
def _checkIPv6Address(host, hostStandardFormat):
if host == ('\x00' * 15) + '\x01':
return hostStandardFormat, Node.TYPE_IPV6|Node.TYPE_LOOPBACK
if host[0] == '\xFE' and (ord(host[1]) & 0xc0) == 0x80:
return hostStandardFormat, Node.TYPE_IPV6|Node.TYPE_LOCAL
if (ord(host[0]) & 0xfe) == 0xfc:
return hostStandardFormat, Node.TYPE_IPV6|Node.TYPE_UNDEF
return hostStandardFormat, Node.TYPE_IPV6

View File

@ -59,28 +59,28 @@ class Socks4aConnection(Socks4a):
def state_auth_done(self):
# Now we can request the actual connection
rmtrslv = False
self.write_buf += struct.pack('>BBH', 0x04, 0x01, self.destination[1])
self.append_write_buf(struct.pack('>BBH', 0x04, 0x01, self.destination[1]))
# If the given destination address is an IP address, we'll
# use the IPv4 address request even if remote resolving was specified.
try:
self.ipaddr = socket.inet_aton(self.destination[0])
self.write_buf += ipaddr
self.append_write_buf(self.ipaddr)
except socket.error:
# Well it's not an IP number, so it's probably a DNS name.
if Proxy._remote_dns:
# Resolve remotely
rmtrslv = True
self.ipaddr = None
self.write_buf += struct.pack("BBBB", 0x00, 0x00, 0x00, 0x01)
self.append_write_buf(struct.pack("BBBB", 0x00, 0x00, 0x00, 0x01))
else:
# Resolve locally
self.ipaddr = socket.inet_aton(socket.gethostbyname(self.destination[0]))
self.write_buf += self.ipaddr
self.append_write_buf(self.ipaddr)
if self._auth:
self.write_buf += self._auth[0]
self.write_buf += chr(0x00).encode()
self.append_write_buf(self._auth[0])
self.append_write_buf(chr(0x00).encode())
if rmtrslv:
self.write_buf += self.destination[0] + chr(0x00).encode()
self.append_write_buf(self.destination[0] + chr(0x00).encode())
self.set_state("pre_connect", 0)
@ -92,12 +92,12 @@ class Socks4aResolver(Socks4a):
def state_auth_done(self):
# Now we can request the actual connection
self.write_buf += struct.pack('>BBH', 0x04, 0xF0, self.destination[1])
self.write_buf += struct.pack("BBBB", 0x00, 0x00, 0x00, 0x01)
self.append_write_buf(struct.pack('>BBH', 0x04, 0xF0, self.destination[1]))
self.append_write_buf(struct.pack("BBBB", 0x00, 0x00, 0x00, 0x01))
if self._auth:
self.write_buf += self._auth[0]
self.write_buf += chr(0x00).encode()
self.write_buf += self.host + chr(0x00).encode()
self.append_write_buf(self._auth[0])
self.append_write_buf(chr(0x00).encode())
self.append_write_buf(self.host + chr(0x00).encode())
self.set_state("pre_connect", 0)
def resolved(self):

View File

@ -17,9 +17,9 @@ class Socks5(Proxy):
def state_init(self):
if self._auth:
self.write_buf += struct.pack('BBBB', 0x05, 0x02, 0x00, 0x02)
self.append_write_buf(struct.pack('BBBB', 0x05, 0x02, 0x00, 0x02))
else:
self.write_buf += struct.pack('BBB', 0x05, 0x01, 0x00)
self.append_write_buf(struct.pack('BBB', 0x05, 0x01, 0x00))
self.set_state("auth_1", 0)
def state_auth_1(self):
@ -35,9 +35,9 @@ class Socks5(Proxy):
self.set_state("auth_done", 2)
elif ret[1] == 2:
# username/password
self.write_buf += struct.pack('BB', 1, len(self._auth[0])) + \
self.append_write_buf(struct.pack('BB', 1, len(self._auth[0])) + \
self._auth[0] + struct.pack('B', len(self._auth[1])) + \
self._auth[1]
self._auth[1])
self.set_state("auth_1", 2)
else:
if ret[1] == 0xff:
@ -130,23 +130,23 @@ class Socks5Connection(Socks5):
def state_auth_done(self):
# Now we can request the actual connection
self.write_buf += struct.pack('BBB', 0x05, 0x01, 0x00)
self.append_write_buf(struct.pack('BBB', 0x05, 0x01, 0x00))
# If the given destination address is an IP address, we'll
# use the IPv4 address request even if remote resolving was specified.
try:
self.ipaddr = socket.inet_aton(self.destination[0])
self.write_buf += chr(0x01).encode() + self.ipaddr
self.append_write_buf(chr(0x01).encode() + self.ipaddr)
except socket.error:
# Well it's not an IP number, so it's probably a DNS name.
if Proxy._remote_dns:
# Resolve remotely
self.ipaddr = None
self.write_buf += chr(0x03).encode() + chr(len(self.destination[0])).encode() + self.destination[0]
self.append_write_buf(chr(0x03).encode() + chr(len(self.destination[0])).encode() + self.destination[0])
else:
# Resolve locally
self.ipaddr = socket.inet_aton(socket.gethostbyname(self.destination[0]))
self.write_buf += chr(0x01).encode() + self.ipaddr
self.write_buf += struct.pack(">H", self.destination[1])
self.append_write_buf(chr(0x01).encode() + self.ipaddr)
self.append_write_buf(struct.pack(">H", self.destination[1]))
self.set_state("pre_connect", 0)
@ -158,9 +158,9 @@ class Socks5Resolver(Socks5):
def state_auth_done(self):
# Now we can request the actual connection
self.write_buf += struct.pack('BBB', 0x05, 0xF0, 0x00)
self.write_buf += chr(0x03).encode() + chr(len(self.host)).encode() + str(self.host)
self.write_buf += struct.pack(">H", self.port)
self.append_write_buf(struct.pack('BBB', 0x05, 0xF0, 0x00))
self.append_write_buf(chr(0x03).encode() + chr(len(self.host)).encode() + str(self.host))
self.append_write_buf(struct.pack(">H", self.port))
self.set_state("pre_connect", 0)
def resolved(self):

View File

@ -2,12 +2,14 @@
SSL/TLS negotiation.
"""
from network.advanceddispatcher import AdvancedDispatcher
import network.asyncore_pollchoose as asyncore
import os
import socket
import ssl
import sys
from network.advanceddispatcher import AdvancedDispatcher
import network.asyncore_pollchoose as asyncore
import paths
import protocol
class TLSDispatcher(AdvancedDispatcher):
@ -30,6 +32,7 @@ class TLSDispatcher(AdvancedDispatcher):
def state_tls_init(self):
self.isSSL = True
self.tlsStarted = True
# Once the connection has been established, it's safe to wrap the
# socket.
if sys.version_info >= (2,7,9):
@ -40,7 +43,7 @@ class TLSDispatcher(AdvancedDispatcher):
context.verify_mode = ssl.CERT_NONE
# also exclude TLSv1 and TLSv1.1 in the future
context.options = ssl.OP_ALL | ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | ssl.OP_SINGLE_ECDH_USE | ssl.OP_CIPHER_SERVER_PREFERENCE
self.sslSocket = context.wrap_socket(self.sock, server_side = self.server_side, do_handshake_on_connect=False)
self.sslSocket = context.wrap_socket(self.socket, server_side = self.server_side, do_handshake_on_connect=False)
else:
self.sslSocket = ssl.wrap_socket(self.socket,
server_side=self.server_side,
@ -51,49 +54,66 @@ class TLSDispatcher(AdvancedDispatcher):
do_handshake_on_connect=False)
self.sslSocket.setblocking(0)
self.want_read = self.want_write = True
self.set_state("tls_handshake")
# if hasattr(self.socket, "context"):
# self.socket.context.set_ecdh_curve("secp256k1")
def writable(self):
if self.tlsStarted and not self.tlsDone:
if self.tlsStarted and not self.tlsDone and len(self.write_buf) == 0:
print "tls writable, %r" % (self.want_write)
return self.want_write
else:
return AdvancedDispacher.writable(self)
return AdvancedDispatcher.writable(self)
def readable(self):
if self.tlsStarted and not self.tlsDone:
if self.tlsStarted and not self.tlsDone and len(self.write_buf) == 0:
print "tls readable, %r" % (self.want_read)
return self.want_read
else:
return AdvancedDispacher.readable(self)
return AdvancedDispatcher.readable(self)
def handle_read(self):
if self.tlsStarted and not self.tlsDone:
self._handshake()
# wait for write buffer flush
if self.tlsStarted and not self.tlsDone and len(self.write_buf) == 0:
print "handshaking (read)"
self.state_tls_handshake()
else:
return AdvancedDispacher.handle_read(self)
print "not handshaking (read)"
return AdvancedDispatcher.handle_read(self)
def handle_write(self):
if self.tlsStarted and not not self.tlsDone:
self._handshake()
# wait for write buffer flush
if self.tlsStarted and not self.tlsDone and len(self.write_buf) == 0:
print "handshaking (write)"
self.state_tls_handshake()
else:
return AdvancedDispacher.handle_write(self)
print "not handshaking (write)"
return AdvancedDispatcher.handle_write(self)
def state_tls_handshake(self):
"""
Perform the handshake.
"""
# wait for flush
if len(self.write_buf) > 0:
return False
# Perform the handshake.
try:
print "handshaking (internal)"
self.sslSocket.do_handshake()
except ssl.SSLError, err:
print "handshake fail"
self.want_read = self.want_write = False
if err.args[0] == ssl.SSL_ERROR_WANT_READ:
print "want read"
self.want_read = True
elif err.args[0] == ssl.SSL_ERROR_WANT_WRITE:
print "want write"
self.want_write = True
else:
raise
else:
print "handshake success"
# The handshake has completed, so remove this channel and...
self.del_channel()
self.set_socket(self.sslSocket)
self.tlsDone = True
self.state_bm_ready()
return False

View File

@ -27,6 +27,11 @@ NODE_SSL = 2
#Bitfield flags
BITFIELD_DOESACK = 1
#Error types
STATUS_WARNING = 0
STATUS_ERROR = 1
STATUS_FATAL = 2
eightBytesOfRandomDataUsedToDetectConnectionsToSelf = pack(
'>Q', random.randrange(1, 18446744073709551615))