From 35469a540f020d077fee175fdc60ac52ed186555 Mon Sep 17 00:00:00 2001
From: Brandon Rodriguez <brodriguez8774@gmail.com>
Date: Fri, 15 Sep 2023 06:35:22 -0400
Subject: [PATCH] Various updates and corrections to get MySQL logic working
 again

---
 py_dbcn/connectors/core/clauses.py       | 50 ++++++++++++++--
 py_dbcn/connectors/mysql/records.py      | 72 +++++++++++++-----------
 py_dbcn/connectors/postgresql/records.py |  2 +-
 tests/connectors/core/test_clauses.py    |  2 +-
 tests/connectors/core/test_database.py   |  6 +-
 tests/connectors/core/test_records.py    |  1 -
 tests/connectors/mysql/constants.py      |  4 +-
 tests/connectors/mysql/test_utils.py     |  2 +-
 8 files changed, 91 insertions(+), 48 deletions(-)

diff --git a/py_dbcn/connectors/core/clauses.py b/py_dbcn/connectors/core/clauses.py
index 98357fe..7b04f33 100644
--- a/py_dbcn/connectors/core/clauses.py
+++ b/py_dbcn/connectors/core/clauses.py
@@ -181,17 +181,35 @@ class BaseClauseBuilder(object):
         new_clause = []
         for item in original_clause:
 
+            # Handle if date/datetime provided as str.
+            if isinstance(item, str):
+                temp_item = item
+                if self._base.validate._is_quoted(item):
+                    temp_item = item.strip()[1:-1]
+                # Attempt to convert to datetime object.
+                try:
+                    item = datetime.datetime.strptime(temp_item.strip(), '%Y-%m-%d %H:%M:%S')
+                except ValueError:
+                    pass
+            if isinstance(item, str):
+                # Attempt to convert to date object.
+                try:
+                    item = datetime.datetime.strptime(temp_item.strip(), '%Y-%m-%d')
+                except ValueError:
+                    pass
+
             # Handle various specific types.
-            is_datetime = False
             if isinstance(item, datetime.datetime):
                 # Is a datetime object. Convert to string.
-                item = "'{0}'".format(item.strftime('%Y-%m-%d %H:%M:%S'))
-                is_datetime = True
+                item = "{0}".format(item.strftime('%Y-%m-%d %H:%M:%S'))
+                new_clause.append(item)
+                continue
 
             elif isinstance(item, datetime.date):
                 # Is a date object. Convert to string.
-                item = "'{0}'".format(item.strftime('%Y-%m-%d'))
-                is_datetime = True
+                item = "{0}".format(item.strftime('%Y-%m-%d'))
+                new_clause.append(item)
+                continue
 
             # Skip handling for other non-str items.
             elif not isinstance(item, str):
@@ -240,6 +258,7 @@ class BaseClauseBuilder(object):
             stripped_left = ''
             stripped_right = ''
             if matches:
+
                 index = 0
                 while index < len(self._parent._reserved_function_names):
                     func_call = self._parent._reserved_function_names[index]
@@ -421,6 +440,7 @@ class WhereClauseBuilder(BaseClauseBuilder):
 
     def _to_array(self, value):
         """Converts clause to array format for initial parsing."""
+
         self._clause_connectors = []
 
         if self._clause_prefix is None:
@@ -486,6 +506,26 @@ class WhereClauseBuilder(BaseClauseBuilder):
             for index in range(len(clause)):
                 clause_item = clause[index]
 
+                # Check if we can parse sub-sections of provided item.
+                if '=' in clause_item and (clause_item.count('=') == 1):
+                    temp = clause_item.split('=')
+
+                    pt_1 = temp[0].strip()
+                    pt_2 = temp[1].strip()
+
+                    if self._base.validate._is_quoted(pt_2):
+                        # Handle each sub-section individually.
+                        original_spaces = self._allow_spaces
+                        self._allow_spaces = True
+                        pt_2 = self._validate_clause([pt_2], original_value=pt_2)[0]
+                        self._allow_spaces = original_spaces
+
+                        # Correct likely incorrect quotes for second half.
+                        if self._base.validate._is_quoted(pt_2):
+                            pt_2 = '{0}{1}{0}'.format(self._base.validate._quote_str_literal_format, pt_2[1:-1])
+
+                        clause_item = '{0} = {1}'.format(pt_1, pt_2)
+
                 # Split based on spaces. For now, we assume only the first item needs quotes.
                 clause_split = clause_item.split(' ')
                 first_item = clause_split[0]
diff --git a/py_dbcn/connectors/mysql/records.py b/py_dbcn/connectors/mysql/records.py
index 50d8cc1..3c4e3e8 100644
--- a/py_dbcn/connectors/mysql/records.py
+++ b/py_dbcn/connectors/mysql/records.py
@@ -43,6 +43,7 @@ class MysqlRecords(BaseRecords):
         :param display_query: Bool indicating if query should output to console. Defaults to True.
         :param display_results: Bool indicating if results should output to console. Defaults to True.
         """
+
         # Check provided size.
         upper_limit = 10000  # 10,000 limit for now.
         if len(values_clause) > upper_limit:
@@ -79,57 +80,62 @@ class MysqlRecords(BaseRecords):
         # Check that provided WHERE clause is valid format.
         columns_clause = self._base.validate.sanitize_columns_clause(columns_clause)
         where_columns_clause = self._base.validate.sanitize_columns_clause(where_columns_clause)
-        split_columns_clause = columns_clause.split(', ')
-        where_columns_clause = where_columns_clause.split(', ')
 
         # Verify each "where column" is present in the base columns clause.
-        for column in where_columns_clause:
-            if column not in split_columns_clause:
+        for column in where_columns_clause.array:
+            if column not in columns_clause.array:
                 raise ValueError(
-                    'All columns specified in WHERE_COLUMNS must also be present in COLUMNS.'
+                    'All columns specified in WHERE_COLUMNS must also be present in COLUMNS. '
                     'Failed to find "{0}" in {1}'.format(
                         column,
-                        split_columns_clause,
+                        columns_clause,
                     )
                 )
 
-        # Check for values that might need formatting.
-        # For example, if we find date/datetime objects, we automatically convert to a str value that won't error.
-        if isinstance(values_clause, list) or isinstance(values_clause, tuple):
-            updated_values_clause = ()
-            for value_set in values_clause:
-                updated_values_set = ()
-                for item in value_set:
-
-                    if isinstance(item, datetime.datetime):
-                        # Is a datetime object. Convert to string.
-                        item = item.strftime('%Y-%m-%d %H:%M:%S')
-                    elif isinstance(item, datetime.date):
-                        # Is a date object. Convert to string.
-                        item = item.strftime('%Y-%m-%d')
-
-                    # Add item to updated inner set.
-                    updated_values_set += (item,)
-
-                # Add item to updated clause.
-                updated_values_clause += (updated_values_set,)
-
-            # Replace original clause.
-            values_clause = ', '.join(str(x) for x in updated_values_clause)
-
         # Now format our clauses for query.
         duplicates_clause = ''
-        for column in split_columns_clause:
+        for column in columns_clause.array:
+
             # Find inverse of columns specified in where_columns. These are what we update.
-            if column not in where_columns_clause:
+            if column not in where_columns_clause.array:
+
+                # if self._base.validate._is_quoted(column):
+                #     column = column[1:-1]
+
                 if duplicates_clause != '':
                     duplicates_clause += ', '
                 duplicates_clause += '{0}=VALUES({0})'.format(column)
 
+        # Check for values that might need formatting.
+        # For example, if we find date/datetime objects, we automatically convert to a str value that won't error.
+        updated_values_clause = ()
+        for index in range(len(values_clause.array)):
+            value_set = values_clause.array[index]
+            updated_values_set = ()
+            for item in value_set:
+
+                if isinstance(item, datetime.datetime):
+                    # Is a datetime object. Convert to string.
+                    item = item.strftime('%Y-%m-%d %H:%M:%S')
+                elif isinstance(item, datetime.date):
+                    # Is a date object. Convert to string.
+                    item = item.strftime('%Y-%m-%d')
+
+                # Add item to updated inner set.
+                updated_values_set += (item,)
+
+            # Add item to updated clause.
+            updated_values_clause += (updated_values_set,)
+
+        values_clause = ', '.join(
+            '{0}'.format(x)
+            for x in updated_values_clause
+        )
+
         # Insert record.
         query = textwrap.dedent(
             """
-            INSERT INTO {0} ({1})
+            INSERT INTO {0} {1}
             VALUES {2}
             ON DUPLICATE KEY UPDATE
                 {3}
diff --git a/py_dbcn/connectors/postgresql/records.py b/py_dbcn/connectors/postgresql/records.py
index 569cf3d..02c9bb7 100644
--- a/py_dbcn/connectors/postgresql/records.py
+++ b/py_dbcn/connectors/postgresql/records.py
@@ -85,7 +85,7 @@ class PostgresqlRecords(BaseRecords):
         for column in where_columns_clause.array:
             if column not in columns_clause.array:
                 raise ValueError(
-                    'All columns specified in WHERE_COLUMNS must also be present in COLUMNS.'
+                    'All columns specified in WHERE_COLUMNS must also be present in COLUMNS. '
                     'Failed to find "{0}" in {1}'.format(
                         column,
                         columns_clause,
diff --git a/tests/connectors/core/test_clauses.py b/tests/connectors/core/test_clauses.py
index 84804e2..e4d279d 100644
--- a/tests/connectors/core/test_clauses.py
+++ b/tests/connectors/core/test_clauses.py
@@ -1049,7 +1049,7 @@ class CoreClauseTestMixin:
                 ],
                 clause_object.array)
             self.assertText(
-                """WHERE ({0}id{0} = 'test') OR ("code" = 1234) AND ("name" = 'Test User')""".format(
+                """WHERE ({0}id{0} = {1}test{1}) OR ({0}code{0} = 1234) AND ({0}name{0} = {1}Test User{1})""".format(
                     self.column_format,
                     self.str_literal_format,
                 ),
diff --git a/tests/connectors/core/test_database.py b/tests/connectors/core/test_database.py
index df28a1a..8d3b388 100644
--- a/tests/connectors/core/test_database.py
+++ b/tests/connectors/core/test_database.py
@@ -187,11 +187,7 @@ class CoreDatabaseTestMixin:
         self.assertIn(db_name.casefold(), (str(x).casefold() for x in results))
 
         # Run test query.
-        error_type = None
-        if self.connector._config.db_type == 'MySQL':
-            error_type = ValueError
-        elif self.connector._config.db_type == 'PostgreSQL':
-            error_type = self.connector.errors.database_already_exists
+        error_type = self.connector.errors.database_already_exists
         with self.assertRaises(error_type):
             self.connector.database.create(db_name)
 
diff --git a/tests/connectors/core/test_records.py b/tests/connectors/core/test_records.py
index f73185e..ce1530c 100644
--- a/tests/connectors/core/test_records.py
+++ b/tests/connectors/core/test_records.py
@@ -1450,7 +1450,6 @@ class CoreRecordsTestMixin:
             self.assertNotIn(old_row_1, results)
             self.assertNotIn(old_row_2, results)
 
-
     def test__update__datetime__success(self):
         """
         Test `UPDATE` query with datetime values.
diff --git a/tests/connectors/mysql/constants.py b/tests/connectors/mysql/constants.py
index 6314a33..5d15aaf 100644
--- a/tests/connectors/mysql/constants.py
+++ b/tests/connectors/mysql/constants.py
@@ -53,6 +53,7 @@ COLUMNS_CLAUSE__AGGREGATES = """
 
 
 COLUMNS_CLAUSE__INSERT_BUG__NUMBER_OF_VALUES = """
+(
     id INT NOT NULL AUTO_INCREMENT,
     test_blank_1 VARCHAR(255),
     first_name VARCHAR(255),
@@ -72,6 +73,7 @@ COLUMNS_CLAUSE__INSERT_BUG__NUMBER_OF_VALUES = """
     date_modified DATETIME,
     is_active TINYINT,
     last_activity TIMESTAMP,
-    test_blank_5 VARCHAR(255)
+    test_blank_5 VARCHAR(255),
     PRIMARY KEY ( id )
+)
 """
diff --git a/tests/connectors/mysql/test_utils.py b/tests/connectors/mysql/test_utils.py
index 7597ab1..8f6d860 100644
--- a/tests/connectors/mysql/test_utils.py
+++ b/tests/connectors/mysql/test_utils.py
@@ -75,7 +75,7 @@ class TestMysqlUtils(TestMysqlDatabaseParent, CoreUtilsTestMixin):
         # Get "now" in Python, and also insert into database.
         detroit_now = datetime.datetime.now(tz=detroit_timezone)
         utc_now = detroit_now.astimezone(utc_timezone)
-        self.connector.records.insert(table_name, '((now()), (now()))')
+        self.connector.records.insert(table_name, '(now(), now())')
 
         # Pull values as database created them.
         records = self.connector.records.select(table_name, 'my_timestamp')
-- 
GitLab