From 35469a540f020d077fee175fdc60ac52ed186555 Mon Sep 17 00:00:00 2001 From: Brandon Rodriguez <brodriguez8774@gmail.com> Date: Fri, 15 Sep 2023 06:35:22 -0400 Subject: [PATCH] Various updates and corrections to get MySQL logic working again --- py_dbcn/connectors/core/clauses.py | 50 ++++++++++++++-- py_dbcn/connectors/mysql/records.py | 72 +++++++++++++----------- py_dbcn/connectors/postgresql/records.py | 2 +- tests/connectors/core/test_clauses.py | 2 +- tests/connectors/core/test_database.py | 6 +- tests/connectors/core/test_records.py | 1 - tests/connectors/mysql/constants.py | 4 +- tests/connectors/mysql/test_utils.py | 2 +- 8 files changed, 91 insertions(+), 48 deletions(-) diff --git a/py_dbcn/connectors/core/clauses.py b/py_dbcn/connectors/core/clauses.py index 98357fe..7b04f33 100644 --- a/py_dbcn/connectors/core/clauses.py +++ b/py_dbcn/connectors/core/clauses.py @@ -181,17 +181,35 @@ class BaseClauseBuilder(object): new_clause = [] for item in original_clause: + # Handle if date/datetime provided as str. + if isinstance(item, str): + temp_item = item + if self._base.validate._is_quoted(item): + temp_item = item.strip()[1:-1] + # Attempt to convert to datetime object. + try: + item = datetime.datetime.strptime(temp_item.strip(), '%Y-%m-%d %H:%M:%S') + except ValueError: + pass + if isinstance(item, str): + # Attempt to convert to date object. + try: + item = datetime.datetime.strptime(temp_item.strip(), '%Y-%m-%d') + except ValueError: + pass + # Handle various specific types. - is_datetime = False if isinstance(item, datetime.datetime): # Is a datetime object. Convert to string. - item = "'{0}'".format(item.strftime('%Y-%m-%d %H:%M:%S')) - is_datetime = True + item = "{0}".format(item.strftime('%Y-%m-%d %H:%M:%S')) + new_clause.append(item) + continue elif isinstance(item, datetime.date): # Is a date object. Convert to string. - item = "'{0}'".format(item.strftime('%Y-%m-%d')) - is_datetime = True + item = "{0}".format(item.strftime('%Y-%m-%d')) + new_clause.append(item) + continue # Skip handling for other non-str items. elif not isinstance(item, str): @@ -240,6 +258,7 @@ class BaseClauseBuilder(object): stripped_left = '' stripped_right = '' if matches: + index = 0 while index < len(self._parent._reserved_function_names): func_call = self._parent._reserved_function_names[index] @@ -421,6 +440,7 @@ class WhereClauseBuilder(BaseClauseBuilder): def _to_array(self, value): """Converts clause to array format for initial parsing.""" + self._clause_connectors = [] if self._clause_prefix is None: @@ -486,6 +506,26 @@ class WhereClauseBuilder(BaseClauseBuilder): for index in range(len(clause)): clause_item = clause[index] + # Check if we can parse sub-sections of provided item. + if '=' in clause_item and (clause_item.count('=') == 1): + temp = clause_item.split('=') + + pt_1 = temp[0].strip() + pt_2 = temp[1].strip() + + if self._base.validate._is_quoted(pt_2): + # Handle each sub-section individually. + original_spaces = self._allow_spaces + self._allow_spaces = True + pt_2 = self._validate_clause([pt_2], original_value=pt_2)[0] + self._allow_spaces = original_spaces + + # Correct likely incorrect quotes for second half. + if self._base.validate._is_quoted(pt_2): + pt_2 = '{0}{1}{0}'.format(self._base.validate._quote_str_literal_format, pt_2[1:-1]) + + clause_item = '{0} = {1}'.format(pt_1, pt_2) + # Split based on spaces. For now, we assume only the first item needs quotes. clause_split = clause_item.split(' ') first_item = clause_split[0] diff --git a/py_dbcn/connectors/mysql/records.py b/py_dbcn/connectors/mysql/records.py index 50d8cc1..3c4e3e8 100644 --- a/py_dbcn/connectors/mysql/records.py +++ b/py_dbcn/connectors/mysql/records.py @@ -43,6 +43,7 @@ class MysqlRecords(BaseRecords): :param display_query: Bool indicating if query should output to console. Defaults to True. :param display_results: Bool indicating if results should output to console. Defaults to True. """ + # Check provided size. upper_limit = 10000 # 10,000 limit for now. if len(values_clause) > upper_limit: @@ -79,57 +80,62 @@ class MysqlRecords(BaseRecords): # Check that provided WHERE clause is valid format. columns_clause = self._base.validate.sanitize_columns_clause(columns_clause) where_columns_clause = self._base.validate.sanitize_columns_clause(where_columns_clause) - split_columns_clause = columns_clause.split(', ') - where_columns_clause = where_columns_clause.split(', ') # Verify each "where column" is present in the base columns clause. - for column in where_columns_clause: - if column not in split_columns_clause: + for column in where_columns_clause.array: + if column not in columns_clause.array: raise ValueError( - 'All columns specified in WHERE_COLUMNS must also be present in COLUMNS.' + 'All columns specified in WHERE_COLUMNS must also be present in COLUMNS. ' 'Failed to find "{0}" in {1}'.format( column, - split_columns_clause, + columns_clause, ) ) - # Check for values that might need formatting. - # For example, if we find date/datetime objects, we automatically convert to a str value that won't error. - if isinstance(values_clause, list) or isinstance(values_clause, tuple): - updated_values_clause = () - for value_set in values_clause: - updated_values_set = () - for item in value_set: - - if isinstance(item, datetime.datetime): - # Is a datetime object. Convert to string. - item = item.strftime('%Y-%m-%d %H:%M:%S') - elif isinstance(item, datetime.date): - # Is a date object. Convert to string. - item = item.strftime('%Y-%m-%d') - - # Add item to updated inner set. - updated_values_set += (item,) - - # Add item to updated clause. - updated_values_clause += (updated_values_set,) - - # Replace original clause. - values_clause = ', '.join(str(x) for x in updated_values_clause) - # Now format our clauses for query. duplicates_clause = '' - for column in split_columns_clause: + for column in columns_clause.array: + # Find inverse of columns specified in where_columns. These are what we update. - if column not in where_columns_clause: + if column not in where_columns_clause.array: + + # if self._base.validate._is_quoted(column): + # column = column[1:-1] + if duplicates_clause != '': duplicates_clause += ', ' duplicates_clause += '{0}=VALUES({0})'.format(column) + # Check for values that might need formatting. + # For example, if we find date/datetime objects, we automatically convert to a str value that won't error. + updated_values_clause = () + for index in range(len(values_clause.array)): + value_set = values_clause.array[index] + updated_values_set = () + for item in value_set: + + if isinstance(item, datetime.datetime): + # Is a datetime object. Convert to string. + item = item.strftime('%Y-%m-%d %H:%M:%S') + elif isinstance(item, datetime.date): + # Is a date object. Convert to string. + item = item.strftime('%Y-%m-%d') + + # Add item to updated inner set. + updated_values_set += (item,) + + # Add item to updated clause. + updated_values_clause += (updated_values_set,) + + values_clause = ', '.join( + '{0}'.format(x) + for x in updated_values_clause + ) + # Insert record. query = textwrap.dedent( """ - INSERT INTO {0} ({1}) + INSERT INTO {0} {1} VALUES {2} ON DUPLICATE KEY UPDATE {3} diff --git a/py_dbcn/connectors/postgresql/records.py b/py_dbcn/connectors/postgresql/records.py index 569cf3d..02c9bb7 100644 --- a/py_dbcn/connectors/postgresql/records.py +++ b/py_dbcn/connectors/postgresql/records.py @@ -85,7 +85,7 @@ class PostgresqlRecords(BaseRecords): for column in where_columns_clause.array: if column not in columns_clause.array: raise ValueError( - 'All columns specified in WHERE_COLUMNS must also be present in COLUMNS.' + 'All columns specified in WHERE_COLUMNS must also be present in COLUMNS. ' 'Failed to find "{0}" in {1}'.format( column, columns_clause, diff --git a/tests/connectors/core/test_clauses.py b/tests/connectors/core/test_clauses.py index 84804e2..e4d279d 100644 --- a/tests/connectors/core/test_clauses.py +++ b/tests/connectors/core/test_clauses.py @@ -1049,7 +1049,7 @@ class CoreClauseTestMixin: ], clause_object.array) self.assertText( - """WHERE ({0}id{0} = 'test') OR ("code" = 1234) AND ("name" = 'Test User')""".format( + """WHERE ({0}id{0} = {1}test{1}) OR ({0}code{0} = 1234) AND ({0}name{0} = {1}Test User{1})""".format( self.column_format, self.str_literal_format, ), diff --git a/tests/connectors/core/test_database.py b/tests/connectors/core/test_database.py index df28a1a..8d3b388 100644 --- a/tests/connectors/core/test_database.py +++ b/tests/connectors/core/test_database.py @@ -187,11 +187,7 @@ class CoreDatabaseTestMixin: self.assertIn(db_name.casefold(), (str(x).casefold() for x in results)) # Run test query. - error_type = None - if self.connector._config.db_type == 'MySQL': - error_type = ValueError - elif self.connector._config.db_type == 'PostgreSQL': - error_type = self.connector.errors.database_already_exists + error_type = self.connector.errors.database_already_exists with self.assertRaises(error_type): self.connector.database.create(db_name) diff --git a/tests/connectors/core/test_records.py b/tests/connectors/core/test_records.py index f73185e..ce1530c 100644 --- a/tests/connectors/core/test_records.py +++ b/tests/connectors/core/test_records.py @@ -1450,7 +1450,6 @@ class CoreRecordsTestMixin: self.assertNotIn(old_row_1, results) self.assertNotIn(old_row_2, results) - def test__update__datetime__success(self): """ Test `UPDATE` query with datetime values. diff --git a/tests/connectors/mysql/constants.py b/tests/connectors/mysql/constants.py index 6314a33..5d15aaf 100644 --- a/tests/connectors/mysql/constants.py +++ b/tests/connectors/mysql/constants.py @@ -53,6 +53,7 @@ COLUMNS_CLAUSE__AGGREGATES = """ COLUMNS_CLAUSE__INSERT_BUG__NUMBER_OF_VALUES = """ +( id INT NOT NULL AUTO_INCREMENT, test_blank_1 VARCHAR(255), first_name VARCHAR(255), @@ -72,6 +73,7 @@ COLUMNS_CLAUSE__INSERT_BUG__NUMBER_OF_VALUES = """ date_modified DATETIME, is_active TINYINT, last_activity TIMESTAMP, - test_blank_5 VARCHAR(255) + test_blank_5 VARCHAR(255), PRIMARY KEY ( id ) +) """ diff --git a/tests/connectors/mysql/test_utils.py b/tests/connectors/mysql/test_utils.py index 7597ab1..8f6d860 100644 --- a/tests/connectors/mysql/test_utils.py +++ b/tests/connectors/mysql/test_utils.py @@ -75,7 +75,7 @@ class TestMysqlUtils(TestMysqlDatabaseParent, CoreUtilsTestMixin): # Get "now" in Python, and also insert into database. detroit_now = datetime.datetime.now(tz=detroit_timezone) utc_now = detroit_now.astimezone(utc_timezone) - self.connector.records.insert(table_name, '((now()), (now()))') + self.connector.records.insert(table_name, '(now(), now())') # Pull values as database created them. records = self.connector.records.select(table_name, 'my_timestamp') -- GitLab