From 96d58f3c1191098dc1b7a22e960dfc691e32bbfb Mon Sep 17 00:00:00 2001 From: Peter Surda Date: Sun, 16 Apr 2017 18:27:15 +0200 Subject: [PATCH] Asyncore update (WIP) --- src/bmproto.py | 159 ++++++++++++++++++++++++----- src/network/advanceddispatcher.py | 16 ++- src/network/asyncore_pollchoose.py | 42 ++++++-- src/network/http.py | 2 +- src/network/node.py | 66 ++++++++++++ src/network/socks4a.py | 24 ++--- src/network/socks5.py | 24 ++--- src/network/tls.py | 52 +++++++--- src/protocol.py | 5 + 9 files changed, 306 insertions(+), 84 deletions(-) create mode 100644 src/network/node.py diff --git a/src/bmproto.py b/src/bmproto.py index 5cd08779..c9160c3b 100644 --- a/src/bmproto.py +++ b/src/bmproto.py @@ -1,15 +1,20 @@ import hashlib import time +from pprint import pprint import socket +from struct import unpack from network.advanceddispatcher import AdvancedDispatcher +from network.node import Node import network.asyncore_pollchoose as asyncore from network.proxy import Proxy, ProxyError, GeneralProxyError from network.socks5 import Socks5Connection, Socks5Resolver, Socks5AuthError, Socks5Error from network.socks4a import Socks4aConnection, Socks4aResolver, Socks4aError +from network.tls import TLSDispatcher import addresses from bmconfigparser import BMConfigParser +import shared import protocol class BMProtoError(ProxyError): pass @@ -30,14 +35,14 @@ class BMConnection(TLSDispatcher): if address is None and sock is not None: self.destination = self.addr() self.isOutbound = False - TLSHandshake.__init__(self, sock, server_side=True) + TLSDispatcher.__init__(self, sock, server_side=True) print "received connection in background from %s:%i" % (self.destination[0], self.destination[1]) else: self.destination = address self.isOutbound = True self.create_socket(socket.AF_INET, socket.SOCK_STREAM) self.connect(self.destination) - TLSHandshake.__init__(self, sock, server_side=False) + TLSDispatcher.__init__(self, sock, server_side=False) print "connecting in background to %s:%i" % (self.destination[0], self.destination[1]) def bm_proto_reset(self): @@ -47,19 +52,22 @@ class BMConnection(TLSDispatcher): self.checksum = None self.payload = None self.invalid = False + self.payloadOffset = 0 def state_init(self): self.bm_proto_reset() - self.write_buf += protocol.assembleVersionMessage(self.destination[0], self.destination[1], (1,), False) + self.append_write_buf(protocol.assembleVersionMessage(self.destination[0], self.destination[1], (1,), False)) if True: print "Sending version (%ib)" % len(self.write_buf) - self.set_state("bm_header", 0) + self.set_state("bm_header") return False def state_bm_ready(self): + print "doing bm ready" self.sendAddr() self.sendBigInv() - return True + self.set_state("bm_header") + return False def state_bm_header(self): if len(self.read_buf) < protocol.Header.size: @@ -101,32 +109,127 @@ class BMConnection(TLSDispatcher): # else assume the command requires a different state to follow return True + def decode_payload_string(self, length): + value = self.payload[self.payloadOffset:self.payloadOffset+length] + self.payloadOffset += length + return value + + def decode_payload_varint(self): + value, offset = addresses.decodeVarint(self.payload[self.payloadOffset:]) + self.payloadOffset += offset + return value + + def decode_payload_node(self): + services, address, port = self.decode_payload_content("Q16sH") + return Node(services, address, port) + + def decode_payload_content(self, pattern = "v"): + # l = varint indicating the length of the next item + # v = varint (or array) + # H = uint16 + # I = uint32 + # Q = uint64 + # i = net_addr (without time and stream number) + # s = string + # 0-9 = length of the next item + # , = end of array + + retval = [] + size = 0 + insideDigit = False + + for i in range(len(pattern)): + if pattern[i] in "0123456789": + size = size * 10 + int(pattern[i]) + continue + elif pattern[i] == "l": + size = self.decode_payload_varint() + continue + if size > 0: + innerval = [] + if pattern[i] == "s": + retval.append(self.payload[self.payloadOffset:self.payloadOffset + size]) + self.payloadOffset += size + else: + for j in range(size): + if "," in pattern[i:]: + retval.append(self.decode_payload_content(pattern[i:pattern.index(",")])) + else: + retval.append(self.decode_payload_content(pattern[i:])) + size = 0 + else: + if pattern[i] == "v": + retval.append(self.decode_payload_varint()) + if pattern[i] == "i": + retval.append(self.decode_payload_node()) + if pattern[i] == "H": + retval.append(unpack(">H", self.payload[self.payloadOffset:self.payloadOffset+2])[0]) + self.payloadOffset += 2 + if pattern[i] == "I": + retval.append(unpack(">I", self.payload[self.payloadOffset:self.payloadOffset+4])[0]) + self.payloadOffset += 4 + if pattern[i] == "Q": + retval.append(unpack(">Q", self.payload[self.payloadOffset:self.payloadOffset+8])[0]) + self.payloadOffset += 8 + return retval + def bm_command_error(self): + fatalStatus, banTime, inventoryVector, errorText = self.decode_payload_content("vvlsls") + def bm_command_getdata(self): + items = self.decode_payload_content("l32s") + #self.antiIntersectionDelay(True) # only handle getdata requests if we have been connected long enough + for i in items: + logger.debug('received getdata request for item:' + hexlify(i)) + if self.objectHashHolderInstance.hasHash(i): + self.antiIntersectionDelay() + else: + if i in Inventory(): + self.append_write_buf(protocol.CreatePacket('object', Inventory()[i].payload)) + else: + #self.antiIntersectionDelay() + logger.warning('%s asked for an object with a getdata which is not in either our memory inventory or our SQL inventory. We probably cleaned it out after advertising it but before they got around to asking for it.' % (self.peer,)) + def bm_command_object(self): + lengthOfTimeWeShouldUseToProcessThisMessage = shared.checkAndShareObjectWithPeers(self.payload) + self.downloadQueue.task_done(calculateInventoryHash(self.payload)) + + def bm_command_addr(self): + addresses = self.decode_payload_content("lQbQ16sH") + def bm_command_ping(self): + self.append_write_buf(protocol.CreatePacket('pong')) + def bm_command_pong(self): + # nothing really + pass def bm_command_verack(self): self.verackReceived = True - return True + if self.verackSent: + if self.isSSL: + self.set_state("tls_init", self.payloadLength) + else: + self.set_state("bm_ready", self.payloadLength) + else: + self.set_state("bm_header", self.payloadLength) + self.bm_proto_reset() + return False def bm_command_version(self): - self.remoteProtocolVersion, self.services, self.timestamp, padding1, self.myExternalIP, padding2, self.remoteNodeIncomingPort = protocol.VersionPacket.unpack(self.payload[:protocol.VersionPacket.size]) + #self.remoteProtocolVersion, self.services, self.timestamp, padding1, self.myExternalIP, padding2, self.remoteNodeIncomingPort = protocol.VersionPacket.unpack(self.payload[:protocol.VersionPacket.size]) + self.remoteProtocolVersion, self.services, self.timestamp, self.sockNode, self.peerNode, self.nonce, self.userAgent, self.streams = self.decode_payload_content("IQQiiQlslv") + self.timeOffset = self.timestamp - int(time.time()) print "remoteProtocolVersion: %i" % (self.remoteProtocolVersion) print "services: %08X" % (self.services) print "time offset: %i" % (self.timestamp - int(time.time())) - print "my external IP: %s" % (socket.inet_ntoa(self.myExternalIP)) - print "remote node incoming port: %i" % (self.remoteNodeIncomingPort) - useragentLength, lengthOfUseragentVarint = addresses.decodeVarint(self.payload[80:84]) - readPosition = 80 + lengthOfUseragentVarint - self.userAgent = self.payload[readPosition:readPosition + useragentLength] - readPosition += useragentLength + print "my external IP: %s" % (self.sockNode.address) + print "remote node incoming port: %i" % (self.peerNode.port) print "user agent: %s" % (self.userAgent) if not self.peerValidityChecks(): # TODO ABORT return True - self.write_buf += protocol.CreatePacket('verack') + self.append_write_buf(protocol.CreatePacket('verack')) self.verackSent = True if ((self.services & protocol.NODE_SSL == protocol.NODE_SSL) and protocol.haveSSL(not self.isOutbound)): @@ -141,21 +244,21 @@ class BMConnection(TLSDispatcher): def peerValidityChecks(self): if self.remoteProtocolVersion < 3: - self.write_buf += protocol.assembleErrorMessage(fatal=2, - errorText="Your is using an old protocol. Closing connection.") + self.append_write_buf(protocol.assembleErrorMessage(fatal=2, + errorText="Your is using an old protocol. Closing connection.")) logger.debug ('Closing connection to old protocol version %s, node: %s', str(self.remoteProtocolVersion), str(self.peer)) return False if self.timeOffset > 3600: - self.write_buf += protocol.assembleErrorMessage(fatal=2, - errorText="Your time is too far in the future compared to mine. Closing connection.") + self.append_write_buf(protocol.assembleErrorMessage(fatal=2, + errorText="Your time is too far in the future compared to mine. Closing connection.")) logger.info("%s's time is too far in the future (%s seconds). Closing connection to it.", self.peer, self.timeOffset) shared.timeOffsetWrongCount += 1 return False elif self.timeOffset < -3600: - self.write_buf += protocol.assembleErrorMessage(fatal=2, - errorText="Your time is too far in the past compared to mine. Closing connection.") + self.append_write_buf(protocol.assembleErrorMessage(fatal=2, + errorText="Your time is too far in the past compared to mine. Closing connection.")) logger.info("%s's time is too far in the past (timeOffset %s seconds). Closing connection to it.", self.peer, self.timeOffset) shared.timeOffsetWrongCount += 1 @@ -163,8 +266,8 @@ class BMConnection(TLSDispatcher): else: shared.timeOffsetWrongCount = 0 if len(self.streams) == 0: - self.write_buf += protocol.assembleErrorMessage(fatal=2, - errorText="We don't have shared stream interests. Closing connection."))) + self.append_write_buf(protocol.assembleErrorMessage(fatal=2, + errorText="We don't have shared stream interests. Closing connection.")) logger.debug ('Closed connection to %s because there is no overlapping interest in streams.', str(self.peer)) return False @@ -174,8 +277,8 @@ class BMConnection(TLSDispatcher): def sendChunk(): if numberOfAddressesInAddrMessage == 0: return - self.write_buf += protocol.CreatePacket('addr', \ - addresses.encodeVarint(numberOfAddressesInAddrMessage) + payload))) + self.append_write_buf(protocol.CreatePacket('addr', \ + addresses.encodeVarint(numberOfAddressesInAddrMessage) + payload)) # We are going to share a maximum number of 1000 addrs (per overlapping # stream) with our peer. 500 from overlapping streams, 250 from the @@ -265,7 +368,7 @@ class BMConnection(TLSDispatcher): payload = encodeVarint(objectCount) + payload logger.debug('Sending huge inv message with %i objects to just this one peer', str(numberOfObjects)) - self.write_buf += protocol.CreatePacket('inv', payload) + self.append_write_buf(protocol.CreatePacket('inv', payload)) # Select all hashes for objects in this stream. bigInvList = {} @@ -335,15 +438,15 @@ if __name__ == "__main__": direct = BMConnection(host) while len(asyncore.socket_map) > 0: print "loop, state = %s" % (direct.state) - asyncore.loop(timeout=1, count=1) + asyncore.loop(timeout=10, count=1) continue proxy = Socks5BMConnection(host) while len(asyncore.socket_map) > 0: # print "loop, state = %s" % (proxy.state) - asyncore.loop(timeout=1, count=1) + asyncore.loop(timeout=10, count=1) proxy = Socks4aBMConnection(host) while len(asyncore.socket_map) > 0: # print "loop, state = %s" % (proxy.state) - asyncore.loop(timeout=1, count=1) + asyncore.loop(timeout=10, count=1) diff --git a/src/network/advanceddispatcher.py b/src/network/advanceddispatcher.py index df6e58ef..dc7eedb0 100644 --- a/src/network/advanceddispatcher.py +++ b/src/network/advanceddispatcher.py @@ -10,11 +10,16 @@ class AdvancedDispatcher(asyncore.dispatcher): self.write_buf = b"" self.state = "init" - def slice_read_buf(self, length=0): - self.read_buf = self.read_buf[length:] + def append_write_buf(self, string = None): + self.write_buf += string def slice_write_buf(self, length=0): - self.write_buf = self.read_buf[length:] + if length > 0: + self.write_buf = self.write_buf[length:] + + def slice_read_buf(self, length=0): + if length > 0: + self.read_buf = self.read_buf[length:] def read_buf_sufficient(self, length=0): if len(self.read_buf) < length: @@ -23,7 +28,7 @@ class AdvancedDispatcher(asyncore.dispatcher): return True def process(self): - if self.state != "init" and len(self.read_buf) == 0: + if self.state not in ["init", "tls_handshake"] and len(self.read_buf) == 0: return while True: try: @@ -34,7 +39,7 @@ class AdvancedDispatcher(asyncore.dispatcher): # missing state raise - def set_state(self, state, length): + def set_state(self, state, length=0): self.slice_read_buf(length) self.state = state @@ -45,6 +50,7 @@ class AdvancedDispatcher(asyncore.dispatcher): return self.connecting or len(self.read_buf) < AdvancedDispatcher._buf_len def handle_read(self): + print "handle_read" self.read_buf += self.recv(AdvancedDispatcher._buf_len) self.process() diff --git a/src/network/asyncore_pollchoose.py b/src/network/asyncore_pollchoose.py index 7fa19f4a..4ccce7f9 100644 --- a/src/network/asyncore_pollchoose.py +++ b/src/network/asyncore_pollchoose.py @@ -60,6 +60,9 @@ from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, ECONNRESET, EINVAL, \ _DISCONNECTED = frozenset((ECONNRESET, ENOTCONN, ESHUTDOWN, ECONNABORTED, EPIPE, EBADF)) +OP_READ = 1 +OP_WRITE = 2 + try: socket_map except NameError: @@ -178,17 +181,25 @@ def poll_poller(timeout=0.0, map=None): poll_poller.pollster = select.poll() if map: for fd, obj in list(map.items()): - flags = 0 + flags = newflags = 0 if obj.readable(): flags |= select.POLLIN | select.POLLPRI + newflags |= OP_READ + else: + newflags &= ~ OP_READ # accepting sockets should not be writable if obj.writable() and not obj.accepting: flags |= select.POLLOUT - if flags: - try: + newflags |= OP_WRITE + else: + newflags &= ~ OP_WRITE + if newflags != obj.flags: + obj.flags = newflags + if obj.poller_registered: poll_poller.pollster.modify(fd, flags) - except IOError: + else: poll_poller.pollster.register(fd, flags) + obj.poller_registered = True try: r = poll_poller.pollster.poll(timeout) except KeyboardInterrupt: @@ -213,19 +224,28 @@ def epoll_poller(timeout=0.0, map=None): epoll_poller.pollster = select.epoll() if map: for fd, obj in map.items(): - flags = 0 + flags = newflags = 0 if obj.readable(): flags |= select.POLLIN | select.POLLPRI - if obj.writable(): + newflags |= OP_READ + else: + newflags &= ~ OP_READ + # accepting sockets should not be writable + if obj.writable() and not obj.accepting: flags |= select.POLLOUT - if flags: + newflags |= OP_WRITE + else: + newflags &= ~ OP_WRITE + if newflags != obj.flags: + obj.flags = newflags # Only check for exceptions if object was either readable # or writable. flags |= select.POLLERR | select.POLLHUP | select.POLLNVAL - try: - epoll_poller.pollster.register(fd, flags) - except IOError: + if obj.poller_registered: epoll_poller.pollster.modify(fd, flags) + else: + epoll_poller.pollster.register(fd, flags) + obj.poller_registered = True try: r = epoll_poller.pollster.poll(timeout) except select.error, err: @@ -306,6 +326,8 @@ class dispatcher: closing = False addr = None ignore_log_types = frozenset(['warning']) + poller_registered = False + flags = 0 def __init__(self, sock=None, map=None): if map is None: diff --git a/src/network/http.py b/src/network/http.py index 93828c83..55cb81a1 100644 --- a/src/network/http.py +++ b/src/network/http.py @@ -19,7 +19,7 @@ class HttpConnection(AdvancedDispatcher): print "connecting in background to %s:%i" % (self.destination[0], self.destination[1]) def state_init(self): - self.write_buf += "GET %s HTTP/1.1\r\nHost: %s\r\nConnection: close\r\n\r\n" % (self.path, self.destination[0]) + self.append_write_buf("GET %s HTTP/1.1\r\nHost: %s\r\nConnection: close\r\n\r\n" % (self.path, self.destination[0])) print "Sending %ib" % (len(self.write_buf)) self.set_state("http_request_sent", 0) return False diff --git a/src/network/node.py b/src/network/node.py new file mode 100644 index 00000000..904ff4d1 --- /dev/null +++ b/src/network/node.py @@ -0,0 +1,66 @@ +import socket +import protocol + +class Node (object): + TYPE_IPV4 = 1 + TYPE_IPV6 = 2 + TYPE_ONION = 3 + TYPE_LOCAL = 4 + TYPE_LOOPBACK = 8 + TYPE_UNDEF = 12 + + def __init__(self, services, address, port): + self.services = services + self.address, self.addressType = Node.decodeIPAddress(address) + self.port = port + + def isLocal(self): + return self.addressType | Node.TYPE_LOCAL > 0 + + def isGlobal(self): + return self.addressType <= Node.TYPE_ONION + + def isOnion(self): + return self.addressType | Node.TYPE_ONION > 0 + + def isLoopback(self): + return self.addressType | Node.TYPE_LOOPBACK > 0 + + @staticmethod + def decodeIPAddress(host): + if host[0:12] == '\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xFF\xFF': + hostStandardFormat = socket.inet_ntop(socket.AF_INET, host[12:]) + return Node.decodeIPv4Address(host[12:], hostStandardFormat) + elif host[0:6] == '\xfd\x87\xd8\x7e\xeb\x43': + # Onion, based on BMD/bitcoind + hostStandardFormat = base64.b32encode(host[6:]).lower() + ".onion" + return hostStandardFormat, Node.TYPE_ONION + else: + hostStandardFormat = socket.inet_ntop(socket.AF_INET6, host) + if hostStandardFormat == "": + # This can happen on Windows systems which are not 64-bit compatible + # so let us drop the IPv6 address. + return hostStandardFormat, Node.TYPE_IPV6|Node.TYPE_UNDEF + return Node.decodeIPv6Address(host, hostStandardFormat) + + @staticmethod + def decodeIPv4Address(host, hostStandardFormat): + if host[0] == '\x7F': # 127/8 + return hostStandardFormat, Node.TYPE_IPV4|Node.TYPE_LOOPBACK + if host[0] == '\x0A': # 10/8 + return hostStandardFormat, Node.TYPE_IPV4|Node.TYPE_LOCAL + if host[0:2] == '\xC0\xA8': # 192.168/16 + return hostStandardFormat, Node.TYPE_IPV4|Node.TYPE_LOCAL + if host[0:2] >= '\xAC\x10' and host[0:2] < '\xAC\x20': # 172.16/12 + return hostStandardFormat, Node.TYPE_IPV4|Node.TYPE_LOCAL + return hostStandardFormat, Node.TYPE_IPV4 + + @staticmethod + def _checkIPv6Address(host, hostStandardFormat): + if host == ('\x00' * 15) + '\x01': + return hostStandardFormat, Node.TYPE_IPV6|Node.TYPE_LOOPBACK + if host[0] == '\xFE' and (ord(host[1]) & 0xc0) == 0x80: + return hostStandardFormat, Node.TYPE_IPV6|Node.TYPE_LOCAL + if (ord(host[0]) & 0xfe) == 0xfc: + return hostStandardFormat, Node.TYPE_IPV6|Node.TYPE_UNDEF + return hostStandardFormat, Node.TYPE_IPV6 diff --git a/src/network/socks4a.py b/src/network/socks4a.py index 091e09a5..02c8d4af 100644 --- a/src/network/socks4a.py +++ b/src/network/socks4a.py @@ -59,28 +59,28 @@ class Socks4aConnection(Socks4a): def state_auth_done(self): # Now we can request the actual connection rmtrslv = False - self.write_buf += struct.pack('>BBH', 0x04, 0x01, self.destination[1]) + self.append_write_buf(struct.pack('>BBH', 0x04, 0x01, self.destination[1])) # If the given destination address is an IP address, we'll # use the IPv4 address request even if remote resolving was specified. try: self.ipaddr = socket.inet_aton(self.destination[0]) - self.write_buf += ipaddr + self.append_write_buf(self.ipaddr) except socket.error: # Well it's not an IP number, so it's probably a DNS name. if Proxy._remote_dns: # Resolve remotely rmtrslv = True self.ipaddr = None - self.write_buf += struct.pack("BBBB", 0x00, 0x00, 0x00, 0x01) + self.append_write_buf(struct.pack("BBBB", 0x00, 0x00, 0x00, 0x01)) else: # Resolve locally self.ipaddr = socket.inet_aton(socket.gethostbyname(self.destination[0])) - self.write_buf += self.ipaddr + self.append_write_buf(self.ipaddr) if self._auth: - self.write_buf += self._auth[0] - self.write_buf += chr(0x00).encode() + self.append_write_buf(self._auth[0]) + self.append_write_buf(chr(0x00).encode()) if rmtrslv: - self.write_buf += self.destination[0] + chr(0x00).encode() + self.append_write_buf(self.destination[0] + chr(0x00).encode()) self.set_state("pre_connect", 0) @@ -92,12 +92,12 @@ class Socks4aResolver(Socks4a): def state_auth_done(self): # Now we can request the actual connection - self.write_buf += struct.pack('>BBH', 0x04, 0xF0, self.destination[1]) - self.write_buf += struct.pack("BBBB", 0x00, 0x00, 0x00, 0x01) + self.append_write_buf(struct.pack('>BBH', 0x04, 0xF0, self.destination[1])) + self.append_write_buf(struct.pack("BBBB", 0x00, 0x00, 0x00, 0x01)) if self._auth: - self.write_buf += self._auth[0] - self.write_buf += chr(0x00).encode() - self.write_buf += self.host + chr(0x00).encode() + self.append_write_buf(self._auth[0]) + self.append_write_buf(chr(0x00).encode()) + self.append_write_buf(self.host + chr(0x00).encode()) self.set_state("pre_connect", 0) def resolved(self): diff --git a/src/network/socks5.py b/src/network/socks5.py index 841c253b..5ba6f3e3 100644 --- a/src/network/socks5.py +++ b/src/network/socks5.py @@ -17,9 +17,9 @@ class Socks5(Proxy): def state_init(self): if self._auth: - self.write_buf += struct.pack('BBBB', 0x05, 0x02, 0x00, 0x02) + self.append_write_buf(struct.pack('BBBB', 0x05, 0x02, 0x00, 0x02)) else: - self.write_buf += struct.pack('BBB', 0x05, 0x01, 0x00) + self.append_write_buf(struct.pack('BBB', 0x05, 0x01, 0x00)) self.set_state("auth_1", 0) def state_auth_1(self): @@ -35,9 +35,9 @@ class Socks5(Proxy): self.set_state("auth_done", 2) elif ret[1] == 2: # username/password - self.write_buf += struct.pack('BB', 1, len(self._auth[0])) + \ + self.append_write_buf(struct.pack('BB', 1, len(self._auth[0])) + \ self._auth[0] + struct.pack('B', len(self._auth[1])) + \ - self._auth[1] + self._auth[1]) self.set_state("auth_1", 2) else: if ret[1] == 0xff: @@ -130,23 +130,23 @@ class Socks5Connection(Socks5): def state_auth_done(self): # Now we can request the actual connection - self.write_buf += struct.pack('BBB', 0x05, 0x01, 0x00) + self.append_write_buf(struct.pack('BBB', 0x05, 0x01, 0x00)) # If the given destination address is an IP address, we'll # use the IPv4 address request even if remote resolving was specified. try: self.ipaddr = socket.inet_aton(self.destination[0]) - self.write_buf += chr(0x01).encode() + self.ipaddr + self.append_write_buf(chr(0x01).encode() + self.ipaddr) except socket.error: # Well it's not an IP number, so it's probably a DNS name. if Proxy._remote_dns: # Resolve remotely self.ipaddr = None - self.write_buf += chr(0x03).encode() + chr(len(self.destination[0])).encode() + self.destination[0] + self.append_write_buf(chr(0x03).encode() + chr(len(self.destination[0])).encode() + self.destination[0]) else: # Resolve locally self.ipaddr = socket.inet_aton(socket.gethostbyname(self.destination[0])) - self.write_buf += chr(0x01).encode() + self.ipaddr - self.write_buf += struct.pack(">H", self.destination[1]) + self.append_write_buf(chr(0x01).encode() + self.ipaddr) + self.append_write_buf(struct.pack(">H", self.destination[1])) self.set_state("pre_connect", 0) @@ -158,9 +158,9 @@ class Socks5Resolver(Socks5): def state_auth_done(self): # Now we can request the actual connection - self.write_buf += struct.pack('BBB', 0x05, 0xF0, 0x00) - self.write_buf += chr(0x03).encode() + chr(len(self.host)).encode() + str(self.host) - self.write_buf += struct.pack(">H", self.port) + self.append_write_buf(struct.pack('BBB', 0x05, 0xF0, 0x00)) + self.append_write_buf(chr(0x03).encode() + chr(len(self.host)).encode() + str(self.host)) + self.append_write_buf(struct.pack(">H", self.port)) self.set_state("pre_connect", 0) def resolved(self): diff --git a/src/network/tls.py b/src/network/tls.py index 023f6cac..c7554891 100644 --- a/src/network/tls.py +++ b/src/network/tls.py @@ -2,12 +2,14 @@ SSL/TLS negotiation. """ -from network.advanceddispatcher import AdvancedDispatcher -import network.asyncore_pollchoose as asyncore +import os import socket import ssl import sys +from network.advanceddispatcher import AdvancedDispatcher +import network.asyncore_pollchoose as asyncore +import paths import protocol class TLSDispatcher(AdvancedDispatcher): @@ -30,6 +32,7 @@ class TLSDispatcher(AdvancedDispatcher): def state_tls_init(self): self.isSSL = True + self.tlsStarted = True # Once the connection has been established, it's safe to wrap the # socket. if sys.version_info >= (2,7,9): @@ -40,7 +43,7 @@ class TLSDispatcher(AdvancedDispatcher): context.verify_mode = ssl.CERT_NONE # also exclude TLSv1 and TLSv1.1 in the future context.options = ssl.OP_ALL | ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | ssl.OP_SINGLE_ECDH_USE | ssl.OP_CIPHER_SERVER_PREFERENCE - self.sslSocket = context.wrap_socket(self.sock, server_side = self.server_side, do_handshake_on_connect=False) + self.sslSocket = context.wrap_socket(self.socket, server_side = self.server_side, do_handshake_on_connect=False) else: self.sslSocket = ssl.wrap_socket(self.socket, server_side=self.server_side, @@ -51,49 +54,66 @@ class TLSDispatcher(AdvancedDispatcher): do_handshake_on_connect=False) self.sslSocket.setblocking(0) self.want_read = self.want_write = True + self.set_state("tls_handshake") # if hasattr(self.socket, "context"): # self.socket.context.set_ecdh_curve("secp256k1") def writable(self): - if self.tlsStarted and not self.tlsDone: + if self.tlsStarted and not self.tlsDone and len(self.write_buf) == 0: + print "tls writable, %r" % (self.want_write) return self.want_write else: - return AdvancedDispacher.writable(self) + return AdvancedDispatcher.writable(self) def readable(self): - if self.tlsStarted and not self.tlsDone: + if self.tlsStarted and not self.tlsDone and len(self.write_buf) == 0: + print "tls readable, %r" % (self.want_read) return self.want_read else: - return AdvancedDispacher.readable(self) + return AdvancedDispatcher.readable(self) def handle_read(self): - if self.tlsStarted and not self.tlsDone: - self._handshake() + # wait for write buffer flush + if self.tlsStarted and not self.tlsDone and len(self.write_buf) == 0: + print "handshaking (read)" + self.state_tls_handshake() else: - return AdvancedDispacher.handle_read(self) + print "not handshaking (read)" + return AdvancedDispatcher.handle_read(self) def handle_write(self): - if self.tlsStarted and not not self.tlsDone: - self._handshake() + # wait for write buffer flush + if self.tlsStarted and not self.tlsDone and len(self.write_buf) == 0: + print "handshaking (write)" + self.state_tls_handshake() else: - return AdvancedDispacher.handle_write(self) + print "not handshaking (write)" + return AdvancedDispatcher.handle_write(self) def state_tls_handshake(self): - """ - Perform the handshake. - """ + # wait for flush + if len(self.write_buf) > 0: + return False + # Perform the handshake. try: + print "handshaking (internal)" self.sslSocket.do_handshake() except ssl.SSLError, err: + print "handshake fail" self.want_read = self.want_write = False if err.args[0] == ssl.SSL_ERROR_WANT_READ: + print "want read" self.want_read = True elif err.args[0] == ssl.SSL_ERROR_WANT_WRITE: + print "want write" self.want_write = True else: raise else: + print "handshake success" # The handshake has completed, so remove this channel and... self.del_channel() self.set_socket(self.sslSocket) self.tlsDone = True + self.state_bm_ready() + return False diff --git a/src/protocol.py b/src/protocol.py index 9698f917..9397cd8b 100644 --- a/src/protocol.py +++ b/src/protocol.py @@ -27,6 +27,11 @@ NODE_SSL = 2 #Bitfield flags BITFIELD_DOESACK = 1 +#Error types +STATUS_WARNING = 0 +STATUS_ERROR = 1 +STATUS_FATAL = 2 + eightBytesOfRandomDataUsedToDetectConnectionsToSelf = pack( '>Q', random.randrange(1, 18446744073709551615))