From ff84454a95ce250f2492202048a874e9bff574b0 Mon Sep 17 00:00:00 2001 From: Muzahid Date: Thu, 11 Mar 2021 14:08:41 +0530 Subject: [PATCH] change as per creae_function --- src/class_sqlThread.py | 12 ++++++------ src/tests/test_sqlthread.py | 29 ++++++++++++++++++----------- 2 files changed, 24 insertions(+), 17 deletions(-) diff --git a/src/class_sqlThread.py b/src/class_sqlThread.py index dad9591a..613a9298 100644 --- a/src/class_sqlThread.py +++ b/src/class_sqlThread.py @@ -207,18 +207,18 @@ class UpgradeDB(): # table. Then we can take out the hash field. self.cur.execute('''ALTER TABLE pubkeys ADD address text DEFAULT '' ''') - # item = '''UPDATE 'pubkeys' SET `hash`='87788778877887788787' where hash=''; ''' - # self.cur.execute(item) + item = '''UPDATE 'pubkeys' SET `hash`='87788778877887788787' where hash=''; ''' + self.cur.execute(item) + # create_function try: self.conn.create_function("enaddr", 3, func=encodeAddress, deterministic=True) - except Exception as err: - logger.error("Got error while pass deterministic in sqlite create function {}".format(err)) - print("Got error while pass deterministic in sqlite create function {}".format(err)) + except (TypeError, sqlite3.NotSupportedError) as err: + logger.error("Got error while pass deterministic in sqlite create function {}, Passing 3 params".format(err)) self.conn.create_function("enaddr", 3, encodeAddress) # replica for loop to update hashed address - self.cur.execute(''' UPDATE pubkeys SET address=(select enaddr(pubkeys.addressversion, 1, pubkeys.hash)) WHERE hash=pubkeys.hash; ''') + self.cur.execute('''UPDATE pubkeys SET address=(select enaddr(pubkeys.addressversion, 1, hash)) WHERE hash=pubkeys.hash; ''') self.run_migrations("9_1") diff --git a/src/tests/test_sqlthread.py b/src/tests/test_sqlthread.py index 11b39f6f..81b68881 100644 --- a/src/tests/test_sqlthread.py +++ b/src/tests/test_sqlthread.py @@ -1,5 +1,5 @@ """ - Test for sqlThread blind signatures + Test for sqlThread """ import os @@ -7,10 +7,11 @@ import unittest import sqlite3 import sys from ..state import appdata -from ..helper_sql import sqlStoredProcedure +from ..helper_sql import sqlStoredProcedure, sql_ready from ..class_sqlThread import (sqlThread, UpgradeDB) from ..addresses import encodeAddress + class TestSqlThread(unittest.TestCase): """ Test case for SQLThread @@ -20,12 +21,16 @@ class TestSqlThread(unittest.TestCase): conn.text_factory = str cur = conn.cursor() + # query file path + root_path = os.path.dirname(os.path.dirname(__file__)) + @classmethod def setUpClass(cls): # Start SQL thread sqlLookup = sqlThread() sqlLookup.daemon = False sqlLookup.start() + sql_ready.wait() @classmethod def setUp(cls): @@ -41,12 +46,12 @@ class TestSqlThread(unittest.TestCase): # Stop sql thread sqlStoredProcedure('exit') - def normalize_version(self, file): + def initialise_database(self, file): + """ + Initialise DB + """ try: - root_path = os.path.dirname(os.path.dirname(__file__)) - sql_file_path = os.path.join(root_path, 'tests/sql/') - sql_file_path = os.path.join(sql_file_path, "init_version_{}.sql".format(file)) - sql_file = open(sql_file_path) + sql_file = open(os.path.join(self.root_path, "tests/sql/{}.sql".format(file))) sql_as_string = sql_file.read() self.cur.executescript(sql_as_string) self.conn.commit() @@ -71,7 +76,7 @@ class TestSqlThread(unittest.TestCase): version = func_name.rsplit('_', 1)[-1] # Update versions DB mocking - self.normalize_version(version) + self.initialise_database("init_version_{}".format(version)) # Test versions upgrade_db = UpgradeDB() @@ -225,15 +230,17 @@ class TestSqlThread(unittest.TestCase): print("Got error while pass deterministic in sqlite create function {}, So called function directly".format(err)) self.conn.create_function("enaddr", 3, encodeAddress) - self.cur.execute('''CREATE TABLE testhash (addressversion int, hash blob, address text, UNIQUE(address) ON CONFLICT IGNORE) ''') + encoded_str = encodeAddress(1, 1, "21122112211221122112") - self.conn.execute('''INSERT INTO testhash (addressversion, hash) VALUES(1, "21122112211221122112"); ''') + # Initialise Database + self.initialise_database("create_function") # call function in query - self.cur.execute('''UPDATE testhash SET address=(select enaddr(testhash.addressversion, 1, testhash.hash)) WHERE hash=testhash.hash; ''') + self.cur.execute('''UPDATE testhash SET address=(select enaddr(testhash.addressversion, 1, hash)) WHERE hash=testhash.hash; ''') # Assertion self.cur.execute('''select address from testhash;''') hsh = self.cur.fetchone()[0] self.assertNotEqual(hsh, 1, "test case fail for create_function") + self.assertEqual(hsh, encoded_str, "test case fail for create_function") self.conn.execute('''DROP TABLE testhash;''')