From a4c43381909f2f1ac4b8c495761a69527c72c14c Mon Sep 17 00:00:00 2001 From: Kashiko Koibumi Date: Sun, 26 May 2024 23:31:47 +0900 Subject: [PATCH] use bytes() on keys to be hashable in Python3 --- src/class_objectProcessor.py | 36 ++++++++++++++++------------ src/class_singleWorker.py | 21 +++++++++------- src/network/bmproto.py | 4 ++-- src/network/dandelion.py | 15 ++++++------ src/network/downloadthread.py | 2 +- src/network/objectracker.py | 21 +++++++++------- src/protocol.py | 8 +++---- src/pyelliptic/cipher.py | 4 ++-- src/pyelliptic/ecc.py | 5 ++-- src/pyelliptic/hash.py | 2 +- src/randomtrackingdict.py | 22 +++++++++-------- src/shared.py | 8 +++---- src/storage/sqlite.py | 17 +++++++------ src/tests/test_randomtrackingdict.py | 8 +++---- 14 files changed, 96 insertions(+), 77 deletions(-) diff --git a/src/class_objectProcessor.py b/src/class_objectProcessor.py index 469ccbfa..9a6f0099 100644 --- a/src/class_objectProcessor.py +++ b/src/class_objectProcessor.py @@ -140,9 +140,10 @@ class objectProcessor(threading.Thread): # bypass nonce and time, retain object type/version/stream + body readPosition = 16 - if data[readPosition:] in state.ackdataForWhichImWatching: + data_bytes = bytes(data[readPosition:]) + if data_bytes in state.ackdataForWhichImWatching: logger.info('This object is an acknowledgement bound for me.') - del state.ackdataForWhichImWatching[data[readPosition:]] + del state.ackdataForWhichImWatching[data_bytes] sqlExecute( "UPDATE sent SET status='ackreceived', lastactiontime=?" " WHERE ackdata=?", int(time.time()), data[readPosition:]) @@ -215,9 +216,10 @@ class objectProcessor(threading.Thread): logger.info( 'the hash requested in this getpubkey request is: %s', hexlify(requestedHash)) + requestedHash_bytes = bytes(requestedHash) # if this address hash is one of mine - if requestedHash in shared.myAddressesByHash: - myAddress = shared.myAddressesByHash[requestedHash] + if requestedHash_bytes in shared.myAddressesByHash: + myAddress = shared.myAddressesByHash[requestedHash_bytes] elif requestedAddressVersionNumber >= 4: requestedTag = data[readPosition:readPosition + 32] if len(requestedTag) != 32: @@ -227,8 +229,9 @@ class objectProcessor(threading.Thread): logger.debug( 'the tag requested in this getpubkey request is: %s', hexlify(requestedTag)) - if requestedTag in shared.myAddressesByTag: - myAddress = shared.myAddressesByTag[requestedTag] + requestedTag_bytes = bytes(requestedTag) + if requestedTag_bytes in shared.myAddressesByTag: + myAddress = shared.myAddressesByTag[requestedTag_bytes] if myAddress == '': logger.info('This getpubkey request is not for any of my keys.') @@ -413,12 +416,13 @@ class objectProcessor(threading.Thread): ' Sanity check failed.') tag = data[readPosition:readPosition + 32] - if tag not in state.neededPubkeys: + tag_bytes = bytes(tag) + if tag_bytes not in state.neededPubkeys: return logger.info( 'We don\'t need this v4 pubkey. We didn\'t ask for it.') # Let us try to decrypt the pubkey - toAddress = state.neededPubkeys[tag][0] + toAddress = state.neededPubkeys[tag_bytes][0] if protocol.decryptAndCheckPubkeyPayload(data, toAddress) == \ 'successful': # At this point we know that we have been waiting on this @@ -483,7 +487,7 @@ class objectProcessor(threading.Thread): # This is a message bound for me. # Look up my address based on the RIPE hash. - toAddress = shared.myAddressesByHash[toRipe] + toAddress = shared.myAddressesByHash[bytes(toRipe)] readPosition = 0 sendersAddressVersionNumber, sendersAddressVersionNumberLength = \ decodeVarint(decryptedData[readPosition:readPosition + 10]) @@ -558,7 +562,7 @@ class objectProcessor(threading.Thread): readPosition += signatureLengthLength signature = decryptedData[ readPosition:readPosition + signatureLength] - signedData = data[8:20] + encodeVarint(1) + encodeVarint( + signedData = bytes(data[8:20]) + encodeVarint(1) + encodeVarint( streamNumberAsClaimedByMsg ) + decryptedData[:positionOfBottomOfAckData] @@ -808,13 +812,14 @@ class objectProcessor(threading.Thread): elif broadcastVersion == 5: embeddedTag = data[readPosition:readPosition + 32] readPosition += 32 - if embeddedTag not in shared.MyECSubscriptionCryptorObjects: + embeddedTag_bytes = bytes(embeddedTag) + if embeddedTag_bytes not in shared.MyECSubscriptionCryptorObjects: logger.debug('We\'re not interested in this broadcast.') return # We are interested in this broadcast because of its tag. # We're going to add some more data which is signed further down. - signedData = data[8:readPosition] - cryptorObject = shared.MyECSubscriptionCryptorObjects[embeddedTag] + signedData = bytes(data[8:readPosition]) + cryptorObject = shared.MyECSubscriptionCryptorObjects[embeddedTag_bytes] try: decryptedData = cryptorObject.decrypt(data[readPosition:]) logger.debug('EC decryption successful') @@ -997,8 +1002,9 @@ class objectProcessor(threading.Thread): encodeVarint(addressVersion) + encodeVarint(streamNumber) + ripe )[32:] - if tag in state.neededPubkeys: - del state.neededPubkeys[tag] + tag_bytes = bytes(tag) + if tag_bytes in state.neededPubkeys: + del state.neededPubkeys[tag_bytes] self.sendMessages(address) @staticmethod diff --git a/src/class_singleWorker.py b/src/class_singleWorker.py index f2821f65..adcae170 100644 --- a/src/class_singleWorker.py +++ b/src/class_singleWorker.py @@ -87,7 +87,7 @@ class singleWorker(StoppableThread): tag = doubleHashOfAddressData[32:] # We'll need this for when we receive a pubkey reply: # it will be encrypted and we'll need to decrypt it. - state.neededPubkeys[tag] = ( + state.neededPubkeys[bytes(tag)] = ( toAddress, highlevelcrypto.makeCryptor( hexlify(privEncryptionKey)) @@ -99,14 +99,14 @@ class singleWorker(StoppableThread): for row in queryreturn: ackdata, = row self.logger.info('Watching for ackdata %s', hexlify(ackdata)) - state.ackdataForWhichImWatching[ackdata] = 0 + state.ackdataForWhichImWatching[bytes(ackdata)] = 0 # Fix legacy (headerless) watched ackdata to include header for oldack in state.ackdataForWhichImWatching: if len(oldack) == 32: # attach legacy header, always constant (msg/1/1) newack = '\x00\x00\x00\x02\x01\x01' + oldack - state.ackdataForWhichImWatching[newack] = 0 + state.ackdataForWhichImWatching[bytes(newack)] = 0 sqlExecute( '''UPDATE sent SET ackdata=? WHERE ackdata=? AND folder = 'sent' ''', newack, oldack @@ -794,8 +794,9 @@ class singleWorker(StoppableThread): encodeVarint(toAddressVersionNumber) + encodeVarint(toStreamNumber) + toRipe )[32:] + toTag_bytes = bytes(toTag) if toaddress in state.neededPubkeys or \ - toTag in state.neededPubkeys: + toTag_bytes in state.neededPubkeys: # We already sent a request for the pubkey sqlExecute( '''UPDATE sent SET status='awaitingpubkey', ''' @@ -836,7 +837,8 @@ class singleWorker(StoppableThread): privEncryptionKey = doubleHashOfToAddressData[:32] # The second half of the sha512 hash. tag = doubleHashOfToAddressData[32:] - state.neededPubkeys[tag] = ( + tag_bytes = bytes(tag) + state.neededPubkeys[tag_bytes] = ( toaddress, highlevelcrypto.makeCryptor( hexlify(privEncryptionKey)) @@ -859,7 +861,7 @@ class singleWorker(StoppableThread): ''' status='doingpubkeypow') AND ''' ''' folder='sent' ''', toaddress) - del state.neededPubkeys[tag] + del state.neededPubkeys[tag_bytes] break # else: # There was something wrong with this @@ -901,7 +903,7 @@ class singleWorker(StoppableThread): # if we aren't sending this to ourselves or a chan if not config.has_section(toaddress): - state.ackdataForWhichImWatching[ackdata] = 0 + state.ackdataForWhichImWatching[bytes(ackdata)] = 0 queues.UISignalQueue.put(( 'updateSentItemStatusByAckdata', ( ackdata, @@ -1412,10 +1414,11 @@ class singleWorker(StoppableThread): privEncryptionKey = doubleHashOfAddressData[:32] # Note that this is the second half of the sha512 hash. tag = doubleHashOfAddressData[32:] - if tag not in state.neededPubkeys: + tag_bytes = bytes(tag) + if tag_bytes not in state.neededPubkeys: # We'll need this for when we receive a pubkey reply: # it will be encrypted and we'll need to decrypt it. - state.neededPubkeys[tag] = ( + state.neededPubkeys[tag_bytes] = ( toAddress, highlevelcrypto.makeCryptor(hexlify(privEncryptionKey)) ) diff --git a/src/network/bmproto.py b/src/network/bmproto.py index ed1d48c4..4b01cdad 100644 --- a/src/network/bmproto.py +++ b/src/network/bmproto.py @@ -415,7 +415,7 @@ class BMProto(AdvancedDispatcher, ObjectTracker): BMProto.stopDownloadingObject(self.object.inventoryHash, True) else: try: - del missingObjects[self.object.inventoryHash] + del missingObjects[bytes(self.object.inventoryHash)] except KeyError: pass @@ -653,7 +653,7 @@ class BMProto(AdvancedDispatcher, ObjectTracker): except KeyError: pass try: - del missingObjects[hashId] + del missingObjects[bytes(hashId)] except KeyError: pass diff --git a/src/network/dandelion.py b/src/network/dandelion.py index 35e70c95..f85ece9e 100644 --- a/src/network/dandelion.py +++ b/src/network/dandelion.py @@ -52,7 +52,7 @@ class Dandelion: # pylint: disable=old-style-class if not state.dandelion_enabled: return with self.lock: - self.hashMap[hashId] = Stem( + self.hashMap[bytes(hashId)] = Stem( self.getNodeStem(source), stream, self.poissonTimeout()) @@ -63,9 +63,10 @@ class Dandelion: # pylint: disable=old-style-class include streams, we only learn this after receiving the object) """ with self.lock: - if hashId in self.hashMap: - self.hashMap[hashId] = Stem( - self.hashMap[hashId].child, + hashId_bytes = bytes(hashId) + if hashId_bytes in self.hashMap: + self.hashMap[hashId_bytes] = Stem( + self.hashMap[hashId_bytes].child, stream, self.poissonTimeout()) @@ -77,17 +78,17 @@ class Dandelion: # pylint: disable=old-style-class ''.join('%02x' % ord(i) for i in hashId), reason) with self.lock: try: - del self.hashMap[hashId] + del self.hashMap[bytes(hashId)] except KeyError: pass def hasHash(self, hashId): """Is inventory vector in stem mode?""" - return hashId in self.hashMap + return bytes(hashId) in self.hashMap def objectChildStem(self, hashId): """Child (i.e. next) node for an inventory vector during stem mode""" - return self.hashMap[hashId].child + return self.hashMap[bytes(hashId)].child def maybeAddStem(self, connection): """ diff --git a/src/network/downloadthread.py b/src/network/downloadthread.py index 4f108c72..baacba23 100644 --- a/src/network/downloadthread.py +++ b/src/network/downloadthread.py @@ -67,7 +67,7 @@ class DownloadThread(StoppableThread): continue payload.extend(chunk) chunkCount += 1 - missingObjects[chunk] = now + missingObjects[bytes(chunk)] = now if not chunkCount: continue payload[0:0] = addresses.encodeVarint(chunkCount) diff --git a/src/network/objectracker.py b/src/network/objectracker.py index a458e5d2..f1a112af 100644 --- a/src/network/objectracker.py +++ b/src/network/objectracker.py @@ -81,25 +81,28 @@ class ObjectTracker(object): def hasObj(self, hashid): """Do we already have object?""" + hashid_bytes = bytes(hashid) if haveBloom: - return hashid in self.invBloom - return hashid in self.objectsNewToMe + return hashid_bytes in self.invBloom + return hashid_bytes in self.objectsNewToMe def handleReceivedInventory(self, hashId): """Handling received inventory""" + hashId_bytes = bytes(hashId) if haveBloom: - self.invBloom.add(hashId) + self.invBloom.add(hashId_bytes) try: with self.objectsNewToThemLock: - del self.objectsNewToThem[hashId] + del self.objectsNewToThem[hashId_bytes] except KeyError: pass - if hashId not in missingObjects: - missingObjects[hashId] = time.time() + if hashId_bytes not in missingObjects: + missingObjects[hashId_bytes] = time.time() self.objectsNewToMe[hashId] = True def handleReceivedObject(self, streamNumber, hashid): """Handling received object""" + hashid_bytes = bytes(hashid) for i in connectionpool.pool.connections(): if not i.fullyEstablished: continue @@ -110,7 +113,7 @@ class ObjectTracker(object): not state.Dandelion.hasHash(hashid) or state.Dandelion.objectChildStem(hashid) == i): with i.objectsNewToThemLock: - i.objectsNewToThem[hashid] = time.time() + i.objectsNewToThem[hashid_bytes] = time.time() # update stream number, # which we didn't have when we just received the dinv # also resets expiration of the stem mode @@ -119,7 +122,7 @@ class ObjectTracker(object): if i == self: try: with i.objectsNewToThemLock: - del i.objectsNewToThem[hashid] + del i.objectsNewToThem[hashid_bytes] except KeyError: pass self.objectsNewToMe.setLastObject() @@ -133,4 +136,4 @@ class ObjectTracker(object): def addAddr(self, hashid): """WIP, should be moved to addrthread.py or removed""" if haveBloom: - self.addrBloom.add(hashid) + self.addrBloom.add(bytes(hashid)) diff --git a/src/protocol.py b/src/protocol.py index 7f9830e5..2f435cbb 100644 --- a/src/protocol.py +++ b/src/protocol.py @@ -293,7 +293,7 @@ def isProofOfWorkSufficient( if TTL < 300: TTL = 300 POW, = unpack('>Q', highlevelcrypto.double_sha512( - data[:8] + hashlib.sha512(data[8:]).digest())[0:8]) + bytes(data[:8]) + hashlib.sha512(data[8:]).digest())[0:8]) return POW <= 2 ** 64 / ( nonceTrialsPerByte * ( len(data) + payloadLengthExtraBytes @@ -465,7 +465,7 @@ def decryptAndCheckPubkeyPayload(data, address): readPosition += varintLength # We'll store the address version and stream number # (and some more) in the pubkeys table. - storedData = data[20:readPosition] + storedData = bytes(data[20:readPosition]) if addressVersion != embeddedAddressVersion: logger.info( @@ -482,11 +482,11 @@ def decryptAndCheckPubkeyPayload(data, address): readPosition += 32 # the time through the tag. More data is appended onto # signedData below after the decryption. - signedData = data[8:readPosition] + signedData = bytes(data[8:readPosition]) encryptedData = data[readPosition:] # Let us try to decrypt the pubkey - toAddress, cryptorObject = state.neededPubkeys[tag] + toAddress, cryptorObject = state.neededPubkeys[bytes(tag)] if toAddress != address: logger.critical( 'decryptAndCheckPubkeyPayload failed due to toAddress' diff --git a/src/pyelliptic/cipher.py b/src/pyelliptic/cipher.py index af6c08ca..2c2c54da 100644 --- a/src/pyelliptic/cipher.py +++ b/src/pyelliptic/cipher.py @@ -30,7 +30,7 @@ class Cipher(object): self.ctx = OpenSSL.EVP_CIPHER_CTX_new() if do == 1 or do == 0: k = OpenSSL.malloc(key, len(key)) - IV = OpenSSL.malloc(iv, len(iv)) + IV = OpenSSL.malloc(bytes(iv), len(iv)) OpenSSL.EVP_CipherInit_ex( self.ctx, self.cipher.get_pointer(), 0, k, IV, do) else: @@ -59,7 +59,7 @@ class Cipher(object): """Update result with more data""" i = OpenSSL.c_int(0) buffer = OpenSSL.malloc(b"", len(input) + self.cipher.get_blocksize()) - inp = OpenSSL.malloc(input, len(input)) + inp = OpenSSL.malloc(bytes(input), len(input)) if OpenSSL.EVP_CipherUpdate(self.ctx, OpenSSL.byref(buffer), OpenSSL.byref(i), inp, len(input)) == 0: raise Exception("[OpenSSL] EVP_CipherUpdate FAIL ...") diff --git a/src/pyelliptic/ecc.py b/src/pyelliptic/ecc.py index c670d023..8f254561 100644 --- a/src/pyelliptic/ecc.py +++ b/src/pyelliptic/ecc.py @@ -7,6 +7,7 @@ Asymmetric cryptography using elliptic curves from hashlib import sha512 from struct import pack, unpack +from ctypes import c_char_p from .cipher import Cipher from .hash import equals, hmac_sha256 @@ -218,8 +219,8 @@ class ECC(object): if other_key == 0: raise Exception("[OpenSSL] EC_KEY_new_by_curve_name FAIL ...") - other_pub_key_x = OpenSSL.BN_bin2bn(pubkey_x, len(pubkey_x), None) - other_pub_key_y = OpenSSL.BN_bin2bn(pubkey_y, len(pubkey_y), None) + other_pub_key_x = OpenSSL.BN_bin2bn(c_char_p(bytes(pubkey_x)), len(pubkey_x), None) + other_pub_key_y = OpenSSL.BN_bin2bn(c_char_p(bytes(pubkey_y)), len(pubkey_y), None) other_group = OpenSSL.EC_KEY_get0_group(other_key) other_pub_key = OpenSSL.EC_POINT_new(other_group) diff --git a/src/pyelliptic/hash.py b/src/pyelliptic/hash.py index 70c9a6ce..b133f447 100644 --- a/src/pyelliptic/hash.py +++ b/src/pyelliptic/hash.py @@ -38,7 +38,7 @@ def hmac_sha256(k, m): Compute the key and the message with HMAC SHA5256 """ key = OpenSSL.malloc(k, len(k)) - d = OpenSSL.malloc(m, len(m)) + d = OpenSSL.malloc(bytes(m), len(m)) md = OpenSSL.malloc(0, 32) i = OpenSSL.pointer(OpenSSL.c_int(0)) OpenSSL.HMAC(OpenSSL.EVP_sha256(), key, len(k), d, len(m), md, i) diff --git a/src/randomtrackingdict.py b/src/randomtrackingdict.py index 5bf19181..0944da2a 100644 --- a/src/randomtrackingdict.py +++ b/src/randomtrackingdict.py @@ -38,10 +38,10 @@ class RandomTrackingDict(object): return self.len def __contains__(self, key): - return key in self.dictionary + return bytes(key) in self.dictionary def __getitem__(self, key): - return self.dictionary[key][1] + return self.dictionary[bytes(key)][1] def _swap(self, i1, i2): with self.lock: @@ -49,26 +49,28 @@ class RandomTrackingDict(object): key2 = self.indexDict[i2] self.indexDict[i1] = key2 self.indexDict[i2] = key1 - self.dictionary[key1][0] = i2 - self.dictionary[key2][0] = i1 + self.dictionary[bytes(key1)][0] = i2 + self.dictionary[bytes(key2)][0] = i1 # for quick reassignment return i2 def __setitem__(self, key, value): with self.lock: - if key in self.dictionary: - self.dictionary[key][1] = value + key_bytes = bytes(key) + if key_bytes in self.dictionary: + self.dictionary[key_bytes][1] = value else: self.indexDict.append(key) - self.dictionary[key] = [self.len, value] + self.dictionary[key_bytes] = [self.len, value] self._swap(self.len, self.len - self.pendingLen) self.len += 1 def __delitem__(self, key): - if key not in self.dictionary: + key_bytes = bytes(key) + if key_bytes not in self.dictionary: raise KeyError with self.lock: - index = self.dictionary[key][0] + index = self.dictionary[key_bytes][0] # not pending if index < self.len - self.pendingLen: # left of pending part @@ -82,7 +84,7 @@ class RandomTrackingDict(object): # operation can improve 4x, but it's already very fast so we'll # ignore it for the time being del self.indexDict[-1] - del self.dictionary[key] + del self.dictionary[key_bytes] self.len -= 1 def setMaxPending(self, maxPending): diff --git a/src/shared.py b/src/shared.py index b85ddb20..a1541eac 100644 --- a/src/shared.py +++ b/src/shared.py @@ -114,11 +114,11 @@ def reloadMyAddressHashes(): if len(privEncryptionKey) == 64: myECCryptorObjects[hashobj] = \ highlevelcrypto.makeCryptor(privEncryptionKey) - myAddressesByHash[hashobj] = addressInKeysFile + myAddressesByHash[bytes(hashobj)] = addressInKeysFile tag = highlevelcrypto.double_sha512( encodeVarint(addressVersionNumber) + encodeVarint(streamNumber) + hashobj)[32:] - myAddressesByTag[tag] = addressInKeysFile + myAddressesByTag[bytes(tag)] = addressInKeysFile if not keyfileSecure: fixSensitiveFilePermissions(os.path.join( @@ -149,7 +149,7 @@ def reloadBroadcastSendersForWhichImWatching(): encodeVarint(addressVersionNumber) + encodeVarint(streamNumber) + hashobj ).digest()[:32] - MyECSubscriptionCryptorObjects[hashobj] = \ + MyECSubscriptionCryptorObjects[bytes(hashobj)] = \ highlevelcrypto.makeCryptor(hexlify(privEncryptionKey)) else: doubleHashOfAddressData = highlevelcrypto.double_sha512( @@ -158,7 +158,7 @@ def reloadBroadcastSendersForWhichImWatching(): ) tag = doubleHashOfAddressData[32:] privEncryptionKey = doubleHashOfAddressData[:32] - MyECSubscriptionCryptorObjects[tag] = \ + MyECSubscriptionCryptorObjects[bytes(tag)] = \ highlevelcrypto.makeCryptor(hexlify(privEncryptionKey)) diff --git a/src/storage/sqlite.py b/src/storage/sqlite.py index eb5df098..5d967aec 100644 --- a/src/storage/sqlite.py +++ b/src/storage/sqlite.py @@ -29,20 +29,22 @@ class SqliteInventory(InventoryStorage): def __contains__(self, hash_): with self.lock: - if hash_ in self._objects: + hash_bytes = bytes(hash_) + if hash_bytes in self._objects: return True rows = sqlQuery( 'SELECT streamnumber FROM inventory WHERE hash=?', sqlite3.Binary(hash_)) if not rows: return False - self._objects[hash_] = rows[0][0] + self._objects[hash_bytes] = rows[0][0] return True def __getitem__(self, hash_): with self.lock: - if hash_ in self._inventory: - return self._inventory[hash_] + hash_bytes = bytes(hash_) + if hash_bytes in self._inventory: + return self._inventory[hash_bytes] rows = sqlQuery( 'SELECT objecttype, streamnumber, payload, expirestime, tag' ' FROM inventory WHERE hash=?', sqlite3.Binary(hash_)) @@ -53,15 +55,16 @@ class SqliteInventory(InventoryStorage): def __setitem__(self, hash_, value): with self.lock: value = InventoryItem(*value) - self._inventory[hash_] = value - self._objects[hash_] = value.stream + hash_bytes = bytes(hash_) + self._inventory[hash_bytes] = value + self._objects[hash_bytes] = value.stream def __delitem__(self, hash_): raise NotImplementedError def __iter__(self): with self.lock: - hashes = self._inventory.keys()[:] + hashes = [] + self._inventory.keys()[:] hashes += (x for x, in sqlQuery('SELECT hash FROM inventory')) return hashes.__iter__() diff --git a/src/tests/test_randomtrackingdict.py b/src/tests/test_randomtrackingdict.py index 2db3c423..cbe0ee55 100644 --- a/src/tests/test_randomtrackingdict.py +++ b/src/tests/test_randomtrackingdict.py @@ -15,10 +15,10 @@ class TestRandomTrackingDict(unittest.TestCase): @staticmethod def randString(): """helper function for tests, generates a random string""" - retval = '' - for _ in range(32): - retval += chr(random.randint(0, 255)) - return retval + retval = bytearray(32) + for i in range(32): + retval[i] = random.randint(0, 255) + return bytes(retval) def test_check_randomtrackingdict(self): """Check the logic of RandomTrackingDict class"""