From 3ac7a0325b43f233e31b3e4d33361292b55a3928 Mon Sep 17 00:00:00 2001
From: Brandon Rodriguez <brodriguez8774@gmail.com>
Date: Tue, 11 Oct 2022 08:45:35 -0400
Subject: [PATCH] Add type hinting option for update_many PostgreSQL query

MySQL equivalent still not implemented.
---
 py_dbcn/connectors/core/records.py    |  49 +++--
 tests/connectors/core/test_records.py | 281 +++++++++++++++++++++++++-
 2 files changed, 312 insertions(+), 18 deletions(-)

diff --git a/py_dbcn/connectors/core/records.py b/py_dbcn/connectors/core/records.py
index 7d23fe0..1566d5b 100644
--- a/py_dbcn/connectors/core/records.py
+++ b/py_dbcn/connectors/core/records.py
@@ -143,6 +143,8 @@ class BaseRecords:
         # Must be array format.
         if not isinstance(values_clause, list) and not isinstance(values_clause, tuple):
             raise ValueError('VALUES clause for INSERT_MANY queries must be in list/tuple format.')
+        if len(values_clause) < 1:
+            raise ValueError('VALUES clause cannot be empty for INSERT_MANY queries.')
         values_clause = self._base.validate.sanitize_values_clause(values_clause)
 
         # Check for values that might need formatting.
@@ -243,6 +245,7 @@ class BaseRecords:
     def update_many(
         self,
         table_name, columns_clause, values_clause, where_columns_clause,
+        column_types_clause=None,
         display_query=True, display_results=True,
     ):
         """Updates record in provided table.
@@ -251,6 +254,8 @@ class BaseRecords:
         :param columns_clause: Clause to specify columns to insert into.
         :param values_clause: Clause to specify values to insert.
         :param where_columns_clause: NOT STANDARD WHERE CLAUSE. Columns to use as WHERE in provided values.
+        :param column_types_clause: Optional clause to provide type hinting for column types. Not required if all
+                                    columns are basic types such as text or integer.
         :param display_query: Bool indicating if query should output to console. Defaults to True.
         :param display_results: Bool indicating if results should output to console. Defaults to True.
         """
@@ -262,6 +267,8 @@ class BaseRecords:
         # Must be array format.
         if not isinstance(values_clause, list) and not isinstance(values_clause, tuple):
             raise ValueError('VALUES clause for INSERT_MANY queries must be in list/tuple format.')
+        if len(values_clause) < 1:
+            raise ValueError('VALUES clause cannot be empty for UPDATE_MANY queries.')
         values_clause = self._base.validate.sanitize_values_clause(values_clause)
 
         # Check that provided WHERE clause is valid format.
@@ -285,26 +292,43 @@ class BaseRecords:
         # For example, if we find date/datetime objects, we automatically convert to a str value that won't error.
         if isinstance(values_clause, list) or isinstance(values_clause, tuple):
             updated_values_clause = ()
-            for item in values_clause:
+            for value_set in values_clause:
+                updated_values_set = ()
+                for item in value_set:
 
-                if isinstance(item, datetime.datetime):
-                    # Is a datetime object. Convert to string.
-                    item = item.strftime('%Y-%m-%d %H:%M:%S')
-                elif isinstance(item, datetime.date):
-                    # Is a date object. Convert to string.
-                    item = item.strftime('%Y-%m-%d')
+                    if isinstance(item, datetime.datetime):
+                        # Is a datetime object. Convert to string.
+                        item = item.strftime('%Y-%m-%d %H:%M:%S')
+                    elif isinstance(item, datetime.date):
+                        # Is a date object. Convert to string.
+                        item = item.strftime('%Y-%m-%d')
+
+                    # Add item to updated inner set.
+                    updated_values_set += (item,)
 
                 # Add item to updated clause.
-                updated_values_clause += (item,)
+                updated_values_clause += (updated_values_set,)
 
             # Replace original clause.
             values_clause = updated_values_clause
 
         # Now format our clauses for query.
-        set_clause = ',\n'.join([
-            '    {0} = pydbcn_temp.{0}'.format(x.strip(self._base.validate._quote_column_format))
-            for x in columns_clause
-        ])
+        if column_types_clause is not None:
+            # Provide type hinting for columns.
+            set_clause = ''
+            for index in range(len(columns_clause)):
+                if set_clause != '':
+                    set_clause += ',\n'
+                set_clause += '    {0} = pydbcn_temp.{0}::{1}'.format(
+                    columns_clause[index].strip(self._base.validate._quote_column_format),
+                    column_types_clause[index],
+                )
+        else:
+            # No type hinting. Provide columns as-is.
+            set_clause = ',\n'.join([
+                '    {0} = pydbcn_temp.{0}'.format(x.strip(self._base.validate._quote_column_format))
+                for x in columns_clause
+            ])
         values_clause = ',\n'.join([
             '    {0}'.format(x)
             for x in values_clause
@@ -338,6 +362,7 @@ class BaseRecords:
             );
             """.format(table_name, set_clause, values_clause, columns_clause, where_columns_clause)
         )
+        # print('\n\nquery:\n{0}'.format(query))
 
         results = self._base.query.execute(query, display_query=display_query)
 
diff --git a/tests/connectors/core/test_records.py b/tests/connectors/core/test_records.py
index 15ddbce..8b30e4e 100644
--- a/tests/connectors/core/test_records.py
+++ b/tests/connectors/core/test_records.py
@@ -997,7 +997,7 @@ class CoreRecordsTestMixin:
 
     def test__update__basic__success(self):
         """
-        Test `UPDATE` query.
+        Test `UPDATE` query with basic values.
         """
         table_name = 'test_queries__update__basic__success'
 
@@ -1097,7 +1097,7 @@ class CoreRecordsTestMixin:
 
     def test__update__datetime__success(self):
         """
-        Test `UPDATE` query.
+        Test `UPDATE` query with datetime values.
         """
         table_name = 'test_queries__update__datetime__success'
 
@@ -1246,11 +1246,11 @@ class CoreRecordsTestMixin:
             self.assertNotIn(old_row_2, results)
             self.assertNotIn(old_row_3, results)
 
-    def test__update_many__success(self):
+    def test__update_many__basic__success(self):
         """
-        Test execute_many `UPDATE` query.
+        Test execute_many `UPDATE` query with basic values.
         """
-        table_name = 'test_queries__update_many__success'
+        table_name = 'test_queries__update_many__basic__success'
 
         # Verify table exists.
         try:
@@ -1364,7 +1364,7 @@ class CoreRecordsTestMixin:
             row_2 = updated_row_2
             row_3 = updated_row_3
 
-        with self.subTest('Run with five updates'):
+        with self.subTest('Run with five updates and alternate where column'):
             # Run test query. Update by non-PK.
             updated_row_4 = (4, 'test_name_4', 'four')
             updated_row_5 = (5, 'test_name_5', 'five')
@@ -1460,6 +1460,275 @@ class CoreRecordsTestMixin:
             self.assertNotIn(row_9, results)
             self.assertNotIn(row_10, results)
 
+            # Update row variables.
+            row_1 = updated_row_1
+            row_2 = updated_row_2
+            row_3 = updated_row_3
+            row_4 = updated_row_4
+            row_5 = updated_row_5
+            row_6 = updated_row_6
+            row_7 = updated_row_7
+            row_8 = updated_row_8
+            row_9 = updated_row_9
+            row_10 = updated_row_10
+
+        with self.subTest('Run with columns in alternate order'):
+            # Run test query.
+            updated_row_3 = (3, 'name as first', 'desc as second')
+            columns_clause = ['name', 'description', 'id']
+            values_clause = [
+                ('name as first', 'desc as second', 3)
+            ]
+            where_columns_clause = ['id']
+            self.connector.records.update_many(table_name, columns_clause, values_clause, where_columns_clause)
+
+            # Verify one record returned.
+            results = self.connector.query.execute('SELECT * FROM {0};'.format(table_name))
+            self.assertEqual(len(results), 10)
+            self.assertIn(row_1, results)
+            self.assertIn(row_2, results)
+            self.assertIn(updated_row_3, results)
+            self.assertIn(row_4, results)
+            self.assertIn(row_5, results)
+            self.assertIn(row_6, results)
+            self.assertIn(row_7, results)
+            self.assertIn(row_8, results)
+            self.assertIn(row_9, results)
+            self.assertIn(row_10, results)
+            self.assertNotIn(row_3, results)
+
+            # Update row variables.
+            row_3 = updated_row_3
+
+        with self.subTest('Run with skipping unused columns'):
+            # Run test query.
+            updated_row_5 = (5, 'this is')
+            updated_row_6 = (6, 'a')
+            updated_row_7 = (7, 'test')
+            columns_clause = ['id', 'description']
+            values_clause = [
+                updated_row_5,
+                updated_row_6,
+                updated_row_7,
+            ]
+            where_columns_clause = ['id']
+            self.connector.records.update_many(table_name, columns_clause, values_clause, where_columns_clause)
+
+            # Verify one record returned.
+            results = self.connector.query.execute('SELECT * FROM {0};'.format(table_name))
+            self.assertEqual(len(results), 10)
+            self.assertIn(row_1, results)
+            self.assertIn(row_2, results)
+            self.assertIn(row_3, results)
+            self.assertIn(row_4, results)
+            self.assertIn((updated_row_5[0], row_5[1], updated_row_5[1]), results)
+            self.assertIn((updated_row_6[0], row_6[1], updated_row_6[1]), results)
+            self.assertIn((updated_row_7[0], row_7[1], updated_row_7[1]), results)
+            self.assertIn(row_8, results)
+            self.assertIn(row_9, results)
+            self.assertIn(row_10, results)
+            self.assertNotIn(row_5, results)
+            self.assertNotIn(row_6, results)
+            self.assertNotIn(row_7, results)
+
+            # Update row variables.
+            row_5 = updated_row_5
+            row_6 = updated_row_6
+            row_7 = updated_row_7
+
+    def test__update_many__datetime__success(self):
+        """
+        Test execute_many `UPDATE` query with datetime values.
+        """
+        table_name = 'test_queries__update_many__datetime__success'
+
+        # Verify table exists.
+        try:
+            self.connector.query.execute('CREATE TABLE {0}{1};'.format(table_name, self._columns_clause__datetime))
+        except self.connector.errors.table_already_exists:
+            # Table already exists, as we want.
+            pass
+
+        # Generate datetime objects.
+        test_datetime__2020 = datetime.datetime(
+            year=2020,
+            month=6,
+            day=15,
+            hour=7,
+            minute=12,
+            second=52,
+            microsecond=0,
+        )
+        test_date__2020 = test_datetime__2020.date()
+        test_datetime__2021 = datetime.datetime(
+            year=2021,
+            month=7,
+            day=16,
+            hour=8,
+            minute=13,
+            second=53,
+            microsecond=0,
+        )
+        test_date__2021 = test_datetime__2021.date()
+        test_datetime__2022 = datetime.datetime(
+            year=2022,
+            month=8,
+            day=17,
+            hour=9,
+            minute=14,
+            second=54,
+            microsecond=0,
+        )
+        test_date__2022 = test_datetime__2022.date()
+
+        # Generate row values.
+        row_1 = (1, test_datetime__2020, test_date__2020)
+        row_2 = (2, test_datetime__2021, test_date__2021)
+        row_3 = (3, test_datetime__2022, test_date__2022)
+
+        # Generate initial rows.
+        rows = [
+            row_1,
+            row_2,
+            row_3,
+        ]
+        self.connector.records.insert_many(table_name, rows)
+
+        # Verify expected state.
+        results = self.connector.query.execute('SELECT * FROM {0};'.format(table_name))
+        self.assertEqual(len(results), 3)
+        self.assertIn(row_1, results)
+        self.assertIn(row_2, results)
+        self.assertIn(row_3, results)
+
+        columns_clause = ['id', 'test_datetime', 'test_date']
+        column_types_clause = ('integer', 'timestamp', 'date')
+        where_columns_clause = ['id']
+
+        with self.subTest('Run with one update'):
+            # Run test query.
+            updated_row_1 = (1, test_datetime__2020.replace(month=1), test_date__2020.replace(month=2))
+            values_clause = [
+                updated_row_1,
+            ]
+            self.connector.records.update_many(
+                table_name,
+                columns_clause,
+                values_clause,
+                where_columns_clause,
+                column_types_clause=column_types_clause,
+            )
+
+            # Verify one record returned.
+            results = self.connector.query.execute('SELECT * FROM {0};'.format(table_name))
+            self.assertEqual(len(results), 3)
+            self.assertIn(updated_row_1, results)
+            self.assertIn(row_2, results)
+            self.assertIn(row_3, results)
+            self.assertNotIn(row_1, results)
+
+            # Update row variables.
+            row_1 = updated_row_1
+
+        with self.subTest('Run with all updated'):
+            # Run test query.
+            updated_row_1 = (1, test_datetime__2020.replace(day=20), test_date__2020.replace(day=10))
+            updated_row_2 = (2, test_datetime__2021.replace(day=20), test_date__2021.replace(day=10))
+            updated_row_3 = (3, test_datetime__2022.replace(day=20), test_date__2022.replace(day=10))
+            values_clause = [
+                updated_row_1,
+                updated_row_2,
+                updated_row_3,
+            ]
+            self.connector.records.update_many(
+                table_name,
+                columns_clause,
+                values_clause,
+                where_columns_clause,
+                column_types_clause=column_types_clause,
+            )
+
+            # Verify one record returned.
+            results = self.connector.query.execute('SELECT * FROM {0};'.format(table_name))
+            self.assertEqual(len(results), 3)
+            self.assertIn(updated_row_1, results)
+            self.assertIn(updated_row_2, results)
+            self.assertIn(updated_row_3, results)
+            self.assertNotIn(row_1, results)
+            self.assertNotIn(row_2, results)
+            self.assertNotIn(row_3, results)
+
+            # Update row variables.
+            row_1 = updated_row_1
+            row_2 = updated_row_2
+            row_3 = updated_row_3
+
+        with self.subTest('Run with columns in alternate order'):
+            # Run test query.
+            columns_clause = ['test_date', 'id', 'test_datetime']
+            column_types_clause = ['date', 'integer', 'timestamp']
+            updated_row_2 = (2, test_datetime__2021.replace(month=3), test_date__2021.replace(month=4))
+            updated_row_3 = (3, test_datetime__2022.replace(month=6), test_date__2022.replace(month=5))
+            values_clause = [
+                (updated_row_2[2], updated_row_2[0], updated_row_2[1]),
+                (updated_row_3[2], updated_row_3[0], updated_row_3[1]),
+            ]
+            self.connector.records.update_many(
+                table_name,
+                columns_clause,
+                values_clause,
+                where_columns_clause,
+                column_types_clause=column_types_clause,
+            )
+
+            # Verify one record returned.
+            results = self.connector.query.execute('SELECT * FROM {0};'.format(table_name))
+            self.assertEqual(len(results), 3)
+            self.assertIn(row_1, results)
+            self.assertIn(updated_row_2, results)
+            self.assertIn(updated_row_3, results)
+            self.assertNotIn(row_2, results)
+            self.assertNotIn(row_3, results)
+
+            # Update row variables.
+            row_2 = updated_row_2
+            row_3 = updated_row_3
+
+        with self.subTest('Run with skipping unused columns'):
+            # Run test query.
+            columns_clause = ['id', 'test_datetime']
+            column_types_clause = ['integer', 'timestamp']
+            updated_row_1 = (1, test_datetime__2020)
+            updated_row_2 = (2, test_datetime__2021)
+            updated_row_3 = (3, test_datetime__2022)
+            values_clause = [
+                updated_row_1,
+                updated_row_2,
+                updated_row_3,
+            ]
+            self.connector.records.update_many(
+                table_name,
+                columns_clause,
+                values_clause,
+                where_columns_clause,
+                column_types_clause=column_types_clause,
+            )
+
+            # Verify one record returned.
+            results = self.connector.query.execute('SELECT * FROM {0};'.format(table_name))
+            self.assertEqual(len(results), 3)
+            self.assertIn(updated_row_1 + (row_1[2],), results)
+            self.assertIn(updated_row_2 + (row_2[2],), results)
+            self.assertIn(updated_row_3 + (row_3[2],), results)
+            self.assertNotIn(row_1, results)
+            self.assertNotIn(row_2, results)
+            self.assertNotIn(row_3, results)
+
+            # Update row variables.
+            row_1 = updated_row_1
+            row_2 = updated_row_2
+            row_3 = updated_row_3
+
     def test__delete__success(self):
         """
         Test `DELETE` query.
-- 
GitLab