Add sqlite functions while versioning

This commit is contained in:
Muzahid 2021-03-04 19:45:41 +05:30
parent 6f9b66ddff
commit 04bd5a0264
Signed by untrusted user: cis-muzahid
GPG Key ID: 1DC85E7D3AB613EA
4 changed files with 120 additions and 12 deletions

View File

@ -8,7 +8,6 @@ import sqlite3
import sys import sys
import threading import threading
import time import time
import helper_sql import helper_sql
import helper_startup import helper_startup
import paths import paths
@ -18,6 +17,7 @@ import tr
from bmconfigparser import BMConfigParser from bmconfigparser import BMConfigParser
from debug import logger from debug import logger
# pylint: disable=attribute-defined-outside-init,protected-access # pylint: disable=attribute-defined-outside-init,protected-access
from addresses import encodeAddress
class sqlThread(threading.Thread): class sqlThread(threading.Thread):
@ -35,6 +35,9 @@ class sqlThread(threading.Thread):
self.cur.execute('PRAGMA secure_delete = true') self.cur.execute('PRAGMA secure_delete = true')
# call create_function for encode address
self.create_function()
try: try:
self.cur.execute( self.cur.execute(
'''CREATE TABLE inbox (msgid blob, toaddress text, fromaddress text, subject text,''' '''CREATE TABLE inbox (msgid blob, toaddress text, fromaddress text, subject text,'''
@ -325,6 +328,7 @@ class sqlThread(threading.Thread):
# We'll also need a `sleeptill` field and a `ttl` field. Also we # We'll also need a `sleeptill` field and a `ttl` field. Also we
# can combine the pubkeyretrynumber and msgretrynumber into one. # can combine the pubkeyretrynumber and msgretrynumber into one.
item = '''SELECT value FROM settings WHERE key='version';''' item = '''SELECT value FROM settings WHERE key='version';'''
parameters = '' parameters = ''
self.cur.execute(item, parameters) self.cur.execute(item, parameters)
@ -358,16 +362,11 @@ class sqlThread(threading.Thread):
logger.debug('In messages.dat database, adding address field to the pubkeys table.') logger.debug('In messages.dat database, adding address field to the pubkeys table.')
# We're going to have to calculate the address for each row in the pubkeys # We're going to have to calculate the address for each row in the pubkeys
# table. Then we can take out the hash field. # table. Then we can take out the hash field.
self.cur.execute('''ALTER TABLE pubkeys ADD address text DEFAULT '' ''') self.cur.execute('''ALTER TABLE pubkeys ADD address text DEFAULT '' ;''')
self.cur.execute('''SELECT hash, addressversion FROM pubkeys''')
queryResult = self.cur.fetchall() # replica for loop to update hashed address
from addresses import encodeAddress self.cur.execute('''UPDATE pubkeys SET address=(enaddr(pubkeys.addressversion, 1, hash)) WHERE hash=pubkeys.hash; ''')
for row in queryResult:
addressHash, addressVersion = row
address = encodeAddress(addressVersion, 1, hash)
item = '''UPDATE pubkeys SET address=? WHERE hash=?;'''
parameters = (address, addressHash)
self.cur.execute(item, parameters)
# Now we can remove the hash field from the pubkeys table. # Now we can remove the hash field from the pubkeys table.
self.cur.execute( self.cur.execute(
'''CREATE TEMPORARY TABLE pubkeys_backup''' '''CREATE TEMPORARY TABLE pubkeys_backup'''
@ -622,3 +621,12 @@ class sqlThread(threading.Thread):
helper_sql.sqlReturnQueue.put((self.cur.fetchall(), rowcount)) helper_sql.sqlReturnQueue.put((self.cur.fetchall(), rowcount))
# helper_sql.sqlSubmitQueue.task_done() # helper_sql.sqlSubmitQueue.task_done()
def create_function(self):
# create_function
try:
self.conn.create_function("enaddr", 3, func=encodeAddress, deterministic=True)
except (TypeError, sqlite3.NotSupportedError) as err:
logger.debug(
"Got error while pass deterministic in sqlite create function {}, Passing 3 params".format(err))
self.conn.create_function("enaddr", 3, encodeAddress)

View File

@ -16,9 +16,16 @@ SQLite objects can only be used from one thread.
or isn't thread-safe. or isn't thread-safe.
""" """
import Queue
# import Queue
try:
import queue as Queue #python3
except ImportError:
import Queue #python2
import threading import threading
sqlSubmitQueue = Queue.Queue() sqlSubmitQueue = Queue.Queue()
"""the queue for SQL""" """the queue for SQL"""
sqlReturnQueue = Queue.Queue() sqlReturnQueue = Queue.Queue()
@ -105,6 +112,15 @@ def sqlExecute(sql_statement, *args):
return rowcount return rowcount
def sqlExecuteScript(sql_statement):
"""Execute SQL script statement"""
statements = sql_statement.split(";")
with SqlBulkExecute() as sql:
for q in statements:
sql.execute("{}".format(q))
def sqlStoredProcedure(procName): def sqlStoredProcedure(procName):
"""Schedule procName to be run""" """Schedule procName to be run"""
assert sql_available assert sql_available

View File

@ -0,0 +1,11 @@
CREATE TABLE `testhash` (
`addressversion` int DEFAULT NULL,
`hash` blob DEFAULT NULL,
`address` text DEFAULT NULL,
UNIQUE(address) ON CONFLICT IGNORE
);
INSERT INTO testhash (addressversion, hash) VALUES(4, "21122112211221122112");

View File

@ -0,0 +1,73 @@
"""
Test for sqlThread
"""
import os
import unittest
from ..helper_sql import sqlStoredProcedure, sql_ready, sqlExecute, SqlBulkExecute, sqlQuery, sqlExecuteScript
from ..class_sqlThread import (sqlThread)
from ..addresses import encodeAddress
from .common import skip_python3
skip_python3()
class TestSqlThread(unittest.TestCase):
"""
Test case for SQLThread
"""
# query file path
root_path = os.path.dirname(os.path.dirname(__file__))
@classmethod
def setUpClass(cls):
# Start SQL thread
sqlLookup = sqlThread()
sqlLookup.daemon = False
sqlLookup.start()
sql_ready.wait()
@classmethod
def setUp(cls):
tables = list(sqlQuery("select name from sqlite_master where type is 'table'"))
with SqlBulkExecute() as sql:
for q in tables:
sql.execute("drop table if exists %s" % q)
@classmethod
def tearDown(cls):
pass
@classmethod
def tearDownClass(cls):
# Stop sql thread
sqlStoredProcedure('exit')
def initialise_database(self, file):
"""
Initialise DB
"""
sql_as_string = open(os.path.join(self.root_path, "tests/sql/{}.sql".format(file))).read()
sqlExecuteScript(sql_as_string)
def test_create_function(self):
# call create function
encoded_str = encodeAddress(4, 1, "21122112211221122112")
# Initialise Database
self.initialise_database("create_function")
sqlExecute('''INSERT INTO testhash (addressversion, hash) VALUES(4, "21122112211221122112")''')
# call function in query
sqlExecute('''UPDATE testhash SET address=(enaddr(testhash.addressversion, 1, hash)) WHERE hash=testhash.hash''')
# Assertion
query = sqlQuery('''select * from testhash;''')
self.assertEqual(query[0][-1], encoded_str, "test case fail for create_function")
sqlExecute('''DROP TABLE testhash''')