From 3052a3c2c24917ee120a8fd954ce6bbe9a5b8b25 Mon Sep 17 00:00:00 2001
From: "kuldeep.k@cisinlabs.com" <kuldeep.k@cisinlabs.com>
Date: Mon, 8 Nov 2021 21:44:12 +0530
Subject: [PATCH] Fixed import issues, updated message.py functions and added
 testcases for messagetypes

---
 src/messagetypes/__init__.py | 18 +++++++++--------
 src/messagetypes/message.py  | 32 ++++++++++++++++-------------
 src/messagetypes/vote.py     |  2 +-
 src/tests/test_msg.py        | 39 ++++++++++++++++++++++++++++++++++++
 4 files changed, 68 insertions(+), 23 deletions(-)
 create mode 100644 src/tests/test_msg.py

diff --git a/src/messagetypes/__init__.py b/src/messagetypes/__init__.py
index 86aca1f4..774f8932 100644
--- a/src/messagetypes/__init__.py
+++ b/src/messagetypes/__init__.py
@@ -1,10 +1,8 @@
 import logging
 from importlib import import_module
 from os import listdir, path
-from string import lower
 
-import messagetypes
-import paths
+from pybitmessage import paths
 
 logger = logging.getLogger('default')
 
@@ -12,7 +10,7 @@ logger = logging.getLogger('default')
 class MsgBase(object):  # pylint: disable=too-few-public-methods
     """Base class for message types"""
     def __init__(self):
-        self.data = {"": lower(type(self).__name__)}
+        self.data = {"": type(self).__name__.lower()}
 
 
 def constructObject(data):
@@ -21,10 +19,14 @@ def constructObject(data):
     if data[""] not in whitelist:
         return None
     try:
-        classBase = getattr(getattr(messagetypes, data[""]), data[""].title())
-    except (NameError, AttributeError):
+        classBase = getattr(import_module(".{}".format(data[""]), __name__), data[""].title())
+    except (NameError, AttributeError, ValueError, ImportError):
         logger.error("Don't know how to handle message type: \"%s\"", data[""], exc_info=True)
         return None
+    except:  # noqa:E722
+        logger.error("Don't know how to handle message type: \"%s\"", data[""], exc_info=True)
+        return None
+
     try:
         returnObj = classBase()
         returnObj.decode(data)
@@ -40,7 +42,7 @@ def constructObject(data):
 
 if paths.frozen is not None:
     import message  # noqa : F401 flake8: disable=unused-import
-    import vote     # noqa : F401 flake8: disable=unused-import
+    import vote  # noqa : F401 flake8: disable=unused-import
 else:
     for mod in listdir(path.dirname(__file__)):
         if mod == "__init__.py":
@@ -49,7 +51,7 @@ else:
         if splitted[1] != ".py":
             continue
         try:
-            import_module(".{}".format(splitted[0]), "messagetypes")
+            import_module(".{}".format(splitted[0]), __name__)
         except ImportError:
             logger.error("Error importing %s", mod, exc_info=True)
         else:
diff --git a/src/messagetypes/message.py b/src/messagetypes/message.py
index 573732d4..a3ceafe5 100644
--- a/src/messagetypes/message.py
+++ b/src/messagetypes/message.py
@@ -1,6 +1,6 @@
 import logging
 
-from messagetypes import MsgBase
+from pybitmessage.messagetypes import MsgBase
 
 logger = logging.getLogger('default')
 
@@ -12,23 +12,27 @@ class Message(MsgBase):
     def decode(self, data):
         """Decode a message"""
         # UTF-8 and variable type validator
-        if isinstance(data["subject"], str):
-            self.subject = unicode(data["subject"], 'utf-8', 'replace')
-        else:
-            self.subject = unicode(str(data["subject"]), 'utf-8', 'replace')
-        if isinstance(data["body"], str):
-            self.body = unicode(data["body"], 'utf-8', 'replace')
-        else:
-            self.body = unicode(str(data["body"]), 'utf-8', 'replace')
+        subject = data.get("subject", "")
+        body = data.get("body", "")
+        try:
+            data["subject"] = subject.decode('utf-8', 'replace')
+        except:
+            data["subject"] = ''
+
+        try:
+            data["body"] = body.decode('utf-8', 'replace')
+        except:
+            data["body"] = ''
+
+        self.subject = data["subject"]
+        self.body = data["body"]
 
     def encode(self, data):
         """Encode a message"""
         super(Message, self).__init__()
-        try:
-            self.data["subject"] = data["subject"]
-            self.data["body"] = data["body"]
-        except KeyError as e:
-            logger.error("Missing key %s", e)
+        self.data["subject"] = data.get("subject", "")
+        self.data["body"] = data.get("body", "")
+
         return self.data
 
     def process(self):
diff --git a/src/messagetypes/vote.py b/src/messagetypes/vote.py
index b559c256..b494c15d 100644
--- a/src/messagetypes/vote.py
+++ b/src/messagetypes/vote.py
@@ -1,6 +1,6 @@
 import logging
 
-from messagetypes import MsgBase
+from pybitmessage.messagetypes import MsgBase
 
 logger = logging.getLogger('default')
 
diff --git a/src/tests/test_msg.py b/src/tests/test_msg.py
new file mode 100644
index 00000000..cb586fa5
--- /dev/null
+++ b/src/tests/test_msg.py
@@ -0,0 +1,39 @@
+"""Tests for messagetypes module"""
+import unittest
+
+from six import text_type
+
+from pybitmessage import messagetypes
+
+sample_data = {"": "message", "subject": "subject", "body": "body"}
+invalid_data = {"": "message", "subject": b"\x01\x02\x03", "body": b"\x01\x02\x03\x04"}
+
+
+class TestMessageTypes(unittest.TestCase):
+    """A test case for messagetypes"""
+
+    def test_msg_encode(self):
+        """Test msg encode"""
+        msgObj = messagetypes.message.Message()
+        encoded_message = msgObj.encode(sample_data)
+        self.assertEqual(type(encoded_message), dict)
+        self.assertEqual(encoded_message["subject"], sample_data["subject"])
+        self.assertEqual(encoded_message["body"], sample_data["body"])
+
+    def test_msg_decode(self):
+        """Test msg decode"""
+        msgObj = messagetypes.constructObject(sample_data)
+        self.assertEqual(msgObj.subject, sample_data["subject"])
+        self.assertEqual(msgObj.body, sample_data["body"])
+
+    def test_invalid_data_type(self):
+        """Test invalid data type"""
+        msgObj = messagetypes.constructObject(invalid_data)
+        self.assertTrue(isinstance(msgObj.subject, text_type))
+        self.assertTrue(isinstance(msgObj.body, text_type))
+
+    def test_msg_process(self):
+        """Test msg process"""
+        msgObj = messagetypes.constructObject(sample_data)
+        self.assertTrue(isinstance(msgObj, messagetypes.message.Message))
+        self.assertIsNone(msgObj.process())