update core changes reuse sql_helper in testcase

This commit is contained in:
Muzahid 2021-04-06 17:51:17 +05:30
parent a3e52099bc
commit 59311f3f31
Signed by untrusted user: cis-muzahid
GPG Key ID: 1DC85E7D3AB613EA
5 changed files with 227 additions and 237 deletions

View File

@ -23,13 +23,25 @@ from addresses import encodeAddress
root_path = os.path.dirname(os.path.dirname(__file__))
def connection_build():
conn = sqlite3.connect(state.appdata + 'messages.dat')
conn.text_factory = str
cur = conn.cursor()
return conn, cur
class UpgradeDB():
"""Upgrade Db with respect to versions"""
cur = None
# cur = None
parameters = None
current_level = None
max_level = 11
conn = None
conn, cur = connection_build()
def __index__(self):
self.conn = conn
self.cur = cur
def get_current_level(self):
# Upgrade Db with respect to their versions
@ -51,10 +63,15 @@ class UpgradeDB():
def run_migrations(self, file):
try:
print("----------file", file)
sql_file = open(os.path.join(root_path, "src/sql/init_version_{}.sql".format(file)))
sql_as_string = sql_file.read()
# self.cur.executescript(sql_as_string)
self.cur.executescript(sql_as_string)
except Exception as err:
print("err")
print(err)
print("err")
if str(err) == 'table inbox already exists':
return "table inbox already exists"
else:
@ -238,9 +255,13 @@ class sqlThread(threading.Thread, UpgradeDB):
def run(self): # pylint: disable=too-many-locals, too-many-branches, too-many-statements
"""Process SQL queries from `.helper_sql.sqlSubmitQueue`"""
helper_sql.sql_available = True
self.conn = sqlite3.connect(state.appdata + 'messages.dat')
self.conn.text_factory = str
self.cur = self.conn.cursor()
# self.conn = sqlite3.connect(state.appdata + 'messages.dat')
# self.conn.text_factory = str
# self.cur = self.conn.cursor()
conn, cur = connection_build()
self.conn = conn
self.cur = cur
self.cur.execute('PRAGMA secure_delete = true')
@ -248,50 +269,38 @@ class sqlThread(threading.Thread, UpgradeDB):
self.create_function()
try:
# sql_file = open(os.path.join(root_path, "src/sql/{}.sql".format("run")))
# sql_as_string = sql_file.read()
# self.cur.executescript(sql_as_string)
# self.cur.execute('''INSERT INTO settings VALUES('version','11')''')
# self.cur.execute('''INSERT INTO settings VALUES('lastvacuumtime',?)''', (int(time.time()),))
# self.cur.execute(
# '''CREATE TABLE objectprocessorqueue'''
# ''' (objecttype int, data blob, UNIQUE(objecttype, data) ON CONFLICT REPLACE)''')
# self.conn.commit()
# logger.info('Created messages database file')
self.cur.execute(
'''CREATE TABLE inbox (msgid blob, toaddress text, fromaddress text, subject text,'''
''' received text, message text, folder text, encodingtype int, read bool, sighash blob,'''
''' UNIQUE(msgid) ON CONFLICT REPLACE)''')
self.cur.execute(
'''CREATE TABLE sent (msgid blob, toaddress text, toripe blob, fromaddress text, subject text,'''
''' message text, ackdata blob, senttime integer, lastactiontime integer,'''
''' sleeptill integer, status text, retrynumber integer, folder text, encodingtype int, ttl int)''')
self.cur.execute(
'''CREATE TABLE subscriptions (label text, address text, enabled bool)''')
self.cur.execute(
'''CREATE TABLE addressbook (label text, address text, UNIQUE(address) ON CONFLICT IGNORE)''')
self.cur.execute(
'''CREATE TABLE blacklist (label text, address text, enabled bool)''')
self.cur.execute(
'''CREATE TABLE whitelist (label text, address text, enabled bool)''')
self.cur.execute(
'''CREATE TABLE pubkeys (address text, addressversion int, transmitdata blob, time int,'''
''' usedpersonally text, UNIQUE(address) ON CONFLICT REPLACE)''')
self.cur.execute(
'''CREATE TABLE inventory (hash blob, objecttype int, streamnumber int, payload blob,'''
''' expirestime integer, tag blob, UNIQUE(hash) ON CONFLICT REPLACE)''')
self.cur.execute(
'''INSERT INTO subscriptions VALUES'''
'''('Bitmessage new releases/announcements','BM-GtovgYdgs7qXPkoYaRgrLFuFKz1SFpsw',1)''')
# '''CREATE TABLE inbox (msgid blob, toaddress text, fromaddress text, subject text,'''
# ''' received text, message text, folder text, encodingtype int, read bool, sighash blob,'''
# ''' UNIQUE(msgid) ON CONFLICT REPLACE)''')
# self.cur.execute(
# '''CREATE TABLE sent (msgid blob, toaddress text, toripe blob, fromaddress text, subject text,'''
# ''' message text, ackdata blob, senttime integer, lastactiontime integer,'''
# ''' sleeptill integer, status text, retrynumber integer, folder text, encodingtype int, ttl int)''')
# '''CREATE TABLE subscriptions (label text, address text, enabled bool)''')
# self.cur.execute(
# '''CREATE TABLE addressbook (label text, address text, UNIQUE(address) ON CONFLICT IGNORE)''')
# self.cur.execute(
# '''CREATE TABLE blacklist (label text, address text, enabled bool)''')
# self.cur.execute(
# '''CREATE TABLE whitelist (label text, address text, enabled bool)''')
# self.cur.execute(
# '''CREATE TABLE pubkeys (address text, addressversion int, transmitdata blob, time int,'''
# ''' usedpersonally text, UNIQUE(address) ON CONFLICT REPLACE)''')
# self.cur.execute(
# '''CREATE TABLE inventory (hash blob, objecttype int, streamnumber int, payload blob,'''
# ''' expirestime integer, tag blob, UNIQUE(hash) ON CONFLICT REPLACE)''')
# self.cur.execute(
# '''INSERT INTO subscriptions VALUES'''
# '''('Bitmessage new releases/announcements','BM-GtovgYdgs7qXPkoYaRgrLFuFKz1SFpsw',1)''')
self.cur.execute(
'''CREATE TABLE settings (key blob, value blob, UNIQUE(key) ON CONFLICT REPLACE)''')
self.cur.execute('''INSERT INTO settings VALUES('version','11')''')
self.cur.execute('''INSERT INTO settings VALUES('lastvacuumtime',?)''', (
int(time.time()),))
self.cur.execute(
'''CREATE TABLE objectprocessorqueue'''
''' (objecttype int, data blob, UNIQUE(objecttype, data) ON CONFLICT REPLACE)''')
# self.cur.execute(
# '''CREATE TABLE objectprocessorqueue'''
# ''' (objecttype int, data blob, UNIQUE(objecttype, data) ON CONFLICT REPLACE)''')
self.conn.commit()
logger.info('Created messages database file')
except Exception as err:

View File

@ -16,14 +16,16 @@ SQLite objects can only be used from one thread.
or isn't thread-safe.
"""
# import Queue
try:
import queue as Queue #python2
import queue as Queue #python3
except ImportError:
import Queue #python3
import Queue #python2
import threading
sqlSubmitQueue = Queue.Queue()
"""the queue for SQL"""
sqlReturnQueue = Queue.Queue()
@ -110,6 +112,15 @@ def sqlExecute(sql_statement, *args):
return rowcount
def sqlExecuteScript(sql_statement):
"""Execute SQL script statement"""
statements = sql_statement.split(";")
with SqlBulkExecute() as sql:
for q in statements:
sql.execute("{}".format(q))
def sqlStoredProcedure(procName):
"""Schedule procName to be run"""
assert sql_available

View File

@ -1,16 +1,5 @@
-- --
-- -- Alter table `addressbook`
-- --
ALTER TABLE addressbook RENAME TO old_addressbook;
-- --
-- -- Table structure for table `addressbook`
-- --
CREATE TABLE `addressbook` (
`label` text NOT NULL,
`address` text NOT NULL,
@ -18,15 +7,6 @@ CREATE TABLE `addressbook` (
) ;
-- --
-- -- Dumping data for table `addressbook`
-- --
INSERT INTO addressbook SELECT label, address FROM old_addressbook;
-- --
-- -- Drop table `old_addressbook`
-- --
DROP TABLE old_addressbook;

View File

@ -1,19 +1,11 @@
-- --
-- -- Table structure for table `testhash`
-- --
CREATE TEMPORARY TABLE `testhash` (
CREATE TABLE `testhash` (
`addressversion` int DEFAULT NULL,
`hash` blob DEFAULT NULL,
`address` text DEFAULT NULL,
UNIQUE(address) ON CONFLICT IGNORE
) ;
);
-- --
-- -- Dumping data for table `testhash`
-- --
INSERT INTO testhash (addressversion, hash) VALUES(1, "21122112211221122112");
INSERT INTO testhash (addressversion, hash) VALUES(4, "21122112211221122112");

View File

@ -4,14 +4,11 @@
import os
import unittest
import sqlite3
import sys
from ..state import appdata
from ..helper_sql import sqlStoredProcedure, sql_ready
from ..helper_sql import sqlStoredProcedure, sql_ready, sqlExecute, SqlBulkExecute, sqlQuery, sqlExecuteScript
from ..class_sqlThread import (sqlThread, UpgradeDB)
from ..addresses import encodeAddress
from .common import skip_python3
import time
skip_python3()
@ -21,25 +18,31 @@ class TestSqlThread(unittest.TestCase):
Test case for SQLThread
"""
conn = sqlite3.connect(appdata + 'messages.dat')
conn.text_factory = str
cur = conn.cursor()
# query file path
root_path = os.path.dirname(os.path.dirname(__file__))
@classmethod
def setUpClass(cls):
# Start SQL thread
print("1============")
sqlLookup = sqlThread()
print("2============")
sqlLookup.daemon = False
print("3============")
# time.sleep(5)
sqlLookup.start()
print("4============")
# time.sleep(5)
print("5============")
sql_ready.wait()
print("6============")
@classmethod
def setUp(cls):
tables = list(cls.cur.execute("select name from sqlite_master where type is 'table'"))
cls.cur.executescript(';'.join(["drop table if exists %s" % i for i in tables]))
tables = list(sqlQuery("select name from sqlite_master where type is 'table'"))
with SqlBulkExecute() as sql:
for q in tables:
sql.execute("drop table if exists %s" % q)
@classmethod
def tearDown(cls):
@ -54,24 +57,9 @@ class TestSqlThread(unittest.TestCase):
"""
Initialise DB
"""
try:
sql_file = open(os.path.join(self.root_path, "tests/sql/{}.sql".format(file)))
sql_as_string = sql_file.read()
self.cur.executescript(sql_as_string)
self.conn.commit()
except Exception as err:
if str(err) == 'table inbox already exists':
return "table inbox already exists"
else:
sys.stderr.write(
'ERROR trying to create database file (message.dat). Error message: %s\n' % str(err))
os._exit(0)
def filter_table_column(self, schema, column):
for x in schema:
for y in x:
if y == column:
yield y
sql_as_string = open(os.path.join(self.root_path, "tests/sql/{}.sql".format(file))).read()
sqlExecuteScript(sql_as_string)
def versioning(func):
def wrapper(*args):
@ -79,21 +67,46 @@ class TestSqlThread(unittest.TestCase):
func_name = func.__name__
version = func_name.rsplit('_', 1)[-1]
print("-------------------------===============")
print(func_name)
print(version)
print("upgrade_schema_data_", version)
print("-------------------------===============")
# Update versions DB mocking
self.initialise_database("init_version_{}".format(version))
# Test versions
upgrade_db = UpgradeDB()
upgrade_db.cur = self.cur
upgrade_db.conn = self.conn
getattr(upgrade_db, "upgrade_schema_data_{}".format(version))()
ret = func(*args)
return ret # <-- use (self, ...)
return wrapper
def change_state(self):
print("change state called")
self.normalize_version("1")
def test_create_function(self):
# call create function
encoded_str = encodeAddress(4, 1, "21122112211221122112")
# Initialise Database
self.initialise_database("create_function")
sqlExecute('''INSERT INTO testhash (addressversion, hash) VALUES(4, "21122112211221122112")''')
# call function in query
sqlExecute('''UPDATE testhash SET address=(enaddr(testhash.addressversion, 1, hash)) WHERE hash=testhash.hash''')
# Assertion
query = sqlQuery('''select * from testhash;''')
self.assertEqual(query[0][-1], encoded_str, "test case fail for create_function")
sqlExecute('''DROP TABLE testhash''')
def filter_table_column(self, schema, column):
for x in schema:
for y in x:
if y == column:
yield y
@versioning
def test_sql_thread_version_1(self):
@ -102,13 +115,15 @@ class TestSqlThread(unittest.TestCase):
"""
# Assertion after versioning
res = self.cur.execute('''PRAGMA table_info('inventory');''')
res = res.fetchall()
res = sqlQuery('''PRAGMA table_info('inventory');''')
# res = res.fetchall()
print(res)
result = list(self.filter_table_column(res, "tag"))
res = [tup for tup in res if any(i in tup for i in ["tag"])]
self.assertEqual(result, ['tag'], "Data not migrated for version 1")
self.assertEqual(res, [(5, 'tag', 'blob', 0, "''", 0)], "Data not migrated for version 1")
@versioning
def test_sql_thread_version_10(self):
"""
@ -116,135 +131,118 @@ class TestSqlThread(unittest.TestCase):
"""
# Assertion
self.cur.execute(''' SELECT count(name) FROM sqlite_master WHERE type='table' AND name='old_addressbook' ''')
self.assertNotEqual(self.cur.fetchone(), 1, "Table old_addressbook not deleted")
res = sqlExecute(''' SELECT count(name) FROM sqlite_master WHERE type='table' AND name='old_addressbook' ''')
print("res---------------------------------")
print(res)
print("res---------------------------------")
self.assertNotEqual(res, 1, "Table old_addressbook not deleted")
self.assertEqual(res, -1, "Table old_addressbook not deleted")
res = self.cur.execute('''PRAGMA table_info('addressbook');''')
res = res.fetchall()
res = sqlQuery('''PRAGMA table_info('addressbook');''')
# # res = res.fetchall()
result = list(self.filter_table_column(res, "address"))
self.assertEqual(result, ['address'], "Data not migrated for version 10")
@versioning
def test_sql_thread_version_9(self):
"""
Test with version 9
"""
# Assertion
self.cur.execute(''' SELECT count(name) FROM sqlite_master WHERE type='table' AND name='pubkeys_backup' ''')
self.assertNotEqual(self.cur.fetchone(), 1, "Table pubkeys_backup not deleted")
res = self.cur.execute('''PRAGMA table_info('pubkeys');''')
res = res.fetchall()
result = list(self.filter_table_column(res, "address"))
self.assertEqual(result, ['address'], "Data not migrated for version 9")
@versioning
def test_sql_thread_version_8(self):
"""
Test with version 8
"""
# Assertion
res = self.cur.execute('''PRAGMA table_info('inbox');''')
res = res.fetchall()
result = list(self.filter_table_column(res, "sighash"))
self.assertEqual(result, ['sighash'], "Data not migrated for version 8")
@versioning
def test_sql_thread_version_7(self):
"""
Test with version 7
"""
# Assertion
pubkeys = self.cur.execute('''SELECT * FROM pubkeys ''')
pubkeys = pubkeys.fetchall()
self.assertEqual(pubkeys, [], "Data not migrated for version 7")
inventory = self.cur.execute('''SELECT * FROM inventory ''')
inventory = inventory.fetchall()
self.assertEqual(inventory, [], "Data not migrated for version 7")
sent = self.cur.execute('''SELECT status FROM sent ''')
sent = sent.fetchall()
self.assertEqual(sent, [('msgqueued',), ('msgqueued',)], "Data not migrated for version 7")
@versioning
def test_sql_thread_version_6(self):
"""
Test with version 6
"""
# Assertion
inventory = self.cur.execute('''PRAGMA table_info('inventory');''')
inventory = inventory.fetchall()
inventory = list(self.filter_table_column(inventory, "expirestime"))
self.assertEqual(inventory, ['expirestime'], "Data not migrated for version 6")
objectprocessorqueue = self.cur.execute('''PRAGMA table_info('inventory');''')
objectprocessorqueue = objectprocessorqueue.fetchall()
objectprocessorqueue = list(self.filter_table_column(objectprocessorqueue, "objecttype"))
self.assertEqual(objectprocessorqueue, ['objecttype'], "Data not migrated for version 6")
@versioning
def test_sql_thread_version_5(self):
"""
Test with version 5
"""
# Assertion
self.cur.execute(''' SELECT count(name) FROM sqlite_master WHERE type='table' AND name='knownnodes' ''')
self.assertNotEqual(self.cur.fetchone(), 1, "Table knownnodes not deleted in versioning 5")
self.cur.execute(
''' SELECT count(name) FROM sqlite_master WHERE type='table' AND name='objectprocessorqueue'; ''')
self.assertNotEqual(self.cur.fetchone(), 0, "Table objectprocessorqueue not created in versioning 5")
@versioning
def test_sql_thread_version_4(self):
"""
Test with version 4
"""
# Assertion
self.cur.execute('''select * from inventory where objecttype = 'pubkey';''')
self.assertNotEqual(self.cur.fetchone(), 1, "Table inventory not deleted in versioning 4")
def test_sql_thread_version_3(self):
"""
Test with version 3 and 1 both are same
"""
pass
@versioning
def test_sql_thread_version_2(self):
"""
Test with version 2
"""
# Assertion
self.cur.execute(''' SELECT count(name) FROM sqlite_master WHERE type='table' AND name='inventory_backup' ''')
self.assertNotEqual(self.cur.fetchone(), 1, "Table inventory_backup not deleted in versioning 2")
def test_create_function(self):
try:
self.conn.create_function("enaddr", 3, func=encodeAddress, deterministic=True)
except Exception as err:
print("Got error while pass deterministic in sqlite create function {}, So called function directly".format(err))
self.conn.create_function("enaddr", 3, encodeAddress)
encoded_str = encodeAddress(1, 1, "21122112211221122112")
# Initialise Database
self.initialise_database("create_function")
# call function in query
self.cur.execute('''UPDATE testhash SET address=(select enaddr(testhash.addressversion, 1, hash)) WHERE hash=testhash.hash; ''')
# Assertion
self.cur.execute('''select address from testhash;''')
hsh = self.cur.fetchone()[0]
self.assertNotEqual(hsh, 1, "test case fail for create_function")
self.assertEqual(hsh, encoded_str, "test case fail for create_function")
self.conn.execute('''DROP TABLE testhash;''')
# @versioning
# def test_sql_thread_version_9(self):
# """
# Test with version 9
# """
#
# # Assertion
# self.cur.execute(''' SELECT count(name) FROM sqlite_master WHERE type='table' AND name='pubkeys_backup' ''')
# self.assertNotEqual(self.cur.fetchone(), 1, "Table pubkeys_backup not deleted")
#
# res = self.cur.execute('''PRAGMA table_info('pubkeys');''')
# res = res.fetchall()
# result = list(self.filter_table_column(res, "address"))
# self.assertEqual(result, ['address'], "Data not migrated for version 9")
#
# @versioning
# def test_sql_thread_version_8(self):
# """
# Test with version 8
# """
#
# # Assertion
# res = self.cur.execute('''PRAGMA table_info('inbox');''')
# res = res.fetchall()
# result = list(self.filter_table_column(res, "sighash"))
# self.assertEqual(result, ['sighash'], "Data not migrated for version 8")
#
# @versioning
# def test_sql_thread_version_7(self):
# """
# Test with version 7
# """
#
# # Assertion
# pubkeys = self.cur.execute('''SELECT * FROM pubkeys ''')
# pubkeys = pubkeys.fetchall()
# self.assertEqual(pubkeys, [], "Data not migrated for version 7")
#
# inventory = self.cur.execute('''SELECT * FROM inventory ''')
# inventory = inventory.fetchall()
# self.assertEqual(inventory, [], "Data not migrated for version 7")
#
# sent = self.cur.execute('''SELECT status FROM sent ''')
# sent = sent.fetchall()
# self.assertEqual(sent, [('msgqueued',), ('msgqueued',)], "Data not migrated for version 7")
#
# @versioning
# def test_sql_thread_version_6(self):
# """
# Test with version 6
# """
#
# # Assertion
#
# inventory = self.cur.execute('''PRAGMA table_info('inventory');''')
# inventory = inventory.fetchall()
# inventory = list(self.filter_table_column(inventory, "expirestime"))
# self.assertEqual(inventory, ['expirestime'], "Data not migrated for version 6")
#
# objectprocessorqueue = self.cur.execute('''PRAGMA table_info('inventory');''')
# objectprocessorqueue = objectprocessorqueue.fetchall()
# objectprocessorqueue = list(self.filter_table_column(objectprocessorqueue, "objecttype"))
# self.assertEqual(objectprocessorqueue, ['objecttype'], "Data not migrated for version 6")
#
# @versioning
# def test_sql_thread_version_5(self):
# """
# Test with version 5
# """
#
# # Assertion
# self.cur.execute(''' SELECT count(name) FROM sqlite_master WHERE type='table' AND name='knownnodes' ''')
# self.assertNotEqual(self.cur.fetchone(), 1, "Table knownnodes not deleted in versioning 5")
# self.cur.execute(
# ''' SELECT count(name) FROM sqlite_master WHERE type='table' AND name='objectprocessorqueue'; ''')
# self.assertNotEqual(self.cur.fetchone(), 0, "Table objectprocessorqueue not created in versioning 5")
#
# @versioning
# def test_sql_thread_version_4(self):
# """
# Test with version 4
# """
#
# # Assertion
# self.cur.execute('''select * from inventory where objecttype = 'pubkey';''')
# self.assertNotEqual(self.cur.fetchone(), 1, "Table inventory not deleted in versioning 4")
#
# def test_sql_thread_version_3(self):
# """
# Test with version 3 and 1 both are same
# """
# pass
#
# @versioning
# def test_sql_thread_version_2(self):
# """
# Test with version 2
# """
#
# # Assertion
# self.cur.execute(''' SELECT count(name) FROM sqlite_master WHERE type='table' AND name='inventory_backup' ''')
# self.assertNotEqual(self.cur.fetchone(), 1, "Table inventory_backup not deleted in versioning 2")