"""
Date: 10-30-19
Class: CS5310
Assignment: Peterson's Algorithm
Author: Brandon Rodriguez


Logic to run multi-processing/multi-threading logic.
"""

# System Imports.
import time
from functools import partial
from multiprocessing import Manager
from multiprocessing import Pool as MultiProcessPool
from multiprocessing.dummy import Pool as MultiThreadPool

# User Class Imports.
from resources import logging as init_logging


# Initialize Logger.
logger = init_logging.get_logger(__name__)


class Parallelism():
    """
    Class to run multi-processing/multi-threading logic.

    Note: Because we're using "multiprocessing" and "multiprocessing.dummy" Python libraries, we can create either
    multi-processing or multi-threading using the same logic and syntax. Very useful.
    """
    def __init__(self, as_safe, *args, multi_process=False, multi_thread=False, **kwargs):
        # Check that either user defined parallelism type to run.
        if not multi_process and not multi_thread:
            raise ValueError('Must define to run as either multi-processing or multi-threading. Got None.')

        if multi_process and multi_thread:
            raise ValueError('Must define to run as either multi-processing or multi-threading. Got both.')

        # Create multi-processing pool.
        if multi_process:
            logger.info('Running program with multi-processing.')
            thread_pool = MultiProcessPool(2)

        # Create multi-threading pool.
        if multi_thread:
            logger.info('Running program with multi-threading.')
            thread_pool = MultiThreadPool(2)

        # Create a parallel process/thread "manager".
        # Note, normally this is used to safely manage things like locks, shared variables, and more.
        # However, the point of this program is specifically to create our own safe lock for "critical sections" of
        # code.
        #
        # Unfortunately, we still need a manager to be able to access shared variables at all. So we only use the
        # manager minimally, just to create shared variables that all processes/threads can access.
        #
        # To accomplish this, we create a manager and only ever call manager.value() to create shared variables.
        # We also make sure to pass the variable "lock=False" to make sure that the manager does not handle any locking
        # logic for us.
        manager = Manager()
        shared_flag = manager.Array('b', [True, True], lock=False) # Create a shared array of bools, of size 2.
        shared_lock = manager.Value('i', 0, lock=False) # Create a shared int.
        shared_counter = manager.Value('i', 0, lock=False)

        logger.info('Created shared Flag: {0}'.format(shared_flag))
        logger.info('Created shared Lock: {0}'.format(shared_lock))
        logger.info('Created shared Counter: {0}'.format(shared_counter))
        logger.info('Creating processes/threads.')
        logger.info('')

        # Now that shared variables are created, run parallel execution.
        # Note that we pass in our variables with partial.
        # Then we tell how many threads to create with range.
        partial_a = partial(self.thread_a_wrapper, as_safe, shared_flag, shared_lock, shared_counter)
        partial_b = partial(self.thread_b_wrapper, as_safe, shared_flag, shared_lock, shared_counter)
        results_a = thread_pool.map_async(partial_a, range(1))
        results_b = thread_pool.map_async(partial_b, range(1))

        # Wait until all processes/threads complete.
        results_a.wait()
        results_b.wait()

        # Close pool and join all results.
        thread_pool.close()
        thread_pool.join()

        # Get process/thread results.
        logger.info('')
        logger.info('Thread Results:')
        logger.info('   A: {0}'.format(results_a.get()))
        logger.info('   B: {0}'.format(results_b.get()))
        logger.info('')

    def thread_a_wrapper(self, as_safe, shared_flag, shared_lock, shared_counter, *args, **kwargs):
        """
        A wrapper is necessary in order to pass more than one arg into a thread.
        Then we call functools.partial() on this wrapper, and pass that into the thread call itself.

        Alternatively, we can condense our args into a single value (Ex: a tuple or dictionary to pass in).
        Then we'd only have "one arg" to pass as far as Python is concerned, so a wrapper becomes unnecessary.
        :param as_safe: Bool indicating if "safe" or "unsafe" method should be ran.
        :param shared_flag: Shared memory "flag" variable.
        :param shared_lock: Shared memory "lock" variable.
        :param shared_counter: Shared memory "counter" variable.
        :return: The returned value from our given thread method.
        """
        # Check if "safe" or "unsafe" method should be ran.
        if as_safe:
            # Run safe method.
            return self.thread_a_method_safe(shared_flag, shared_lock, shared_counter, *args, **kwargs)
        else:
            # Run unsafe method.
            return self.thread_a_method_unsafe(shared_flag, shared_lock, shared_counter, *args, **kwargs)

    def thread_b_wrapper(self, as_safe, shared_flag, shared_lock, shared_counter, *args, **kwargs):
        """
        Like described above in "thread_a_wrapper", this is necessary to pass more than one arg into a thread.
        :param as_safe: Bool indicating if "safe" or "unsafe" method should be ran.
        :param shared_flag: Shared memory "flag" variable.
        :param shared_lock: Shared memory "lock" variable.
        :param shared_counter: Shared memory "counter" variable.
        :return: The returned value from our given thread method.
        """
        # Check if "safe" or "unsafe" method should be ran.
        if as_safe:
            # Run safe method.
            return self.thread_b_method_safe(shared_flag, shared_lock, shared_counter, *args, **kwargs)
        else:
            # Run unsafe method.
            return self.thread_b_method_unsafe(shared_flag, shared_lock, shared_counter, *args, **kwargs)

    def thread_a_method_unsafe(self, flag, turn, counter, *args, **kwargs):
        """
        Function for thread "a" to run.
        :param flag: Shared memory "flag" variable.
        :param turn: Shared memory "lock" variable.
        :param counter: Shared memory "counter" variable.
        """
        logger.info('Running unsafe thread "a" method.')

        # Loop through and add 100 to counter.
        # Since this is the "unsafe" version, we don't care about flag or counter values.
        for index in range(100):
            counter.value += 1

        logger.info('Exiting unsafe thread "a" method.')
        return counter.value

    def thread_b_method_unsafe(self, flag, turn, counter, *args, **kwargs):
        """
        Function for thread "b" to run.
        :param flag: Shared memory "flag" variable.
        :param turn: Shared memory "lock" variable.
        :param counter: Shared memory "counter" variable.
        """
        logger.info('Running unsafe thread "b" method.')

        # Loop through and add 100 to counter.
        # Since this is the "unsafe" version, we don't care about flag or counter values.
        for index in range(100):
            counter.value += 1

        logger.info('Exiting unsafe thread "b" method.')
        return counter.value

    def thread_a_method_safe(self, flag, turn, counter, *args, **kwargs):
        """
        Function for thread "a" to run.
        :param flag: Shared memory "flag" variable.
        :param turn: Shared memory "lock" variable.
        :param counter: Shared memory "counter" variable.
        """
        logger.info('Running safe thread "a" method.')

        # Set to be other thread's turn.
        flag[0] = True
        turn.value = 1

        # Loop on a sleep command until it's this thread's turn.
        while flag[1] and turn.value == 1:
            # Normally, we'd want to time.sleep here to wait for some amount of time.
            # However, since this program is so small, we just use pass instead.
            pass

        # If we made it this far, then it's this thread's turn.
        # Enter critical section.
        logger.info('Entering thread "a" critical section.')

        # Loop through and add 100 to counter.
        for index in range(100):
            counter.value += 1

        logger.info('Exiting thread "a" critical section.')
        # Critical section over.
        # Set to be other thread's turn.
        flag[0] = False

        logger.info('Exiting safe thread "a" method.')
        return counter.value

    def thread_b_method_safe(self, flag, turn, counter, *args, **kwargs):
        """
        Function for thread "b" to run.
        :param flag: Shared memory "flag" variable.
        :param turn: Shared memory "lock" variable.
        :param counter: Shared memory "counter" variable.
        """
        logger.info('Running safe thread "b" method.')

        # Set to be other thread's turn.
        flag[1] = True
        turn.value = 0

        # Loop on a sleep command until it's this thread's turn.
        while flag[0] and turn.value == 0:
            # Normally, we'd want to time.sleep here to wait for some amount of time.
            # However, since this program is so small, we just use pass instead.
            pass

        # If we made it this far, then it's this thread's turn.
        # Enter critical section.
        logger.info('Entering thread "b" critical section.')

        # Loop through and add 100 to counter.
        for index in range(100):
            counter.value += 1

        logger.info('Exiting thread "b" critical section.')
        # Critical section over.
        # Set to be other thread's turn.
        flag[1] = False

        logger.info('Exiting safe thread "b" method.')
        return counter.value
