diff --git a/py_dbcn/connectors/core/query.py b/py_dbcn/connectors/core/query.py index b6153e786214af208785cd73652756aad0c156f5..642f7eed53c74e7fca106866f83a4aa124146a48 100644 --- a/py_dbcn/connectors/core/query.py +++ b/py_dbcn/connectors/core/query.py @@ -56,14 +56,14 @@ class BaseQuery: results = [] return results - def execute_many(self, query, values, display_query=True): + def execute_many(self, query, data, display_query=True): """""" if display_query: self._base.display.query(query) # Create connection and execute query. cursor = self._base._connection.cursor() - cursor.executemany(query, values) + cursor.executemany(query, data) # Get results. results = self._fetch_results(cursor) diff --git a/py_dbcn/connectors/core/records.py b/py_dbcn/connectors/core/records.py index 52b187a00d43e7f0975e8ccb3504560f014d94a4..d3f570cc347636c71f35529dbd98cc21e2e91db0 100644 --- a/py_dbcn/connectors/core/records.py +++ b/py_dbcn/connectors/core/records.py @@ -233,6 +233,118 @@ class BaseRecords: return results + def update_many( + self, + table_name, columns_clause, values_clause, where_columns_clause, + display_query=True, display_results=True, + ): + """Updates record in provided table. + + :param table_name: Name of table to insert into. + :param columns_clause: Clause to specify columns to insert into. + :param values_clause: Clause to specify values to insert. + :param where_columns_clause: NOT STANDARD WHERE CLAUSE. Columns to use as WHERE in provided values. + :param display_query: Bool indicating if query should output to console. Defaults to True. + :param display_results: Bool indicating if results should output to console. Defaults to True. + """ + # Check that provided table name is valid format. + if not self._base.validate.table_name(table_name): + raise ValueError('Invalid table name of "{0}".'.format(table_name)) + + # Check that provided VALUES clause is valid format. + # Must be array format. + if not isinstance(values_clause, list) and not isinstance(values_clause, tuple): + raise ValueError('VALUES clause for INSERT_MANY queries must be in list/tuple format.') + values_clause = self._base.validate.sanitize_values_clause(values_clause) + + # Check that provided WHERE clause is valid format. + columns_clause = self._base.validate.sanitize_columns_clause(columns_clause) + where_columns_clause = self._base.validate.sanitize_columns_clause(where_columns_clause) + columns_clause = columns_clause.split(', ') + where_columns_clause = where_columns_clause.split(', ') + + # Verify each "where column" is present in the base columns clause. + for column in where_columns_clause: + if column not in columns_clause: + raise ValueError( + 'All columns specified in WHERE_COLUMNS must also be present in COLUMNS.' + 'Failed to find "{0}" in {1}'.format( + column, + columns_clause, + ) + ) + + # Check for values that might need formatting. + # For example, if we find date/datetime objects, we automatically convert to a str value that won't error. + if isinstance(values_clause, list) or isinstance(values_clause, tuple): + updated_values_clause = () + for item in values_clause: + + if isinstance(item, datetime.datetime): + # Is a datetime object. Convert to string. + item = item.strftime('%Y-%m-%d %H:%M:%S') + elif isinstance(item, datetime.date): + # Is a date object. Convert to string. + item = item.strftime('%Y-%m-%d') + + # Add item to updated clause. + updated_values_clause += (item,) + + # Replace original clause. + values_clause = updated_values_clause + + # Now format our clauses for query. + set_clause = ',\n'.join([ + ' {0} = pydbcn_temp.{0}'.format(x.strip(self._base.validate._quote_column_format)) + for x in columns_clause + ]) + values_clause = ',\n'.join([ + ' {0}'.format(x) + for x in values_clause + ]) + columns_clause = ', '.join( + x.strip(self._base.validate._quote_column_format) + for x in columns_clause + ) + where_columns_clause = ',\n'.join([ + ' pydbcn_update_table.{0} = pydbcn_temp.{0}'.format(x.strip(self._base.validate._quote_column_format)) + for x in where_columns_clause + ]) + + # print('\n\n\n\n') + # print('\ntable_name:\n{0}'.format(table_name)) + # print('\nset_clause:\n{0}'.format(set_clause)) + # print('\nvalues_clause:\n{0}'.format(values_clause)) + # print('\nwhere_columns_clause:\n{0}'.format(where_columns_clause)) + # print('\ncolumns_clause:\n{0}'.format(columns_clause)) + + # Update records. + query = """ + UPDATE {0} AS pydbcn_update_table SET + {1} + FROM (VALUES + {2} + ) AS pydbcn_temp ({3}) + WHERE ( + {4} + ); + """.format(table_name, set_clause, values_clause, columns_clause, where_columns_clause) + + results = self._base.query.execute(query, display_query=display_query) + + # # Do a select to get the updated values as results. + # # TODO: Currently doesn't get any results. Not sure how to dynamically get them at this time. + # results = self.select( + # table_name, + # where_clause=where_clause, + # display_query=False, + # display_results=display_results, + # ) + + self.select(table_name) + + return results + def delete(self, table_name, where_clause, display_query=True, display_results=True): """Deletes record(s) in given table. diff --git a/py_dbcn/connectors/mysql/records.py b/py_dbcn/connectors/mysql/records.py index f7968da8821936b48bba85154f5db7577a078312..b6fb70cfb1f5e81fb4ffbd751c08ba1c22b4d082 100644 --- a/py_dbcn/connectors/mysql/records.py +++ b/py_dbcn/connectors/mysql/records.py @@ -24,3 +24,6 @@ class MysqlRecords(BaseRecords): super().__init__(parent, *args, **kwargs) logger.debug('Generating related (MySQL) Records class.') + + def update_many(self, *args, **kwargs): + raise NotImplementedError('Currently not implemented for MySQL.') diff --git a/tests/connectors/core/test_records.py b/tests/connectors/core/test_records.py index c1de4bc6341750fe1e47cb9213806b9be8336be2..15ddbced6d47fdd12fc16d3a4655c7c0ca882ad2 100644 --- a/tests/connectors/core/test_records.py +++ b/tests/connectors/core/test_records.py @@ -1246,6 +1246,220 @@ class CoreRecordsTestMixin: self.assertNotIn(old_row_2, results) self.assertNotIn(old_row_3, results) + def test__update_many__success(self): + """ + Test execute_many `UPDATE` query. + """ + table_name = 'test_queries__update_many__success' + + # Verify table exists. + try: + self.connector.query.execute('CREATE TABLE {0}{1};'.format(table_name, self._columns_clause__basic)) + except self.connector.errors.table_already_exists: + # Table already exists, as we want. + pass + + # Verify starting state. + results = self.connector.query.execute('SELECT * FROM {0};'.format(table_name)) + self.assertEqual(len(results), 0) + + # Generate row values. + row_1 = (1, 'test_name_1', 'test_desc_1') + row_2 = (2, 'test_name_2', 'test_desc_2') + row_3 = (3, 'test_name_3', 'test_desc_3') + row_4 = (4, 'test_name_4', 'test_desc_4') + row_5 = (5, 'test_name_5', 'test_desc_5') + row_6 = (6, 'test_name_6', 'test_desc_6') + row_7 = (7, 'test_name_7', 'test_desc_7') + row_8 = (8, 'test_name_8', 'test_desc_8') + row_9 = (9, 'test_name_9', 'test_desc_9') + row_10 = (10, 'test_name_10', 'test_desc_10') + + # Generate initial rows. + rows = [ + row_1, + row_2, + row_3, + row_4, + row_5, + row_6, + row_7, + row_8, + row_9, + row_10, + ] + self.connector.records.insert_many(table_name, rows) + + # Verify expected state. + results = self.connector.query.execute('SELECT * FROM {0};'.format(table_name)) + self.assertEqual(len(results), 10) + self.assertIn(row_1, results) + self.assertIn(row_2, results) + self.assertIn(row_3, results) + self.assertIn(row_4, results) + self.assertIn(row_5, results) + self.assertIn(row_6, results) + self.assertIn(row_7, results) + self.assertIn(row_8, results) + self.assertIn(row_9, results) + self.assertIn(row_10, results) + + with self.subTest('Run with one update'): + # Run test query. + updated_row_1 = (1, 'test_name_1_updated', 'test_desc_1') + columns_clause = ['id', 'name', 'description'] + values_clause = [ + updated_row_1, + ] + where_columns_clause = ['id'] + self.connector.records.update_many(table_name, columns_clause, values_clause, where_columns_clause) + + # Verify one record returned. + results = self.connector.query.execute('SELECT * FROM {0};'.format(table_name)) + self.assertEqual(len(results), 10) + self.assertIn(updated_row_1, results) + self.assertIn(row_2, results) + self.assertIn(row_3, results) + self.assertIn(row_4, results) + self.assertIn(row_5, results) + self.assertIn(row_6, results) + self.assertIn(row_7, results) + self.assertIn(row_8, results) + self.assertIn(row_9, results) + self.assertIn(row_10, results) + self.assertNotIn(row_1, results) + + # Update row variables. + row_1 = updated_row_1 + + with self.subTest('Run with two updates'): + # Run test query. Update by PK. + updated_row_2 = (2, 'aaa', 'test_desc_2') + updated_row_3 = (3, 'bbb', 'test_desc_3') + columns_clause = ['id', 'name', 'description'] + values_clause = [ + updated_row_2, + updated_row_3, + ] + where_columns_clause = ['id'] + self.connector.records.update_many(table_name, columns_clause, values_clause, where_columns_clause) + + # Verify one record returned. + results = self.connector.query.execute('SELECT * FROM {0};'.format(table_name)) + self.assertEqual(len(results), 10) + self.assertIn(row_1, results) + self.assertIn(updated_row_2, results) + self.assertIn(updated_row_3, results) + self.assertIn(row_4, results) + self.assertIn(row_5, results) + self.assertIn(row_6, results) + self.assertIn(row_7, results) + self.assertIn(row_8, results) + self.assertIn(row_9, results) + self.assertIn(row_10, results) + self.assertNotIn(row_2, results) + self.assertNotIn(row_3, results) + + # Update row variables. + row_2 = updated_row_2 + row_3 = updated_row_3 + + with self.subTest('Run with five updates'): + # Run test query. Update by non-PK. + updated_row_4 = (4, 'test_name_4', 'four') + updated_row_5 = (5, 'test_name_5', 'five') + updated_row_6 = (6, 'test_name_6', 'six') + updated_row_7 = (7, 'test_name_7', 'seven') + updated_row_8 = (8, 'test_name_8', 'eight') + columns_clause = ['id', 'name', 'description'] + values_clause = [ + updated_row_4, + updated_row_5, + updated_row_6, + updated_row_7, + updated_row_8, + ] + where_columns_clause = ['name'] + self.connector.records.update_many(table_name, columns_clause, values_clause, where_columns_clause) + + # Verify five records returned. + results = self.connector.query.execute('SELECT * FROM {0};'.format(table_name)) + self.assertEqual(len(results), 10) + self.assertIn(row_1, results) + self.assertIn(row_2, results) + self.assertIn(row_3, results) + self.assertIn(updated_row_4, results) + self.assertIn(updated_row_5, results) + self.assertIn(updated_row_6, results) + self.assertIn(updated_row_7, results) + self.assertIn(updated_row_8, results) + self.assertIn(row_9, results) + self.assertIn(row_10, results) + self.assertNotIn(row_4, results) + self.assertNotIn(row_5, results) + self.assertNotIn(row_6, results) + self.assertNotIn(row_7, results) + self.assertNotIn(row_8, results) + + # Update row variables. + row_4 = updated_row_4 + row_5 = updated_row_5 + row_6 = updated_row_6 + row_7 = updated_row_7 + row_8 = updated_row_8 + + with self.subTest('Run with ten updates'): + # Run test query. + updated_row_1 = (1, '"110"', '"10010"') + updated_row_2 = (2, '"109"', '"10009"') + updated_row_3 = (3, '"108"', '"10008"') + updated_row_4 = (4, '"107"', '"10007"') + updated_row_5 = (5, '"106"', '"10006"') + updated_row_6 = (6, '"105"', '"10005"') + updated_row_7 = (7, '"104"', '"10004"') + updated_row_8 = (8, '"103"', '"10003"') + updated_row_9 = (9, '"102"', '"10002"') + updated_row_10 = (10, '"101"', '"10001"') + columns_clause = 'id, name, description' + values_clause = [ + updated_row_1, + updated_row_2, + updated_row_3, + updated_row_4, + updated_row_5, + updated_row_6, + updated_row_7, + updated_row_8, + updated_row_9, + updated_row_10, + ] + where_columns_clause = 'id' + self.connector.records.update_many(table_name, columns_clause, values_clause, where_columns_clause) + + # Verify ten records returned. + results = self.connector.query.execute('SELECT * FROM {0};'.format(table_name)) + self.assertEqual(len(results), 10) + self.assertIn(updated_row_1, results) + self.assertIn(updated_row_2, results) + self.assertIn(updated_row_3, results) + self.assertIn(updated_row_4, results) + self.assertIn(updated_row_5, results) + self.assertIn(updated_row_6, results) + self.assertIn(updated_row_7, results) + self.assertIn(updated_row_8, results) + self.assertIn(updated_row_9, results) + self.assertIn(updated_row_10, results) + self.assertNotIn(row_1, results) + self.assertNotIn(row_2, results) + self.assertNotIn(row_3, results) + self.assertNotIn(row_4, results) + self.assertNotIn(row_5, results) + self.assertNotIn(row_6, results) + self.assertNotIn(row_7, results) + self.assertNotIn(row_8, results) + self.assertNotIn(row_9, results) + self.assertNotIn(row_10, results) + def test__delete__success(self): """ Test `DELETE` query.