use bytes() on keys to be hashable in Python3

This commit is contained in:
Kashiko Koibumi 2024-05-26 23:31:47 +09:00
parent 3a04e351cc
commit a4c4338190
No known key found for this signature in database
GPG Key ID: 8F06E069E37C40C4
14 changed files with 96 additions and 77 deletions

View File

@ -140,9 +140,10 @@ class objectProcessor(threading.Thread):
# bypass nonce and time, retain object type/version/stream + body # bypass nonce and time, retain object type/version/stream + body
readPosition = 16 readPosition = 16
if data[readPosition:] in state.ackdataForWhichImWatching: data_bytes = bytes(data[readPosition:])
if data_bytes in state.ackdataForWhichImWatching:
logger.info('This object is an acknowledgement bound for me.') logger.info('This object is an acknowledgement bound for me.')
del state.ackdataForWhichImWatching[data[readPosition:]] del state.ackdataForWhichImWatching[data_bytes]
sqlExecute( sqlExecute(
"UPDATE sent SET status='ackreceived', lastactiontime=?" "UPDATE sent SET status='ackreceived', lastactiontime=?"
" WHERE ackdata=?", int(time.time()), data[readPosition:]) " WHERE ackdata=?", int(time.time()), data[readPosition:])
@ -215,9 +216,10 @@ class objectProcessor(threading.Thread):
logger.info( logger.info(
'the hash requested in this getpubkey request is: %s', 'the hash requested in this getpubkey request is: %s',
hexlify(requestedHash)) hexlify(requestedHash))
requestedHash_bytes = bytes(requestedHash)
# if this address hash is one of mine # if this address hash is one of mine
if requestedHash in shared.myAddressesByHash: if requestedHash_bytes in shared.myAddressesByHash:
myAddress = shared.myAddressesByHash[requestedHash] myAddress = shared.myAddressesByHash[requestedHash_bytes]
elif requestedAddressVersionNumber >= 4: elif requestedAddressVersionNumber >= 4:
requestedTag = data[readPosition:readPosition + 32] requestedTag = data[readPosition:readPosition + 32]
if len(requestedTag) != 32: if len(requestedTag) != 32:
@ -227,8 +229,9 @@ class objectProcessor(threading.Thread):
logger.debug( logger.debug(
'the tag requested in this getpubkey request is: %s', 'the tag requested in this getpubkey request is: %s',
hexlify(requestedTag)) hexlify(requestedTag))
if requestedTag in shared.myAddressesByTag: requestedTag_bytes = bytes(requestedTag)
myAddress = shared.myAddressesByTag[requestedTag] if requestedTag_bytes in shared.myAddressesByTag:
myAddress = shared.myAddressesByTag[requestedTag_bytes]
if myAddress == '': if myAddress == '':
logger.info('This getpubkey request is not for any of my keys.') logger.info('This getpubkey request is not for any of my keys.')
@ -413,12 +416,13 @@ class objectProcessor(threading.Thread):
' Sanity check failed.') ' Sanity check failed.')
tag = data[readPosition:readPosition + 32] tag = data[readPosition:readPosition + 32]
if tag not in state.neededPubkeys: tag_bytes = bytes(tag)
if tag_bytes not in state.neededPubkeys:
return logger.info( return logger.info(
'We don\'t need this v4 pubkey. We didn\'t ask for it.') 'We don\'t need this v4 pubkey. We didn\'t ask for it.')
# Let us try to decrypt the pubkey # Let us try to decrypt the pubkey
toAddress = state.neededPubkeys[tag][0] toAddress = state.neededPubkeys[tag_bytes][0]
if protocol.decryptAndCheckPubkeyPayload(data, toAddress) == \ if protocol.decryptAndCheckPubkeyPayload(data, toAddress) == \
'successful': 'successful':
# At this point we know that we have been waiting on this # At this point we know that we have been waiting on this
@ -483,7 +487,7 @@ class objectProcessor(threading.Thread):
# This is a message bound for me. # This is a message bound for me.
# Look up my address based on the RIPE hash. # Look up my address based on the RIPE hash.
toAddress = shared.myAddressesByHash[toRipe] toAddress = shared.myAddressesByHash[bytes(toRipe)]
readPosition = 0 readPosition = 0
sendersAddressVersionNumber, sendersAddressVersionNumberLength = \ sendersAddressVersionNumber, sendersAddressVersionNumberLength = \
decodeVarint(decryptedData[readPosition:readPosition + 10]) decodeVarint(decryptedData[readPosition:readPosition + 10])
@ -558,7 +562,7 @@ class objectProcessor(threading.Thread):
readPosition += signatureLengthLength readPosition += signatureLengthLength
signature = decryptedData[ signature = decryptedData[
readPosition:readPosition + signatureLength] readPosition:readPosition + signatureLength]
signedData = data[8:20] + encodeVarint(1) + encodeVarint( signedData = bytes(data[8:20]) + encodeVarint(1) + encodeVarint(
streamNumberAsClaimedByMsg streamNumberAsClaimedByMsg
) + decryptedData[:positionOfBottomOfAckData] ) + decryptedData[:positionOfBottomOfAckData]
@ -808,13 +812,14 @@ class objectProcessor(threading.Thread):
elif broadcastVersion == 5: elif broadcastVersion == 5:
embeddedTag = data[readPosition:readPosition + 32] embeddedTag = data[readPosition:readPosition + 32]
readPosition += 32 readPosition += 32
if embeddedTag not in shared.MyECSubscriptionCryptorObjects: embeddedTag_bytes = bytes(embeddedTag)
if embeddedTag_bytes not in shared.MyECSubscriptionCryptorObjects:
logger.debug('We\'re not interested in this broadcast.') logger.debug('We\'re not interested in this broadcast.')
return return
# We are interested in this broadcast because of its tag. # We are interested in this broadcast because of its tag.
# We're going to add some more data which is signed further down. # We're going to add some more data which is signed further down.
signedData = data[8:readPosition] signedData = bytes(data[8:readPosition])
cryptorObject = shared.MyECSubscriptionCryptorObjects[embeddedTag] cryptorObject = shared.MyECSubscriptionCryptorObjects[embeddedTag_bytes]
try: try:
decryptedData = cryptorObject.decrypt(data[readPosition:]) decryptedData = cryptorObject.decrypt(data[readPosition:])
logger.debug('EC decryption successful') logger.debug('EC decryption successful')
@ -997,8 +1002,9 @@ class objectProcessor(threading.Thread):
encodeVarint(addressVersion) + encodeVarint(streamNumber) encodeVarint(addressVersion) + encodeVarint(streamNumber)
+ ripe + ripe
)[32:] )[32:]
if tag in state.neededPubkeys: tag_bytes = bytes(tag)
del state.neededPubkeys[tag] if tag_bytes in state.neededPubkeys:
del state.neededPubkeys[tag_bytes]
self.sendMessages(address) self.sendMessages(address)
@staticmethod @staticmethod

View File

@ -87,7 +87,7 @@ class singleWorker(StoppableThread):
tag = doubleHashOfAddressData[32:] tag = doubleHashOfAddressData[32:]
# We'll need this for when we receive a pubkey reply: # We'll need this for when we receive a pubkey reply:
# it will be encrypted and we'll need to decrypt it. # it will be encrypted and we'll need to decrypt it.
state.neededPubkeys[tag] = ( state.neededPubkeys[bytes(tag)] = (
toAddress, toAddress,
highlevelcrypto.makeCryptor( highlevelcrypto.makeCryptor(
hexlify(privEncryptionKey)) hexlify(privEncryptionKey))
@ -99,14 +99,14 @@ class singleWorker(StoppableThread):
for row in queryreturn: for row in queryreturn:
ackdata, = row ackdata, = row
self.logger.info('Watching for ackdata %s', hexlify(ackdata)) self.logger.info('Watching for ackdata %s', hexlify(ackdata))
state.ackdataForWhichImWatching[ackdata] = 0 state.ackdataForWhichImWatching[bytes(ackdata)] = 0
# Fix legacy (headerless) watched ackdata to include header # Fix legacy (headerless) watched ackdata to include header
for oldack in state.ackdataForWhichImWatching: for oldack in state.ackdataForWhichImWatching:
if len(oldack) == 32: if len(oldack) == 32:
# attach legacy header, always constant (msg/1/1) # attach legacy header, always constant (msg/1/1)
newack = '\x00\x00\x00\x02\x01\x01' + oldack newack = '\x00\x00\x00\x02\x01\x01' + oldack
state.ackdataForWhichImWatching[newack] = 0 state.ackdataForWhichImWatching[bytes(newack)] = 0
sqlExecute( sqlExecute(
'''UPDATE sent SET ackdata=? WHERE ackdata=? AND folder = 'sent' ''', '''UPDATE sent SET ackdata=? WHERE ackdata=? AND folder = 'sent' ''',
newack, oldack newack, oldack
@ -794,8 +794,9 @@ class singleWorker(StoppableThread):
encodeVarint(toAddressVersionNumber) encodeVarint(toAddressVersionNumber)
+ encodeVarint(toStreamNumber) + toRipe + encodeVarint(toStreamNumber) + toRipe
)[32:] )[32:]
toTag_bytes = bytes(toTag)
if toaddress in state.neededPubkeys or \ if toaddress in state.neededPubkeys or \
toTag in state.neededPubkeys: toTag_bytes in state.neededPubkeys:
# We already sent a request for the pubkey # We already sent a request for the pubkey
sqlExecute( sqlExecute(
'''UPDATE sent SET status='awaitingpubkey', ''' '''UPDATE sent SET status='awaitingpubkey', '''
@ -836,7 +837,8 @@ class singleWorker(StoppableThread):
privEncryptionKey = doubleHashOfToAddressData[:32] privEncryptionKey = doubleHashOfToAddressData[:32]
# The second half of the sha512 hash. # The second half of the sha512 hash.
tag = doubleHashOfToAddressData[32:] tag = doubleHashOfToAddressData[32:]
state.neededPubkeys[tag] = ( tag_bytes = bytes(tag)
state.neededPubkeys[tag_bytes] = (
toaddress, toaddress,
highlevelcrypto.makeCryptor( highlevelcrypto.makeCryptor(
hexlify(privEncryptionKey)) hexlify(privEncryptionKey))
@ -859,7 +861,7 @@ class singleWorker(StoppableThread):
''' status='doingpubkeypow') AND ''' ''' status='doingpubkeypow') AND '''
''' folder='sent' ''', ''' folder='sent' ''',
toaddress) toaddress)
del state.neededPubkeys[tag] del state.neededPubkeys[tag_bytes]
break break
# else: # else:
# There was something wrong with this # There was something wrong with this
@ -901,7 +903,7 @@ class singleWorker(StoppableThread):
# if we aren't sending this to ourselves or a chan # if we aren't sending this to ourselves or a chan
if not config.has_section(toaddress): if not config.has_section(toaddress):
state.ackdataForWhichImWatching[ackdata] = 0 state.ackdataForWhichImWatching[bytes(ackdata)] = 0
queues.UISignalQueue.put(( queues.UISignalQueue.put((
'updateSentItemStatusByAckdata', ( 'updateSentItemStatusByAckdata', (
ackdata, ackdata,
@ -1412,10 +1414,11 @@ class singleWorker(StoppableThread):
privEncryptionKey = doubleHashOfAddressData[:32] privEncryptionKey = doubleHashOfAddressData[:32]
# Note that this is the second half of the sha512 hash. # Note that this is the second half of the sha512 hash.
tag = doubleHashOfAddressData[32:] tag = doubleHashOfAddressData[32:]
if tag not in state.neededPubkeys: tag_bytes = bytes(tag)
if tag_bytes not in state.neededPubkeys:
# We'll need this for when we receive a pubkey reply: # We'll need this for when we receive a pubkey reply:
# it will be encrypted and we'll need to decrypt it. # it will be encrypted and we'll need to decrypt it.
state.neededPubkeys[tag] = ( state.neededPubkeys[tag_bytes] = (
toAddress, toAddress,
highlevelcrypto.makeCryptor(hexlify(privEncryptionKey)) highlevelcrypto.makeCryptor(hexlify(privEncryptionKey))
) )

View File

@ -415,7 +415,7 @@ class BMProto(AdvancedDispatcher, ObjectTracker):
BMProto.stopDownloadingObject(self.object.inventoryHash, True) BMProto.stopDownloadingObject(self.object.inventoryHash, True)
else: else:
try: try:
del missingObjects[self.object.inventoryHash] del missingObjects[bytes(self.object.inventoryHash)]
except KeyError: except KeyError:
pass pass
@ -653,7 +653,7 @@ class BMProto(AdvancedDispatcher, ObjectTracker):
except KeyError: except KeyError:
pass pass
try: try:
del missingObjects[hashId] del missingObjects[bytes(hashId)]
except KeyError: except KeyError:
pass pass

View File

@ -52,7 +52,7 @@ class Dandelion: # pylint: disable=old-style-class
if not state.dandelion_enabled: if not state.dandelion_enabled:
return return
with self.lock: with self.lock:
self.hashMap[hashId] = Stem( self.hashMap[bytes(hashId)] = Stem(
self.getNodeStem(source), self.getNodeStem(source),
stream, stream,
self.poissonTimeout()) self.poissonTimeout())
@ -63,9 +63,10 @@ class Dandelion: # pylint: disable=old-style-class
include streams, we only learn this after receiving the object) include streams, we only learn this after receiving the object)
""" """
with self.lock: with self.lock:
if hashId in self.hashMap: hashId_bytes = bytes(hashId)
self.hashMap[hashId] = Stem( if hashId_bytes in self.hashMap:
self.hashMap[hashId].child, self.hashMap[hashId_bytes] = Stem(
self.hashMap[hashId_bytes].child,
stream, stream,
self.poissonTimeout()) self.poissonTimeout())
@ -77,17 +78,17 @@ class Dandelion: # pylint: disable=old-style-class
''.join('%02x' % ord(i) for i in hashId), reason) ''.join('%02x' % ord(i) for i in hashId), reason)
with self.lock: with self.lock:
try: try:
del self.hashMap[hashId] del self.hashMap[bytes(hashId)]
except KeyError: except KeyError:
pass pass
def hasHash(self, hashId): def hasHash(self, hashId):
"""Is inventory vector in stem mode?""" """Is inventory vector in stem mode?"""
return hashId in self.hashMap return bytes(hashId) in self.hashMap
def objectChildStem(self, hashId): def objectChildStem(self, hashId):
"""Child (i.e. next) node for an inventory vector during stem mode""" """Child (i.e. next) node for an inventory vector during stem mode"""
return self.hashMap[hashId].child return self.hashMap[bytes(hashId)].child
def maybeAddStem(self, connection): def maybeAddStem(self, connection):
""" """

View File

@ -67,7 +67,7 @@ class DownloadThread(StoppableThread):
continue continue
payload.extend(chunk) payload.extend(chunk)
chunkCount += 1 chunkCount += 1
missingObjects[chunk] = now missingObjects[bytes(chunk)] = now
if not chunkCount: if not chunkCount:
continue continue
payload[0:0] = addresses.encodeVarint(chunkCount) payload[0:0] = addresses.encodeVarint(chunkCount)

View File

@ -81,25 +81,28 @@ class ObjectTracker(object):
def hasObj(self, hashid): def hasObj(self, hashid):
"""Do we already have object?""" """Do we already have object?"""
hashid_bytes = bytes(hashid)
if haveBloom: if haveBloom:
return hashid in self.invBloom return hashid_bytes in self.invBloom
return hashid in self.objectsNewToMe return hashid_bytes in self.objectsNewToMe
def handleReceivedInventory(self, hashId): def handleReceivedInventory(self, hashId):
"""Handling received inventory""" """Handling received inventory"""
hashId_bytes = bytes(hashId)
if haveBloom: if haveBloom:
self.invBloom.add(hashId) self.invBloom.add(hashId_bytes)
try: try:
with self.objectsNewToThemLock: with self.objectsNewToThemLock:
del self.objectsNewToThem[hashId] del self.objectsNewToThem[hashId_bytes]
except KeyError: except KeyError:
pass pass
if hashId not in missingObjects: if hashId_bytes not in missingObjects:
missingObjects[hashId] = time.time() missingObjects[hashId_bytes] = time.time()
self.objectsNewToMe[hashId] = True self.objectsNewToMe[hashId] = True
def handleReceivedObject(self, streamNumber, hashid): def handleReceivedObject(self, streamNumber, hashid):
"""Handling received object""" """Handling received object"""
hashid_bytes = bytes(hashid)
for i in connectionpool.pool.connections(): for i in connectionpool.pool.connections():
if not i.fullyEstablished: if not i.fullyEstablished:
continue continue
@ -110,7 +113,7 @@ class ObjectTracker(object):
not state.Dandelion.hasHash(hashid) not state.Dandelion.hasHash(hashid)
or state.Dandelion.objectChildStem(hashid) == i): or state.Dandelion.objectChildStem(hashid) == i):
with i.objectsNewToThemLock: with i.objectsNewToThemLock:
i.objectsNewToThem[hashid] = time.time() i.objectsNewToThem[hashid_bytes] = time.time()
# update stream number, # update stream number,
# which we didn't have when we just received the dinv # which we didn't have when we just received the dinv
# also resets expiration of the stem mode # also resets expiration of the stem mode
@ -119,7 +122,7 @@ class ObjectTracker(object):
if i == self: if i == self:
try: try:
with i.objectsNewToThemLock: with i.objectsNewToThemLock:
del i.objectsNewToThem[hashid] del i.objectsNewToThem[hashid_bytes]
except KeyError: except KeyError:
pass pass
self.objectsNewToMe.setLastObject() self.objectsNewToMe.setLastObject()
@ -133,4 +136,4 @@ class ObjectTracker(object):
def addAddr(self, hashid): def addAddr(self, hashid):
"""WIP, should be moved to addrthread.py or removed""" """WIP, should be moved to addrthread.py or removed"""
if haveBloom: if haveBloom:
self.addrBloom.add(hashid) self.addrBloom.add(bytes(hashid))

View File

@ -293,7 +293,7 @@ def isProofOfWorkSufficient(
if TTL < 300: if TTL < 300:
TTL = 300 TTL = 300
POW, = unpack('>Q', highlevelcrypto.double_sha512( POW, = unpack('>Q', highlevelcrypto.double_sha512(
data[:8] + hashlib.sha512(data[8:]).digest())[0:8]) bytes(data[:8]) + hashlib.sha512(data[8:]).digest())[0:8])
return POW <= 2 ** 64 / ( return POW <= 2 ** 64 / (
nonceTrialsPerByte * ( nonceTrialsPerByte * (
len(data) + payloadLengthExtraBytes len(data) + payloadLengthExtraBytes
@ -465,7 +465,7 @@ def decryptAndCheckPubkeyPayload(data, address):
readPosition += varintLength readPosition += varintLength
# We'll store the address version and stream number # We'll store the address version and stream number
# (and some more) in the pubkeys table. # (and some more) in the pubkeys table.
storedData = data[20:readPosition] storedData = bytes(data[20:readPosition])
if addressVersion != embeddedAddressVersion: if addressVersion != embeddedAddressVersion:
logger.info( logger.info(
@ -482,11 +482,11 @@ def decryptAndCheckPubkeyPayload(data, address):
readPosition += 32 readPosition += 32
# the time through the tag. More data is appended onto # the time through the tag. More data is appended onto
# signedData below after the decryption. # signedData below after the decryption.
signedData = data[8:readPosition] signedData = bytes(data[8:readPosition])
encryptedData = data[readPosition:] encryptedData = data[readPosition:]
# Let us try to decrypt the pubkey # Let us try to decrypt the pubkey
toAddress, cryptorObject = state.neededPubkeys[tag] toAddress, cryptorObject = state.neededPubkeys[bytes(tag)]
if toAddress != address: if toAddress != address:
logger.critical( logger.critical(
'decryptAndCheckPubkeyPayload failed due to toAddress' 'decryptAndCheckPubkeyPayload failed due to toAddress'

View File

@ -30,7 +30,7 @@ class Cipher(object):
self.ctx = OpenSSL.EVP_CIPHER_CTX_new() self.ctx = OpenSSL.EVP_CIPHER_CTX_new()
if do == 1 or do == 0: if do == 1 or do == 0:
k = OpenSSL.malloc(key, len(key)) k = OpenSSL.malloc(key, len(key))
IV = OpenSSL.malloc(iv, len(iv)) IV = OpenSSL.malloc(bytes(iv), len(iv))
OpenSSL.EVP_CipherInit_ex( OpenSSL.EVP_CipherInit_ex(
self.ctx, self.cipher.get_pointer(), 0, k, IV, do) self.ctx, self.cipher.get_pointer(), 0, k, IV, do)
else: else:
@ -59,7 +59,7 @@ class Cipher(object):
"""Update result with more data""" """Update result with more data"""
i = OpenSSL.c_int(0) i = OpenSSL.c_int(0)
buffer = OpenSSL.malloc(b"", len(input) + self.cipher.get_blocksize()) buffer = OpenSSL.malloc(b"", len(input) + self.cipher.get_blocksize())
inp = OpenSSL.malloc(input, len(input)) inp = OpenSSL.malloc(bytes(input), len(input))
if OpenSSL.EVP_CipherUpdate(self.ctx, OpenSSL.byref(buffer), if OpenSSL.EVP_CipherUpdate(self.ctx, OpenSSL.byref(buffer),
OpenSSL.byref(i), inp, len(input)) == 0: OpenSSL.byref(i), inp, len(input)) == 0:
raise Exception("[OpenSSL] EVP_CipherUpdate FAIL ...") raise Exception("[OpenSSL] EVP_CipherUpdate FAIL ...")

View File

@ -7,6 +7,7 @@ Asymmetric cryptography using elliptic curves
from hashlib import sha512 from hashlib import sha512
from struct import pack, unpack from struct import pack, unpack
from ctypes import c_char_p
from .cipher import Cipher from .cipher import Cipher
from .hash import equals, hmac_sha256 from .hash import equals, hmac_sha256
@ -218,8 +219,8 @@ class ECC(object):
if other_key == 0: if other_key == 0:
raise Exception("[OpenSSL] EC_KEY_new_by_curve_name FAIL ...") raise Exception("[OpenSSL] EC_KEY_new_by_curve_name FAIL ...")
other_pub_key_x = OpenSSL.BN_bin2bn(pubkey_x, len(pubkey_x), None) other_pub_key_x = OpenSSL.BN_bin2bn(c_char_p(bytes(pubkey_x)), len(pubkey_x), None)
other_pub_key_y = OpenSSL.BN_bin2bn(pubkey_y, len(pubkey_y), None) other_pub_key_y = OpenSSL.BN_bin2bn(c_char_p(bytes(pubkey_y)), len(pubkey_y), None)
other_group = OpenSSL.EC_KEY_get0_group(other_key) other_group = OpenSSL.EC_KEY_get0_group(other_key)
other_pub_key = OpenSSL.EC_POINT_new(other_group) other_pub_key = OpenSSL.EC_POINT_new(other_group)

View File

@ -38,7 +38,7 @@ def hmac_sha256(k, m):
Compute the key and the message with HMAC SHA5256 Compute the key and the message with HMAC SHA5256
""" """
key = OpenSSL.malloc(k, len(k)) key = OpenSSL.malloc(k, len(k))
d = OpenSSL.malloc(m, len(m)) d = OpenSSL.malloc(bytes(m), len(m))
md = OpenSSL.malloc(0, 32) md = OpenSSL.malloc(0, 32)
i = OpenSSL.pointer(OpenSSL.c_int(0)) i = OpenSSL.pointer(OpenSSL.c_int(0))
OpenSSL.HMAC(OpenSSL.EVP_sha256(), key, len(k), d, len(m), md, i) OpenSSL.HMAC(OpenSSL.EVP_sha256(), key, len(k), d, len(m), md, i)

View File

@ -38,10 +38,10 @@ class RandomTrackingDict(object):
return self.len return self.len
def __contains__(self, key): def __contains__(self, key):
return key in self.dictionary return bytes(key) in self.dictionary
def __getitem__(self, key): def __getitem__(self, key):
return self.dictionary[key][1] return self.dictionary[bytes(key)][1]
def _swap(self, i1, i2): def _swap(self, i1, i2):
with self.lock: with self.lock:
@ -49,26 +49,28 @@ class RandomTrackingDict(object):
key2 = self.indexDict[i2] key2 = self.indexDict[i2]
self.indexDict[i1] = key2 self.indexDict[i1] = key2
self.indexDict[i2] = key1 self.indexDict[i2] = key1
self.dictionary[key1][0] = i2 self.dictionary[bytes(key1)][0] = i2
self.dictionary[key2][0] = i1 self.dictionary[bytes(key2)][0] = i1
# for quick reassignment # for quick reassignment
return i2 return i2
def __setitem__(self, key, value): def __setitem__(self, key, value):
with self.lock: with self.lock:
if key in self.dictionary: key_bytes = bytes(key)
self.dictionary[key][1] = value if key_bytes in self.dictionary:
self.dictionary[key_bytes][1] = value
else: else:
self.indexDict.append(key) self.indexDict.append(key)
self.dictionary[key] = [self.len, value] self.dictionary[key_bytes] = [self.len, value]
self._swap(self.len, self.len - self.pendingLen) self._swap(self.len, self.len - self.pendingLen)
self.len += 1 self.len += 1
def __delitem__(self, key): def __delitem__(self, key):
if key not in self.dictionary: key_bytes = bytes(key)
if key_bytes not in self.dictionary:
raise KeyError raise KeyError
with self.lock: with self.lock:
index = self.dictionary[key][0] index = self.dictionary[key_bytes][0]
# not pending # not pending
if index < self.len - self.pendingLen: if index < self.len - self.pendingLen:
# left of pending part # left of pending part
@ -82,7 +84,7 @@ class RandomTrackingDict(object):
# operation can improve 4x, but it's already very fast so we'll # operation can improve 4x, but it's already very fast so we'll
# ignore it for the time being # ignore it for the time being
del self.indexDict[-1] del self.indexDict[-1]
del self.dictionary[key] del self.dictionary[key_bytes]
self.len -= 1 self.len -= 1
def setMaxPending(self, maxPending): def setMaxPending(self, maxPending):

View File

@ -114,11 +114,11 @@ def reloadMyAddressHashes():
if len(privEncryptionKey) == 64: if len(privEncryptionKey) == 64:
myECCryptorObjects[hashobj] = \ myECCryptorObjects[hashobj] = \
highlevelcrypto.makeCryptor(privEncryptionKey) highlevelcrypto.makeCryptor(privEncryptionKey)
myAddressesByHash[hashobj] = addressInKeysFile myAddressesByHash[bytes(hashobj)] = addressInKeysFile
tag = highlevelcrypto.double_sha512( tag = highlevelcrypto.double_sha512(
encodeVarint(addressVersionNumber) encodeVarint(addressVersionNumber)
+ encodeVarint(streamNumber) + hashobj)[32:] + encodeVarint(streamNumber) + hashobj)[32:]
myAddressesByTag[tag] = addressInKeysFile myAddressesByTag[bytes(tag)] = addressInKeysFile
if not keyfileSecure: if not keyfileSecure:
fixSensitiveFilePermissions(os.path.join( fixSensitiveFilePermissions(os.path.join(
@ -149,7 +149,7 @@ def reloadBroadcastSendersForWhichImWatching():
encodeVarint(addressVersionNumber) encodeVarint(addressVersionNumber)
+ encodeVarint(streamNumber) + hashobj + encodeVarint(streamNumber) + hashobj
).digest()[:32] ).digest()[:32]
MyECSubscriptionCryptorObjects[hashobj] = \ MyECSubscriptionCryptorObjects[bytes(hashobj)] = \
highlevelcrypto.makeCryptor(hexlify(privEncryptionKey)) highlevelcrypto.makeCryptor(hexlify(privEncryptionKey))
else: else:
doubleHashOfAddressData = highlevelcrypto.double_sha512( doubleHashOfAddressData = highlevelcrypto.double_sha512(
@ -158,7 +158,7 @@ def reloadBroadcastSendersForWhichImWatching():
) )
tag = doubleHashOfAddressData[32:] tag = doubleHashOfAddressData[32:]
privEncryptionKey = doubleHashOfAddressData[:32] privEncryptionKey = doubleHashOfAddressData[:32]
MyECSubscriptionCryptorObjects[tag] = \ MyECSubscriptionCryptorObjects[bytes(tag)] = \
highlevelcrypto.makeCryptor(hexlify(privEncryptionKey)) highlevelcrypto.makeCryptor(hexlify(privEncryptionKey))

View File

@ -29,20 +29,22 @@ class SqliteInventory(InventoryStorage):
def __contains__(self, hash_): def __contains__(self, hash_):
with self.lock: with self.lock:
if hash_ in self._objects: hash_bytes = bytes(hash_)
if hash_bytes in self._objects:
return True return True
rows = sqlQuery( rows = sqlQuery(
'SELECT streamnumber FROM inventory WHERE hash=?', 'SELECT streamnumber FROM inventory WHERE hash=?',
sqlite3.Binary(hash_)) sqlite3.Binary(hash_))
if not rows: if not rows:
return False return False
self._objects[hash_] = rows[0][0] self._objects[hash_bytes] = rows[0][0]
return True return True
def __getitem__(self, hash_): def __getitem__(self, hash_):
with self.lock: with self.lock:
if hash_ in self._inventory: hash_bytes = bytes(hash_)
return self._inventory[hash_] if hash_bytes in self._inventory:
return self._inventory[hash_bytes]
rows = sqlQuery( rows = sqlQuery(
'SELECT objecttype, streamnumber, payload, expirestime, tag' 'SELECT objecttype, streamnumber, payload, expirestime, tag'
' FROM inventory WHERE hash=?', sqlite3.Binary(hash_)) ' FROM inventory WHERE hash=?', sqlite3.Binary(hash_))
@ -53,15 +55,16 @@ class SqliteInventory(InventoryStorage):
def __setitem__(self, hash_, value): def __setitem__(self, hash_, value):
with self.lock: with self.lock:
value = InventoryItem(*value) value = InventoryItem(*value)
self._inventory[hash_] = value hash_bytes = bytes(hash_)
self._objects[hash_] = value.stream self._inventory[hash_bytes] = value
self._objects[hash_bytes] = value.stream
def __delitem__(self, hash_): def __delitem__(self, hash_):
raise NotImplementedError raise NotImplementedError
def __iter__(self): def __iter__(self):
with self.lock: with self.lock:
hashes = self._inventory.keys()[:] hashes = [] + self._inventory.keys()[:]
hashes += (x for x, in sqlQuery('SELECT hash FROM inventory')) hashes += (x for x, in sqlQuery('SELECT hash FROM inventory'))
return hashes.__iter__() return hashes.__iter__()

View File

@ -15,10 +15,10 @@ class TestRandomTrackingDict(unittest.TestCase):
@staticmethod @staticmethod
def randString(): def randString():
"""helper function for tests, generates a random string""" """helper function for tests, generates a random string"""
retval = '' retval = bytearray(32)
for _ in range(32): for i in range(32):
retval += chr(random.randint(0, 255)) retval[i] = random.randint(0, 255)
return retval return bytes(retval)
def test_check_randomtrackingdict(self): def test_check_randomtrackingdict(self):
"""Check the logic of RandomTrackingDict class""" """Check the logic of RandomTrackingDict class"""