network code quality fixes
This commit is contained in:
parent
a31d6c8422
commit
81872c7f2f
|
@ -1,3 +1,6 @@
|
||||||
|
"""
|
||||||
|
A thread to handle network concerns
|
||||||
|
"""
|
||||||
import network.asyncore_pollchoose as asyncore
|
import network.asyncore_pollchoose as asyncore
|
||||||
import state
|
import state
|
||||||
from network.connectionpool import BMConnectionPool
|
from network.connectionpool import BMConnectionPool
|
||||||
|
@ -6,7 +9,7 @@ from threads import StoppableThread
|
||||||
|
|
||||||
|
|
||||||
class BMNetworkThread(StoppableThread):
|
class BMNetworkThread(StoppableThread):
|
||||||
"""A thread to handle network concerns"""
|
"""Main network thread"""
|
||||||
name = "Asyncore"
|
name = "Asyncore"
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
"""
|
"""
|
||||||
src/network/objectracker.py
|
Module for tracking objects
|
||||||
===========================
|
|
||||||
"""
|
"""
|
||||||
import time
|
import time
|
||||||
from threading import RLock
|
from threading import RLock
|
||||||
|
@ -50,15 +49,18 @@ class ObjectTracker(object):
|
||||||
"""Init bloom filter for tracking. WIP."""
|
"""Init bloom filter for tracking. WIP."""
|
||||||
if haveBloom:
|
if haveBloom:
|
||||||
# lock?
|
# lock?
|
||||||
self.invBloom = BloomFilter(capacity=ObjectTracker.invInitialCapacity,
|
self.invBloom = BloomFilter(
|
||||||
error_rate=ObjectTracker.invErrorRate)
|
capacity=ObjectTracker.invInitialCapacity,
|
||||||
|
error_rate=ObjectTracker.invErrorRate)
|
||||||
|
|
||||||
def initAddrBloom(self):
|
def initAddrBloom(self):
|
||||||
"""Init bloom filter for tracking addrs, WIP. This either needs to be moved to addrthread.py or removed."""
|
"""Init bloom filter for tracking addrs, WIP.
|
||||||
|
This either needs to be moved to addrthread.py or removed."""
|
||||||
if haveBloom:
|
if haveBloom:
|
||||||
# lock?
|
# lock?
|
||||||
self.addrBloom = BloomFilter(capacity=ObjectTracker.invInitialCapacity,
|
self.addrBloom = BloomFilter(
|
||||||
error_rate=ObjectTracker.invErrorRate)
|
capacity=ObjectTracker.invInitialCapacity,
|
||||||
|
error_rate=ObjectTracker.invErrorRate)
|
||||||
|
|
||||||
def clean(self):
|
def clean(self):
|
||||||
"""Clean up tracking to prevent memory bloat"""
|
"""Clean up tracking to prevent memory bloat"""
|
||||||
|
@ -71,7 +73,10 @@ class ObjectTracker(object):
|
||||||
# release memory
|
# release memory
|
||||||
deadline = time.time() - ObjectTracker.trackingExpires
|
deadline = time.time() - ObjectTracker.trackingExpires
|
||||||
with self.objectsNewToThemLock:
|
with self.objectsNewToThemLock:
|
||||||
self.objectsNewToThem = {k: v for k, v in self.objectsNewToThem.iteritems() if v >= deadline}
|
self.objectsNewToThem = {
|
||||||
|
k: v
|
||||||
|
for k, v in self.objectsNewToThem.iteritems()
|
||||||
|
if v >= deadline}
|
||||||
self.lastCleaned = time.time()
|
self.lastCleaned = time.time()
|
||||||
|
|
||||||
def hasObj(self, hashid):
|
def hasObj(self, hashid):
|
||||||
|
@ -102,10 +107,12 @@ class ObjectTracker(object):
|
||||||
del i.objectsNewToMe[hashid]
|
del i.objectsNewToMe[hashid]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
if streamNumber in i.streams and (
|
if streamNumber in i.streams and (
|
||||||
not Dandelion().hasHash(hashid) or Dandelion().objectChildStem(hashid) == i):
|
not Dandelion().hasHash(hashid) or
|
||||||
|
Dandelion().objectChildStem(hashid) == i):
|
||||||
with i.objectsNewToThemLock:
|
with i.objectsNewToThemLock:
|
||||||
i.objectsNewToThem[hashid] = time.time()
|
i.objectsNewToThem[hashid] = time.time()
|
||||||
# update stream number, which we didn't have when we just received the dinv
|
# update stream number,
|
||||||
|
# which we didn't have when we just received the dinv
|
||||||
# also resets expiration of the stem mode
|
# also resets expiration of the stem mode
|
||||||
Dandelion().setHashStream(hashid, streamNumber)
|
Dandelion().setHashStream(hashid, streamNumber)
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
"""
|
"""
|
||||||
src/network/proxy.py
|
Set proxy if avaiable otherwise exception
|
||||||
====================
|
|
||||||
"""
|
"""
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
import logging
|
import logging
|
||||||
|
@ -122,8 +121,7 @@ class Proxy(AdvancedDispatcher):
|
||||||
BMConfigParser().safeGet(
|
BMConfigParser().safeGet(
|
||||||
"bitmessagesettings", "socksusername"),
|
"bitmessagesettings", "socksusername"),
|
||||||
BMConfigParser().safeGet(
|
BMConfigParser().safeGet(
|
||||||
"bitmessagesettings", "sockspassword")
|
"bitmessagesettings", "sockspassword"))
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
self.auth = None
|
self.auth = None
|
||||||
self.connect(
|
self.connect(
|
||||||
|
|
|
@ -1,8 +1,6 @@
|
||||||
"""
|
"""
|
||||||
src/randomtrackingdict.py
|
Track randomize ordered dict
|
||||||
=========================
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import random
|
import random
|
||||||
from threading import RLock
|
from threading import RLock
|
||||||
from time import time
|
from time import time
|
||||||
|
@ -14,10 +12,12 @@ class RandomTrackingDict(object):
|
||||||
"""
|
"""
|
||||||
Dict with randomised order and tracking.
|
Dict with randomised order and tracking.
|
||||||
|
|
||||||
Keeps a track of how many items have been requested from the dict, and timeouts. Resets after all objects have been
|
Keeps a track of how many items have been requested from the dict,
|
||||||
retrieved and timed out. The main purpose of this isn't as much putting related code together as performance
|
and timeouts. Resets after all objects have been retrieved and timed out.
|
||||||
optimisation and anonymisation of downloading of objects from other peers. If done using a standard dict or array,
|
The main purpose of this isn't as much putting related code together
|
||||||
it takes too much CPU (and looks convoluted). Randomisation helps with anonymity.
|
as performance optimisation and anonymisation of downloading of objects
|
||||||
|
from other peers. If done using a standard dict or array, it takes
|
||||||
|
too much CPU (and looks convoluted). Randomisation helps with anonymity.
|
||||||
"""
|
"""
|
||||||
# pylint: disable=too-many-instance-attributes
|
# pylint: disable=too-many-instance-attributes
|
||||||
maxPending = 10
|
maxPending = 10
|
||||||
|
@ -85,13 +85,14 @@ class RandomTrackingDict(object):
|
||||||
|
|
||||||
def setMaxPending(self, maxPending):
|
def setMaxPending(self, maxPending):
|
||||||
"""
|
"""
|
||||||
Sets maximum number of objects that can be retrieved from the class simultaneously as long as there is no
|
Sets maximum number of objects that can be retrieved from the class
|
||||||
timeout
|
simultaneously as long as there is no timeout
|
||||||
"""
|
"""
|
||||||
self.maxPending = maxPending
|
self.maxPending = maxPending
|
||||||
|
|
||||||
def setPendingTimeout(self, pendingTimeout):
|
def setPendingTimeout(self, pendingTimeout):
|
||||||
"""Sets how long to wait for a timeout if max pending is reached (or all objects have been retrieved)"""
|
"""Sets how long to wait for a timeout if max pending is reached
|
||||||
|
(or all objects have been retrieved)"""
|
||||||
self.pendingTimeout = pendingTimeout
|
self.pendingTimeout = pendingTimeout
|
||||||
|
|
||||||
def setLastObject(self):
|
def setLastObject(self):
|
||||||
|
@ -99,7 +100,8 @@ class RandomTrackingDict(object):
|
||||||
self.lastObject = time()
|
self.lastObject = time()
|
||||||
|
|
||||||
def randomKeys(self, count=1):
|
def randomKeys(self, count=1):
|
||||||
"""Retrieve count random keys from the dict that haven't already been retrieved"""
|
"""Retrieve count random keys from the dict
|
||||||
|
that haven't already been retrieved"""
|
||||||
if self.len == 0 or ((self.pendingLen >= self.maxPending or
|
if self.len == 0 or ((self.pendingLen >= self.maxPending or
|
||||||
self.pendingLen == self.len) and self.lastPoll +
|
self.pendingLen == self.len) and self.lastPoll +
|
||||||
self.pendingTimeout > time()):
|
self.pendingTimeout > time()):
|
||||||
|
@ -109,13 +111,15 @@ class RandomTrackingDict(object):
|
||||||
with self.lock:
|
with self.lock:
|
||||||
# reset if we've requested all
|
# reset if we've requested all
|
||||||
# and if last object received too long time ago
|
# and if last object received too long time ago
|
||||||
if self.pendingLen == self.len and self.lastObject + self.pendingTimeout < time():
|
if self.pendingLen == self.len and self.lastObject + \
|
||||||
|
self.pendingTimeout < time():
|
||||||
self.pendingLen = 0
|
self.pendingLen = 0
|
||||||
self.setLastObject()
|
self.setLastObject()
|
||||||
available = self.len - self.pendingLen
|
available = self.len - self.pendingLen
|
||||||
if count > available:
|
if count > available:
|
||||||
count = available
|
count = available
|
||||||
randomIndex = helper_random.randomsample(range(self.len - self.pendingLen), count)
|
randomIndex = helper_random.randomsample(
|
||||||
|
range(self.len - self.pendingLen), count)
|
||||||
retval = [self.indexDict[i] for i in randomIndex]
|
retval = [self.indexDict[i] for i in randomIndex]
|
||||||
|
|
||||||
for i in sorted(randomIndex, reverse=True):
|
for i in sorted(randomIndex, reverse=True):
|
||||||
|
|
|
@ -1,3 +1,6 @@
|
||||||
|
"""
|
||||||
|
Process data incoming from network
|
||||||
|
"""
|
||||||
import errno
|
import errno
|
||||||
import Queue
|
import Queue
|
||||||
import socket
|
import socket
|
||||||
|
@ -10,6 +13,8 @@ from threads import StoppableThread
|
||||||
|
|
||||||
|
|
||||||
class ReceiveQueueThread(StoppableThread):
|
class ReceiveQueueThread(StoppableThread):
|
||||||
|
"""This thread processes data received from the network
|
||||||
|
(which is done by the asyncore thread)"""
|
||||||
def __init__(self, num=0):
|
def __init__(self, num=0):
|
||||||
super(ReceiveQueueThread, self).__init__(name="ReceiveQueue_%i" % num)
|
super(ReceiveQueueThread, self).__init__(name="ReceiveQueue_%i" % num)
|
||||||
|
|
||||||
|
@ -32,12 +37,14 @@ class ReceiveQueueThread(StoppableThread):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
connection = BMConnectionPool().getConnectionByAddr(dest)
|
connection = BMConnectionPool().getConnectionByAddr(dest)
|
||||||
except KeyError: # connection object not found
|
# connection object not found
|
||||||
|
except KeyError:
|
||||||
receiveDataQueue.task_done()
|
receiveDataQueue.task_done()
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
connection.process()
|
connection.process()
|
||||||
except UnknownStateError: # state isn't implemented
|
# state isn't implemented
|
||||||
|
except UnknownStateError:
|
||||||
pass
|
pass
|
||||||
except socket.error as err:
|
except socket.error as err:
|
||||||
if err.errno == errno.EBADF:
|
if err.errno == errno.EBADF:
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
"""
|
"""
|
||||||
src/network/socks4a.py
|
SOCKS4a proxy module
|
||||||
=================================
|
|
||||||
"""
|
"""
|
||||||
# pylint: disable=attribute-defined-outside-init
|
# pylint: disable=attribute-defined-outside-init
|
||||||
import socket
|
import socket
|
||||||
|
|
|
@ -1,7 +1,5 @@
|
||||||
"""
|
"""
|
||||||
src/network/socks5.py
|
SOCKS5 proxy module
|
||||||
=====================
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# pylint: disable=attribute-defined-outside-init
|
# pylint: disable=attribute-defined-outside-init
|
||||||
|
|
||||||
|
@ -155,7 +153,8 @@ class Socks5(Proxy):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def proxy_sock_name(self):
|
def proxy_sock_name(self):
|
||||||
"""Handle return value when using SOCKS5 for DNS resolving instead of connecting."""
|
"""Handle return value when using SOCKS5
|
||||||
|
for DNS resolving instead of connecting."""
|
||||||
return socket.inet_ntoa(self.__proxysockname[0])
|
return socket.inet_ntoa(self.__proxysockname[0])
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
"""
|
"""
|
||||||
src/network/stats.py
|
Network statistics
|
||||||
====================
|
|
||||||
"""
|
"""
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
@ -34,7 +33,9 @@ def uploadSpeed():
|
||||||
currentTimestamp = time.time()
|
currentTimestamp = time.time()
|
||||||
if int(lastSentTimestamp) < int(currentTimestamp):
|
if int(lastSentTimestamp) < int(currentTimestamp):
|
||||||
currentSentBytes = asyncore.sentBytes
|
currentSentBytes = asyncore.sentBytes
|
||||||
currentSentSpeed = int((currentSentBytes - lastSentBytes) / (currentTimestamp - lastSentTimestamp))
|
currentSentSpeed = int(
|
||||||
|
(currentSentBytes - lastSentBytes) / (
|
||||||
|
currentTimestamp - lastSentTimestamp))
|
||||||
lastSentBytes = currentSentBytes
|
lastSentBytes = currentSentBytes
|
||||||
lastSentTimestamp = currentTimestamp
|
lastSentTimestamp = currentTimestamp
|
||||||
return currentSentSpeed
|
return currentSentSpeed
|
||||||
|
@ -53,7 +54,8 @@ def downloadSpeed():
|
||||||
if int(lastReceivedTimestamp) < int(currentTimestamp):
|
if int(lastReceivedTimestamp) < int(currentTimestamp):
|
||||||
currentReceivedBytes = asyncore.receivedBytes
|
currentReceivedBytes = asyncore.receivedBytes
|
||||||
currentReceivedSpeed = int(
|
currentReceivedSpeed = int(
|
||||||
(currentReceivedBytes - lastReceivedBytes) / (currentTimestamp - lastReceivedTimestamp))
|
(currentReceivedBytes - lastReceivedBytes) / (
|
||||||
|
currentTimestamp - lastReceivedTimestamp))
|
||||||
lastReceivedBytes = currentReceivedBytes
|
lastReceivedBytes = currentReceivedBytes
|
||||||
lastReceivedTimestamp = currentTimestamp
|
lastReceivedTimestamp = currentTimestamp
|
||||||
return currentReceivedSpeed
|
return currentReceivedSpeed
|
||||||
|
|
|
@ -1,9 +1,7 @@
|
||||||
|
"""
|
||||||
|
TCP protocol handler
|
||||||
|
"""
|
||||||
# pylint: disable=too-many-ancestors
|
# pylint: disable=too-many-ancestors
|
||||||
"""
|
|
||||||
src/network/tcp.py
|
|
||||||
==================
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
|
@ -31,7 +29,7 @@ from network.socks4a import Socks4aConnection
|
||||||
from network.socks5 import Socks5Connection
|
from network.socks5 import Socks5Connection
|
||||||
from network.tls import TLSDispatcher
|
from network.tls import TLSDispatcher
|
||||||
from node import Peer
|
from node import Peer
|
||||||
from queues import UISignalQueue, invQueue, receiveDataQueue
|
from queues import invQueue, receiveDataQueue, UISignalQueue
|
||||||
|
|
||||||
logger = logging.getLogger('default')
|
logger = logging.getLogger('default')
|
||||||
|
|
||||||
|
@ -39,7 +37,6 @@ logger = logging.getLogger('default')
|
||||||
class TCPConnection(BMProto, TLSDispatcher):
|
class TCPConnection(BMProto, TLSDispatcher):
|
||||||
# pylint: disable=too-many-instance-attributes
|
# pylint: disable=too-many-instance-attributes
|
||||||
"""
|
"""
|
||||||
|
|
||||||
.. todo:: Look to understand and/or fix the non-parent-init-called
|
.. todo:: Look to understand and/or fix the non-parent-init-called
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -85,7 +82,8 @@ class TCPConnection(BMProto, TLSDispatcher):
|
||||||
not protocol.checkSocksIP(self.destination.host)
|
not protocol.checkSocksIP(self.destination.host)
|
||||||
)
|
)
|
||||||
except socket.error:
|
except socket.error:
|
||||||
pass # it's probably a hostname
|
# it's probably a hostname
|
||||||
|
pass
|
||||||
self.network_group = protocol.network_group(self.destination.host)
|
self.network_group = protocol.network_group(self.destination.host)
|
||||||
ObjectTracker.__init__(self) # pylint: disable=non-parent-init-called
|
ObjectTracker.__init__(self) # pylint: disable=non-parent-init-called
|
||||||
self.bm_proto_reset()
|
self.bm_proto_reset()
|
||||||
|
@ -140,10 +138,9 @@ class TCPConnection(BMProto, TLSDispatcher):
|
||||||
if not self.isOutbound and not self.local:
|
if not self.isOutbound and not self.local:
|
||||||
shared.clientHasReceivedIncomingConnections = True
|
shared.clientHasReceivedIncomingConnections = True
|
||||||
UISignalQueue.put(('setStatusIcon', 'green'))
|
UISignalQueue.put(('setStatusIcon', 'green'))
|
||||||
UISignalQueue.put((
|
UISignalQueue.put(
|
||||||
'updateNetworkStatusTab',
|
('updateNetworkStatusTab', (
|
||||||
(self.isOutbound, True, self.destination)
|
self.isOutbound, True, self.destination)))
|
||||||
))
|
|
||||||
self.antiIntersectionDelay(True)
|
self.antiIntersectionDelay(True)
|
||||||
self.fullyEstablished = True
|
self.fullyEstablished = True
|
||||||
if self.isOutbound:
|
if self.isOutbound:
|
||||||
|
@ -215,8 +212,8 @@ class TCPConnection(BMProto, TLSDispatcher):
|
||||||
bigInvList[objHash] = 0
|
bigInvList[objHash] = 0
|
||||||
objectCount = 0
|
objectCount = 0
|
||||||
payload = b''
|
payload = b''
|
||||||
# Now let us start appending all of these hashes together. They will be
|
# Now let us start appending all of these hashes together.
|
||||||
# sent out in a big inv message to our new peer.
|
# They will be sent out in a big inv message to our new peer.
|
||||||
for obj_hash, _ in bigInvList.items():
|
for obj_hash, _ in bigInvList.items():
|
||||||
payload += obj_hash
|
payload += obj_hash
|
||||||
objectCount += 1
|
objectCount += 1
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
"""
|
"""
|
||||||
SSL/TLS negotiation.
|
SSL/TLS negotiation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import socket
|
import socket
|
||||||
|
@ -10,6 +9,7 @@ import sys
|
||||||
|
|
||||||
from network.advanceddispatcher import AdvancedDispatcher
|
from network.advanceddispatcher import AdvancedDispatcher
|
||||||
import network.asyncore_pollchoose as asyncore
|
import network.asyncore_pollchoose as asyncore
|
||||||
|
|
||||||
from queues import receiveDataQueue
|
from queues import receiveDataQueue
|
||||||
import paths
|
import paths
|
||||||
|
|
||||||
|
@ -24,7 +24,8 @@ if sys.version_info >= (2, 7, 13):
|
||||||
# ssl.PROTOCOL_TLS1.2
|
# ssl.PROTOCOL_TLS1.2
|
||||||
sslProtocolVersion = ssl.PROTOCOL_TLS # pylint: disable=no-member
|
sslProtocolVersion = ssl.PROTOCOL_TLS # pylint: disable=no-member
|
||||||
elif sys.version_info >= (2, 7, 9):
|
elif sys.version_info >= (2, 7, 9):
|
||||||
# this means any SSL/TLS. SSLv2 and 3 are excluded with an option after context is created
|
# this means any SSL/TLS.
|
||||||
|
# SSLv2 and 3 are excluded with an option after context is created
|
||||||
sslProtocolVersion = ssl.PROTOCOL_SSLv23
|
sslProtocolVersion = ssl.PROTOCOL_SSLv23
|
||||||
else:
|
else:
|
||||||
# this means TLSv1, there is no way to set "TLSv1 or higher" or
|
# this means TLSv1, there is no way to set "TLSv1 or higher" or
|
||||||
|
@ -33,7 +34,8 @@ else:
|
||||||
|
|
||||||
|
|
||||||
# ciphers
|
# ciphers
|
||||||
if ssl.OPENSSL_VERSION_NUMBER >= 0x10100000 and not ssl.OPENSSL_VERSION.startswith("LibreSSL"):
|
if ssl.OPENSSL_VERSION_NUMBER >= 0x10100000 and not \
|
||||||
|
ssl.OPENSSL_VERSION.startswith("LibreSSL"):
|
||||||
sslProtocolCiphers = "AECDH-AES256-SHA@SECLEVEL=0"
|
sslProtocolCiphers = "AECDH-AES256-SHA@SECLEVEL=0"
|
||||||
else:
|
else:
|
||||||
sslProtocolCiphers = "AECDH-AES256-SHA"
|
sslProtocolCiphers = "AECDH-AES256-SHA"
|
||||||
|
@ -41,19 +43,19 @@ else:
|
||||||
|
|
||||||
class TLSDispatcher(AdvancedDispatcher):
|
class TLSDispatcher(AdvancedDispatcher):
|
||||||
"""TLS functionality for classes derived from AdvancedDispatcher"""
|
"""TLS functionality for classes derived from AdvancedDispatcher"""
|
||||||
# pylint: disable=too-many-instance-attributes
|
# pylint: disable=too-many-instance-attributes, too-many-arguments
|
||||||
# pylint: disable=too-many-arguments,super-init-not-called,unused-argument
|
# pylint: disable=super-init-not-called
|
||||||
def __init__(
|
def __init__(self, _=None, sock=None, certfile=None, keyfile=None,
|
||||||
self, address=None, sock=None, certfile=None, keyfile=None,
|
server_side=False, ciphers=sslProtocolCiphers):
|
||||||
server_side=False, ciphers=sslProtocolCiphers
|
|
||||||
):
|
|
||||||
self.want_read = self.want_write = True
|
self.want_read = self.want_write = True
|
||||||
if certfile is None:
|
if certfile is None:
|
||||||
self.certfile = os.path.join(paths.codePath(), 'sslkeys', 'cert.pem')
|
self.certfile = os.path.join(
|
||||||
|
paths.codePath(), 'sslkeys', 'cert.pem')
|
||||||
else:
|
else:
|
||||||
self.certfile = certfile
|
self.certfile = certfile
|
||||||
if keyfile is None:
|
if keyfile is None:
|
||||||
self.keyfile = os.path.join(paths.codePath(), 'sslkeys', 'key.pem')
|
self.keyfile = os.path.join(
|
||||||
|
paths.codePath(), 'sslkeys', 'key.pem')
|
||||||
else:
|
else:
|
||||||
self.keyfile = keyfile
|
self.keyfile = keyfile
|
||||||
self.server_side = server_side
|
self.server_side = server_side
|
||||||
|
@ -68,20 +70,23 @@ class TLSDispatcher(AdvancedDispatcher):
|
||||||
# pylint: disable=attribute-defined-outside-init
|
# pylint: disable=attribute-defined-outside-init
|
||||||
self.isSSL = True
|
self.isSSL = True
|
||||||
self.tlsStarted = True
|
self.tlsStarted = True
|
||||||
# Once the connection has been established, it's safe to wrap the
|
# Once the connection has been established,
|
||||||
# socket.
|
# it's safe to wrap the socket.
|
||||||
if sys.version_info >= (2, 7, 9):
|
if sys.version_info >= (2, 7, 9):
|
||||||
context = ssl.create_default_context(
|
context = ssl.create_default_context(
|
||||||
purpose=ssl.Purpose.SERVER_AUTH if self.server_side else ssl.Purpose.CLIENT_AUTH)
|
purpose=ssl.Purpose.SERVER_AUTH
|
||||||
|
if self.server_side else ssl.Purpose.CLIENT_AUTH)
|
||||||
context.set_ciphers(self.ciphers)
|
context.set_ciphers(self.ciphers)
|
||||||
context.set_ecdh_curve("secp256k1")
|
context.set_ecdh_curve("secp256k1")
|
||||||
context.check_hostname = False
|
context.check_hostname = False
|
||||||
context.verify_mode = ssl.CERT_NONE
|
context.verify_mode = ssl.CERT_NONE
|
||||||
# also exclude TLSv1 and TLSv1.1 in the future
|
# also exclude TLSv1 and TLSv1.1 in the future
|
||||||
context.options = ssl.OP_ALL | ssl.OP_NO_SSLv2 |\
|
context.options = ssl.OP_ALL | ssl.OP_NO_SSLv2 |\
|
||||||
ssl.OP_NO_SSLv3 | ssl.OP_SINGLE_ECDH_USE | ssl.OP_CIPHER_SERVER_PREFERENCE
|
ssl.OP_NO_SSLv3 | ssl.OP_SINGLE_ECDH_USE |\
|
||||||
|
ssl.OP_CIPHER_SERVER_PREFERENCE
|
||||||
self.sslSocket = context.wrap_socket(
|
self.sslSocket = context.wrap_socket(
|
||||||
self.socket, server_side=self.server_side, do_handshake_on_connect=False)
|
self.socket, server_side=self.server_side,
|
||||||
|
do_handshake_on_connect=False)
|
||||||
else:
|
else:
|
||||||
self.sslSocket = ssl.wrap_socket(
|
self.sslSocket = ssl.wrap_socket(
|
||||||
self.socket, server_side=self.server_side,
|
self.socket, server_side=self.server_side,
|
||||||
|
@ -115,12 +120,15 @@ class TLSDispatcher(AdvancedDispatcher):
|
||||||
def readable(self):
|
def readable(self):
|
||||||
"""Handle readable check for TLS-enabled sockets"""
|
"""Handle readable check for TLS-enabled sockets"""
|
||||||
try:
|
try:
|
||||||
# during TLS handshake, and after flushing write buffer, return status of last handshake attempt
|
# during TLS handshake, and after flushing write buffer,
|
||||||
|
# return status of last handshake attempt
|
||||||
if self.tlsStarted and not self.tlsDone and not self.write_buf:
|
if self.tlsStarted and not self.tlsDone and not self.write_buf:
|
||||||
# print "tls readable, %r" % (self.want_read)
|
# print "tls readable, %r" % (self.want_read)
|
||||||
return self.want_read
|
return self.want_read
|
||||||
# prior to TLS handshake, receiveDataThread should emulate synchronous behaviour
|
# prior to TLS handshake,
|
||||||
elif not self.fullyEstablished and (self.expectBytes == 0 or not self.write_buf_empty()):
|
# receiveDataThread should emulate synchronous behaviour
|
||||||
|
elif not self.fullyEstablished and (
|
||||||
|
self.expectBytes == 0 or not self.write_buf_empty()):
|
||||||
return False
|
return False
|
||||||
return AdvancedDispatcher.readable(self)
|
return AdvancedDispatcher.readable(self)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
|
@ -135,10 +143,14 @@ class TLSDispatcher(AdvancedDispatcher):
|
||||||
try:
|
try:
|
||||||
# wait for write buffer flush
|
# wait for write buffer flush
|
||||||
if self.tlsStarted and not self.tlsDone and not self.write_buf:
|
if self.tlsStarted and not self.tlsDone and not self.write_buf:
|
||||||
# logger.debug("%s:%i TLS handshaking (read)", self.destination.host, self.destination.port)
|
# logger.debug(
|
||||||
|
# "%s:%i TLS handshaking (read)", self.destination.host,
|
||||||
|
# self.destination.port)
|
||||||
self.tls_handshake()
|
self.tls_handshake()
|
||||||
else:
|
else:
|
||||||
# logger.debug("%s:%i Not TLS handshaking (read)", self.destination.host, self.destination.port)
|
# logger.debug(
|
||||||
|
# "%s:%i Not TLS handshaking (read)", self.destination.host,
|
||||||
|
# self.destination.port)
|
||||||
return AdvancedDispatcher.handle_read(self)
|
return AdvancedDispatcher.handle_read(self)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
return AdvancedDispatcher.handle_read(self)
|
return AdvancedDispatcher.handle_read(self)
|
||||||
|
@ -161,10 +173,14 @@ class TLSDispatcher(AdvancedDispatcher):
|
||||||
try:
|
try:
|
||||||
# wait for write buffer flush
|
# wait for write buffer flush
|
||||||
if self.tlsStarted and not self.tlsDone and not self.write_buf:
|
if self.tlsStarted and not self.tlsDone and not self.write_buf:
|
||||||
# logger.debug("%s:%i TLS handshaking (write)", self.destination.host, self.destination.port)
|
# logger.debug(
|
||||||
|
# "%s:%i TLS handshaking (write)", self.destination.host,
|
||||||
|
# self.destination.port)
|
||||||
self.tls_handshake()
|
self.tls_handshake()
|
||||||
else:
|
else:
|
||||||
# logger.debug("%s:%i Not TLS handshaking (write)", self.destination.host, self.destination.port)
|
# logger.debug(
|
||||||
|
# "%s:%i Not TLS handshaking (write)", self.destination.host,
|
||||||
|
# self.destination.port)
|
||||||
return AdvancedDispatcher.handle_write(self)
|
return AdvancedDispatcher.handle_write(self)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
return AdvancedDispatcher.handle_write(self)
|
return AdvancedDispatcher.handle_write(self)
|
||||||
|
@ -188,7 +204,8 @@ class TLSDispatcher(AdvancedDispatcher):
|
||||||
# print "handshaking (internal)"
|
# print "handshaking (internal)"
|
||||||
self.sslSocket.do_handshake()
|
self.sslSocket.do_handshake()
|
||||||
except ssl.SSLError as err:
|
except ssl.SSLError as err:
|
||||||
# print "%s:%i: handshake fail" % (self.destination.host, self.destination.port)
|
# print "%s:%i: handshake fail" % (
|
||||||
|
# self.destination.host, self.destination.port)
|
||||||
self.want_read = self.want_write = False
|
self.want_read = self.want_write = False
|
||||||
if err.args[0] == ssl.SSL_ERROR_WANT_READ:
|
if err.args[0] == ssl.SSL_ERROR_WANT_READ:
|
||||||
# print "want read"
|
# print "want read"
|
||||||
|
|
|
@ -1,13 +1,12 @@
|
||||||
"""
|
"""
|
||||||
src/network/udp.py
|
UDP protocol handler
|
||||||
==================
|
|
||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
import time
|
|
||||||
import socket
|
import socket
|
||||||
|
import time
|
||||||
|
|
||||||
import state
|
|
||||||
import protocol
|
import protocol
|
||||||
|
import state
|
||||||
from bmproto import BMProto
|
from bmproto import BMProto
|
||||||
from node import Peer
|
from node import Peer
|
||||||
from objectracker import ObjectTracker
|
from objectracker import ObjectTracker
|
||||||
|
@ -79,7 +78,7 @@ class UDPSocket(BMProto): # pylint: disable=too-many-instance-attributes
|
||||||
if not self.local:
|
if not self.local:
|
||||||
return True
|
return True
|
||||||
remoteport = False
|
remoteport = False
|
||||||
for seenTime, stream, services, ip, port in addresses:
|
for seenTime, stream, _, ip, port in addresses:
|
||||||
decodedIP = protocol.checkIPAddress(str(ip))
|
decodedIP = protocol.checkIPAddress(str(ip))
|
||||||
if stream not in state.streamsInWhichIAmParticipating:
|
if stream not in state.streamsInWhichIAmParticipating:
|
||||||
continue
|
continue
|
||||||
|
@ -96,9 +95,8 @@ class UDPSocket(BMProto): # pylint: disable=too-many-instance-attributes
|
||||||
"received peer discovery from %s:%i (port %i):",
|
"received peer discovery from %s:%i (port %i):",
|
||||||
self.destination.host, self.destination.port, remoteport)
|
self.destination.host, self.destination.port, remoteport)
|
||||||
if self.local:
|
if self.local:
|
||||||
state.discoveredPeers[
|
state.discoveredPeers[Peer(self.destination.host, remoteport)] = \
|
||||||
Peer(self.destination.host, remoteport)
|
time.time()
|
||||||
] = time.time()
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def bm_command_portcheck(self):
|
def bm_command_portcheck(self):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user