Asyncore update (WIP)

This commit is contained in:
Peter Šurda 2017-04-16 18:27:15 +02:00
parent e6f0b34f9b
commit 96d58f3c11
Signed by: PeterSurda
GPG Key ID: 0C5F50C0B5F37D87
9 changed files with 306 additions and 84 deletions

View File

@ -1,15 +1,20 @@
import hashlib import hashlib
import time import time
from pprint import pprint
import socket import socket
from struct import unpack
from network.advanceddispatcher import AdvancedDispatcher from network.advanceddispatcher import AdvancedDispatcher
from network.node import Node
import network.asyncore_pollchoose as asyncore import network.asyncore_pollchoose as asyncore
from network.proxy import Proxy, ProxyError, GeneralProxyError from network.proxy import Proxy, ProxyError, GeneralProxyError
from network.socks5 import Socks5Connection, Socks5Resolver, Socks5AuthError, Socks5Error from network.socks5 import Socks5Connection, Socks5Resolver, Socks5AuthError, Socks5Error
from network.socks4a import Socks4aConnection, Socks4aResolver, Socks4aError from network.socks4a import Socks4aConnection, Socks4aResolver, Socks4aError
from network.tls import TLSDispatcher
import addresses import addresses
from bmconfigparser import BMConfigParser from bmconfigparser import BMConfigParser
import shared
import protocol import protocol
class BMProtoError(ProxyError): pass class BMProtoError(ProxyError): pass
@ -30,14 +35,14 @@ class BMConnection(TLSDispatcher):
if address is None and sock is not None: if address is None and sock is not None:
self.destination = self.addr() self.destination = self.addr()
self.isOutbound = False self.isOutbound = False
TLSHandshake.__init__(self, sock, server_side=True) TLSDispatcher.__init__(self, sock, server_side=True)
print "received connection in background from %s:%i" % (self.destination[0], self.destination[1]) print "received connection in background from %s:%i" % (self.destination[0], self.destination[1])
else: else:
self.destination = address self.destination = address
self.isOutbound = True self.isOutbound = True
self.create_socket(socket.AF_INET, socket.SOCK_STREAM) self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
self.connect(self.destination) self.connect(self.destination)
TLSHandshake.__init__(self, sock, server_side=False) TLSDispatcher.__init__(self, sock, server_side=False)
print "connecting in background to %s:%i" % (self.destination[0], self.destination[1]) print "connecting in background to %s:%i" % (self.destination[0], self.destination[1])
def bm_proto_reset(self): def bm_proto_reset(self):
@ -47,19 +52,22 @@ class BMConnection(TLSDispatcher):
self.checksum = None self.checksum = None
self.payload = None self.payload = None
self.invalid = False self.invalid = False
self.payloadOffset = 0
def state_init(self): def state_init(self):
self.bm_proto_reset() self.bm_proto_reset()
self.write_buf += protocol.assembleVersionMessage(self.destination[0], self.destination[1], (1,), False) self.append_write_buf(protocol.assembleVersionMessage(self.destination[0], self.destination[1], (1,), False))
if True: if True:
print "Sending version (%ib)" % len(self.write_buf) print "Sending version (%ib)" % len(self.write_buf)
self.set_state("bm_header", 0) self.set_state("bm_header")
return False return False
def state_bm_ready(self): def state_bm_ready(self):
print "doing bm ready"
self.sendAddr() self.sendAddr()
self.sendBigInv() self.sendBigInv()
return True self.set_state("bm_header")
return False
def state_bm_header(self): def state_bm_header(self):
if len(self.read_buf) < protocol.Header.size: if len(self.read_buf) < protocol.Header.size:
@ -101,32 +109,127 @@ class BMConnection(TLSDispatcher):
# else assume the command requires a different state to follow # else assume the command requires a different state to follow
return True return True
def decode_payload_string(self, length):
value = self.payload[self.payloadOffset:self.payloadOffset+length]
self.payloadOffset += length
return value
def decode_payload_varint(self):
value, offset = addresses.decodeVarint(self.payload[self.payloadOffset:])
self.payloadOffset += offset
return value
def decode_payload_node(self):
services, address, port = self.decode_payload_content("Q16sH")
return Node(services, address, port)
def decode_payload_content(self, pattern = "v"):
# l = varint indicating the length of the next item
# v = varint (or array)
# H = uint16
# I = uint32
# Q = uint64
# i = net_addr (without time and stream number)
# s = string
# 0-9 = length of the next item
# , = end of array
retval = []
size = 0
insideDigit = False
for i in range(len(pattern)):
if pattern[i] in "0123456789":
size = size * 10 + int(pattern[i])
continue
elif pattern[i] == "l":
size = self.decode_payload_varint()
continue
if size > 0:
innerval = []
if pattern[i] == "s":
retval.append(self.payload[self.payloadOffset:self.payloadOffset + size])
self.payloadOffset += size
else:
for j in range(size):
if "," in pattern[i:]:
retval.append(self.decode_payload_content(pattern[i:pattern.index(",")]))
else:
retval.append(self.decode_payload_content(pattern[i:]))
size = 0
else:
if pattern[i] == "v":
retval.append(self.decode_payload_varint())
if pattern[i] == "i":
retval.append(self.decode_payload_node())
if pattern[i] == "H":
retval.append(unpack(">H", self.payload[self.payloadOffset:self.payloadOffset+2])[0])
self.payloadOffset += 2
if pattern[i] == "I":
retval.append(unpack(">I", self.payload[self.payloadOffset:self.payloadOffset+4])[0])
self.payloadOffset += 4
if pattern[i] == "Q":
retval.append(unpack(">Q", self.payload[self.payloadOffset:self.payloadOffset+8])[0])
self.payloadOffset += 8
return retval
def bm_command_error(self): def bm_command_error(self):
fatalStatus, banTime, inventoryVector, errorText = self.decode_payload_content("vvlsls")
def bm_command_getdata(self): def bm_command_getdata(self):
items = self.decode_payload_content("l32s")
#self.antiIntersectionDelay(True) # only handle getdata requests if we have been connected long enough
for i in items:
logger.debug('received getdata request for item:' + hexlify(i))
if self.objectHashHolderInstance.hasHash(i):
self.antiIntersectionDelay()
else:
if i in Inventory():
self.append_write_buf(protocol.CreatePacket('object', Inventory()[i].payload))
else:
#self.antiIntersectionDelay()
logger.warning('%s asked for an object with a getdata which is not in either our memory inventory or our SQL inventory. We probably cleaned it out after advertising it but before they got around to asking for it.' % (self.peer,))
def bm_command_object(self): def bm_command_object(self):
lengthOfTimeWeShouldUseToProcessThisMessage = shared.checkAndShareObjectWithPeers(self.payload)
self.downloadQueue.task_done(calculateInventoryHash(self.payload))
def bm_command_addr(self):
addresses = self.decode_payload_content("lQbQ16sH")
def bm_command_ping(self): def bm_command_ping(self):
self.append_write_buf(protocol.CreatePacket('pong'))
def bm_command_pong(self): def bm_command_pong(self):
# nothing really
pass
def bm_command_verack(self): def bm_command_verack(self):
self.verackReceived = True self.verackReceived = True
return True if self.verackSent:
if self.isSSL:
self.set_state("tls_init", self.payloadLength)
else:
self.set_state("bm_ready", self.payloadLength)
else:
self.set_state("bm_header", self.payloadLength)
self.bm_proto_reset()
return False
def bm_command_version(self): def bm_command_version(self):
self.remoteProtocolVersion, self.services, self.timestamp, padding1, self.myExternalIP, padding2, self.remoteNodeIncomingPort = protocol.VersionPacket.unpack(self.payload[:protocol.VersionPacket.size]) #self.remoteProtocolVersion, self.services, self.timestamp, padding1, self.myExternalIP, padding2, self.remoteNodeIncomingPort = protocol.VersionPacket.unpack(self.payload[:protocol.VersionPacket.size])
self.remoteProtocolVersion, self.services, self.timestamp, self.sockNode, self.peerNode, self.nonce, self.userAgent, self.streams = self.decode_payload_content("IQQiiQlslv")
self.timeOffset = self.timestamp - int(time.time())
print "remoteProtocolVersion: %i" % (self.remoteProtocolVersion) print "remoteProtocolVersion: %i" % (self.remoteProtocolVersion)
print "services: %08X" % (self.services) print "services: %08X" % (self.services)
print "time offset: %i" % (self.timestamp - int(time.time())) print "time offset: %i" % (self.timestamp - int(time.time()))
print "my external IP: %s" % (socket.inet_ntoa(self.myExternalIP)) print "my external IP: %s" % (self.sockNode.address)
print "remote node incoming port: %i" % (self.remoteNodeIncomingPort) print "remote node incoming port: %i" % (self.peerNode.port)
useragentLength, lengthOfUseragentVarint = addresses.decodeVarint(self.payload[80:84])
readPosition = 80 + lengthOfUseragentVarint
self.userAgent = self.payload[readPosition:readPosition + useragentLength]
readPosition += useragentLength
print "user agent: %s" % (self.userAgent) print "user agent: %s" % (self.userAgent)
if not self.peerValidityChecks(): if not self.peerValidityChecks():
# TODO ABORT # TODO ABORT
return True return True
self.write_buf += protocol.CreatePacket('verack') self.append_write_buf(protocol.CreatePacket('verack'))
self.verackSent = True self.verackSent = True
if ((self.services & protocol.NODE_SSL == protocol.NODE_SSL) and if ((self.services & protocol.NODE_SSL == protocol.NODE_SSL) and
protocol.haveSSL(not self.isOutbound)): protocol.haveSSL(not self.isOutbound)):
@ -141,21 +244,21 @@ class BMConnection(TLSDispatcher):
def peerValidityChecks(self): def peerValidityChecks(self):
if self.remoteProtocolVersion < 3: if self.remoteProtocolVersion < 3:
self.write_buf += protocol.assembleErrorMessage(fatal=2, self.append_write_buf(protocol.assembleErrorMessage(fatal=2,
errorText="Your is using an old protocol. Closing connection.") errorText="Your is using an old protocol. Closing connection."))
logger.debug ('Closing connection to old protocol version %s, node: %s', logger.debug ('Closing connection to old protocol version %s, node: %s',
str(self.remoteProtocolVersion), str(self.peer)) str(self.remoteProtocolVersion), str(self.peer))
return False return False
if self.timeOffset > 3600: if self.timeOffset > 3600:
self.write_buf += protocol.assembleErrorMessage(fatal=2, self.append_write_buf(protocol.assembleErrorMessage(fatal=2,
errorText="Your time is too far in the future compared to mine. Closing connection.") errorText="Your time is too far in the future compared to mine. Closing connection."))
logger.info("%s's time is too far in the future (%s seconds). Closing connection to it.", logger.info("%s's time is too far in the future (%s seconds). Closing connection to it.",
self.peer, self.timeOffset) self.peer, self.timeOffset)
shared.timeOffsetWrongCount += 1 shared.timeOffsetWrongCount += 1
return False return False
elif self.timeOffset < -3600: elif self.timeOffset < -3600:
self.write_buf += protocol.assembleErrorMessage(fatal=2, self.append_write_buf(protocol.assembleErrorMessage(fatal=2,
errorText="Your time is too far in the past compared to mine. Closing connection.") errorText="Your time is too far in the past compared to mine. Closing connection."))
logger.info("%s's time is too far in the past (timeOffset %s seconds). Closing connection to it.", logger.info("%s's time is too far in the past (timeOffset %s seconds). Closing connection to it.",
self.peer, self.timeOffset) self.peer, self.timeOffset)
shared.timeOffsetWrongCount += 1 shared.timeOffsetWrongCount += 1
@ -163,8 +266,8 @@ class BMConnection(TLSDispatcher):
else: else:
shared.timeOffsetWrongCount = 0 shared.timeOffsetWrongCount = 0
if len(self.streams) == 0: if len(self.streams) == 0:
self.write_buf += protocol.assembleErrorMessage(fatal=2, self.append_write_buf(protocol.assembleErrorMessage(fatal=2,
errorText="We don't have shared stream interests. Closing connection."))) errorText="We don't have shared stream interests. Closing connection."))
logger.debug ('Closed connection to %s because there is no overlapping interest in streams.', logger.debug ('Closed connection to %s because there is no overlapping interest in streams.',
str(self.peer)) str(self.peer))
return False return False
@ -174,8 +277,8 @@ class BMConnection(TLSDispatcher):
def sendChunk(): def sendChunk():
if numberOfAddressesInAddrMessage == 0: if numberOfAddressesInAddrMessage == 0:
return return
self.write_buf += protocol.CreatePacket('addr', \ self.append_write_buf(protocol.CreatePacket('addr', \
addresses.encodeVarint(numberOfAddressesInAddrMessage) + payload))) addresses.encodeVarint(numberOfAddressesInAddrMessage) + payload))
# We are going to share a maximum number of 1000 addrs (per overlapping # We are going to share a maximum number of 1000 addrs (per overlapping
# stream) with our peer. 500 from overlapping streams, 250 from the # stream) with our peer. 500 from overlapping streams, 250 from the
@ -265,7 +368,7 @@ class BMConnection(TLSDispatcher):
payload = encodeVarint(objectCount) + payload payload = encodeVarint(objectCount) + payload
logger.debug('Sending huge inv message with %i objects to just this one peer', logger.debug('Sending huge inv message with %i objects to just this one peer',
str(numberOfObjects)) str(numberOfObjects))
self.write_buf += protocol.CreatePacket('inv', payload) self.append_write_buf(protocol.CreatePacket('inv', payload))
# Select all hashes for objects in this stream. # Select all hashes for objects in this stream.
bigInvList = {} bigInvList = {}
@ -335,15 +438,15 @@ if __name__ == "__main__":
direct = BMConnection(host) direct = BMConnection(host)
while len(asyncore.socket_map) > 0: while len(asyncore.socket_map) > 0:
print "loop, state = %s" % (direct.state) print "loop, state = %s" % (direct.state)
asyncore.loop(timeout=1, count=1) asyncore.loop(timeout=10, count=1)
continue continue
proxy = Socks5BMConnection(host) proxy = Socks5BMConnection(host)
while len(asyncore.socket_map) > 0: while len(asyncore.socket_map) > 0:
# print "loop, state = %s" % (proxy.state) # print "loop, state = %s" % (proxy.state)
asyncore.loop(timeout=1, count=1) asyncore.loop(timeout=10, count=1)
proxy = Socks4aBMConnection(host) proxy = Socks4aBMConnection(host)
while len(asyncore.socket_map) > 0: while len(asyncore.socket_map) > 0:
# print "loop, state = %s" % (proxy.state) # print "loop, state = %s" % (proxy.state)
asyncore.loop(timeout=1, count=1) asyncore.loop(timeout=10, count=1)

View File

@ -10,11 +10,16 @@ class AdvancedDispatcher(asyncore.dispatcher):
self.write_buf = b"" self.write_buf = b""
self.state = "init" self.state = "init"
def slice_read_buf(self, length=0): def append_write_buf(self, string = None):
self.read_buf = self.read_buf[length:] self.write_buf += string
def slice_write_buf(self, length=0): def slice_write_buf(self, length=0):
self.write_buf = self.read_buf[length:] if length > 0:
self.write_buf = self.write_buf[length:]
def slice_read_buf(self, length=0):
if length > 0:
self.read_buf = self.read_buf[length:]
def read_buf_sufficient(self, length=0): def read_buf_sufficient(self, length=0):
if len(self.read_buf) < length: if len(self.read_buf) < length:
@ -23,7 +28,7 @@ class AdvancedDispatcher(asyncore.dispatcher):
return True return True
def process(self): def process(self):
if self.state != "init" and len(self.read_buf) == 0: if self.state not in ["init", "tls_handshake"] and len(self.read_buf) == 0:
return return
while True: while True:
try: try:
@ -34,7 +39,7 @@ class AdvancedDispatcher(asyncore.dispatcher):
# missing state # missing state
raise raise
def set_state(self, state, length): def set_state(self, state, length=0):
self.slice_read_buf(length) self.slice_read_buf(length)
self.state = state self.state = state
@ -45,6 +50,7 @@ class AdvancedDispatcher(asyncore.dispatcher):
return self.connecting or len(self.read_buf) < AdvancedDispatcher._buf_len return self.connecting or len(self.read_buf) < AdvancedDispatcher._buf_len
def handle_read(self): def handle_read(self):
print "handle_read"
self.read_buf += self.recv(AdvancedDispatcher._buf_len) self.read_buf += self.recv(AdvancedDispatcher._buf_len)
self.process() self.process()

View File

@ -60,6 +60,9 @@ from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, ECONNRESET, EINVAL, \
_DISCONNECTED = frozenset((ECONNRESET, ENOTCONN, ESHUTDOWN, ECONNABORTED, EPIPE, _DISCONNECTED = frozenset((ECONNRESET, ENOTCONN, ESHUTDOWN, ECONNABORTED, EPIPE,
EBADF)) EBADF))
OP_READ = 1
OP_WRITE = 2
try: try:
socket_map socket_map
except NameError: except NameError:
@ -178,17 +181,25 @@ def poll_poller(timeout=0.0, map=None):
poll_poller.pollster = select.poll() poll_poller.pollster = select.poll()
if map: if map:
for fd, obj in list(map.items()): for fd, obj in list(map.items()):
flags = 0 flags = newflags = 0
if obj.readable(): if obj.readable():
flags |= select.POLLIN | select.POLLPRI flags |= select.POLLIN | select.POLLPRI
newflags |= OP_READ
else:
newflags &= ~ OP_READ
# accepting sockets should not be writable # accepting sockets should not be writable
if obj.writable() and not obj.accepting: if obj.writable() and not obj.accepting:
flags |= select.POLLOUT flags |= select.POLLOUT
if flags: newflags |= OP_WRITE
try: else:
newflags &= ~ OP_WRITE
if newflags != obj.flags:
obj.flags = newflags
if obj.poller_registered:
poll_poller.pollster.modify(fd, flags) poll_poller.pollster.modify(fd, flags)
except IOError: else:
poll_poller.pollster.register(fd, flags) poll_poller.pollster.register(fd, flags)
obj.poller_registered = True
try: try:
r = poll_poller.pollster.poll(timeout) r = poll_poller.pollster.poll(timeout)
except KeyboardInterrupt: except KeyboardInterrupt:
@ -213,19 +224,28 @@ def epoll_poller(timeout=0.0, map=None):
epoll_poller.pollster = select.epoll() epoll_poller.pollster = select.epoll()
if map: if map:
for fd, obj in map.items(): for fd, obj in map.items():
flags = 0 flags = newflags = 0
if obj.readable(): if obj.readable():
flags |= select.POLLIN | select.POLLPRI flags |= select.POLLIN | select.POLLPRI
if obj.writable(): newflags |= OP_READ
else:
newflags &= ~ OP_READ
# accepting sockets should not be writable
if obj.writable() and not obj.accepting:
flags |= select.POLLOUT flags |= select.POLLOUT
if flags: newflags |= OP_WRITE
else:
newflags &= ~ OP_WRITE
if newflags != obj.flags:
obj.flags = newflags
# Only check for exceptions if object was either readable # Only check for exceptions if object was either readable
# or writable. # or writable.
flags |= select.POLLERR | select.POLLHUP | select.POLLNVAL flags |= select.POLLERR | select.POLLHUP | select.POLLNVAL
try: if obj.poller_registered:
epoll_poller.pollster.register(fd, flags)
except IOError:
epoll_poller.pollster.modify(fd, flags) epoll_poller.pollster.modify(fd, flags)
else:
epoll_poller.pollster.register(fd, flags)
obj.poller_registered = True
try: try:
r = epoll_poller.pollster.poll(timeout) r = epoll_poller.pollster.poll(timeout)
except select.error, err: except select.error, err:
@ -306,6 +326,8 @@ class dispatcher:
closing = False closing = False
addr = None addr = None
ignore_log_types = frozenset(['warning']) ignore_log_types = frozenset(['warning'])
poller_registered = False
flags = 0
def __init__(self, sock=None, map=None): def __init__(self, sock=None, map=None):
if map is None: if map is None:

View File

@ -19,7 +19,7 @@ class HttpConnection(AdvancedDispatcher):
print "connecting in background to %s:%i" % (self.destination[0], self.destination[1]) print "connecting in background to %s:%i" % (self.destination[0], self.destination[1])
def state_init(self): def state_init(self):
self.write_buf += "GET %s HTTP/1.1\r\nHost: %s\r\nConnection: close\r\n\r\n" % (self.path, self.destination[0]) self.append_write_buf("GET %s HTTP/1.1\r\nHost: %s\r\nConnection: close\r\n\r\n" % (self.path, self.destination[0]))
print "Sending %ib" % (len(self.write_buf)) print "Sending %ib" % (len(self.write_buf))
self.set_state("http_request_sent", 0) self.set_state("http_request_sent", 0)
return False return False

66
src/network/node.py Normal file
View File

@ -0,0 +1,66 @@
import socket
import protocol
class Node (object):
TYPE_IPV4 = 1
TYPE_IPV6 = 2
TYPE_ONION = 3
TYPE_LOCAL = 4
TYPE_LOOPBACK = 8
TYPE_UNDEF = 12
def __init__(self, services, address, port):
self.services = services
self.address, self.addressType = Node.decodeIPAddress(address)
self.port = port
def isLocal(self):
return self.addressType | Node.TYPE_LOCAL > 0
def isGlobal(self):
return self.addressType <= Node.TYPE_ONION
def isOnion(self):
return self.addressType | Node.TYPE_ONION > 0
def isLoopback(self):
return self.addressType | Node.TYPE_LOOPBACK > 0
@staticmethod
def decodeIPAddress(host):
if host[0:12] == '\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xFF\xFF':
hostStandardFormat = socket.inet_ntop(socket.AF_INET, host[12:])
return Node.decodeIPv4Address(host[12:], hostStandardFormat)
elif host[0:6] == '\xfd\x87\xd8\x7e\xeb\x43':
# Onion, based on BMD/bitcoind
hostStandardFormat = base64.b32encode(host[6:]).lower() + ".onion"
return hostStandardFormat, Node.TYPE_ONION
else:
hostStandardFormat = socket.inet_ntop(socket.AF_INET6, host)
if hostStandardFormat == "":
# This can happen on Windows systems which are not 64-bit compatible
# so let us drop the IPv6 address.
return hostStandardFormat, Node.TYPE_IPV6|Node.TYPE_UNDEF
return Node.decodeIPv6Address(host, hostStandardFormat)
@staticmethod
def decodeIPv4Address(host, hostStandardFormat):
if host[0] == '\x7F': # 127/8
return hostStandardFormat, Node.TYPE_IPV4|Node.TYPE_LOOPBACK
if host[0] == '\x0A': # 10/8
return hostStandardFormat, Node.TYPE_IPV4|Node.TYPE_LOCAL
if host[0:2] == '\xC0\xA8': # 192.168/16
return hostStandardFormat, Node.TYPE_IPV4|Node.TYPE_LOCAL
if host[0:2] >= '\xAC\x10' and host[0:2] < '\xAC\x20': # 172.16/12
return hostStandardFormat, Node.TYPE_IPV4|Node.TYPE_LOCAL
return hostStandardFormat, Node.TYPE_IPV4
@staticmethod
def _checkIPv6Address(host, hostStandardFormat):
if host == ('\x00' * 15) + '\x01':
return hostStandardFormat, Node.TYPE_IPV6|Node.TYPE_LOOPBACK
if host[0] == '\xFE' and (ord(host[1]) & 0xc0) == 0x80:
return hostStandardFormat, Node.TYPE_IPV6|Node.TYPE_LOCAL
if (ord(host[0]) & 0xfe) == 0xfc:
return hostStandardFormat, Node.TYPE_IPV6|Node.TYPE_UNDEF
return hostStandardFormat, Node.TYPE_IPV6

View File

@ -59,28 +59,28 @@ class Socks4aConnection(Socks4a):
def state_auth_done(self): def state_auth_done(self):
# Now we can request the actual connection # Now we can request the actual connection
rmtrslv = False rmtrslv = False
self.write_buf += struct.pack('>BBH', 0x04, 0x01, self.destination[1]) self.append_write_buf(struct.pack('>BBH', 0x04, 0x01, self.destination[1]))
# If the given destination address is an IP address, we'll # If the given destination address is an IP address, we'll
# use the IPv4 address request even if remote resolving was specified. # use the IPv4 address request even if remote resolving was specified.
try: try:
self.ipaddr = socket.inet_aton(self.destination[0]) self.ipaddr = socket.inet_aton(self.destination[0])
self.write_buf += ipaddr self.append_write_buf(self.ipaddr)
except socket.error: except socket.error:
# Well it's not an IP number, so it's probably a DNS name. # Well it's not an IP number, so it's probably a DNS name.
if Proxy._remote_dns: if Proxy._remote_dns:
# Resolve remotely # Resolve remotely
rmtrslv = True rmtrslv = True
self.ipaddr = None self.ipaddr = None
self.write_buf += struct.pack("BBBB", 0x00, 0x00, 0x00, 0x01) self.append_write_buf(struct.pack("BBBB", 0x00, 0x00, 0x00, 0x01))
else: else:
# Resolve locally # Resolve locally
self.ipaddr = socket.inet_aton(socket.gethostbyname(self.destination[0])) self.ipaddr = socket.inet_aton(socket.gethostbyname(self.destination[0]))
self.write_buf += self.ipaddr self.append_write_buf(self.ipaddr)
if self._auth: if self._auth:
self.write_buf += self._auth[0] self.append_write_buf(self._auth[0])
self.write_buf += chr(0x00).encode() self.append_write_buf(chr(0x00).encode())
if rmtrslv: if rmtrslv:
self.write_buf += self.destination[0] + chr(0x00).encode() self.append_write_buf(self.destination[0] + chr(0x00).encode())
self.set_state("pre_connect", 0) self.set_state("pre_connect", 0)
@ -92,12 +92,12 @@ class Socks4aResolver(Socks4a):
def state_auth_done(self): def state_auth_done(self):
# Now we can request the actual connection # Now we can request the actual connection
self.write_buf += struct.pack('>BBH', 0x04, 0xF0, self.destination[1]) self.append_write_buf(struct.pack('>BBH', 0x04, 0xF0, self.destination[1]))
self.write_buf += struct.pack("BBBB", 0x00, 0x00, 0x00, 0x01) self.append_write_buf(struct.pack("BBBB", 0x00, 0x00, 0x00, 0x01))
if self._auth: if self._auth:
self.write_buf += self._auth[0] self.append_write_buf(self._auth[0])
self.write_buf += chr(0x00).encode() self.append_write_buf(chr(0x00).encode())
self.write_buf += self.host + chr(0x00).encode() self.append_write_buf(self.host + chr(0x00).encode())
self.set_state("pre_connect", 0) self.set_state("pre_connect", 0)
def resolved(self): def resolved(self):

View File

@ -17,9 +17,9 @@ class Socks5(Proxy):
def state_init(self): def state_init(self):
if self._auth: if self._auth:
self.write_buf += struct.pack('BBBB', 0x05, 0x02, 0x00, 0x02) self.append_write_buf(struct.pack('BBBB', 0x05, 0x02, 0x00, 0x02))
else: else:
self.write_buf += struct.pack('BBB', 0x05, 0x01, 0x00) self.append_write_buf(struct.pack('BBB', 0x05, 0x01, 0x00))
self.set_state("auth_1", 0) self.set_state("auth_1", 0)
def state_auth_1(self): def state_auth_1(self):
@ -35,9 +35,9 @@ class Socks5(Proxy):
self.set_state("auth_done", 2) self.set_state("auth_done", 2)
elif ret[1] == 2: elif ret[1] == 2:
# username/password # username/password
self.write_buf += struct.pack('BB', 1, len(self._auth[0])) + \ self.append_write_buf(struct.pack('BB', 1, len(self._auth[0])) + \
self._auth[0] + struct.pack('B', len(self._auth[1])) + \ self._auth[0] + struct.pack('B', len(self._auth[1])) + \
self._auth[1] self._auth[1])
self.set_state("auth_1", 2) self.set_state("auth_1", 2)
else: else:
if ret[1] == 0xff: if ret[1] == 0xff:
@ -130,23 +130,23 @@ class Socks5Connection(Socks5):
def state_auth_done(self): def state_auth_done(self):
# Now we can request the actual connection # Now we can request the actual connection
self.write_buf += struct.pack('BBB', 0x05, 0x01, 0x00) self.append_write_buf(struct.pack('BBB', 0x05, 0x01, 0x00))
# If the given destination address is an IP address, we'll # If the given destination address is an IP address, we'll
# use the IPv4 address request even if remote resolving was specified. # use the IPv4 address request even if remote resolving was specified.
try: try:
self.ipaddr = socket.inet_aton(self.destination[0]) self.ipaddr = socket.inet_aton(self.destination[0])
self.write_buf += chr(0x01).encode() + self.ipaddr self.append_write_buf(chr(0x01).encode() + self.ipaddr)
except socket.error: except socket.error:
# Well it's not an IP number, so it's probably a DNS name. # Well it's not an IP number, so it's probably a DNS name.
if Proxy._remote_dns: if Proxy._remote_dns:
# Resolve remotely # Resolve remotely
self.ipaddr = None self.ipaddr = None
self.write_buf += chr(0x03).encode() + chr(len(self.destination[0])).encode() + self.destination[0] self.append_write_buf(chr(0x03).encode() + chr(len(self.destination[0])).encode() + self.destination[0])
else: else:
# Resolve locally # Resolve locally
self.ipaddr = socket.inet_aton(socket.gethostbyname(self.destination[0])) self.ipaddr = socket.inet_aton(socket.gethostbyname(self.destination[0]))
self.write_buf += chr(0x01).encode() + self.ipaddr self.append_write_buf(chr(0x01).encode() + self.ipaddr)
self.write_buf += struct.pack(">H", self.destination[1]) self.append_write_buf(struct.pack(">H", self.destination[1]))
self.set_state("pre_connect", 0) self.set_state("pre_connect", 0)
@ -158,9 +158,9 @@ class Socks5Resolver(Socks5):
def state_auth_done(self): def state_auth_done(self):
# Now we can request the actual connection # Now we can request the actual connection
self.write_buf += struct.pack('BBB', 0x05, 0xF0, 0x00) self.append_write_buf(struct.pack('BBB', 0x05, 0xF0, 0x00))
self.write_buf += chr(0x03).encode() + chr(len(self.host)).encode() + str(self.host) self.append_write_buf(chr(0x03).encode() + chr(len(self.host)).encode() + str(self.host))
self.write_buf += struct.pack(">H", self.port) self.append_write_buf(struct.pack(">H", self.port))
self.set_state("pre_connect", 0) self.set_state("pre_connect", 0)
def resolved(self): def resolved(self):

View File

@ -2,12 +2,14 @@
SSL/TLS negotiation. SSL/TLS negotiation.
""" """
from network.advanceddispatcher import AdvancedDispatcher import os
import network.asyncore_pollchoose as asyncore
import socket import socket
import ssl import ssl
import sys import sys
from network.advanceddispatcher import AdvancedDispatcher
import network.asyncore_pollchoose as asyncore
import paths
import protocol import protocol
class TLSDispatcher(AdvancedDispatcher): class TLSDispatcher(AdvancedDispatcher):
@ -30,6 +32,7 @@ class TLSDispatcher(AdvancedDispatcher):
def state_tls_init(self): def state_tls_init(self):
self.isSSL = True self.isSSL = True
self.tlsStarted = True
# Once the connection has been established, it's safe to wrap the # Once the connection has been established, it's safe to wrap the
# socket. # socket.
if sys.version_info >= (2,7,9): if sys.version_info >= (2,7,9):
@ -40,7 +43,7 @@ class TLSDispatcher(AdvancedDispatcher):
context.verify_mode = ssl.CERT_NONE context.verify_mode = ssl.CERT_NONE
# also exclude TLSv1 and TLSv1.1 in the future # also exclude TLSv1 and TLSv1.1 in the future
context.options = ssl.OP_ALL | ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | ssl.OP_SINGLE_ECDH_USE | ssl.OP_CIPHER_SERVER_PREFERENCE context.options = ssl.OP_ALL | ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | ssl.OP_SINGLE_ECDH_USE | ssl.OP_CIPHER_SERVER_PREFERENCE
self.sslSocket = context.wrap_socket(self.sock, server_side = self.server_side, do_handshake_on_connect=False) self.sslSocket = context.wrap_socket(self.socket, server_side = self.server_side, do_handshake_on_connect=False)
else: else:
self.sslSocket = ssl.wrap_socket(self.socket, self.sslSocket = ssl.wrap_socket(self.socket,
server_side=self.server_side, server_side=self.server_side,
@ -51,49 +54,66 @@ class TLSDispatcher(AdvancedDispatcher):
do_handshake_on_connect=False) do_handshake_on_connect=False)
self.sslSocket.setblocking(0) self.sslSocket.setblocking(0)
self.want_read = self.want_write = True self.want_read = self.want_write = True
self.set_state("tls_handshake")
# if hasattr(self.socket, "context"): # if hasattr(self.socket, "context"):
# self.socket.context.set_ecdh_curve("secp256k1") # self.socket.context.set_ecdh_curve("secp256k1")
def writable(self): def writable(self):
if self.tlsStarted and not self.tlsDone: if self.tlsStarted and not self.tlsDone and len(self.write_buf) == 0:
print "tls writable, %r" % (self.want_write)
return self.want_write return self.want_write
else: else:
return AdvancedDispacher.writable(self) return AdvancedDispatcher.writable(self)
def readable(self): def readable(self):
if self.tlsStarted and not self.tlsDone: if self.tlsStarted and not self.tlsDone and len(self.write_buf) == 0:
print "tls readable, %r" % (self.want_read)
return self.want_read return self.want_read
else: else:
return AdvancedDispacher.readable(self) return AdvancedDispatcher.readable(self)
def handle_read(self): def handle_read(self):
if self.tlsStarted and not self.tlsDone: # wait for write buffer flush
self._handshake() if self.tlsStarted and not self.tlsDone and len(self.write_buf) == 0:
print "handshaking (read)"
self.state_tls_handshake()
else: else:
return AdvancedDispacher.handle_read(self) print "not handshaking (read)"
return AdvancedDispatcher.handle_read(self)
def handle_write(self): def handle_write(self):
if self.tlsStarted and not not self.tlsDone: # wait for write buffer flush
self._handshake() if self.tlsStarted and not self.tlsDone and len(self.write_buf) == 0:
print "handshaking (write)"
self.state_tls_handshake()
else: else:
return AdvancedDispacher.handle_write(self) print "not handshaking (write)"
return AdvancedDispatcher.handle_write(self)
def state_tls_handshake(self): def state_tls_handshake(self):
""" # wait for flush
Perform the handshake. if len(self.write_buf) > 0:
""" return False
# Perform the handshake.
try: try:
print "handshaking (internal)"
self.sslSocket.do_handshake() self.sslSocket.do_handshake()
except ssl.SSLError, err: except ssl.SSLError, err:
print "handshake fail"
self.want_read = self.want_write = False self.want_read = self.want_write = False
if err.args[0] == ssl.SSL_ERROR_WANT_READ: if err.args[0] == ssl.SSL_ERROR_WANT_READ:
print "want read"
self.want_read = True self.want_read = True
elif err.args[0] == ssl.SSL_ERROR_WANT_WRITE: elif err.args[0] == ssl.SSL_ERROR_WANT_WRITE:
print "want write"
self.want_write = True self.want_write = True
else: else:
raise raise
else: else:
print "handshake success"
# The handshake has completed, so remove this channel and... # The handshake has completed, so remove this channel and...
self.del_channel() self.del_channel()
self.set_socket(self.sslSocket) self.set_socket(self.sslSocket)
self.tlsDone = True self.tlsDone = True
self.state_bm_ready()
return False

View File

@ -27,6 +27,11 @@ NODE_SSL = 2
#Bitfield flags #Bitfield flags
BITFIELD_DOESACK = 1 BITFIELD_DOESACK = 1
#Error types
STATUS_WARNING = 0
STATUS_ERROR = 1
STATUS_FATAL = 2
eightBytesOfRandomDataUsedToDetectConnectionsToSelf = pack( eightBytesOfRandomDataUsedToDetectConnectionsToSelf = pack(
'>Q', random.randrange(1, 18446744073709551615)) '>Q', random.randrange(1, 18446744073709551615))