Add sqlite functions while versioning
This commit is contained in:
parent
6f9b66ddff
commit
04bd5a0264
|
@ -8,7 +8,6 @@ import sqlite3
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import helper_sql
|
import helper_sql
|
||||||
import helper_startup
|
import helper_startup
|
||||||
import paths
|
import paths
|
||||||
|
@ -18,6 +17,7 @@ import tr
|
||||||
from bmconfigparser import BMConfigParser
|
from bmconfigparser import BMConfigParser
|
||||||
from debug import logger
|
from debug import logger
|
||||||
# pylint: disable=attribute-defined-outside-init,protected-access
|
# pylint: disable=attribute-defined-outside-init,protected-access
|
||||||
|
from addresses import encodeAddress
|
||||||
|
|
||||||
|
|
||||||
class sqlThread(threading.Thread):
|
class sqlThread(threading.Thread):
|
||||||
|
@ -35,6 +35,9 @@ class sqlThread(threading.Thread):
|
||||||
|
|
||||||
self.cur.execute('PRAGMA secure_delete = true')
|
self.cur.execute('PRAGMA secure_delete = true')
|
||||||
|
|
||||||
|
# call create_function for encode address
|
||||||
|
self.create_function()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.cur.execute(
|
self.cur.execute(
|
||||||
'''CREATE TABLE inbox (msgid blob, toaddress text, fromaddress text, subject text,'''
|
'''CREATE TABLE inbox (msgid blob, toaddress text, fromaddress text, subject text,'''
|
||||||
|
@ -325,6 +328,7 @@ class sqlThread(threading.Thread):
|
||||||
|
|
||||||
# We'll also need a `sleeptill` field and a `ttl` field. Also we
|
# We'll also need a `sleeptill` field and a `ttl` field. Also we
|
||||||
# can combine the pubkeyretrynumber and msgretrynumber into one.
|
# can combine the pubkeyretrynumber and msgretrynumber into one.
|
||||||
|
|
||||||
item = '''SELECT value FROM settings WHERE key='version';'''
|
item = '''SELECT value FROM settings WHERE key='version';'''
|
||||||
parameters = ''
|
parameters = ''
|
||||||
self.cur.execute(item, parameters)
|
self.cur.execute(item, parameters)
|
||||||
|
@ -358,16 +362,11 @@ class sqlThread(threading.Thread):
|
||||||
logger.debug('In messages.dat database, adding address field to the pubkeys table.')
|
logger.debug('In messages.dat database, adding address field to the pubkeys table.')
|
||||||
# We're going to have to calculate the address for each row in the pubkeys
|
# We're going to have to calculate the address for each row in the pubkeys
|
||||||
# table. Then we can take out the hash field.
|
# table. Then we can take out the hash field.
|
||||||
self.cur.execute('''ALTER TABLE pubkeys ADD address text DEFAULT '' ''')
|
self.cur.execute('''ALTER TABLE pubkeys ADD address text DEFAULT '' ;''')
|
||||||
self.cur.execute('''SELECT hash, addressversion FROM pubkeys''')
|
|
||||||
queryResult = self.cur.fetchall()
|
# replica for loop to update hashed address
|
||||||
from addresses import encodeAddress
|
self.cur.execute('''UPDATE pubkeys SET address=(enaddr(pubkeys.addressversion, 1, hash)) WHERE hash=pubkeys.hash; ''')
|
||||||
for row in queryResult:
|
|
||||||
addressHash, addressVersion = row
|
|
||||||
address = encodeAddress(addressVersion, 1, hash)
|
|
||||||
item = '''UPDATE pubkeys SET address=? WHERE hash=?;'''
|
|
||||||
parameters = (address, addressHash)
|
|
||||||
self.cur.execute(item, parameters)
|
|
||||||
# Now we can remove the hash field from the pubkeys table.
|
# Now we can remove the hash field from the pubkeys table.
|
||||||
self.cur.execute(
|
self.cur.execute(
|
||||||
'''CREATE TEMPORARY TABLE pubkeys_backup'''
|
'''CREATE TEMPORARY TABLE pubkeys_backup'''
|
||||||
|
@ -622,3 +621,12 @@ class sqlThread(threading.Thread):
|
||||||
|
|
||||||
helper_sql.sqlReturnQueue.put((self.cur.fetchall(), rowcount))
|
helper_sql.sqlReturnQueue.put((self.cur.fetchall(), rowcount))
|
||||||
# helper_sql.sqlSubmitQueue.task_done()
|
# helper_sql.sqlSubmitQueue.task_done()
|
||||||
|
|
||||||
|
def create_function(self):
|
||||||
|
# create_function
|
||||||
|
try:
|
||||||
|
self.conn.create_function("enaddr", 3, func=encodeAddress, deterministic=True)
|
||||||
|
except (TypeError, sqlite3.NotSupportedError) as err:
|
||||||
|
logger.debug(
|
||||||
|
"Got error while pass deterministic in sqlite create function {}, Passing 3 params".format(err))
|
||||||
|
self.conn.create_function("enaddr", 3, encodeAddress)
|
||||||
|
|
|
@ -16,9 +16,16 @@ SQLite objects can only be used from one thread.
|
||||||
or isn't thread-safe.
|
or isn't thread-safe.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import Queue
|
|
||||||
|
# import Queue
|
||||||
|
try:
|
||||||
|
import queue as Queue #python3
|
||||||
|
except ImportError:
|
||||||
|
import Queue #python2
|
||||||
|
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
|
|
||||||
sqlSubmitQueue = Queue.Queue()
|
sqlSubmitQueue = Queue.Queue()
|
||||||
"""the queue for SQL"""
|
"""the queue for SQL"""
|
||||||
sqlReturnQueue = Queue.Queue()
|
sqlReturnQueue = Queue.Queue()
|
||||||
|
@ -105,6 +112,15 @@ def sqlExecute(sql_statement, *args):
|
||||||
return rowcount
|
return rowcount
|
||||||
|
|
||||||
|
|
||||||
|
def sqlExecuteScript(sql_statement):
|
||||||
|
"""Execute SQL script statement"""
|
||||||
|
|
||||||
|
statements = sql_statement.split(";")
|
||||||
|
with SqlBulkExecute() as sql:
|
||||||
|
for q in statements:
|
||||||
|
sql.execute("{}".format(q))
|
||||||
|
|
||||||
|
|
||||||
def sqlStoredProcedure(procName):
|
def sqlStoredProcedure(procName):
|
||||||
"""Schedule procName to be run"""
|
"""Schedule procName to be run"""
|
||||||
assert sql_available
|
assert sql_available
|
||||||
|
|
11
src/tests/sql/create_function.sql
Normal file
11
src/tests/sql/create_function.sql
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
CREATE TABLE `testhash` (
|
||||||
|
`addressversion` int DEFAULT NULL,
|
||||||
|
`hash` blob DEFAULT NULL,
|
||||||
|
`address` text DEFAULT NULL,
|
||||||
|
UNIQUE(address) ON CONFLICT IGNORE
|
||||||
|
);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
INSERT INTO testhash (addressversion, hash) VALUES(4, "21122112211221122112");
|
||||||
|
|
73
src/tests/test_sqlthread.py
Normal file
73
src/tests/test_sqlthread.py
Normal file
|
@ -0,0 +1,73 @@
|
||||||
|
"""
|
||||||
|
Test for sqlThread
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from ..helper_sql import sqlStoredProcedure, sql_ready, sqlExecute, SqlBulkExecute, sqlQuery, sqlExecuteScript
|
||||||
|
from ..class_sqlThread import (sqlThread)
|
||||||
|
from ..addresses import encodeAddress
|
||||||
|
from .common import skip_python3
|
||||||
|
|
||||||
|
|
||||||
|
skip_python3()
|
||||||
|
|
||||||
|
|
||||||
|
class TestSqlThread(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Test case for SQLThread
|
||||||
|
"""
|
||||||
|
|
||||||
|
# query file path
|
||||||
|
root_path = os.path.dirname(os.path.dirname(__file__))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
# Start SQL thread
|
||||||
|
sqlLookup = sqlThread()
|
||||||
|
sqlLookup.daemon = False
|
||||||
|
sqlLookup.start()
|
||||||
|
sql_ready.wait()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUp(cls):
|
||||||
|
tables = list(sqlQuery("select name from sqlite_master where type is 'table'"))
|
||||||
|
with SqlBulkExecute() as sql:
|
||||||
|
for q in tables:
|
||||||
|
sql.execute("drop table if exists %s" % q)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDown(cls):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
# Stop sql thread
|
||||||
|
sqlStoredProcedure('exit')
|
||||||
|
|
||||||
|
def initialise_database(self, file):
|
||||||
|
"""
|
||||||
|
Initialise DB
|
||||||
|
"""
|
||||||
|
|
||||||
|
sql_as_string = open(os.path.join(self.root_path, "tests/sql/{}.sql".format(file))).read()
|
||||||
|
sqlExecuteScript(sql_as_string)
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_function(self):
|
||||||
|
# call create function
|
||||||
|
|
||||||
|
encoded_str = encodeAddress(4, 1, "21122112211221122112")
|
||||||
|
|
||||||
|
# Initialise Database
|
||||||
|
self.initialise_database("create_function")
|
||||||
|
|
||||||
|
sqlExecute('''INSERT INTO testhash (addressversion, hash) VALUES(4, "21122112211221122112")''')
|
||||||
|
# call function in query
|
||||||
|
|
||||||
|
sqlExecute('''UPDATE testhash SET address=(enaddr(testhash.addressversion, 1, hash)) WHERE hash=testhash.hash''')
|
||||||
|
|
||||||
|
# Assertion
|
||||||
|
query = sqlQuery('''select * from testhash;''')
|
||||||
|
self.assertEqual(query[0][-1], encoded_str, "test case fail for create_function")
|
||||||
|
sqlExecute('''DROP TABLE testhash''')
|
Reference in New Issue
Block a user