From 2b93736a33e694d9670864adf1832ac48d167d02 Mon Sep 17 00:00:00 2001
From: anand k <anand.theskss@gmail.com>
Date: Thu, 20 Jun 2024 07:55:26 +0530
Subject: [PATCH] Used dependency injection in network threads

---
 src/bitmessageqt/settings.py      |  2 +-
 src/network/__init__.py           |  9 +++++----
 src/network/addrthread.py         | 15 +++++++++------
 src/network/announcethread.py     |  9 ++++++---
 src/network/downloadthread.py     | 13 ++++++++-----
 src/network/invthread.py          | 28 +++++++++++++++++-----------
 src/network/networkthread.py      |  7 +++++--
 src/network/receivequeuethread.py | 10 +++++-----
 src/network/uploadthread.py       | 10 +++++++---
 src/tests/test_network.py         |  2 +-
 10 files changed, 64 insertions(+), 41 deletions(-)

diff --git a/src/bitmessageqt/settings.py b/src/bitmessageqt/settings.py
index eeb507c7..c9fb21cd 100644
--- a/src/bitmessageqt/settings.py
+++ b/src/bitmessageqt/settings.py
@@ -414,7 +414,7 @@ class SettingsDialog(QtGui.QDialog):
                 'bitmessagesettings', 'udp'):
             self.config.set('bitmessagesettings', 'udp', str(udp_enabled))
             if udp_enabled:
-                announceThread = AnnounceThread()
+                announceThread = AnnounceThread(self.config)
                 announceThread.daemon = True
                 announceThread.start()
             else:
diff --git a/src/network/__init__.py b/src/network/__init__.py
index c3d7a21d..730a9579 100644
--- a/src/network/__init__.py
+++ b/src/network/__init__.py
@@ -38,18 +38,19 @@ def start(config, state):
     readKnownNodes()
     connectionpool.pool.connectToStream(1)
     for thread in (
-        BMNetworkThread(), InvThread(), AddrThread(),
-        DownloadThread(), UploadThread()
+        BMNetworkThread(queues), InvThread(protocol, state, queues, addresses),
+        AddrThread(protocol, queues), DownloadThread(state, protocol, addresses),
+        UploadThread(protocol, state)
     ):
         thread.daemon = True
         thread.start()
 
     # Optional components
     for i in range(config.getint('threads', 'receive')):
-        thread = ReceiveQueueThread(i)
+        thread = ReceiveQueueThread(queues, i)
         thread.daemon = True
         thread.start()
     if config.safeGetBoolean('bitmessagesettings', 'udp'):
-        state.announceThread = AnnounceThread()
+        state.announceThread = AnnounceThread(config)
         state.announceThread.daemon = True
         state.announceThread.start()
diff --git a/src/network/addrthread.py b/src/network/addrthread.py
index 489401cc..132223c0 100644
--- a/src/network/addrthread.py
+++ b/src/network/addrthread.py
@@ -6,8 +6,6 @@ from six.moves import queue
 # magic imports!
 import connectionpool
 from helper_random import randomshuffle
-from protocol import assembleAddrMessage
-from network import queues  # FIXME: init with queue
 
 from threads import StoppableThread
 
@@ -16,12 +14,17 @@ class AddrThread(StoppableThread):
     """(Node) address broadcasting thread"""
     name = "AddrBroadcaster"
 
+    def __init__(self, protocol, queues):
+        self.protocol = protocol
+        self.queues = queues
+        StoppableThread.__init__(self)
+
     def run(self):
         while not self._stopped:
             chunk = []
             while True:
                 try:
-                    data = queues.addrQueue.get(False)
+                    data = self.queues.addrQueue.get(False)
                     chunk.append(data)
                 except queue.Empty:
                     break
@@ -41,9 +44,9 @@ class AddrThread(StoppableThread):
                             continue
                         filtered.append((stream, peer, seen))
                     if filtered:
-                        i.append_write_buf(assembleAddrMessage(filtered))
+                        i.append_write_buf(self.protocol.assembleAddrMessage(filtered))
 
-            queues.addrQueue.iterate()
+            self.queues.addrQueue.iterate()
             for i in range(len(chunk)):
-                queues.addrQueue.task_done()
+                self.queues.addrQueue.task_done()
             self.stop.wait(1)
diff --git a/src/network/announcethread.py b/src/network/announcethread.py
index 1ef1c87f..cb26165e 100644
--- a/src/network/announcethread.py
+++ b/src/network/announcethread.py
@@ -5,7 +5,6 @@ import time
 
 # magic imports!
 import connectionpool
-from network import config
 from protocol import assembleAddrMessage
 
 from node import Peer
@@ -17,18 +16,22 @@ class AnnounceThread(StoppableThread):
     name = "Announcer"
     announceInterval = 60
 
+    def __init__(self, config):
+        self.config = config
+        StoppableThread.__init__(self)
+
     def run(self):
         lastSelfAnnounced = 0
         while not self._stopped:
             processed = 0
             if lastSelfAnnounced < time.time() - self.announceInterval:
-                self.announceSelf()
+                self.announceSelf(self.config)
                 lastSelfAnnounced = time.time()
             if processed == 0:
                 self.stop.wait(10)
 
     @staticmethod
-    def announceSelf():
+    def announceSelf(config):
         """Announce our presence"""
         for connection in connectionpool.pool.udpSockets.values():
             if not connection.announcing:
diff --git a/src/network/downloadthread.py b/src/network/downloadthread.py
index d8cc303e..54b238fe 100644
--- a/src/network/downloadthread.py
+++ b/src/network/downloadthread.py
@@ -2,7 +2,7 @@
 `DownloadThread` class definition
 """
 import time
-from network import state, protocol, addresses, dandelion_ins
+from network import dandelion_ins
 import helper_random
 import connectionpool
 from objectracker import missingObjects
@@ -17,8 +17,11 @@ class DownloadThread(StoppableThread):
     cleanInterval = 60
     requestExpires = 3600
 
-    def __init__(self):
+    def __init__(self, state, protocol, addresses):
         super(DownloadThread, self).__init__(name="Downloader")
+        self.state = state
+        self.protocol = protocol
+        self.addresses = addresses
         self.lastCleaned = time.time()
 
     def cleanPending(self):
@@ -57,7 +60,7 @@ class DownloadThread(StoppableThread):
                 payload = bytearray()
                 chunkCount = 0
                 for chunk in request:
-                    if chunk in state.Inventory and not dandelion_ins.hasHash(chunk):
+                    if chunk in self.state.Inventory and not dandelion_ins.hasHash(chunk):
                         try:
                             del i.objectsNewToMe[chunk]
                         except KeyError:
@@ -68,8 +71,8 @@ class DownloadThread(StoppableThread):
                     missingObjects[chunk] = now
                 if not chunkCount:
                     continue
-                payload[0:0] = addresses.encodeVarint(chunkCount)
-                i.append_write_buf(protocol.CreatePacket('getdata', payload))
+                payload[0:0] = self.addresses.encodeVarint(chunkCount)
+                i.append_write_buf(self.protocol.CreatePacket('getdata', payload))
                 self.logger.debug(
                     '%s:%i Requesting %i objects',
                     i.destination.host, i.destination.port, chunkCount)
diff --git a/src/network/invthread.py b/src/network/invthread.py
index 02a5e48f..7c9fdcf4 100644
--- a/src/network/invthread.py
+++ b/src/network/invthread.py
@@ -5,7 +5,6 @@ import Queue
 import random
 from time import time
 
-from network import protocol, state, queues, addresses
 import connectionpool
 from network import dandelion_ins
 from threads import StoppableThread
@@ -34,6 +33,13 @@ class InvThread(StoppableThread):
 
     name = "InvBroadcaster"
 
+    def __init__(self, protocol, state, queues, addresses):
+        self.protocol = protocol
+        self.state = state
+        self.queues = queues
+        self.addresses = addresses
+        StoppableThread.__init__(self)
+
     @staticmethod
     def handleLocallyGenerated(stream, hashId):
         """Locally generated inventory items require special handling"""
@@ -45,13 +51,13 @@ class InvThread(StoppableThread):
             connection.objectsNewToThem[hashId] = time()
 
     def run(self):  # pylint: disable=too-many-branches
-        while not state.shutdown:  # pylint: disable=too-many-nested-blocks
+        while not self.state.shutdown:  # pylint: disable=too-many-nested-blocks
             chunk = []
             while True:
                 # Dandelion fluff trigger by expiration
-                handleExpiredDandelion(dandelion_ins.expire(queues.invQueue))
+                handleExpiredDandelion(dandelion_ins.expire(self.queues.invQueue))
                 try:
-                    data = queues.invQueue.get(False)
+                    data = self.queues.invQueue.get(False)
                     chunk.append((data[0], data[1]))
                     # locally generated
                     if len(data) == 2 or data[2] is None:
@@ -78,7 +84,7 @@ class InvThread(StoppableThread):
                                 if random.randint(1, 100) >= dandelion_ins.enabled:  # nosec B311
                                     fluffs.append(inv[1])
                                 # send a dinv only if the stem node supports dandelion
-                                elif connection.services & protocol.NODE_DANDELION > 0:
+                                elif connection.services & self.protocol.NODE_DANDELION > 0:
                                     stems.append(inv[1])
                                 else:
                                     fluffs.append(inv[1])
@@ -87,20 +93,20 @@ class InvThread(StoppableThread):
 
                     if fluffs:
                         random.shuffle(fluffs)
-                        connection.append_write_buf(protocol.CreatePacket(
+                        connection.append_write_buf(self.protocol.CreatePacket(
                             'inv',
-                            addresses.encodeVarint(
+                            self.addresses.encodeVarint(
                                 len(fluffs)) + ''.join(fluffs)))
                     if stems:
                         random.shuffle(stems)
-                        connection.append_write_buf(protocol.CreatePacket(
+                        connection.append_write_buf(self.protocol.CreatePacket(
                             'dinv',
-                            addresses.encodeVarint(
+                            self.addresses.encodeVarint(
                                 len(stems)) + ''.join(stems)))
 
-            queues.invQueue.iterate()
+            self.queues.invQueue.iterate()
             for _ in range(len(chunk)):
-                queues.invQueue.task_done()
+                self.queues.invQueue.task_done()
 
             dandelion_ins.reRandomiseStems()
 
diff --git a/src/network/networkthread.py b/src/network/networkthread.py
index 1fe9ed77..ea1ca264 100644
--- a/src/network/networkthread.py
+++ b/src/network/networkthread.py
@@ -3,7 +3,6 @@ A thread to handle network concerns
 """
 import network.asyncore_pollchoose as asyncore
 import connectionpool
-from network import queues
 from threads import StoppableThread
 
 
@@ -11,12 +10,16 @@ class BMNetworkThread(StoppableThread):
     """Main network thread"""
     name = "Asyncore"
 
+    def __init__(self, queues):
+        self.queues = queues
+        StoppableThread.__init__(self)
+
     def run(self):
         try:
             while not self._stopped:
                 connectionpool.pool.loop()
         except Exception as e:
-            queues.excQueue.put((self.name, e))
+            self.queues.excQueue.put((self.name, e))
             raise
 
     def stopThread(self):
diff --git a/src/network/receivequeuethread.py b/src/network/receivequeuethread.py
index a6b38bd5..49a9dccd 100644
--- a/src/network/receivequeuethread.py
+++ b/src/network/receivequeuethread.py
@@ -7,20 +7,20 @@ import socket
 
 import connectionpool
 from network.advanceddispatcher import UnknownStateError
-from network import queues
 from threads import StoppableThread
 
 
 class ReceiveQueueThread(StoppableThread):
     """This thread processes data received from the network
     (which is done by the asyncore thread)"""
-    def __init__(self, num=0):
+    def __init__(self, queues, num=0):
+        self.queues = queues
         super(ReceiveQueueThread, self).__init__(name="ReceiveQueue_%i" % num)
 
     def run(self):
         while not self._stopped:
             try:
-                dest = queues.receiveDataQueue.get(block=True, timeout=1)
+                dest = self.queues.receiveDataQueue.get(block=True, timeout=1)
             except Queue.Empty:
                 continue
 
@@ -38,7 +38,7 @@ class ReceiveQueueThread(StoppableThread):
                 connection = connectionpool.pool.getConnectionByAddr(dest)
             # connection object not found
             except KeyError:
-                queues.receiveDataQueue.task_done()
+                self.queues.receiveDataQueue.task_done()
                 continue
             try:
                 connection.process()
@@ -52,4 +52,4 @@ class ReceiveQueueThread(StoppableThread):
                     self.logger.error('Socket error: %s', err)
             except:  # noqa:E722
                 self.logger.error('Error processing', exc_info=True)
-            queues.receiveDataQueue.task_done()
+            self.queues.receiveDataQueue.task_done()
diff --git a/src/network/uploadthread.py b/src/network/uploadthread.py
index 6f67cc64..c309ec7d 100644
--- a/src/network/uploadthread.py
+++ b/src/network/uploadthread.py
@@ -4,7 +4,6 @@
 import time
 
 import helper_random
-from network import protocol, state
 import connectionpool
 from randomtrackingdict import RandomTrackingDict
 from network import dandelion_ins
@@ -18,6 +17,11 @@ class UploadThread(StoppableThread):
     maxBufSize = 2097152  # 2MB
     name = "Uploader"
 
+    def __init__(self, protocol, state):
+        self.protocol = protocol
+        self.state = state
+        StoppableThread.__init__(self)
+
     def run(self):
         while not self._stopped:
             uploaded = 0
@@ -48,8 +52,8 @@ class UploadThread(StoppableThread):
                             i.destination)
                         break
                     try:
-                        payload.extend(protocol.CreatePacket(
-                            'object', state.Inventory[chunk].payload))
+                        payload.extend(self.protocol.CreatePacket(
+                            'object', self.state.Inventory[chunk].payload))
                         chunk_count += 1
                     except KeyError:
                         i.antiIntersectionDelay()
diff --git a/src/tests/test_network.py b/src/tests/test_network.py
index 206117e0..f9a4cfe0 100644
--- a/src/tests/test_network.py
+++ b/src/tests/test_network.py
@@ -74,7 +74,7 @@ class TestNetwork(TestPartialRun):
 
         for _ in range(10):
             try:
-                self.state.announceThread.announceSelf()
+                self.state.announceThread.announceSelf(self.config)
             except AttributeError:
                 self.fail('state.announceThread is not set properly')
             time.sleep(1)