change as per creae_function

This commit is contained in:
Muzahid 2021-03-11 14:08:41 +05:30
parent 27e07954a0
commit ff84454a95
Signed by untrusted user: cis-muzahid
GPG Key ID: 1DC85E7D3AB613EA
2 changed files with 24 additions and 17 deletions

View File

@ -207,18 +207,18 @@ class UpgradeDB():
# table. Then we can take out the hash field. # table. Then we can take out the hash field.
self.cur.execute('''ALTER TABLE pubkeys ADD address text DEFAULT '' ''') self.cur.execute('''ALTER TABLE pubkeys ADD address text DEFAULT '' ''')
# item = '''UPDATE 'pubkeys' SET `hash`='87788778877887788787' where hash=''; ''' item = '''UPDATE 'pubkeys' SET `hash`='87788778877887788787' where hash=''; '''
# self.cur.execute(item) self.cur.execute(item)
# create_function
try: try:
self.conn.create_function("enaddr", 3, func=encodeAddress, deterministic=True) self.conn.create_function("enaddr", 3, func=encodeAddress, deterministic=True)
except Exception as err: except (TypeError, sqlite3.NotSupportedError) as err:
logger.error("Got error while pass deterministic in sqlite create function {}".format(err)) logger.error("Got error while pass deterministic in sqlite create function {}, Passing 3 params".format(err))
print("Got error while pass deterministic in sqlite create function {}".format(err))
self.conn.create_function("enaddr", 3, encodeAddress) self.conn.create_function("enaddr", 3, encodeAddress)
# replica for loop to update hashed address # replica for loop to update hashed address
self.cur.execute(''' UPDATE pubkeys SET address=(select enaddr(pubkeys.addressversion, 1, pubkeys.hash)) WHERE hash=pubkeys.hash; ''') self.cur.execute('''UPDATE pubkeys SET address=(select enaddr(pubkeys.addressversion, 1, hash)) WHERE hash=pubkeys.hash; ''')
self.run_migrations("9_1") self.run_migrations("9_1")

View File

@ -1,5 +1,5 @@
""" """
Test for sqlThread blind signatures Test for sqlThread
""" """
import os import os
@ -7,10 +7,11 @@ import unittest
import sqlite3 import sqlite3
import sys import sys
from ..state import appdata from ..state import appdata
from ..helper_sql import sqlStoredProcedure from ..helper_sql import sqlStoredProcedure, sql_ready
from ..class_sqlThread import (sqlThread, UpgradeDB) from ..class_sqlThread import (sqlThread, UpgradeDB)
from ..addresses import encodeAddress from ..addresses import encodeAddress
class TestSqlThread(unittest.TestCase): class TestSqlThread(unittest.TestCase):
""" """
Test case for SQLThread Test case for SQLThread
@ -20,12 +21,16 @@ class TestSqlThread(unittest.TestCase):
conn.text_factory = str conn.text_factory = str
cur = conn.cursor() cur = conn.cursor()
# query file path
root_path = os.path.dirname(os.path.dirname(__file__))
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
# Start SQL thread # Start SQL thread
sqlLookup = sqlThread() sqlLookup = sqlThread()
sqlLookup.daemon = False sqlLookup.daemon = False
sqlLookup.start() sqlLookup.start()
sql_ready.wait()
@classmethod @classmethod
def setUp(cls): def setUp(cls):
@ -41,12 +46,12 @@ class TestSqlThread(unittest.TestCase):
# Stop sql thread # Stop sql thread
sqlStoredProcedure('exit') sqlStoredProcedure('exit')
def normalize_version(self, file): def initialise_database(self, file):
"""
Initialise DB
"""
try: try:
root_path = os.path.dirname(os.path.dirname(__file__)) sql_file = open(os.path.join(self.root_path, "tests/sql/{}.sql".format(file)))
sql_file_path = os.path.join(root_path, 'tests/sql/')
sql_file_path = os.path.join(sql_file_path, "init_version_{}.sql".format(file))
sql_file = open(sql_file_path)
sql_as_string = sql_file.read() sql_as_string = sql_file.read()
self.cur.executescript(sql_as_string) self.cur.executescript(sql_as_string)
self.conn.commit() self.conn.commit()
@ -71,7 +76,7 @@ class TestSqlThread(unittest.TestCase):
version = func_name.rsplit('_', 1)[-1] version = func_name.rsplit('_', 1)[-1]
# Update versions DB mocking # Update versions DB mocking
self.normalize_version(version) self.initialise_database("init_version_{}".format(version))
# Test versions # Test versions
upgrade_db = UpgradeDB() upgrade_db = UpgradeDB()
@ -225,15 +230,17 @@ class TestSqlThread(unittest.TestCase):
print("Got error while pass deterministic in sqlite create function {}, So called function directly".format(err)) print("Got error while pass deterministic in sqlite create function {}, So called function directly".format(err))
self.conn.create_function("enaddr", 3, encodeAddress) self.conn.create_function("enaddr", 3, encodeAddress)
self.cur.execute('''CREATE TABLE testhash (addressversion int, hash blob, address text, UNIQUE(address) ON CONFLICT IGNORE) ''') encoded_str = encodeAddress(1, 1, "21122112211221122112")
self.conn.execute('''INSERT INTO testhash (addressversion, hash) VALUES(1, "21122112211221122112"); ''') # Initialise Database
self.initialise_database("create_function")
# call function in query # call function in query
self.cur.execute('''UPDATE testhash SET address=(select enaddr(testhash.addressversion, 1, testhash.hash)) WHERE hash=testhash.hash; ''') self.cur.execute('''UPDATE testhash SET address=(select enaddr(testhash.addressversion, 1, hash)) WHERE hash=testhash.hash; ''')
# Assertion # Assertion
self.cur.execute('''select address from testhash;''') self.cur.execute('''select address from testhash;''')
hsh = self.cur.fetchone()[0] hsh = self.cur.fetchone()[0]
self.assertNotEqual(hsh, 1, "test case fail for create_function") self.assertNotEqual(hsh, 1, "test case fail for create_function")
self.assertEqual(hsh, encoded_str, "test case fail for create_function")
self.conn.execute('''DROP TABLE testhash;''') self.conn.execute('''DROP TABLE testhash;''')