From 3752f2037387e53de15d7a9d562a9178b09dc617 Mon Sep 17 00:00:00 2001 From: Brandon Rodriguez <brodriguez8774@gmail.com> Date: Sun, 13 Nov 2022 14:49:59 -0500 Subject: [PATCH] Properly implement UPDATE MANY query, now both in MySQL and PostgreSQL --- py_dbcn/connectors/core/records.py | 151 +---------------------- py_dbcn/connectors/mysql/records.py | 118 +++++++++++++++++- py_dbcn/connectors/postgresql/records.py | 145 ++++++++++++++++++++++ 3 files changed, 264 insertions(+), 150 deletions(-) diff --git a/py_dbcn/connectors/core/records.py b/py_dbcn/connectors/core/records.py index 593953b..a9cedc4 100644 --- a/py_dbcn/connectors/core/records.py +++ b/py_dbcn/connectors/core/records.py @@ -242,154 +242,9 @@ class BaseRecords: return results - def update_many( - self, - table_name, columns_clause, values_clause, where_columns_clause, - column_types_clause=None, - display_query=True, display_results=True, - ): - """Updates record in provided table. - - :param table_name: Name of table to insert into. - :param columns_clause: Clause to specify columns to insert into. - :param values_clause: Clause to specify values to insert. - :param where_columns_clause: NOT STANDARD WHERE CLAUSE. Columns to use as WHERE in provided values. - :param column_types_clause: Optional clause to provide type hinting for column types. Not required if all - columns are basic types such as text or integer. - :param display_query: Bool indicating if query should output to console. Defaults to True. - :param display_results: Bool indicating if results should output to console. Defaults to True. - """ - # Check provided size. - upper_limit = 10000 # 10,000 limit for now. - if len(values_clause) > upper_limit: - print('Subdividing query.') - # Exceeds upper limit. Recursively call self on smaller subsets. - for index in range(0, len(values_clause), upper_limit): - print(' Range [{0}:{1}]'.format(index, index + upper_limit)) - self.update_many( - table_name, - columns_clause, - values_clause[index:index + upper_limit], - where_columns_clause, - column_types_clause=column_types_clause, - display_query=display_query, - display_results=display_results, - ) - - # Terminate once all recursive calls have finished. - return - - # Check that provided table name is valid format. - if not self._base.validate.table_name(table_name): - raise ValueError('Invalid table name of "{0}".'.format(table_name)) - - # Check that provided VALUES clause is valid format. - # Must be array format. - if not isinstance(values_clause, list) and not isinstance(values_clause, tuple): - raise ValueError('VALUES clause for INSERT_MANY queries must be in list/tuple format.') - if len(values_clause) < 1: - raise ValueError('VALUES clause cannot be empty for UPDATE_MANY queries.') - values_clause = self._base.validate.sanitize_values_clause(values_clause) - - # Check that provided WHERE clause is valid format. - columns_clause = self._base.validate.sanitize_columns_clause(columns_clause) - where_columns_clause = self._base.validate.sanitize_columns_clause(where_columns_clause) - columns_clause = columns_clause.split(', ') - where_columns_clause = where_columns_clause.split(', ') - - # Verify each "where column" is present in the base columns clause. - for column in where_columns_clause: - if column not in columns_clause: - raise ValueError( - 'All columns specified in WHERE_COLUMNS must also be present in COLUMNS.' - 'Failed to find "{0}" in {1}'.format( - column, - columns_clause, - ) - ) - - # Check for values that might need formatting. - # For example, if we find date/datetime objects, we automatically convert to a str value that won't error. - if isinstance(values_clause, list) or isinstance(values_clause, tuple): - updated_values_clause = () - for value_set in values_clause: - updated_values_set = () - for item in value_set: - - if isinstance(item, datetime.datetime): - # Is a datetime object. Convert to string. - item = item.strftime('%Y-%m-%d %H:%M:%S') - elif isinstance(item, datetime.date): - # Is a date object. Convert to string. - item = item.strftime('%Y-%m-%d') - - # Add item to updated inner set. - updated_values_set += (item,) - - # Add item to updated clause. - updated_values_clause += (updated_values_set,) - - # Replace original clause. - values_clause = updated_values_clause - - # Now format our clauses for query. - if column_types_clause is not None: - # Provide type hinting for columns. - set_clause = '' - for index in range(len(columns_clause)): - if set_clause != '': - set_clause += ',\n' - set_clause += ' {0} = pydbcn_temp.{0}::{1}'.format( - columns_clause[index].strip(self._base.validate._quote_column_format), - column_types_clause[index], - ) - else: - # No type hinting. Provide columns as-is. - set_clause = ',\n'.join([ - ' {0} = pydbcn_temp.{0}'.format(x.strip(self._base.validate._quote_column_format)) - for x in columns_clause - ]) - values_clause = ',\n'.join([ - ' {0}'.format(x) - for x in values_clause - ]) - columns_clause = ', '.join([ - x.strip(self._base.validate._quote_column_format) - for x in columns_clause - ]) - where_columns_clause = ' AND\n'.join([ - ' pydbcn_update_table.{0} = pydbcn_temp.{0}'.format(x.strip(self._base.validate._quote_column_format)) - for x in where_columns_clause - ]) - - # print('\n\n\n\n') - # print('\ntable_name:\n{0}'.format(table_name)) - # print('\nset_clause:\n{0}'.format(set_clause)) - # print('\nvalues_clause:\n{0}'.format(values_clause)) - # print('\nwhere_columns_clause:\n{0}'.format(where_columns_clause)) - # print('\ncolumns_clause:\n{0}'.format(columns_clause)) - - # Update records. - query = f'UPDATE {table_name} AS pydbcn_update_table SET\n' - query += f'{set_clause}\n' - query += f'FROM (VALUES\n' - query += f'{values_clause}\n' - query += f') AS pydbcn_temp ({columns_clause})\n' - query += f'WHERE (\n' - query += f'{where_columns_clause}\n' - query += f');' - results = self._base.query.execute(query, display_query=display_query) - - # # Do a select to get the updated values as results. - # # TODO: Currently doesn't get any results. Not sure how to dynamically get them at this time. - # results = self.select( - # table_name, - # where_clause=where_clause, - # display_query=False, - # display_results=display_results, - # ) - - return results + def update_many(self, *args, **kwargs): + """Updates record in provided table.""" + raise NotImplementedError('Currently not implemented for {0}.'.format(self._base._config.db_type)) def delete(self, table_name, where_clause, display_query=True, display_results=True): """Deletes record(s) in given table. diff --git a/py_dbcn/connectors/mysql/records.py b/py_dbcn/connectors/mysql/records.py index b6fb70c..50d8cc1 100644 --- a/py_dbcn/connectors/mysql/records.py +++ b/py_dbcn/connectors/mysql/records.py @@ -5,6 +5,8 @@ Contains database connection logic specific to MySQL databases. """ # System Imports. +import datetime +import textwrap # Internal Imports. from py_dbcn.connectors.core.records import BaseRecords @@ -25,5 +27,117 @@ class MysqlRecords(BaseRecords): logger.debug('Generating related (MySQL) Records class.') - def update_many(self, *args, **kwargs): - raise NotImplementedError('Currently not implemented for MySQL.') + def update_many( + self, + table_name, columns_clause, values_clause, where_columns_clause, + column_types_clause=None, + display_query=True, display_results=True, + ): + """Updates record in provided table. + + :param table_name: Name of table to insert into. + :param columns_clause: Clause to specify columns to insert into. + :param values_clause: Clause to specify values to insert. + :param where_columns_clause: NOT STANDARD WHERE CLAUSE. Columns to use as WHERE in provided values. + :param column_types_clause: Used in PostgreSQL, but ignored in MySQL. + :param display_query: Bool indicating if query should output to console. Defaults to True. + :param display_results: Bool indicating if results should output to console. Defaults to True. + """ + # Check provided size. + upper_limit = 10000 # 10,000 limit for now. + if len(values_clause) > upper_limit: + if display_query: + print('Subdividing query.') + # Exceeds upper limit. Recursively call self on smaller subsets. + for index in range(0, len(values_clause), upper_limit): + if display_query: + print(' Range [{0}:{1}]'.format(index, index + upper_limit)) + self.update_many( + table_name, + columns_clause, + values_clause[index:index + upper_limit], + where_columns_clause, + display_query=display_query, + display_results=display_results, + ) + + # Terminate once all recursive calls have finished. + return + + # Check that provided table name is valid format. + if not self._base.validate.table_name(table_name): + raise ValueError('Invalid table name of "{0}".'.format(table_name)) + + # Check that provided VALUES clause is valid format. + # Must be array format. + if not isinstance(values_clause, list) and not isinstance(values_clause, tuple): + raise ValueError('VALUES clause for INSERT_MANY queries must be in list/tuple format.') + if len(values_clause) < 1: + raise ValueError('VALUES clause cannot be empty for UPDATE_MANY queries.') + values_clause = self._base.validate.sanitize_values_clause(values_clause) + + # Check that provided WHERE clause is valid format. + columns_clause = self._base.validate.sanitize_columns_clause(columns_clause) + where_columns_clause = self._base.validate.sanitize_columns_clause(where_columns_clause) + split_columns_clause = columns_clause.split(', ') + where_columns_clause = where_columns_clause.split(', ') + + # Verify each "where column" is present in the base columns clause. + for column in where_columns_clause: + if column not in split_columns_clause: + raise ValueError( + 'All columns specified in WHERE_COLUMNS must also be present in COLUMNS.' + 'Failed to find "{0}" in {1}'.format( + column, + split_columns_clause, + ) + ) + + # Check for values that might need formatting. + # For example, if we find date/datetime objects, we automatically convert to a str value that won't error. + if isinstance(values_clause, list) or isinstance(values_clause, tuple): + updated_values_clause = () + for value_set in values_clause: + updated_values_set = () + for item in value_set: + + if isinstance(item, datetime.datetime): + # Is a datetime object. Convert to string. + item = item.strftime('%Y-%m-%d %H:%M:%S') + elif isinstance(item, datetime.date): + # Is a date object. Convert to string. + item = item.strftime('%Y-%m-%d') + + # Add item to updated inner set. + updated_values_set += (item,) + + # Add item to updated clause. + updated_values_clause += (updated_values_set,) + + # Replace original clause. + values_clause = ', '.join(str(x) for x in updated_values_clause) + + # Now format our clauses for query. + duplicates_clause = '' + for column in split_columns_clause: + # Find inverse of columns specified in where_columns. These are what we update. + if column not in where_columns_clause: + if duplicates_clause != '': + duplicates_clause += ', ' + duplicates_clause += '{0}=VALUES({0})'.format(column) + + # Insert record. + query = textwrap.dedent( + """ + INSERT INTO {0} ({1}) + VALUES {2} + ON DUPLICATE KEY UPDATE + {3} + ; + """.format(table_name, columns_clause, values_clause, duplicates_clause) + ) + results = self._base.query.execute(query, display_query=display_query) + if display_results: + self._base.display.results('{0}'.format(results)) + + return results diff --git a/py_dbcn/connectors/postgresql/records.py b/py_dbcn/connectors/postgresql/records.py index 4e977de..4c40f03 100644 --- a/py_dbcn/connectors/postgresql/records.py +++ b/py_dbcn/connectors/postgresql/records.py @@ -5,6 +5,7 @@ Contains database connection logic specific to PostgreSQL databases. """ # System Imports. +import datetime # Internal Imports. from py_dbcn.connectors.core.records import BaseRecords @@ -24,3 +25,147 @@ class PostgresqlRecords(BaseRecords): super().__init__(parent, *args, **kwargs) logger.debug('Generating related (PostgreSQL) Records class.') + + def update_many( + self, + table_name, columns_clause, values_clause, where_columns_clause, + column_types_clause=None, + display_query=True, display_results=True, + ): + """Updates record in provided table. + + :param table_name: Name of table to insert into. + :param columns_clause: Clause to specify columns to insert into. + :param values_clause: Clause to specify values to insert. + :param where_columns_clause: NOT STANDARD WHERE CLAUSE. Columns to use as WHERE in provided values. + :param column_types_clause: Optional clause to provide type hinting for column types. Not required if all + columns are basic types such as text or integer. + :param display_query: Bool indicating if query should output to console. Defaults to True. + :param display_results: Bool indicating if results should output to console. Defaults to True. + """ + # Check provided size. + upper_limit = 10000 # 10,000 limit for now. + if len(values_clause) > upper_limit: + if display_query: + print('Subdividing query.') + # Exceeds upper limit. Recursively call self on smaller subsets. + for index in range(0, len(values_clause), upper_limit): + if display_query: + print(' Range [{0}:{1}]'.format(index, index + upper_limit)) + self.update_many( + table_name, + columns_clause, + values_clause[index:index + upper_limit], + where_columns_clause, + column_types_clause=column_types_clause, + display_query=display_query, + display_results=display_results, + ) + + # Terminate once all recursive calls have finished. + return + + # Check that provided table name is valid format. + if not self._base.validate.table_name(table_name): + raise ValueError('Invalid table name of "{0}".'.format(table_name)) + + # Check that provided VALUES clause is valid format. + # Must be array format. + if not isinstance(values_clause, list) and not isinstance(values_clause, tuple): + raise ValueError('VALUES clause for INSERT_MANY queries must be in list/tuple format.') + if len(values_clause) < 1: + raise ValueError('VALUES clause cannot be empty for UPDATE_MANY queries.') + values_clause = self._base.validate.sanitize_values_clause(values_clause) + + # Check that provided WHERE clause is valid format. + columns_clause = self._base.validate.sanitize_columns_clause(columns_clause) + where_columns_clause = self._base.validate.sanitize_columns_clause(where_columns_clause) + columns_clause = columns_clause.split(', ') + where_columns_clause = where_columns_clause.split(', ') + + # Verify each "where column" is present in the base columns clause. + for column in where_columns_clause: + if column not in columns_clause: + raise ValueError( + 'All columns specified in WHERE_COLUMNS must also be present in COLUMNS.' + 'Failed to find "{0}" in {1}'.format( + column, + columns_clause, + ) + ) + + # Check for values that might need formatting. + # For example, if we find date/datetime objects, we automatically convert to a str value that won't error. + if isinstance(values_clause, list) or isinstance(values_clause, tuple): + updated_values_clause = () + for value_set in values_clause: + updated_values_set = () + for item in value_set: + + if isinstance(item, datetime.datetime): + # Is a datetime object. Convert to string. + item = item.strftime('%Y-%m-%d %H:%M:%S') + elif isinstance(item, datetime.date): + # Is a date object. Convert to string. + item = item.strftime('%Y-%m-%d') + + # Add item to updated inner set. + updated_values_set += (item,) + + # Add item to updated clause. + updated_values_clause += (updated_values_set,) + + # Replace original clause. + values_clause = updated_values_clause + + # Now format our clauses for query. + if column_types_clause is not None: + # Provide type hinting for columns. + set_clause = '' + for index in range(len(columns_clause)): + if set_clause != '': + set_clause += ',\n' + set_clause += ' {0} = pydbcn_temp.{0}::{1}'.format( + columns_clause[index].strip(self._base.validate._quote_column_format), + column_types_clause[index], + ) + else: + # No type hinting. Provide columns as-is. + set_clause = ',\n'.join([ + ' {0} = pydbcn_temp.{0}'.format(x.strip(self._base.validate._quote_column_format)) + for x in columns_clause + ]) + values_clause = ',\n'.join([ + ' {0}'.format(x) + for x in values_clause + ]) + columns_clause = ', '.join([ + x.strip(self._base.validate._quote_column_format) + for x in columns_clause + ]) + where_columns_clause = ' AND\n'.join([ + ' pydbcn_update_table.{0} = pydbcn_temp.{0}'.format(x.strip(self._base.validate._quote_column_format)) + for x in where_columns_clause + ]) + + # Update records. + query = f'UPDATE {table_name} AS pydbcn_update_table SET\n' + query += f'{set_clause}\n' + query += f'FROM (VALUES\n' + query += f'{values_clause}\n' + query += f') AS pydbcn_temp ({columns_clause})\n' + query += f'WHERE (\n' + query += f'{where_columns_clause}\n' + query += f');' + results = self._base.query.execute(query, display_query=display_query) + + # # Do a select to get the updated values as results. + # # TODO: Currently doesn't get any results. Not sure how to dynamically get them at this time. + # results = self.select( + # table_name, + # where_clause=where_clause, + # display_query=False, + # display_results=display_results, + # ) + + return results -- GitLab