diff --git a/py_dbcn/connectors/core/records.py b/py_dbcn/connectors/core/records.py index 7d23fe051699d84ac288a3fb03447d5eb3671aa5..1566d5b56172d7032bb98b215ac6c9275cd9c4d3 100644 --- a/py_dbcn/connectors/core/records.py +++ b/py_dbcn/connectors/core/records.py @@ -143,6 +143,8 @@ class BaseRecords: # Must be array format. if not isinstance(values_clause, list) and not isinstance(values_clause, tuple): raise ValueError('VALUES clause for INSERT_MANY queries must be in list/tuple format.') + if len(values_clause) < 1: + raise ValueError('VALUES clause cannot be empty for INSERT_MANY queries.') values_clause = self._base.validate.sanitize_values_clause(values_clause) # Check for values that might need formatting. @@ -243,6 +245,7 @@ class BaseRecords: def update_many( self, table_name, columns_clause, values_clause, where_columns_clause, + column_types_clause=None, display_query=True, display_results=True, ): """Updates record in provided table. @@ -251,6 +254,8 @@ class BaseRecords: :param columns_clause: Clause to specify columns to insert into. :param values_clause: Clause to specify values to insert. :param where_columns_clause: NOT STANDARD WHERE CLAUSE. Columns to use as WHERE in provided values. + :param column_types_clause: Optional clause to provide type hinting for column types. Not required if all + columns are basic types such as text or integer. :param display_query: Bool indicating if query should output to console. Defaults to True. :param display_results: Bool indicating if results should output to console. Defaults to True. """ @@ -262,6 +267,8 @@ class BaseRecords: # Must be array format. if not isinstance(values_clause, list) and not isinstance(values_clause, tuple): raise ValueError('VALUES clause for INSERT_MANY queries must be in list/tuple format.') + if len(values_clause) < 1: + raise ValueError('VALUES clause cannot be empty for UPDATE_MANY queries.') values_clause = self._base.validate.sanitize_values_clause(values_clause) # Check that provided WHERE clause is valid format. @@ -285,26 +292,43 @@ class BaseRecords: # For example, if we find date/datetime objects, we automatically convert to a str value that won't error. if isinstance(values_clause, list) or isinstance(values_clause, tuple): updated_values_clause = () - for item in values_clause: + for value_set in values_clause: + updated_values_set = () + for item in value_set: - if isinstance(item, datetime.datetime): - # Is a datetime object. Convert to string. - item = item.strftime('%Y-%m-%d %H:%M:%S') - elif isinstance(item, datetime.date): - # Is a date object. Convert to string. - item = item.strftime('%Y-%m-%d') + if isinstance(item, datetime.datetime): + # Is a datetime object. Convert to string. + item = item.strftime('%Y-%m-%d %H:%M:%S') + elif isinstance(item, datetime.date): + # Is a date object. Convert to string. + item = item.strftime('%Y-%m-%d') + + # Add item to updated inner set. + updated_values_set += (item,) # Add item to updated clause. - updated_values_clause += (item,) + updated_values_clause += (updated_values_set,) # Replace original clause. values_clause = updated_values_clause # Now format our clauses for query. - set_clause = ',\n'.join([ - ' {0} = pydbcn_temp.{0}'.format(x.strip(self._base.validate._quote_column_format)) - for x in columns_clause - ]) + if column_types_clause is not None: + # Provide type hinting for columns. + set_clause = '' + for index in range(len(columns_clause)): + if set_clause != '': + set_clause += ',\n' + set_clause += ' {0} = pydbcn_temp.{0}::{1}'.format( + columns_clause[index].strip(self._base.validate._quote_column_format), + column_types_clause[index], + ) + else: + # No type hinting. Provide columns as-is. + set_clause = ',\n'.join([ + ' {0} = pydbcn_temp.{0}'.format(x.strip(self._base.validate._quote_column_format)) + for x in columns_clause + ]) values_clause = ',\n'.join([ ' {0}'.format(x) for x in values_clause @@ -338,6 +362,7 @@ class BaseRecords: ); """.format(table_name, set_clause, values_clause, columns_clause, where_columns_clause) ) + # print('\n\nquery:\n{0}'.format(query)) results = self._base.query.execute(query, display_query=display_query) diff --git a/tests/connectors/core/test_records.py b/tests/connectors/core/test_records.py index 15ddbced6d47fdd12fc16d3a4655c7c0ca882ad2..8b30e4e24fb9fa46f2e4e34aedff93c9c783d865 100644 --- a/tests/connectors/core/test_records.py +++ b/tests/connectors/core/test_records.py @@ -997,7 +997,7 @@ class CoreRecordsTestMixin: def test__update__basic__success(self): """ - Test `UPDATE` query. + Test `UPDATE` query with basic values. """ table_name = 'test_queries__update__basic__success' @@ -1097,7 +1097,7 @@ class CoreRecordsTestMixin: def test__update__datetime__success(self): """ - Test `UPDATE` query. + Test `UPDATE` query with datetime values. """ table_name = 'test_queries__update__datetime__success' @@ -1246,11 +1246,11 @@ class CoreRecordsTestMixin: self.assertNotIn(old_row_2, results) self.assertNotIn(old_row_3, results) - def test__update_many__success(self): + def test__update_many__basic__success(self): """ - Test execute_many `UPDATE` query. + Test execute_many `UPDATE` query with basic values. """ - table_name = 'test_queries__update_many__success' + table_name = 'test_queries__update_many__basic__success' # Verify table exists. try: @@ -1364,7 +1364,7 @@ class CoreRecordsTestMixin: row_2 = updated_row_2 row_3 = updated_row_3 - with self.subTest('Run with five updates'): + with self.subTest('Run with five updates and alternate where column'): # Run test query. Update by non-PK. updated_row_4 = (4, 'test_name_4', 'four') updated_row_5 = (5, 'test_name_5', 'five') @@ -1460,6 +1460,275 @@ class CoreRecordsTestMixin: self.assertNotIn(row_9, results) self.assertNotIn(row_10, results) + # Update row variables. + row_1 = updated_row_1 + row_2 = updated_row_2 + row_3 = updated_row_3 + row_4 = updated_row_4 + row_5 = updated_row_5 + row_6 = updated_row_6 + row_7 = updated_row_7 + row_8 = updated_row_8 + row_9 = updated_row_9 + row_10 = updated_row_10 + + with self.subTest('Run with columns in alternate order'): + # Run test query. + updated_row_3 = (3, 'name as first', 'desc as second') + columns_clause = ['name', 'description', 'id'] + values_clause = [ + ('name as first', 'desc as second', 3) + ] + where_columns_clause = ['id'] + self.connector.records.update_many(table_name, columns_clause, values_clause, where_columns_clause) + + # Verify one record returned. + results = self.connector.query.execute('SELECT * FROM {0};'.format(table_name)) + self.assertEqual(len(results), 10) + self.assertIn(row_1, results) + self.assertIn(row_2, results) + self.assertIn(updated_row_3, results) + self.assertIn(row_4, results) + self.assertIn(row_5, results) + self.assertIn(row_6, results) + self.assertIn(row_7, results) + self.assertIn(row_8, results) + self.assertIn(row_9, results) + self.assertIn(row_10, results) + self.assertNotIn(row_3, results) + + # Update row variables. + row_3 = updated_row_3 + + with self.subTest('Run with skipping unused columns'): + # Run test query. + updated_row_5 = (5, 'this is') + updated_row_6 = (6, 'a') + updated_row_7 = (7, 'test') + columns_clause = ['id', 'description'] + values_clause = [ + updated_row_5, + updated_row_6, + updated_row_7, + ] + where_columns_clause = ['id'] + self.connector.records.update_many(table_name, columns_clause, values_clause, where_columns_clause) + + # Verify one record returned. + results = self.connector.query.execute('SELECT * FROM {0};'.format(table_name)) + self.assertEqual(len(results), 10) + self.assertIn(row_1, results) + self.assertIn(row_2, results) + self.assertIn(row_3, results) + self.assertIn(row_4, results) + self.assertIn((updated_row_5[0], row_5[1], updated_row_5[1]), results) + self.assertIn((updated_row_6[0], row_6[1], updated_row_6[1]), results) + self.assertIn((updated_row_7[0], row_7[1], updated_row_7[1]), results) + self.assertIn(row_8, results) + self.assertIn(row_9, results) + self.assertIn(row_10, results) + self.assertNotIn(row_5, results) + self.assertNotIn(row_6, results) + self.assertNotIn(row_7, results) + + # Update row variables. + row_5 = updated_row_5 + row_6 = updated_row_6 + row_7 = updated_row_7 + + def test__update_many__datetime__success(self): + """ + Test execute_many `UPDATE` query with datetime values. + """ + table_name = 'test_queries__update_many__datetime__success' + + # Verify table exists. + try: + self.connector.query.execute('CREATE TABLE {0}{1};'.format(table_name, self._columns_clause__datetime)) + except self.connector.errors.table_already_exists: + # Table already exists, as we want. + pass + + # Generate datetime objects. + test_datetime__2020 = datetime.datetime( + year=2020, + month=6, + day=15, + hour=7, + minute=12, + second=52, + microsecond=0, + ) + test_date__2020 = test_datetime__2020.date() + test_datetime__2021 = datetime.datetime( + year=2021, + month=7, + day=16, + hour=8, + minute=13, + second=53, + microsecond=0, + ) + test_date__2021 = test_datetime__2021.date() + test_datetime__2022 = datetime.datetime( + year=2022, + month=8, + day=17, + hour=9, + minute=14, + second=54, + microsecond=0, + ) + test_date__2022 = test_datetime__2022.date() + + # Generate row values. + row_1 = (1, test_datetime__2020, test_date__2020) + row_2 = (2, test_datetime__2021, test_date__2021) + row_3 = (3, test_datetime__2022, test_date__2022) + + # Generate initial rows. + rows = [ + row_1, + row_2, + row_3, + ] + self.connector.records.insert_many(table_name, rows) + + # Verify expected state. + results = self.connector.query.execute('SELECT * FROM {0};'.format(table_name)) + self.assertEqual(len(results), 3) + self.assertIn(row_1, results) + self.assertIn(row_2, results) + self.assertIn(row_3, results) + + columns_clause = ['id', 'test_datetime', 'test_date'] + column_types_clause = ('integer', 'timestamp', 'date') + where_columns_clause = ['id'] + + with self.subTest('Run with one update'): + # Run test query. + updated_row_1 = (1, test_datetime__2020.replace(month=1), test_date__2020.replace(month=2)) + values_clause = [ + updated_row_1, + ] + self.connector.records.update_many( + table_name, + columns_clause, + values_clause, + where_columns_clause, + column_types_clause=column_types_clause, + ) + + # Verify one record returned. + results = self.connector.query.execute('SELECT * FROM {0};'.format(table_name)) + self.assertEqual(len(results), 3) + self.assertIn(updated_row_1, results) + self.assertIn(row_2, results) + self.assertIn(row_3, results) + self.assertNotIn(row_1, results) + + # Update row variables. + row_1 = updated_row_1 + + with self.subTest('Run with all updated'): + # Run test query. + updated_row_1 = (1, test_datetime__2020.replace(day=20), test_date__2020.replace(day=10)) + updated_row_2 = (2, test_datetime__2021.replace(day=20), test_date__2021.replace(day=10)) + updated_row_3 = (3, test_datetime__2022.replace(day=20), test_date__2022.replace(day=10)) + values_clause = [ + updated_row_1, + updated_row_2, + updated_row_3, + ] + self.connector.records.update_many( + table_name, + columns_clause, + values_clause, + where_columns_clause, + column_types_clause=column_types_clause, + ) + + # Verify one record returned. + results = self.connector.query.execute('SELECT * FROM {0};'.format(table_name)) + self.assertEqual(len(results), 3) + self.assertIn(updated_row_1, results) + self.assertIn(updated_row_2, results) + self.assertIn(updated_row_3, results) + self.assertNotIn(row_1, results) + self.assertNotIn(row_2, results) + self.assertNotIn(row_3, results) + + # Update row variables. + row_1 = updated_row_1 + row_2 = updated_row_2 + row_3 = updated_row_3 + + with self.subTest('Run with columns in alternate order'): + # Run test query. + columns_clause = ['test_date', 'id', 'test_datetime'] + column_types_clause = ['date', 'integer', 'timestamp'] + updated_row_2 = (2, test_datetime__2021.replace(month=3), test_date__2021.replace(month=4)) + updated_row_3 = (3, test_datetime__2022.replace(month=6), test_date__2022.replace(month=5)) + values_clause = [ + (updated_row_2[2], updated_row_2[0], updated_row_2[1]), + (updated_row_3[2], updated_row_3[0], updated_row_3[1]), + ] + self.connector.records.update_many( + table_name, + columns_clause, + values_clause, + where_columns_clause, + column_types_clause=column_types_clause, + ) + + # Verify one record returned. + results = self.connector.query.execute('SELECT * FROM {0};'.format(table_name)) + self.assertEqual(len(results), 3) + self.assertIn(row_1, results) + self.assertIn(updated_row_2, results) + self.assertIn(updated_row_3, results) + self.assertNotIn(row_2, results) + self.assertNotIn(row_3, results) + + # Update row variables. + row_2 = updated_row_2 + row_3 = updated_row_3 + + with self.subTest('Run with skipping unused columns'): + # Run test query. + columns_clause = ['id', 'test_datetime'] + column_types_clause = ['integer', 'timestamp'] + updated_row_1 = (1, test_datetime__2020) + updated_row_2 = (2, test_datetime__2021) + updated_row_3 = (3, test_datetime__2022) + values_clause = [ + updated_row_1, + updated_row_2, + updated_row_3, + ] + self.connector.records.update_many( + table_name, + columns_clause, + values_clause, + where_columns_clause, + column_types_clause=column_types_clause, + ) + + # Verify one record returned. + results = self.connector.query.execute('SELECT * FROM {0};'.format(table_name)) + self.assertEqual(len(results), 3) + self.assertIn(updated_row_1 + (row_1[2],), results) + self.assertIn(updated_row_2 + (row_2[2],), results) + self.assertIn(updated_row_3 + (row_3[2],), results) + self.assertNotIn(row_1, results) + self.assertNotIn(row_2, results) + self.assertNotIn(row_3, results) + + # Update row variables. + row_1 = updated_row_1 + row_2 = updated_row_2 + row_3 = updated_row_3 + def test__delete__success(self): """ Test `DELETE` query.