diff --git a/src/bmconfigparser.py b/src/bmconfigparser.py index be7b3e87..ff43fd7c 100644 --- a/src/bmconfigparser.py +++ b/src/bmconfigparser.py @@ -2,22 +2,13 @@ BMConfigParser class definition and default configuration settings """ -import sys -if sys.version_info[0] == 3: - # python 3 - import configparser as ConfigParser - SafeConfigParser = ConfigParser.ConfigParser -else: - # python 2 - import ConfigParser - SafeConfigParser = ConfigParser.SafeConfigParser - -import state -from singleton import Singleton +import ConfigParser import os import shutil from datetime import datetime +import state +from singleton import Singleton BMConfigDefaults = { "bitmessagesettings": { @@ -52,8 +43,7 @@ BMConfigDefaults = { @Singleton -class BMConfigParser(SafeConfigParser): - +class BMConfigParser(ConfigParser.SafeConfigParser): """ Singleton class inherited from :class:`ConfigParser.SafeConfigParser` with additional methods specific to bitmessage config. @@ -70,47 +60,26 @@ class BMConfigParser(SafeConfigParser): raise ValueError("Invalid value %s" % value) return ConfigParser.ConfigParser.set(self, section, option, value) - def get(self, section, option, raw=False, vars=None): - if sys.version_info[0] == 3: - # pylint: disable=arguments-differ - try: - if section == "bitmessagesettings" and option == "timeformat": - return ConfigParser.ConfigParser.get( - self, section, option) - try: - return self._temp[section][option] - except KeyError: - pass + def get(self, section, option, raw=False, variables=None): + # pylint: disable=arguments-differ + try: + if section == "bitmessagesettings" and option == "timeformat": return ConfigParser.ConfigParser.get( - self, section, option) - except ConfigParser.InterpolationError: - return ConfigParser.ConfigParser.get( - self, section, option) - except (ConfigParser.NoSectionError, ConfigParser.NoOptionError) as e: - try: - return BMConfigDefaults[section][option] - except (KeyError, ValueError, AttributeError): - raise e - else: - # pylint: disable=arguments-differ + self, section, option, raw, variables) try: - if section == "bitmessagesettings" and option == "timeformat": - return ConfigParser.ConfigParser.get( - self, section, option, raw, vars) - try: - return self._temp[section][option] - except KeyError: - pass - return ConfigParser.ConfigParser.get( - self, section, option, True, vars) - except ConfigParser.InterpolationError: - return ConfigParser.ConfigParser.get( - self, section, option, True, vars) - except (ConfigParser.NoSectionError, ConfigParser.NoOptionError) as e: - try: - return BMConfigDefaults[section][option] - except (KeyError, ValueError, AttributeError): - raise e + return self._temp[section][option] + except KeyError: + pass + return ConfigParser.ConfigParser.get( + self, section, option, True, variables) + except ConfigParser.InterpolationError: + return ConfigParser.ConfigParser.get( + self, section, option, True, variables) + except (ConfigParser.NoSectionError, ConfigParser.NoOptionError) as e: + try: + return BMConfigDefaults[section][option] + except (KeyError, ValueError, AttributeError): + raise e def setTemp(self, section, option, value=None): """Temporary set option to value, not saving.""" @@ -222,4 +191,3 @@ class BMConfigParser(SafeConfigParser): if value < 0 or value > 8: return False return True - diff --git a/src/class_sqlThread.py b/src/class_sqlThread.py index 84188408..86460e7b 100644 --- a/src/class_sqlThread.py +++ b/src/class_sqlThread.py @@ -8,7 +8,6 @@ import sqlite3 import sys import threading import time - import helper_sql import helper_startup import paths @@ -18,6 +17,7 @@ import tr from bmconfigparser import BMConfigParser from debug import logger # pylint: disable=attribute-defined-outside-init,protected-access +from addresses import encodeAddress class sqlThread(threading.Thread): @@ -35,6 +35,9 @@ class sqlThread(threading.Thread): self.cur.execute('PRAGMA secure_delete = true') + # call create_function for encode address + self.create_function() + try: self.cur.execute( '''CREATE TABLE inbox (msgid blob, toaddress text, fromaddress text, subject text,''' @@ -325,6 +328,7 @@ class sqlThread(threading.Thread): # We'll also need a `sleeptill` field and a `ttl` field. Also we # can combine the pubkeyretrynumber and msgretrynumber into one. + item = '''SELECT value FROM settings WHERE key='version';''' parameters = '' self.cur.execute(item, parameters) @@ -358,16 +362,11 @@ class sqlThread(threading.Thread): logger.debug('In messages.dat database, adding address field to the pubkeys table.') # We're going to have to calculate the address for each row in the pubkeys # table. Then we can take out the hash field. - self.cur.execute('''ALTER TABLE pubkeys ADD address text DEFAULT '' ''') - self.cur.execute('''SELECT hash, addressversion FROM pubkeys''') - queryResult = self.cur.fetchall() - from addresses import encodeAddress - for row in queryResult: - addressHash, addressVersion = row - address = encodeAddress(addressVersion, 1, hash) - item = '''UPDATE pubkeys SET address=? WHERE hash=?;''' - parameters = (address, addressHash) - self.cur.execute(item, parameters) + self.cur.execute('''ALTER TABLE pubkeys ADD address text DEFAULT '' ;''') + + # replica for loop to update hashed address + self.cur.execute('''UPDATE pubkeys SET address=(enaddr(pubkeys.addressversion, 1, hash)) WHERE hash=pubkeys.hash; ''') + # Now we can remove the hash field from the pubkeys table. self.cur.execute( '''CREATE TEMPORARY TABLE pubkeys_backup''' @@ -622,3 +621,12 @@ class sqlThread(threading.Thread): helper_sql.sqlReturnQueue.put((self.cur.fetchall(), rowcount)) # helper_sql.sqlSubmitQueue.task_done() + + def create_function(self): + # create_function + try: + self.conn.create_function("enaddr", 3, func=encodeAddress, deterministic=True) + except (TypeError, sqlite3.NotSupportedError) as err: + logger.debug( + "Got error while pass deterministic in sqlite create function {}, Passing 3 params".format(err)) + self.conn.create_function("enaddr", 3, encodeAddress) diff --git a/src/helper_sql.py b/src/helper_sql.py index 043bccf2..26398848 100644 --- a/src/helper_sql.py +++ b/src/helper_sql.py @@ -16,9 +16,16 @@ SQLite objects can only be used from one thread. or isn't thread-safe. """ -import Queue + +# import Queue +try: + import queue as Queue #python3 +except ImportError: + import Queue #python2 + import threading + sqlSubmitQueue = Queue.Queue() """the queue for SQL""" sqlReturnQueue = Queue.Queue() @@ -105,6 +112,15 @@ def sqlExecute(sql_statement, *args): return rowcount +def sqlExecuteScript(sql_statement): + """Execute SQL script statement""" + + statements = sql_statement.split(";") + with SqlBulkExecute() as sql: + for q in statements: + sql.execute("{}".format(q)) + + def sqlStoredProcedure(procName): """Schedule procName to be run""" assert sql_available diff --git a/src/tests/sql/create_function.sql b/src/tests/sql/create_function.sql new file mode 100644 index 00000000..b66d1aa1 --- /dev/null +++ b/src/tests/sql/create_function.sql @@ -0,0 +1,11 @@ +CREATE TABLE `testhash` ( + `addressversion` int DEFAULT NULL, + `hash` blob DEFAULT NULL, + `address` text DEFAULT NULL, + UNIQUE(address) ON CONFLICT IGNORE +); + + + +INSERT INTO testhash (addressversion, hash) VALUES(4, "21122112211221122112"); + diff --git a/src/tests/test_config.py b/src/tests/test_config.py index d44ec738..a3b90a4c 100644 --- a/src/tests/test_config.py +++ b/src/tests/test_config.py @@ -2,7 +2,11 @@ Various tests for config """ +import os import unittest +import tempfile + +from .test_process import TestProcessProto from pybitmessage.bmconfigparser import BMConfigParser @@ -34,3 +38,32 @@ class TestConfig(unittest.TestCase): BMConfigParser().safeGetInt('nonexistent', 'nonexistent'), 0) self.assertEqual( BMConfigParser().safeGetInt('nonexistent', 'nonexistent', 42), 42) + + +class TestProcessConfig(TestProcessProto): + """A test case for keys.dat""" + home = tempfile.mkdtemp() + + def test_config_defaults(self): + """Test settings in the generated config""" + self._stop_process() + self._kill_process() + config = BMConfigParser() + config.read(os.path.join(self.home, 'keys.dat')) + + self.assertEqual(config.safeGetInt( + 'bitmessagesettings', 'settingsversion'), 10) + self.assertEqual(config.safeGetInt( + 'bitmessagesettings', 'port'), 8444) + # don't connect + self.assertTrue(config.safeGetBoolean( + 'bitmessagesettings', 'dontconnect')) + # API disabled + self.assertFalse(config.safeGetBoolean( + 'bitmessagesettings', 'apienabled')) + + # extralowdifficulty is false + self.assertEqual(config.safeGetInt( + 'bitmessagesettings', 'defaultnoncetrialsperbyte'), 1000) + self.assertEqual(config.safeGetInt( + 'bitmessagesettings', 'defaultpayloadlengthextrabytes'), 1000) diff --git a/src/tests/test_config_process.py b/src/tests/test_config_process.py deleted file mode 100644 index f3cf19f2..00000000 --- a/src/tests/test_config_process.py +++ /dev/null @@ -1,38 +0,0 @@ -""" -Various tests for config -""" - -import os -import tempfile -from pybitmessage.bmconfigparser import BMConfigParser -from .test_process import TestProcessProto - - -class TestProcessConfig(TestProcessProto): - """A test case for keys.dat""" - home = tempfile.mkdtemp() - - - def test_config_defaults(self): - """Test settings in the generated config""" - config = BMConfigParser() - self._stop_process() - self._kill_process() - config.read(os.path.join(self.home, 'keys.dat')) - - self.assertEqual(config.safeGetInt( - 'bitmessagesettings', 'settingsversion'), 10) - self.assertEqual(config.safeGetInt( - 'bitmessagesettings', 'port'), 8444) - # don't connect - self.assertTrue(config.safeGetBoolean( - 'bitmessagesettings', 'dontconnect')) - # API disabled - self.assertFalse(config.safeGetBoolean( - 'bitmessagesettings', 'apienabled')) - - # extralowdifficulty is false - self.assertEqual(config.safeGetInt( - 'bitmessagesettings', 'defaultnoncetrialsperbyte'), 1000) - self.assertEqual(config.safeGetInt( - 'bitmessagesettings', 'defaultpayloadlengthextrabytes'), 1000) diff --git a/src/tests/test_sqlthread.py b/src/tests/test_sqlthread.py new file mode 100644 index 00000000..d51c6623 --- /dev/null +++ b/src/tests/test_sqlthread.py @@ -0,0 +1,73 @@ +""" + Test for sqlThread +""" + +import os +import unittest +from ..helper_sql import sqlStoredProcedure, sql_ready, sqlExecute, SqlBulkExecute, sqlQuery, sqlExecuteScript +from ..class_sqlThread import (sqlThread) +from ..addresses import encodeAddress +from .common import skip_python3 + + +skip_python3() + + +class TestSqlThread(unittest.TestCase): + """ + Test case for SQLThread + """ + + # query file path + root_path = os.path.dirname(os.path.dirname(__file__)) + + @classmethod + def setUpClass(cls): + # Start SQL thread + sqlLookup = sqlThread() + sqlLookup.daemon = False + sqlLookup.start() + sql_ready.wait() + + @classmethod + def setUp(cls): + tables = list(sqlQuery("select name from sqlite_master where type is 'table'")) + with SqlBulkExecute() as sql: + for q in tables: + sql.execute("drop table if exists %s" % q) + + @classmethod + def tearDown(cls): + pass + + @classmethod + def tearDownClass(cls): + # Stop sql thread + sqlStoredProcedure('exit') + + def initialise_database(self, file): + """ + Initialise DB + """ + + sql_as_string = open(os.path.join(self.root_path, "tests/sql/{}.sql".format(file))).read() + sqlExecuteScript(sql_as_string) + + + def test_create_function(self): + # call create function + + encoded_str = encodeAddress(4, 1, "21122112211221122112") + + # Initialise Database + self.initialise_database("create_function") + + sqlExecute('''INSERT INTO testhash (addressversion, hash) VALUES(4, "21122112211221122112")''') + # call function in query + + sqlExecute('''UPDATE testhash SET address=(enaddr(testhash.addressversion, 1, hash)) WHERE hash=testhash.hash''') + + # Assertion + query = sqlQuery('''select * from testhash;''') + self.assertEqual(query[0][-1], encoded_str, "test case fail for create_function") + sqlExecute('''DROP TABLE testhash''')