Add sqlite functions while versioning
This commit is contained in:
parent
6f9b66ddff
commit
04bd5a0264
|
@ -8,7 +8,6 @@ import sqlite3
|
|||
import sys
|
||||
import threading
|
||||
import time
|
||||
|
||||
import helper_sql
|
||||
import helper_startup
|
||||
import paths
|
||||
|
@ -18,6 +17,7 @@ import tr
|
|||
from bmconfigparser import BMConfigParser
|
||||
from debug import logger
|
||||
# pylint: disable=attribute-defined-outside-init,protected-access
|
||||
from addresses import encodeAddress
|
||||
|
||||
|
||||
class sqlThread(threading.Thread):
|
||||
|
@ -35,6 +35,9 @@ class sqlThread(threading.Thread):
|
|||
|
||||
self.cur.execute('PRAGMA secure_delete = true')
|
||||
|
||||
# call create_function for encode address
|
||||
self.create_function()
|
||||
|
||||
try:
|
||||
self.cur.execute(
|
||||
'''CREATE TABLE inbox (msgid blob, toaddress text, fromaddress text, subject text,'''
|
||||
|
@ -325,6 +328,7 @@ class sqlThread(threading.Thread):
|
|||
|
||||
# We'll also need a `sleeptill` field and a `ttl` field. Also we
|
||||
# can combine the pubkeyretrynumber and msgretrynumber into one.
|
||||
|
||||
item = '''SELECT value FROM settings WHERE key='version';'''
|
||||
parameters = ''
|
||||
self.cur.execute(item, parameters)
|
||||
|
@ -358,16 +362,11 @@ class sqlThread(threading.Thread):
|
|||
logger.debug('In messages.dat database, adding address field to the pubkeys table.')
|
||||
# We're going to have to calculate the address for each row in the pubkeys
|
||||
# table. Then we can take out the hash field.
|
||||
self.cur.execute('''ALTER TABLE pubkeys ADD address text DEFAULT '' ''')
|
||||
self.cur.execute('''SELECT hash, addressversion FROM pubkeys''')
|
||||
queryResult = self.cur.fetchall()
|
||||
from addresses import encodeAddress
|
||||
for row in queryResult:
|
||||
addressHash, addressVersion = row
|
||||
address = encodeAddress(addressVersion, 1, hash)
|
||||
item = '''UPDATE pubkeys SET address=? WHERE hash=?;'''
|
||||
parameters = (address, addressHash)
|
||||
self.cur.execute(item, parameters)
|
||||
self.cur.execute('''ALTER TABLE pubkeys ADD address text DEFAULT '' ;''')
|
||||
|
||||
# replica for loop to update hashed address
|
||||
self.cur.execute('''UPDATE pubkeys SET address=(enaddr(pubkeys.addressversion, 1, hash)) WHERE hash=pubkeys.hash; ''')
|
||||
|
||||
# Now we can remove the hash field from the pubkeys table.
|
||||
self.cur.execute(
|
||||
'''CREATE TEMPORARY TABLE pubkeys_backup'''
|
||||
|
@ -622,3 +621,12 @@ class sqlThread(threading.Thread):
|
|||
|
||||
helper_sql.sqlReturnQueue.put((self.cur.fetchall(), rowcount))
|
||||
# helper_sql.sqlSubmitQueue.task_done()
|
||||
|
||||
def create_function(self):
|
||||
# create_function
|
||||
try:
|
||||
self.conn.create_function("enaddr", 3, func=encodeAddress, deterministic=True)
|
||||
except (TypeError, sqlite3.NotSupportedError) as err:
|
||||
logger.debug(
|
||||
"Got error while pass deterministic in sqlite create function {}, Passing 3 params".format(err))
|
||||
self.conn.create_function("enaddr", 3, encodeAddress)
|
||||
|
|
|
@ -16,9 +16,16 @@ SQLite objects can only be used from one thread.
|
|||
or isn't thread-safe.
|
||||
"""
|
||||
|
||||
import Queue
|
||||
|
||||
# import Queue
|
||||
try:
|
||||
import queue as Queue #python3
|
||||
except ImportError:
|
||||
import Queue #python2
|
||||
|
||||
import threading
|
||||
|
||||
|
||||
sqlSubmitQueue = Queue.Queue()
|
||||
"""the queue for SQL"""
|
||||
sqlReturnQueue = Queue.Queue()
|
||||
|
@ -105,6 +112,15 @@ def sqlExecute(sql_statement, *args):
|
|||
return rowcount
|
||||
|
||||
|
||||
def sqlExecuteScript(sql_statement):
|
||||
"""Execute SQL script statement"""
|
||||
|
||||
statements = sql_statement.split(";")
|
||||
with SqlBulkExecute() as sql:
|
||||
for q in statements:
|
||||
sql.execute("{}".format(q))
|
||||
|
||||
|
||||
def sqlStoredProcedure(procName):
|
||||
"""Schedule procName to be run"""
|
||||
assert sql_available
|
||||
|
|
11
src/tests/sql/create_function.sql
Normal file
11
src/tests/sql/create_function.sql
Normal file
|
@ -0,0 +1,11 @@
|
|||
CREATE TABLE `testhash` (
|
||||
`addressversion` int DEFAULT NULL,
|
||||
`hash` blob DEFAULT NULL,
|
||||
`address` text DEFAULT NULL,
|
||||
UNIQUE(address) ON CONFLICT IGNORE
|
||||
);
|
||||
|
||||
|
||||
|
||||
INSERT INTO testhash (addressversion, hash) VALUES(4, "21122112211221122112");
|
||||
|
73
src/tests/test_sqlthread.py
Normal file
73
src/tests/test_sqlthread.py
Normal file
|
@ -0,0 +1,73 @@
|
|||
"""
|
||||
Test for sqlThread
|
||||
"""
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from ..helper_sql import sqlStoredProcedure, sql_ready, sqlExecute, SqlBulkExecute, sqlQuery, sqlExecuteScript
|
||||
from ..class_sqlThread import (sqlThread)
|
||||
from ..addresses import encodeAddress
|
||||
from .common import skip_python3
|
||||
|
||||
|
||||
skip_python3()
|
||||
|
||||
|
||||
class TestSqlThread(unittest.TestCase):
|
||||
"""
|
||||
Test case for SQLThread
|
||||
"""
|
||||
|
||||
# query file path
|
||||
root_path = os.path.dirname(os.path.dirname(__file__))
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
# Start SQL thread
|
||||
sqlLookup = sqlThread()
|
||||
sqlLookup.daemon = False
|
||||
sqlLookup.start()
|
||||
sql_ready.wait()
|
||||
|
||||
@classmethod
|
||||
def setUp(cls):
|
||||
tables = list(sqlQuery("select name from sqlite_master where type is 'table'"))
|
||||
with SqlBulkExecute() as sql:
|
||||
for q in tables:
|
||||
sql.execute("drop table if exists %s" % q)
|
||||
|
||||
@classmethod
|
||||
def tearDown(cls):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
# Stop sql thread
|
||||
sqlStoredProcedure('exit')
|
||||
|
||||
def initialise_database(self, file):
|
||||
"""
|
||||
Initialise DB
|
||||
"""
|
||||
|
||||
sql_as_string = open(os.path.join(self.root_path, "tests/sql/{}.sql".format(file))).read()
|
||||
sqlExecuteScript(sql_as_string)
|
||||
|
||||
|
||||
def test_create_function(self):
|
||||
# call create function
|
||||
|
||||
encoded_str = encodeAddress(4, 1, "21122112211221122112")
|
||||
|
||||
# Initialise Database
|
||||
self.initialise_database("create_function")
|
||||
|
||||
sqlExecute('''INSERT INTO testhash (addressversion, hash) VALUES(4, "21122112211221122112")''')
|
||||
# call function in query
|
||||
|
||||
sqlExecute('''UPDATE testhash SET address=(enaddr(testhash.addressversion, 1, hash)) WHERE hash=testhash.hash''')
|
||||
|
||||
# Assertion
|
||||
query = sqlQuery('''select * from testhash;''')
|
||||
self.assertEqual(query[0][-1], encoded_str, "test case fail for create_function")
|
||||
sqlExecute('''DROP TABLE testhash''')
|
Reference in New Issue
Block a user