PyBitmessage-2021-04-27/src/tests/test_sqlthread.py

172 lines
5.1 KiB
Python

"""
Test for sqlThread blind signatures
"""
import os
import unittest
import shutil # used for moving the messages.dat file
import sqlite3
import sys
import threading
import time
# print sys.path
# import state
# import sql
# from sql import init_version_1.sql
from ..state import appdata
from ..helper_sql import sqlStoredProcedure
# import threading
# from hashlib import sha256
# from ..helper_sql import sqlQuery
# from ..version import softwareVersion
# from common import cleanup
from ..class_sqlThread import (sqlThread, UpgradeDB)
class TestSqlThread(unittest.TestCase):
"""
Test case for SQLThread
"""
conn = sqlite3.connect(appdata + 'messages.dat')
conn.text_factory = str
cur = conn.cursor()
@classmethod
def setUpClass(cls):
# Start SQL thread
sqlLookup = sqlThread()
sqlLookup.daemon = False
sqlLookup.start()
@classmethod
def setUp(cls):
tables = list(cls.cur.execute("select name from sqlite_master where type is 'table'"))
cls.cur.executescript(';'.join(["drop table if exists %s" % i for i in tables]))
@classmethod
def tearDown(cls):
pass
@classmethod
def tearDownClass(cls):
# Stop sql thread
sqlStoredProcedure('exit')
def db_connection(self):
conn = sqlite3.connect(appdata + 'messages.dat')
conn.text_factory = str
return conn.cursor()
def normalize_version(self, file):
try:
root_path = os.path.dirname(os.path.dirname(__file__))
sql_file_path = os.path.join(root_path, 'tests/sql/')
sql_file_path = os.path.join(sql_file_path, "init_version_"+file+".sql")
sql_file = open(sql_file_path)
sql_as_string = sql_file.read()
self.cur.executescript(sql_as_string)
self.conn.commit()
except Exception as err:
if str(err) == 'table inbox already exists':
return "table inbox already exists"
else:
sys.stderr.write(
'ERROR trying to create database file (message.dat). Error message: %s\n' % str(err))
os._exit(0)
def filter_table_column(self, schema, column):
for x in schema:
for y in x:
if y == column:
yield y
def test_sql_thread_version_1(self):
"""
Test with version 1
"""
# normalise Db for version 1
self.normalize_version("1")
# Test versioning
upgrade_db = UpgradeDB()
upgrade_db.cur = self.cur
upgrade_db.upgrade_schema_data_1()
# Assertion after versioning
res = self.cur.execute('''PRAGMA table_info('inventory');''')
res = res.fetchall()
result = list(self.filter_table_column(res, "tag"))
res = [tup for tup in res if any(i in tup for i in ["tag"])]
self.assertEqual(result, ['tag'] , "Data not migrated for version 1")
self.assertEqual(res, [(5, 'tag', 'blob', 0, "''", 0)] , "Data not migrated for version 1")
def test_sql_thread_version_10(self):
"""
Test with version 10
"""
# Update version 10
self.normalize_version("10")
# Test versioning
upgrade_db = UpgradeDB()
upgrade_db.cur = self.cur
upgrade_db.upgrade_schema_data_10()
# Assertion
self.cur.execute(''' SELECT count(name) FROM sqlite_master WHERE type='table' AND name='old_addressbook' ''')
self.assertNotEqual(self.cur.fetchone(), 1 , "Table old_addressbook not deleted")
res = self.cur.execute('''PRAGMA table_info('addressbook');''')
res = res.fetchall()
result = list(self.filter_table_column(res, "address"))
self.assertEqual(result, ['address'] , "Data not migrated for version 10")
def test_sql_thread_version_9(self):
"""
Test with version 9
"""
# Update version 9
self.normalize_version("9")
# Test versioning
upgrade_db = UpgradeDB()
upgrade_db.cur = self.cur
upgrade_db.upgrade_schema_data_9()
# Assertion
self.cur.execute(''' SELECT count(name) FROM sqlite_master WHERE type='table' AND name='pubkeys_backup' ''')
self.assertNotEqual(self.cur.fetchone(), 1 , "Table pubkeys_backup not deleted")
res = self.cur.execute('''PRAGMA table_info('pubkeys');''')
res = res.fetchall()
result = list(self.filter_table_column(res, "address"))
self.assertEqual(result, ['address'] , "Data not migrated for version 9")
def test_sql_thread_version_8(self):
"""
Test with version 8
"""
# Update version 8
self.normalize_version("8")
# Test versioning
upgrade_db = UpgradeDB()
upgrade_db.cur = self.cur
upgrade_db.upgrade_schema_data_8()
# # Assertion
res = self.cur.execute('''PRAGMA table_info('inbox');''')
res = res.fetchall()
result = list(self.filter_table_column(res, "sighash"))
self.assertEqual(result, ['sighash'] , "Data not migrated for version 8")