From 04bd5a0264603798cf5472f6bc54be88305e7fe7 Mon Sep 17 00:00:00 2001 From: Muzahid Date: Thu, 4 Mar 2021 19:45:41 +0530 Subject: [PATCH] Add sqlite functions while versioning --- src/class_sqlThread.py | 30 ++++++++----- src/helper_sql.py | 18 +++++++- src/tests/sql/create_function.sql | 11 +++++ src/tests/test_sqlthread.py | 73 +++++++++++++++++++++++++++++++ 4 files changed, 120 insertions(+), 12 deletions(-) create mode 100644 src/tests/sql/create_function.sql create mode 100644 src/tests/test_sqlthread.py diff --git a/src/class_sqlThread.py b/src/class_sqlThread.py index 84188408..86460e7b 100644 --- a/src/class_sqlThread.py +++ b/src/class_sqlThread.py @@ -8,7 +8,6 @@ import sqlite3 import sys import threading import time - import helper_sql import helper_startup import paths @@ -18,6 +17,7 @@ import tr from bmconfigparser import BMConfigParser from debug import logger # pylint: disable=attribute-defined-outside-init,protected-access +from addresses import encodeAddress class sqlThread(threading.Thread): @@ -35,6 +35,9 @@ class sqlThread(threading.Thread): self.cur.execute('PRAGMA secure_delete = true') + # call create_function for encode address + self.create_function() + try: self.cur.execute( '''CREATE TABLE inbox (msgid blob, toaddress text, fromaddress text, subject text,''' @@ -325,6 +328,7 @@ class sqlThread(threading.Thread): # We'll also need a `sleeptill` field and a `ttl` field. Also we # can combine the pubkeyretrynumber and msgretrynumber into one. + item = '''SELECT value FROM settings WHERE key='version';''' parameters = '' self.cur.execute(item, parameters) @@ -358,16 +362,11 @@ class sqlThread(threading.Thread): logger.debug('In messages.dat database, adding address field to the pubkeys table.') # We're going to have to calculate the address for each row in the pubkeys # table. Then we can take out the hash field. - self.cur.execute('''ALTER TABLE pubkeys ADD address text DEFAULT '' ''') - self.cur.execute('''SELECT hash, addressversion FROM pubkeys''') - queryResult = self.cur.fetchall() - from addresses import encodeAddress - for row in queryResult: - addressHash, addressVersion = row - address = encodeAddress(addressVersion, 1, hash) - item = '''UPDATE pubkeys SET address=? WHERE hash=?;''' - parameters = (address, addressHash) - self.cur.execute(item, parameters) + self.cur.execute('''ALTER TABLE pubkeys ADD address text DEFAULT '' ;''') + + # replica for loop to update hashed address + self.cur.execute('''UPDATE pubkeys SET address=(enaddr(pubkeys.addressversion, 1, hash)) WHERE hash=pubkeys.hash; ''') + # Now we can remove the hash field from the pubkeys table. self.cur.execute( '''CREATE TEMPORARY TABLE pubkeys_backup''' @@ -622,3 +621,12 @@ class sqlThread(threading.Thread): helper_sql.sqlReturnQueue.put((self.cur.fetchall(), rowcount)) # helper_sql.sqlSubmitQueue.task_done() + + def create_function(self): + # create_function + try: + self.conn.create_function("enaddr", 3, func=encodeAddress, deterministic=True) + except (TypeError, sqlite3.NotSupportedError) as err: + logger.debug( + "Got error while pass deterministic in sqlite create function {}, Passing 3 params".format(err)) + self.conn.create_function("enaddr", 3, encodeAddress) diff --git a/src/helper_sql.py b/src/helper_sql.py index 043bccf2..26398848 100644 --- a/src/helper_sql.py +++ b/src/helper_sql.py @@ -16,9 +16,16 @@ SQLite objects can only be used from one thread. or isn't thread-safe. """ -import Queue + +# import Queue +try: + import queue as Queue #python3 +except ImportError: + import Queue #python2 + import threading + sqlSubmitQueue = Queue.Queue() """the queue for SQL""" sqlReturnQueue = Queue.Queue() @@ -105,6 +112,15 @@ def sqlExecute(sql_statement, *args): return rowcount +def sqlExecuteScript(sql_statement): + """Execute SQL script statement""" + + statements = sql_statement.split(";") + with SqlBulkExecute() as sql: + for q in statements: + sql.execute("{}".format(q)) + + def sqlStoredProcedure(procName): """Schedule procName to be run""" assert sql_available diff --git a/src/tests/sql/create_function.sql b/src/tests/sql/create_function.sql new file mode 100644 index 00000000..b66d1aa1 --- /dev/null +++ b/src/tests/sql/create_function.sql @@ -0,0 +1,11 @@ +CREATE TABLE `testhash` ( + `addressversion` int DEFAULT NULL, + `hash` blob DEFAULT NULL, + `address` text DEFAULT NULL, + UNIQUE(address) ON CONFLICT IGNORE +); + + + +INSERT INTO testhash (addressversion, hash) VALUES(4, "21122112211221122112"); + diff --git a/src/tests/test_sqlthread.py b/src/tests/test_sqlthread.py new file mode 100644 index 00000000..d51c6623 --- /dev/null +++ b/src/tests/test_sqlthread.py @@ -0,0 +1,73 @@ +""" + Test for sqlThread +""" + +import os +import unittest +from ..helper_sql import sqlStoredProcedure, sql_ready, sqlExecute, SqlBulkExecute, sqlQuery, sqlExecuteScript +from ..class_sqlThread import (sqlThread) +from ..addresses import encodeAddress +from .common import skip_python3 + + +skip_python3() + + +class TestSqlThread(unittest.TestCase): + """ + Test case for SQLThread + """ + + # query file path + root_path = os.path.dirname(os.path.dirname(__file__)) + + @classmethod + def setUpClass(cls): + # Start SQL thread + sqlLookup = sqlThread() + sqlLookup.daemon = False + sqlLookup.start() + sql_ready.wait() + + @classmethod + def setUp(cls): + tables = list(sqlQuery("select name from sqlite_master where type is 'table'")) + with SqlBulkExecute() as sql: + for q in tables: + sql.execute("drop table if exists %s" % q) + + @classmethod + def tearDown(cls): + pass + + @classmethod + def tearDownClass(cls): + # Stop sql thread + sqlStoredProcedure('exit') + + def initialise_database(self, file): + """ + Initialise DB + """ + + sql_as_string = open(os.path.join(self.root_path, "tests/sql/{}.sql".format(file))).read() + sqlExecuteScript(sql_as_string) + + + def test_create_function(self): + # call create function + + encoded_str = encodeAddress(4, 1, "21122112211221122112") + + # Initialise Database + self.initialise_database("create_function") + + sqlExecute('''INSERT INTO testhash (addressversion, hash) VALUES(4, "21122112211221122112")''') + # call function in query + + sqlExecute('''UPDATE testhash SET address=(enaddr(testhash.addressversion, 1, hash)) WHERE hash=testhash.hash''') + + # Assertion + query = sqlQuery('''select * from testhash;''') + self.assertEqual(query[0][-1], encoded_str, "test case fail for create_function") + sqlExecute('''DROP TABLE testhash''')