change as per creae_function
This commit is contained in:
parent
27e07954a0
commit
ff84454a95
|
@ -207,18 +207,18 @@ class UpgradeDB():
|
||||||
# table. Then we can take out the hash field.
|
# table. Then we can take out the hash field.
|
||||||
self.cur.execute('''ALTER TABLE pubkeys ADD address text DEFAULT '' ''')
|
self.cur.execute('''ALTER TABLE pubkeys ADD address text DEFAULT '' ''')
|
||||||
|
|
||||||
# item = '''UPDATE 'pubkeys' SET `hash`='87788778877887788787' where hash=''; '''
|
item = '''UPDATE 'pubkeys' SET `hash`='87788778877887788787' where hash=''; '''
|
||||||
# self.cur.execute(item)
|
self.cur.execute(item)
|
||||||
|
|
||||||
|
# create_function
|
||||||
try:
|
try:
|
||||||
self.conn.create_function("enaddr", 3, func=encodeAddress, deterministic=True)
|
self.conn.create_function("enaddr", 3, func=encodeAddress, deterministic=True)
|
||||||
except Exception as err:
|
except (TypeError, sqlite3.NotSupportedError) as err:
|
||||||
logger.error("Got error while pass deterministic in sqlite create function {}".format(err))
|
logger.error("Got error while pass deterministic in sqlite create function {}, Passing 3 params".format(err))
|
||||||
print("Got error while pass deterministic in sqlite create function {}".format(err))
|
|
||||||
self.conn.create_function("enaddr", 3, encodeAddress)
|
self.conn.create_function("enaddr", 3, encodeAddress)
|
||||||
|
|
||||||
# replica for loop to update hashed address
|
# replica for loop to update hashed address
|
||||||
self.cur.execute(''' UPDATE pubkeys SET address=(select enaddr(pubkeys.addressversion, 1, pubkeys.hash)) WHERE hash=pubkeys.hash; ''')
|
self.cur.execute('''UPDATE pubkeys SET address=(select enaddr(pubkeys.addressversion, 1, hash)) WHERE hash=pubkeys.hash; ''')
|
||||||
|
|
||||||
self.run_migrations("9_1")
|
self.run_migrations("9_1")
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
"""
|
"""
|
||||||
Test for sqlThread blind signatures
|
Test for sqlThread
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
@ -7,10 +7,11 @@ import unittest
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import sys
|
import sys
|
||||||
from ..state import appdata
|
from ..state import appdata
|
||||||
from ..helper_sql import sqlStoredProcedure
|
from ..helper_sql import sqlStoredProcedure, sql_ready
|
||||||
from ..class_sqlThread import (sqlThread, UpgradeDB)
|
from ..class_sqlThread import (sqlThread, UpgradeDB)
|
||||||
from ..addresses import encodeAddress
|
from ..addresses import encodeAddress
|
||||||
|
|
||||||
|
|
||||||
class TestSqlThread(unittest.TestCase):
|
class TestSqlThread(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
Test case for SQLThread
|
Test case for SQLThread
|
||||||
|
@ -20,12 +21,16 @@ class TestSqlThread(unittest.TestCase):
|
||||||
conn.text_factory = str
|
conn.text_factory = str
|
||||||
cur = conn.cursor()
|
cur = conn.cursor()
|
||||||
|
|
||||||
|
# query file path
|
||||||
|
root_path = os.path.dirname(os.path.dirname(__file__))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
# Start SQL thread
|
# Start SQL thread
|
||||||
sqlLookup = sqlThread()
|
sqlLookup = sqlThread()
|
||||||
sqlLookup.daemon = False
|
sqlLookup.daemon = False
|
||||||
sqlLookup.start()
|
sqlLookup.start()
|
||||||
|
sql_ready.wait()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUp(cls):
|
def setUp(cls):
|
||||||
|
@ -41,12 +46,12 @@ class TestSqlThread(unittest.TestCase):
|
||||||
# Stop sql thread
|
# Stop sql thread
|
||||||
sqlStoredProcedure('exit')
|
sqlStoredProcedure('exit')
|
||||||
|
|
||||||
def normalize_version(self, file):
|
def initialise_database(self, file):
|
||||||
|
"""
|
||||||
|
Initialise DB
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
root_path = os.path.dirname(os.path.dirname(__file__))
|
sql_file = open(os.path.join(self.root_path, "tests/sql/{}.sql".format(file)))
|
||||||
sql_file_path = os.path.join(root_path, 'tests/sql/')
|
|
||||||
sql_file_path = os.path.join(sql_file_path, "init_version_{}.sql".format(file))
|
|
||||||
sql_file = open(sql_file_path)
|
|
||||||
sql_as_string = sql_file.read()
|
sql_as_string = sql_file.read()
|
||||||
self.cur.executescript(sql_as_string)
|
self.cur.executescript(sql_as_string)
|
||||||
self.conn.commit()
|
self.conn.commit()
|
||||||
|
@ -71,7 +76,7 @@ class TestSqlThread(unittest.TestCase):
|
||||||
version = func_name.rsplit('_', 1)[-1]
|
version = func_name.rsplit('_', 1)[-1]
|
||||||
|
|
||||||
# Update versions DB mocking
|
# Update versions DB mocking
|
||||||
self.normalize_version(version)
|
self.initialise_database("init_version_{}".format(version))
|
||||||
|
|
||||||
# Test versions
|
# Test versions
|
||||||
upgrade_db = UpgradeDB()
|
upgrade_db = UpgradeDB()
|
||||||
|
@ -225,15 +230,17 @@ class TestSqlThread(unittest.TestCase):
|
||||||
print("Got error while pass deterministic in sqlite create function {}, So called function directly".format(err))
|
print("Got error while pass deterministic in sqlite create function {}, So called function directly".format(err))
|
||||||
self.conn.create_function("enaddr", 3, encodeAddress)
|
self.conn.create_function("enaddr", 3, encodeAddress)
|
||||||
|
|
||||||
self.cur.execute('''CREATE TABLE testhash (addressversion int, hash blob, address text, UNIQUE(address) ON CONFLICT IGNORE) ''')
|
encoded_str = encodeAddress(1, 1, "21122112211221122112")
|
||||||
|
|
||||||
self.conn.execute('''INSERT INTO testhash (addressversion, hash) VALUES(1, "21122112211221122112"); ''')
|
# Initialise Database
|
||||||
|
self.initialise_database("create_function")
|
||||||
|
|
||||||
# call function in query
|
# call function in query
|
||||||
self.cur.execute('''UPDATE testhash SET address=(select enaddr(testhash.addressversion, 1, testhash.hash)) WHERE hash=testhash.hash; ''')
|
self.cur.execute('''UPDATE testhash SET address=(select enaddr(testhash.addressversion, 1, hash)) WHERE hash=testhash.hash; ''')
|
||||||
|
|
||||||
# Assertion
|
# Assertion
|
||||||
self.cur.execute('''select address from testhash;''')
|
self.cur.execute('''select address from testhash;''')
|
||||||
hsh = self.cur.fetchone()[0]
|
hsh = self.cur.fetchone()[0]
|
||||||
self.assertNotEqual(hsh, 1, "test case fail for create_function")
|
self.assertNotEqual(hsh, 1, "test case fail for create_function")
|
||||||
|
self.assertEqual(hsh, encoded_str, "test case fail for create_function")
|
||||||
self.conn.execute('''DROP TABLE testhash;''')
|
self.conn.execute('''DROP TABLE testhash;''')
|
||||||
|
|
Loading…
Reference in New Issue
Block a user