diff --git a/py_dbcn/connectors/core/display.py b/py_dbcn/connectors/core/display.py index e2b3c1c850ac9dc386e0429383e4f8cb40fe3751..160edf442461abe3597a2cc61577fec6c20998f9 100644 --- a/py_dbcn/connectors/core/display.py +++ b/py_dbcn/connectors/core/display.py @@ -334,14 +334,18 @@ class RecordDisplay: total_col_len = 0 for table_col in table_cols: col_len = len(table_col) - record_len = self._base.query.execute( - self._parent.max_col_length_query .format( - table_col, - table_name, - self._base.validate._quote_identifier_format, - ), - display_query=False, - )[0][0] + if not any(keyword_str in table_col for keyword_str in self._base.validate._reserved_function_names): + record_len = self._base.query.execute( + self._parent.max_col_length_query.format( + table_col, + table_name, + self._base.validate._quote_identifier_format, + ), + display_query=False, + )[0][0] + else: + # Keyword found in str. For now, to prevent errors, default to size of 19 and skip query. + record_len = 19 length = max(col_len, record_len or 0) col_len_array.append(length) total_col_len += length + 2 diff --git a/py_dbcn/connectors/core/records.py b/py_dbcn/connectors/core/records.py index a9cedc452d44072dde618a63e6ead64aeaf29d09..5426d6ddeebe9fc9e6fcb578e10ebb9e1206b320 100644 --- a/py_dbcn/connectors/core/records.py +++ b/py_dbcn/connectors/core/records.py @@ -131,13 +131,15 @@ class BaseRecords: return results def insert_many(self, table_name, values_clause, columns_clause=None, display_query=True, display_results=True): - """""" + """"Inserts multiple records into provided table with one query.""" # Check that provided table name is valid format. if not self._base.validate.table_name(table_name): raise ValueError('Invalid table name of "{0}".'.format(table_name)) # Check that provided COLUMNS clause is valid format. columns_clause = self._base.validate.sanitize_columns_clause(columns_clause) + if columns_clause != '': + columns_clause = ' ({0})'.format(columns_clause) # Check that provided VALUES clause is valid format. # Must be array format. diff --git a/py_dbcn/connectors/core/validate.py b/py_dbcn/connectors/core/validate.py index 8c8fd450b8d99daef4409763c7d4dc3385400720..dc8fdfa4285464cbfbc92aad6bd3038fdf3fc3d4 100644 --- a/py_dbcn/connectors/core/validate.py +++ b/py_dbcn/connectors/core/validate.py @@ -48,6 +48,7 @@ class BaseValidate: self._quote_identifier_format = None self._quote_order_by_format = None self._quote_str_literal_format = None + self._reserved_function_names = None # region Name Validation diff --git a/py_dbcn/connectors/postgresql/validate.py b/py_dbcn/connectors/postgresql/validate.py index b81048c191a147ac0632458b9e329be7689b6942..ee6d5e44312df5c58c41c83788cc3bfca0d4d890 100644 --- a/py_dbcn/connectors/postgresql/validate.py +++ b/py_dbcn/connectors/postgresql/validate.py @@ -16,7 +16,7 @@ logger = init_logging(__name__) # Module Variables. -QUOTE_COLUMN_FORMAT = """'""" # Used for quoting table columns. +QUOTE_COLUMN_FORMAT = """\"""" # Used for quoting table columns. QUOTE_IDENTIFIER_FORMAT = """\"""" # Used for quoting identifiers (such as SELECT clause field id's). QUOTE_ORDER_BY_FORMAT = """\"""" # Used for quoting values in ORDER BY clause. QUOTE_STR_LITERAL_FORMAT = """'""" # Used for quoting actual strings. @@ -38,8 +38,12 @@ class PostgresqlValidate(BaseValidate): # https://www.postgresql.org/docs/current/sql-keywords-appendix.html self._reserved_function_names = [ 'ABS', - 'AVG,' + 'AVG', + 'BIT_AND', + 'BIT_OR', 'BIT_LENGTH', + 'BOOL_AND', + 'BOOL_OR', 'CASE', 'CAST', 'CEIL', @@ -82,6 +86,7 @@ class PostgresqlValidate(BaseValidate): 'SESSION_USER', 'SPACE', 'SQRT', + 'STDDEV', 'STDDEV_POP', 'STDDEV_SAMP', 'SUBSTRING', @@ -91,6 +96,7 @@ class PostgresqlValidate(BaseValidate): 'UPPER', 'VAR_POP', 'VAR_SAMP', + 'VARIANCE', 'YEAR', ] diff --git a/tests/connectors/core/test_records.py b/tests/connectors/core/test_records.py index 874a61cf988fce9cdaa926d9384c836b62774f0b..cfb0809218637e76b2b13294eccfd42dc9889800 100644 --- a/tests/connectors/core/test_records.py +++ b/tests/connectors/core/test_records.py @@ -8,7 +8,7 @@ various specific database test classes. This ensures that all databases types ru # System Imports. import datetime -import textwrap +from decimal import Decimal # Internal Imports. @@ -29,6 +29,7 @@ class CoreRecordsTestMixin: cls._columns_clause__basic = None cls._columns_clause__datetime = None + cls._columns_clause__aggregates = None def test_error_catch_types(self): """Tests to ensure database ERROR types are properly caught. @@ -779,6 +780,61 @@ class CoreRecordsTestMixin: self.assertIn(row_4, results) self.assertIn(row_5, results) + def test__select__aggregates(self): + """""" + table_name = 'test_queries__select__aggregate' + # Verify table exists. + try: + self.connector.query.execute('CREATE TABLE {0}{1};'.format(table_name, self._columns_clause__aggregates)) + except self.connector.errors.table_already_exists: + # Table already exists, as we want. + pass + + # Prepopulate with a few records. + self.connector.records.insert_many( + table_name, + [ + ('test one', 10, False), + ('test two', 12, False), + ('test three', 5, False), + ('test four', 3, False), + ('test five', 22, False), + ], + columns_clause=('test_str, test_int, test_bool'), + ) + + with self.subTest('SELECT with AVG aggregation'): + # Run test query. + results = self.connector.records.select(table_name, 'AVG(test_int)') + + # Verify return aggregate result. + self.assertEqual(len(results), 1) + self.assertEqual(results[0][0], Decimal('10.4')) + + with self.subTest('SELECT with MAX aggregation'): + # Run test query. + results = self.connector.records.select(table_name, 'MAX(test_int)') + + # Verify return aggregate result. + self.assertEqual(len(results), 1) + self.assertEqual(results[0][0], 22) + + with self.subTest('SELECT with MIN aggregation'): + # Run test query. + results = self.connector.records.select(table_name, 'MIN(test_int)') + + # Verify return aggregate result. + self.assertEqual(len(results), 1) + self.assertEqual(results[0][0], 3) + + with self.subTest('SELECT with SUM aggregation'): + # Run test query. + results = self.connector.records.select(table_name, 'SUM(test_int)') + + # Verify return aggregate result. + self.assertEqual(len(results), 1) + self.assertEqual(results[0][0], 52) + def test__insert__basic__success(self): """ Test `INSERT` query with basic values. diff --git a/tests/connectors/mysql/constants.py b/tests/connectors/mysql/constants.py index deac2545d3af09f4d10be86950ca53732a9489b7..ef69418088fae3fffd03bd65273169cb086d75e9 100644 --- a/tests/connectors/mysql/constants.py +++ b/tests/connectors/mysql/constants.py @@ -39,3 +39,14 @@ COLUMNS_CLAUSE__DATETIME = """ PRIMARY KEY ( id ) ) """.strip() + + +COLUMNS_CLAUSE__AGGREGATES = """ +( + id INT NOT NULL AUTO_INCREMENT, + test_str VARCHAR(100), + test_int INT, + test_bool TINYINT, + PRIMARY KEY ( id ) +) +""".strip() diff --git a/tests/connectors/mysql/test_records.py b/tests/connectors/mysql/test_records.py index a06aa1cba666fb4e93822a55aef0764aead91f67..d617c2f18202b236e3d049182a90666506065944 100644 --- a/tests/connectors/mysql/test_records.py +++ b/tests/connectors/mysql/test_records.py @@ -3,10 +3,10 @@ Tests for "records" logic of "MySQL" DB Connector class. """ # System Imports. -import textwrap +from decimal import Decimal # Internal Imports. -from .constants import COLUMNS_CLAUSE__BASIC, COLUMNS_CLAUSE__DATETIME +from .constants import COLUMNS_CLAUSE__BASIC, COLUMNS_CLAUSE__DATETIME, COLUMNS_CLAUSE__AGGREGATES from .test_core import TestMysqlDatabaseParent from tests.connectors.core.test_records import CoreRecordsTestMixin @@ -39,6 +39,7 @@ class TestMysqlRecords(TestMysqlDatabaseParent, CoreRecordsTestMixin): # Define default table columns. cls._columns_clause__basic = COLUMNS_CLAUSE__BASIC cls._columns_clause__datetime = COLUMNS_CLAUSE__DATETIME + cls._columns_clause__aggregates = COLUMNS_CLAUSE__AGGREGATES def test_error_catch_types(self): """Tests to ensure database ERROR types are properly caught. @@ -62,3 +63,66 @@ class TestMysqlRecords(TestMysqlDatabaseParent, CoreRecordsTestMixin): # Check that we use the correct handler. with self.assertRaises(self.connector.errors.table_already_exists): self.connector.query.execute('CREATE TABLE {0} {1};'.format(table_name, self._columns_clause__basic)) + + def test__select__aggregates(self): + """""" + table_name = 'test_queries__select__aggregate' + + # Run parent tests. + super().test__select__aggregates() + + # Tests that require slightly different syntax in different database types. + with self.subTest('SELECT with BIT_OR aggregation'): + # Run test query. + results = self.connector.records.select(table_name, 'BIT_OR(test_bool)') + + # Verify return aggregate result. + # No records returned True, so should be False. + self.assertEqual(len(results), 1) + self.assertEqual(results[0][0], False) + + # Upset a single record to be true, and test again. + results = self.connector.records.update(table_name, 'test_bool = True', where_clause='WHERE id = 2') + results = self.connector.records.select(table_name, 'BIT_OR(test_bool)') + + # Verify return aggregate result. + # At least one record returned True so should be True. + self.assertEqual(len(results), 1) + self.assertEqual(results[0][0], True) + + with self.subTest('SELECT with BIT_AND aggregation'): + # Run test query. + results = self.connector.records.select(table_name, 'BIT_AND(test_bool)') + + # Verify return aggregate result. + # Not all records returned True, so should be False. + self.assertEqual(len(results), 1) + self.assertEqual(results[0][0], False) + + # Update all records to be true, and test again. + results = self.connector.records.update(table_name, 'test_bool = True', where_clause='') + results = self.connector.records.select(table_name, 'bit_or(test_bool)') + + # Verify return aggregate result. + # All records returned True so should be True. + self.assertEqual(len(results), 1) + self.assertEqual(results[0][0], True) + + # Reset booleans to be False. + self.connector.records.update(table_name, 'test_bool = False', where_clause='') + + with self.subTest('SELECT with STDDEV aggregation'): + # Run test query. + results = self.connector.records.select(table_name, 'STDDEV(test_int)') + + # Verify return aggregate result. + self.assertEqual(len(results), 1) + self.assertEqual(results[0][0], 6.651315659326356) + + with self.subTest('SELECT with VARIANCE aggregation'): + # Run test query. + results = self.connector.records.select(table_name, 'VARIANCE(test_int)') + + # Verify return aggregate result. + self.assertEqual(len(results), 1) + self.assertEqual(results[0][0], 44.239999999999995) diff --git a/tests/connectors/postgresql/constants.py b/tests/connectors/postgresql/constants.py index 90c0cf9837a964440c553e568ae53775716ceaec..046dea67dd82f1bcf74075abd1fa50ab099b2119 100644 --- a/tests/connectors/postgresql/constants.py +++ b/tests/connectors/postgresql/constants.py @@ -38,3 +38,13 @@ COLUMNS_CLAUSE__DATETIME = """ test_date DATE ) """.strip() + + +COLUMNS_CLAUSE__AGGREGATES = """ +( + id serial PRIMARY KEY, + test_str VARCHAR(100), + test_int INTEGER, + test_bool BOOLEAN +) +""".strip() diff --git a/tests/connectors/postgresql/test_records.py b/tests/connectors/postgresql/test_records.py index 3f07b1b4c56b1d691bd94369c3091f5bfc77532a..0027928ec3ba03914ea8f859e4a5b11d262779eb 100644 --- a/tests/connectors/postgresql/test_records.py +++ b/tests/connectors/postgresql/test_records.py @@ -3,10 +3,10 @@ Tests for "records" logic of "PostgreSQL" DB Connector class. """ # System Imports. -import textwrap +from decimal import Decimal # Internal Imports. -from .constants import COLUMNS_CLAUSE__BASIC, COLUMNS_CLAUSE__DATETIME +from .constants import COLUMNS_CLAUSE__BASIC, COLUMNS_CLAUSE__DATETIME, COLUMNS_CLAUSE__AGGREGATES from .test_core import TestPostgresqlDatabaseParent from tests.connectors.core.test_records import CoreRecordsTestMixin @@ -39,6 +39,7 @@ class TestPostgresqlRecords(TestPostgresqlDatabaseParent, CoreRecordsTestMixin): # Define default table columns. cls._columns_clause__basic = COLUMNS_CLAUSE__BASIC cls._columns_clause__datetime = COLUMNS_CLAUSE__DATETIME + cls._columns_clause__aggregates = COLUMNS_CLAUSE__AGGREGATES def test_error_catch_types(self): """Tests to ensure database ERROR types are properly caught. @@ -62,3 +63,106 @@ class TestPostgresqlRecords(TestPostgresqlDatabaseParent, CoreRecordsTestMixin): # Check that we use the correct handler. with self.assertRaises(self.connector.errors.table_already_exists): self.connector.query.execute('CREATE TABLE {0} {1};'.format(table_name, self._columns_clause__basic)) + + def test__select__aggregates(self): + """""" + table_name = 'test_queries__select__aggregate' + + # Run parent tests. + super().test__select__aggregates() + + # Tests that require slightly different syntax in different database types. + with self.subTest('SELECT with BIT_OR aggregation'): + # Run test query. + results = self.connector.records.select(table_name, 'BIT_OR(test_bool::int)') + + # Verify return aggregate result. + # No records returned True, so should be False. + self.assertEqual(len(results), 1) + self.assertEqual(results[0][0], False) + + # Upset a single record to be true, and test again. + results = self.connector.records.update(table_name, 'test_bool = True', where_clause='WHERE id = 2') + results = self.connector.records.select(table_name, 'BIT_OR(test_bool::int)') + + # Verify return aggregate result. + # At least one record returned True so should be True. + self.assertEqual(len(results), 1) + self.assertEqual(results[0][0], True) + + with self.subTest('SELECT with BIT_AND aggregation'): + # Run test query. + results = self.connector.records.select(table_name, 'BIT_AND(test_bool::int)') + + # Verify return aggregate result. + # Not all records returned True, so should be False. + self.assertEqual(len(results), 1) + self.assertEqual(results[0][0], False) + + # Update all records to be true, and test again. + results = self.connector.records.update(table_name, 'test_bool = True', where_clause='') + results = self.connector.records.select(table_name, 'bit_or(test_bool::int)') + + # Verify return aggregate result. + # All records returned True so should be True. + self.assertEqual(len(results), 1) + self.assertEqual(results[0][0], True) + + # Reset booleans to be False. + self.connector.records.update(table_name, 'test_bool = False', where_clause='') + + with self.subTest('SELECT with STDDEV aggregation'): + # Run test query. + results = self.connector.records.select(table_name, 'STDDEV(test_int)') + + # Verify return aggregate result. + self.assertEqual(len(results), 1) + self.assertEqual(results[0][0], Decimal('7.4363969770312827')) + + with self.subTest('SELECT with VARIANCE aggregation'): + # Run test query. + results = self.connector.records.select(table_name, 'VARIANCE(test_int)') + + # Verify return aggregate result. + self.assertEqual(len(results), 1) + self.assertEqual(results[0][0], Decimal('55.3')) + + # Aggregate functions that don't exist outside of PostgreSQL. + with self.subTest('SELECT with BOOL_OR aggregation'): + # Run test query. + results = self.connector.records.select(table_name, 'BOOL_OR(test_bool)') + + # Verify return aggregate result. + # No records returned True, so should be False. + self.assertEqual(len(results), 1) + self.assertEqual(results[0][0], False) + + # Upset a single record to be true, and test again. + results = self.connector.records.update(table_name, 'test_bool = True', where_clause='WHERE id = 2') + results = self.connector.records.select(table_name, 'BOOL_OR(test_bool)') + + # Verify return aggregate result. + # At least one record returned True so should be True. + self.assertEqual(len(results), 1) + self.assertEqual(results[0][0], True) + + with self.subTest('SELECT with BOOL_AND aggregation'): + # Run test query. + results = self.connector.records.select(table_name, 'BOOL_AND(test_bool)') + + # Verify return aggregate result. + # Not all records returned True, so should be False. + self.assertEqual(len(results), 1) + self.assertEqual(results[0][0], False) + + # Update all records to be true, and test again. + self.connector.records.update(table_name, 'test_bool = True', where_clause='') + results = self.connector.records.select(table_name, 'BOOL_or(test_bool)') + + # Verify return aggregate result. + # All records returned True so should be True. + self.assertEqual(len(results), 1) + self.assertEqual(results[0][0], True) + + # Reset booleans to be False. + self.connector.records.update(table_name, 'test_bool = False', where_clause='') diff --git a/tests/connectors/postgresql/test_validate.py b/tests/connectors/postgresql/test_validate.py index 5131d318b10870f387ead60e5a58b8c0d062ecd4..f551bdb0ebc54cfcaf0d8e4a8ed22da7c5c5360d 100644 --- a/tests/connectors/postgresql/test_validate.py +++ b/tests/connectors/postgresql/test_validate.py @@ -50,7 +50,7 @@ class TestPostgresqlValidate(TestPostgresqlDatabaseParent, CoreValidateTestMixin def test__column_quote_format(self): # Verify quote str is as we expect. - self.assertText("'{0}'", self._quote_columns_format) + self.assertText('"{0}"', self._quote_columns_format) def test__select_identifier_quote_format(self): # Verify quote str is as we expect.