diff --git a/src/class_sqlThread.py b/src/class_sqlThread.py index 37c32e70..6e623ccb 100644 --- a/src/class_sqlThread.py +++ b/src/class_sqlThread.py @@ -69,8 +69,6 @@ class UpgradeDB(object): Execute SQL files and queries """ try: - print("=======================") - print(file_name) if int(file_name) == 8: res = self.cur.execute('''PRAGMA table_info('inbox');''') print("""""""""""""""-----------res""""""""""""""") @@ -139,6 +137,19 @@ class sqlThread(threading.Thread, UpgradeDB): def __init__(self): super(sqlThread, self).__init__() threading.Thread.__init__(self, name="SQL") + self.__flag = threading.Event() + + # Adding for pause/resume thread + self.__flag = threading.Event() + self.__flag.set() # Set to True + self.__running = threading.Event() # Used to stop the thread identification + self.__running.set() # Set running to True + + def pause(self): + self.__flag.clear() # Set to False to block the thread + + def resume(self): + self.__flag.set() # Set to True, let the thread stop blocking def run(self): # pylint: disable=too-many-locals, too-many-branches, too-many-statements, # Redefinition-of-parameters-type-from-tuple-to-str, R0204, line-too-long, E501 diff --git a/src/tests/sql/init_version_9.sql b/src/tests/sql/init_version_9.sql index b3f8db44..0f6a308e 100644 --- a/src/tests/sql/init_version_9.sql +++ b/src/tests/sql/init_version_9.sql @@ -17,11 +17,13 @@ CREATE TABLE IF NOT EXISTS `sent` ( UNIQUE(msgid) ON CONFLICT REPLACE ) ; +DROP TABLE IF EXISTS `pubkeys`; + CREATE TABLE IF NOT EXISTS `pubkeys` ( - `hash` text, - `addressversion` int, - `transmitdata` blob, - `time` int, - `usedpersonally` text, + `hash` text DEFAULT NULL, + `addressversion` int DEFAULT NULL, + `transmitdata` blob DEFAULT NULL, + `time` int DEFAULT NULL, + `usedpersonally` text DEFAULT NULL UNIQUE(hash) ON CONFLICT REPLACE ) ; diff --git a/src/tests/test_sqlthread.py b/src/tests/test_sqlthread.py index 151b5095..3b5d2239 100644 --- a/src/tests/test_sqlthread.py +++ b/src/tests/test_sqlthread.py @@ -110,20 +110,20 @@ class TestSqlThread(unittest.TestCase): func_name = func.__name__ version = func_name.rsplit('_', 1)[-1] - print("-------------=========") - print(version) - print("-------------=========") if int(version) == 8: res = sqlQuery('''PRAGMA table_info('inbox');''') print("""""""""""""""res""""""""""""""") print(res) + # sqlThread().pause() # Update versions DB mocking self.initialise_database("init_version_{}".format(version)) - + print("-------------=========") + print(version) + print("-------------=========") if int(version) == 9: sqlThread().create_function() @@ -133,12 +133,12 @@ class TestSqlThread(unittest.TestCase): print("""""""""""""""-----------res""""""""""""""") print(res) - # Test versions upgrade_db = UpgradeDB() upgrade_db._upgrade_one_level_sql_statement(int(version)) # pylint: disable= W0212, protected-access # upgrade_db.upgrade_to_latest(upgrade_db.cur, upgrade_db.conn) # upgrade_db.upgrade_to_latest(upgrade_db.cur, upgrade_db.conn, int(version)) + # sqlThread().resume() return func(*args) # <-- use (self, ...) func = self return wrapper @@ -255,20 +255,20 @@ class TestSqlThread(unittest.TestCase): result = list(filter_table_column(res, "sighash")) self.assertEqual(result, ['sighash'], "Data not migrated for version 8") - # @version - # def test_sql_thread_version_9(self): - # """ - # Test with version 9 - # """ - # - # # Assertion - # res = sqlQuery(''' SELECT count(name) FROM sqlite_master WHERE type='table' AND name='pubkeys_backup' ''') - # self.assertNotEqual(res[0][0], 1, "Table pubkeys_backup not deleted") - # - # res = sqlQuery('''PRAGMA table_info('pubkeys');''') - # # res = res.fetchall() - # result = list(filter_table_column(res, "address")) - # self.assertEqual(result, ['address'], "Data not migrated for version 9") + @version + def test_sql_thread_version_9(self): + """ + Test with version 9 + """ + + # Assertion + res = sqlQuery(''' SELECT count(name) FROM sqlite_master WHERE type='table' AND name='pubkeys_backup' ''') + self.assertNotEqual(res[0][0], 1, "Table pubkeys_backup not deleted") + + res = sqlQuery('''PRAGMA table_info('pubkeys');''') + # res = res.fetchall() + result = list(filter_table_column(res, "address")) + self.assertEqual(result, ['address'], "Data not migrated for version 9") # @version # def test_sql_thread_version_10(self):