diff --git a/py_dbcn/connectors/core/clauses.py b/py_dbcn/connectors/core/clauses.py index 97b8ea9e8e4ec53193622b1d1324ad87abd3d9e2..5bf167de6bac26f7f1b18f43048d7f161a42bd4f 100644 --- a/py_dbcn/connectors/core/clauses.py +++ b/py_dbcn/connectors/core/clauses.py @@ -79,6 +79,10 @@ class BaseClauseBuilder(object): def array(self, value): self._to_array(value) + @property + def context(self): + return ', '.join('{}' for i in range(len(self.array))) + def _to_array(self, value): """Converts clause to array format for initial parsing.""" if self._clause_prefix is None: @@ -162,13 +166,16 @@ class BaseClauseBuilder(object): for item in original_clause: # Handle various specific types. + is_datetime = False if isinstance(item, datetime.datetime): # Is a datetime object. Convert to string. item = "'{0}'".format(item.strftime('%Y-%m-%d %H:%M:%S')) + is_datetime = True elif isinstance(item, datetime.date): # Is a date object. Convert to string. item = "'{0}'".format(item.strftime('%Y-%m-%d')) + is_datetime = True # Skip handling for other non-str items. elif not isinstance(item, str): @@ -228,9 +235,23 @@ class BaseClauseBuilder(object): item = item[:-5].rstrip() order_by_descriptor = ' DESC' + # # Extra string handling for date/datetime objects. + # if is_datetime: + # item = item[1:-1] + print('') print('item: {0}'.format(item)) + # if ( + # len(item) > 0 + # and item != '*' + # and ( + # item[0] not in ['"', "'", '`'] + # or item[-1] not in ['"', "'", '`'] + # ) + # ): + # item = """'{0}'""".format(item) + # If we made it this far, item is valid. Escape with proper quote format and readd. is_quoted = False if self.is_quoted(item): @@ -486,6 +507,72 @@ class ValuesClauseBuilder(BaseClauseBuilder): self.array = clause +class ValuesManyClauseBuilder(ValuesClauseBuilder): + """""" + + def _validate_clause(self, original_clause): + """Used to validate/sanitize an array of clause values.""" + + # Handle the same as original logic, except there is one extra layer. + # So loop through each inner item and hand that to validation. + print('\n\n\n\n') + print('original_clause:') + print('{0}'.format(original_clause)) + + if len(original_clause) > 0: + for index in range(len(original_clause)): + inner_clause = original_clause[index] + print(' inner_clause:') + print(' {0}'.format(inner_clause)) + original_clause[index] = super()._validate_clause(inner_clause) + print(' updated inner_clause:') + print(' {0}'.format(original_clause[index])) + + print('final result:') + print('{0}'.format(original_clause)) + + # Return validated clause. + return original_clause + + else: + # Return empty clause. + return [] + + def __str__(self): + if len(self.array) > 0: + # Non-empty clause. Format for str output. + to_str = self.context + all_values = [] + for inner_array in self.array: + for value in inner_array: + all_values.append(value) + print('all_values:') + print('{0}'.format(all_values)) + print(to_str.format(*all_values)) + to_str = to_str.format(*all_values) + print('to_str:') + print('{0}'.format(to_str)) + if self._print_parens: + to_str = '{0}({1})'.format(self._print_prefix, to_str) + else: + to_str = '{0}{1}'.format(self._print_prefix, to_str) + return to_str + else: + # Empty clause. + return '' + + @property + def context(self): + if len(self.array) > 0: + context_line = ', '.join('{}' for i in range(len(self.array[0]))) + context_line = ' ({0})'.format(context_line) + context = ',\n'.join(context_line for i in range(len(self.array))) + context += '\n' + return context + else: + return '' + + class SetClauseBuilder(BaseClauseBuilder): """""" def __init__(self, validation_class, clause, *args, clause_type='VALUES', **kwargs): diff --git a/py_dbcn/connectors/core/validate.py b/py_dbcn/connectors/core/validate.py index f4820b1e845927f4876639ec8c328092d69208a7..3449121ccf010fde879e881397447426561f33f7 100644 --- a/py_dbcn/connectors/core/validate.py +++ b/py_dbcn/connectors/core/validate.py @@ -361,6 +361,9 @@ class BaseValidate: """ return clauses.ValuesClauseBuilder(self, clause) + def sanitize_values_many_clause(self, clause): + return clauses.ValuesManyClauseBuilder(self, clause) + def sanitize_set_clause(self, clause): return clauses.SetClauseBuilder(self, clause) diff --git a/py_dbcn/connectors/postgresql/records.py b/py_dbcn/connectors/postgresql/records.py index 86648acf20d60ee80c79770da249487c8ac948f8..74697d0551bf3aa6e2d39fbd597c0d8572aadcb0 100644 --- a/py_dbcn/connectors/postgresql/records.py +++ b/py_dbcn/connectors/postgresql/records.py @@ -75,7 +75,7 @@ class PostgresqlRecords(BaseRecords): raise ValueError('VALUES clause for INSERT_MANY queries must be in list/tuple format.') if len(values_clause) < 1: raise ValueError('VALUES clause cannot be empty for UPDATE_MANY queries.') - values_clause = self._base.validate.sanitize_values_clause(values_clause) + values_clause = self._base.validate.sanitize_values_many_clause(values_clause) # Check that provided WHERE clause is valid format. columns_clause = self._base.validate.sanitize_columns_clause(columns_clause) @@ -110,7 +110,7 @@ class PostgresqlRecords(BaseRecords): for x in columns_clause.array ]) values_clause = ',\n'.join([ - ' {0}'.format(x) + ' {0}'.format(tuple(x)) for x in values_clause.array ]) columns_clause = ', '.join([ diff --git a/tests/connectors/core/test_records.py b/tests/connectors/core/test_records.py index 5447dfbd3cdcba002fb44e65ed53f1a291e2ef80..d5c9151e47df29a981b9be444b9e4e2afc5bd27b 100644 --- a/tests/connectors/core/test_records.py +++ b/tests/connectors/core/test_records.py @@ -976,11 +976,11 @@ class CoreRecordsTestMixin: # # self.assertEqual(len(results), 2) # # self.assertIn(row, results) - def test__insert_many__success(self): + def test__insert_many__basic__success(self): """ Test execute_many `INSERT` query. """ - table_name = 'test_queries__insert_many__success' + table_name = 'test_queries__insert_many__basic__success' # Verify table exists. try: @@ -1093,6 +1093,121 @@ class CoreRecordsTestMixin: self.assertIn(row_9, results) self.assertIn(row_10, results) + def test__insert_many__datetime__success(self): + """ + Test execute_many `UPDATE` query with datetime values. + """ + table_name = 'test_queries__insert_many__datetime__success' + + # Verify table exists. + try: + self.connector.query.execute('CREATE TABLE {0}{1};'.format(table_name, self._columns_clause__datetime)) + except self.connector.errors.table_already_exists: + # Table already exists, as we want. + pass + + # Verify starting state. + results = self.connector.query.execute('SELECT * FROM {0};'.format(table_name)) + self.assertEqual(len(results), 0) + + # Generate datetime objects. + test_datetime__2020 = datetime.datetime( + year=2020, + month=6, + day=15, + hour=7, + minute=12, + second=52, + microsecond=0, + ) + test_date__2020 = test_datetime__2020.date() + test_datetime__2021 = datetime.datetime( + year=2021, + month=7, + day=16, + hour=8, + minute=13, + second=53, + microsecond=0, + ) + test_date__2021 = test_datetime__2021.date() + test_datetime__2022 = datetime.datetime( + year=2022, + month=8, + day=17, + hour=9, + minute=14, + second=54, + microsecond=0, + ) + test_date__2022 = test_datetime__2022.date() + + # Generate row values. + row_1 = (1, test_datetime__2020, test_date__2020) + row_2 = (2, test_datetime__2021, test_date__2021) + row_3 = (3, test_datetime__2022, test_date__2022) + + # Generate initial rows. + rows = [ + row_1, + row_2, + row_3, + ] + + with self.subTest('With one insert'): + # Run test query. + rows = [ + row_1, + ] + self.connector.records.insert_many(table_name, rows) + + # Verify one record returned. + results = self.connector.query.execute('SELECT * FROM {0};'.format(table_name)) + self.assertEqual(len(results), 1) + self.assertIn(row_1, results) + + # Reset table. + self.connector.tables.drop(table_name) + self.connector.tables.create(table_name, self._columns_clause__datetime) + + with self.subTest('With two inserts'): + # Run test query. + rows = [ + row_1, + row_2, + ] + self.connector.records.insert_many(table_name, rows) + + # Verify one record returned. + results = self.connector.query.execute('SELECT * FROM {0};'.format(table_name)) + self.assertEqual(len(results), 2) + self.assertIn(row_1, results) + self.assertIn(row_2, results) + + # Reset table. + self.connector.tables.drop(table_name) + self.connector.tables.create(table_name, self._columns_clause__datetime) + + with self.subTest('With three inserts'): + # Run test query. + rows = [ + row_1, + row_2, + row_3, + ] + self.connector.records.insert_many(table_name, rows) + + # Verify five records returned. + results = self.connector.query.execute('SELECT * FROM {0};'.format(table_name)) + self.assertEqual(len(results), 3) + self.assertIn(row_1, results) + self.assertIn(row_2, results) + self.assertIn(row_3, results) + + # Reset table. + self.connector.tables.drop(table_name) + self.connector.tables.create(table_name, self._columns_clause__datetime) + def test__update__basic__success(self): """ Test `UPDATE` query with basic values. @@ -1508,16 +1623,16 @@ class CoreRecordsTestMixin: with self.subTest('With ten updates'): # Run test query. - updated_row_1 = (1, '"110"', '"10010"') - updated_row_2 = (2, '"109"', '"10009"') - updated_row_3 = (3, '"108"', '"10008"') - updated_row_4 = (4, '"107"', '"10007"') - updated_row_5 = (5, '"106"', '"10006"') - updated_row_6 = (6, '"105"', '"10005"') - updated_row_7 = (7, '"104"', '"10004"') - updated_row_8 = (8, '"103"', '"10003"') - updated_row_9 = (9, '"102"', '"10002"') - updated_row_10 = (10, '"101"', '"10001"') + updated_row_1 = (1, '110', '10010') + updated_row_2 = (2, '109', '10009') + updated_row_3 = (3, '108', '10008') + updated_row_4 = (4, '107', '10007') + updated_row_5 = (5, '106', '10006') + updated_row_6 = (6, '105', '10005') + updated_row_7 = (7, '104', '10004') + updated_row_8 = (8, '103', '10003') + updated_row_9 = (9, '102', '10002') + updated_row_10 = (10, '101', '10001') columns_clause = 'id, name, description' values_clause = [ updated_row_1,