import csv
import heapq
import logging
import time
import numpy as np

import pulp

# Configure logging
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)


def read_lengths_from_csv(file_path):
    lengths = []
    with open(file_path, "r") as file:
        reader = csv.reader(file)
        for row in reader:
            length, count = map(int, row)
            lengths.extend([length] * count)
    return lengths


def best_fit_decreasing(stock_lengths, demand_lengths):
    sorted_demands = sorted(demand_lengths, reverse=True)
    solution = []
    remaining_lengths = []

    # Initialize the priority queue with available stock lengths
    for i, length in enumerate(stock_lengths):
        heapq.heappush(remaining_lengths, (length, i))

    for demand_idx, demand in enumerate(sorted_demands):
        best_fit_index = -1
        min_remaining_length = float('inf')
        temp_list = []

        while remaining_lengths:
            remaining_length, index = heapq.heappop(remaining_lengths)
            if remaining_length >= demand:
                best_fit_index = index
                min_remaining_length = remaining_length - demand
                break
            else:
                temp_list.append((remaining_length, index))

        for item in temp_list:
            heapq.heappush(remaining_lengths, item)

        if best_fit_index == -1:
            # No existing bin can fit the demand, use a new bin
            new_stock_length = stock_lengths[len(remaining_lengths) % len(stock_lengths)]
            remaining_lengths.append((new_stock_length - demand, len(remaining_lengths)))
            heapq.heappush(remaining_lengths, (new_stock_length - demand, len(remaining_lengths) - 1))
            solution.append((len(remaining_lengths) - 1, demand_idx, new_stock_length))
        else:
            # Fit into the best bin
            solution.append((best_fit_index, demand_idx, min_remaining_length + demand))
            heapq.heappush(remaining_lengths, (min_remaining_length, best_fit_index))

    return solution, remaining_lengths


def worst_fit_decreasing(stock_lengths, demand_lengths):
    sorted_demands = sorted(demand_lengths, reverse=True)
    solution = []
    remaining_lengths = []

    # Initialize the max-heap with available stock lengths (use negative lengths for max-heap behavior)
    for i, length in enumerate(stock_lengths):
        heapq.heappush(remaining_lengths, (-length, i))

    for demand_idx, demand in enumerate(sorted_demands):
        worst_fit_index = -1
        temp_list = []

        while remaining_lengths:
            remaining_length, index = heapq.heappop(remaining_lengths)
            remaining_length = -remaining_length  # Convert back to positive length
            if remaining_length >= demand:
                worst_fit_index = index
                new_remaining_length = remaining_length - demand
                break
            else:
                temp_list.append((-remaining_length, index))

        for item in temp_list:
            heapq.heappush(remaining_lengths, item)

        if worst_fit_index == -1:
            # No existing bin can fit the demand, use a new bin
            new_stock_length = stock_lengths[len(remaining_lengths) % len(stock_lengths)]
            remaining_lengths.append((-new_stock_length + demand, len(remaining_lengths)))
            heapq.heappush(remaining_lengths, (-(new_stock_length - demand), len(remaining_lengths) - 1))
            solution.append((len(remaining_lengths) - 1, demand_idx, new_stock_length))
        else:
            # Fit into the worst bin
            solution.append((worst_fit_index, demand_idx, remaining_length))
            if new_remaining_length > 0:
                heapq.heappush(remaining_lengths, (-new_remaining_length, worst_fit_index))

    return solution, remaining_lengths


def calculate_score(solution, stock_lengths, demand_lengths):
    num_rolls_used = len(set(roll for roll, _, _ in solution))
    total_length_cut = sum(demand_lengths[demand_idx] for _, demand_idx, _ in solution)
    total_wastage = sum(remaining for _, _, remaining in solution if remaining > 0)
    used_stock_lengths = sum(stock_lengths[stock_index] for stock_index, _, _ in solution)

    score = total_length_cut - (num_rolls_used * used_stock_lengths) - total_wastage
    return score


def balance_cuts_heuristic(stock_lengths, demand_lengths):
    sorted_demands = sorted(demand_lengths, reverse=True)
    solution = []
    current_loads = [0] * len(stock_lengths)  # Initialize current loads to zero

    for demand_idx, demand in enumerate(sorted_demands):
        best_roll = -1
        min_load = float('inf')
        for i, load in enumerate(current_loads):
            if load + demand <= stock_lengths[i] and load < min_load:
                best_roll = i
                min_load = load
        if best_roll == -1:
            # No roll can accommodate the demand, handle this case appropriately
            continue
        else:
            # Place the cut in the best_roll
            current_loads[best_roll] += demand
            solution.append((best_roll, demand_idx, stock_lengths[best_roll] - current_loads[best_roll]))

    return solution, current_loads


def solve_pulp(
        stock_lengths_file,
        demand_lengths_file,
        minimize_wastage=True,
        use_initial_solution=True,
        balance_cuts=False,
        print_func=print
):
    logging.info("Reading stock and demand lengths from CSV files.")
    # Read stock and demand lengths from CSV files
    stock_lengths = read_lengths_from_csv(stock_lengths_file)
    demand_lengths = read_lengths_from_csv(demand_lengths_file)

    print_func("Preprocessing demand lengths.")
    # Check for impossible demand lengths
    max_stock_length = max(stock_lengths)
    too_large_demands = [d for d in demand_lengths if d > max_stock_length]
    if too_large_demands:
        print_func("The following demand lengths are too large to be satisfied by any stock length:")
        for demand in too_large_demands:
            print_func(f"Demand length: {demand}")
        return False

    # Check if the total stock lengths can satisfy the total demand lengths
    if sum(stock_lengths) < sum(demand_lengths):
        logging.error("The total stock lengths are insufficient to satisfy the total demand lengths.")
        print_func("The total stock lengths are insufficient to satisfy the total demand lengths.")
        return False

    # Start timer
    start_time = time.time()

    # Generate initial solutions and select the best one if requested
    if use_initial_solution:
        print_func("Generating initial solution using BFD and WFD.")
        initial_solution_bfd, remaining_bfd = best_fit_decreasing(stock_lengths, demand_lengths)
        initial_solution_wfd, remaining_wfd = worst_fit_decreasing(stock_lengths, demand_lengths)
        print_func("Remaining lengths: BFD = %i, WFD = %i" % (sum(r for r, _ in remaining_bfd), sum(r for r, _ in remaining_wfd)))

        score_bfd = calculate_score(initial_solution_bfd, stock_lengths, demand_lengths)
        score_wfd = calculate_score(initial_solution_wfd, stock_lengths, demand_lengths)

        print_func("BFD = %i, WFD = %i" % (score_bfd, score_wfd))
        # Choose the better initial solution based on the number of stock rolls used
        if score_bfd > score_wfd:
            print_func("Choosing BFD initial solution.")
            initial_solution = initial_solution_bfd
        else:
            print_func("Choosing WFD initial solution.")
            initial_solution = initial_solution_wfd
        max_stock_used = len([roll for roll, _, _ in initial_solution])
        print_func("Trimming stock lengths array based on initial solution.")
        stock_lengths = [stock_lengths[stock_idx] for stock_idx, _, _ in initial_solution]
    else:
        max_stock_used = len(stock_lengths) * len(demand_lengths)

    if balance_cuts:
        print_func("Balancing cuts using the heuristic.")
        initial_solution, _ = balance_cuts_heuristic(stock_lengths, demand_lengths)
        max_stock_used = len(set(roll for roll, _, _ in initial_solution))

    # Define the optimization problem
    prob = pulp.LpProblem("CuttingStock", pulp.LpMinimize)

    print_func("Creating decision variables.")
    # Decision variables
    x = pulp.LpVariable.dicts(
        "x",
        ((i, j) for i in range(max_stock_used) for j in range(len(demand_lengths))),
        lowBound=0,
        cat="Binary",
    )
    y = pulp.LpVariable.dicts("y", (i for i in range(max_stock_used)), cat="Binary")
    if minimize_wastage:
        w = pulp.LpVariable.dicts(
            "w", (i for i in range(max_stock_used)), lowBound=0, cat="Continuous"
        )

    print_func("Setting up the objective function to minimize the number of stock rolls used.")
    # Objective function: minimize the number of stock rolls used
    prob += pulp.lpSum([y[i] for i in range(max_stock_used)])

    print_func("Adding constraints to the problem.")
    # Constraints: each demand must be fulfilled exactly once
    for j in range(len(demand_lengths)):
        prob += pulp.lpSum([x[i, j] for i in range(max_stock_used)]) == 1

    # Constraints: the total length of demands cut from each stock roll must not exceed the stock length
    for i in range(max_stock_used):
        prob += (
            pulp.lpSum(
                [demand_lengths[j] * x[i, j] for j in range(len(demand_lengths))]
            )
            <= stock_lengths[i % len(stock_lengths)] * y[i]
        )
        if minimize_wastage:
            prob += w[i] == stock_lengths[i % len(stock_lengths)] * y[i] - pulp.lpSum(
                [demand_lengths[j] * x[i, j] for j in range(len(demand_lengths))]
            )

    # Set initial solution values for faster convergence if requested
    if use_initial_solution:
        print_func("Setting initial solution values for faster convergence.")
        used_stocks = set()
        for stock_index, demand_index, _ in initial_solution:
            if stock_index < max_stock_used and demand_index < len(demand_lengths):
                x[stock_index, demand_index].setInitialValue(1)
                used_stocks.add(stock_index)
        for stock_index in used_stocks:
            y[stock_index].setInitialValue(1)

    print_func("Solving the optimization problem.")
    # prob.solve(pulp.GLPK(path="glpk-4.65/w64/glpsol.exe", msg=False))
    prob.solve(pulp.PULP_CBC_CMD(msg=False))

    if pulp.LpStatus[prob.status] != "Optimal":
        print_func("Optimal solution not found.")
        return False

    logging.info("Collecting and sorting the results.")
    results = []
    total_demand_cut_length = 0
    total_wastage = 0
    used_rolls = set()
    roll_counter = 1
    unique_cut_patterns = set()

    for i in range(max_stock_used):
        if y[i].varValue > 0:
            stock_length = stock_lengths[i % len(stock_lengths)]
            demand_cuts = [
                demand_lengths[j]
                for j in range(len(demand_lengths))
                if x[i, j].varValue > 0
            ]
            total_demand_cut = sum(demand_cuts)
            wastage = stock_length - total_demand_cut
            total_demand_cut_length += total_demand_cut
            total_wastage += wastage
            results.append(
                (stock_length, demand_cuts, total_demand_cut, wastage, roll_counter)
            )
            used_rolls.add(i)
            roll_counter += 1
            unique_cut_patterns.add(tuple(sorted(demand_cuts)))

    results.sort(key=lambda x: (x[0], -len(x[1])))

    logging.info("Printing the results.")
    for stock_length, demand_cuts, total_demand_cut, wastage, roll_counter in results:
        result_line = f"Roll number {roll_counter} ({stock_length})\n\t{' + '.join(map(str, demand_cuts))} = {total_demand_cut}\n\tWastage = {wastage}"
        print_func(result_line)

    num_unique_cut_patterns = len(unique_cut_patterns)
    total_cuts = sum(len(demand_cuts) for _, demand_cuts, _, _, _ in results)
    average_cuts_per_roll = total_cuts / len(used_rolls) if used_rolls else 0
    max_cuts_roll = max(len(demand_cuts) for _, demand_cuts, _, _, _ in results) if results else 0

    print_func("\nSummary:")
    print_func("---------")
    print_func(f"Number of rolls used: {len(used_rolls)}")
    print_func(f"Total length cut: {total_demand_cut_length}")
    print_func(f"Total wastage: {total_wastage}")

    print_func("\nDetailed Statistics:")
    print_func("---------------------")
    stock_count = {}
    for stock_length, _, _, _, _ in results:
        if stock_length not in stock_count:
            stock_count[stock_length] = 0
        stock_count[stock_length] += 1
    print_func("Stock Lengths Used:")
    for stock_length, count in stock_count.items():
        print_func(f"    - {stock_length} ({count} time{'s' if count > 1 else ''})")

    print_func("\nWastage Distribution:")
    min_wastage = min(w for _, _, _, w, _ in results) if results else 0
    max_wastage = max(w for _, _, _, w, _ in results) if results else 0
    avg_wastage = total_wastage / len(used_rolls) if used_rolls else 0
    print_func(f"    - Minimum wastage: {min_wastage}")
    print_func(f"    - Maximum wastage: {max_wastage}")
    print_func(f"    - Average wastage per roll: {avg_wastage:.2f}")

    print_func("\nDemand Lengths Statistics:")
    total_demand_lengths = len(demand_lengths)
    fulfilled_demand_lengths = total_demand_lengths
    unfulfilled_demand_lengths = 0
    print_func(f"    - Total demand lengths: {total_demand_lengths}")
    print_func(f"    - Fulfilled demand lengths: {fulfilled_demand_lengths}")
    print_func(f"    - Unfulfilled demand lengths: {unfulfilled_demand_lengths}")

    print_func("\nPerformance Metrics:")
    print_func(f"    - Total time taken: {time.time() - start_time:.2f} seconds")
    print_func(f"    - Solver status: {pulp.LpStatus[prob.status]}")

    print_func("\nUtilization Analysis:")
    print_func("---------------------")
    for stock_length in stock_count:
        print_func(f"    - Roll Utilization (Stock Length: {stock_length}):")
        for stock, demand_cuts, total_demand_cut, _, roll_counter in results:
            if stock == stock_length:
                utilization = (total_demand_cut / stock) * 100
                print_func(f"        - Roll {roll_counter}: {utilization:.2f}%")

    print_func("\nBalance of Cuts:")
    print_func(f"    - Number of unique cut patterns: {num_unique_cut_patterns}")
    print_func(f"    - Average cuts per roll: {average_cuts_per_roll:.2f}")
    print_func(f"    - Rolls with most cuts: {max_cuts_roll}")

    return True


def solve_heuristics(
        stock_lengths_file,
        demand_lengths_file,
        print_func=print
):
    stock_lengths = read_lengths_from_csv(stock_lengths_file)
    demand_lengths = read_lengths_from_csv(demand_lengths_file)
    
    # Start timer
    start_time = time.time()
    
    logging.info("Generating initial solution using BFD and WFD.")
    initial_solution_bfd, remaining_bfd = best_fit_decreasing(stock_lengths, demand_lengths)
    initial_solution_wfd, remaining_wfd = worst_fit_decreasing(stock_lengths, demand_lengths)
    logging.info("Remaining lengths: BFD = %i, WFD = %i" % (sum(r for r, _ in remaining_bfd), sum(r for r, _ in remaining_wfd)))
    
    score_bfd = calculate_score(initial_solution_bfd, stock_lengths, demand_lengths)
    score_wfd = calculate_score(initial_solution_wfd, stock_lengths, demand_lengths)
    
    logging.info("BFD = %i, WFD = %i" % (score_bfd, score_wfd))
    if score_bfd > score_wfd:
        logging.info("Choosing BFD solution.")
        best_solution = initial_solution_bfd
    else:
        logging.info("Choosing WFD solution.")
        best_solution = initial_solution_wfd
    
    logging.info("Collecting and sorting the results.")
    results = []
    total_demand_cut_length = 0
    total_wastage = 0
    used_rolls = set()
    roll_counter = 1
    unique_cut_patterns = set()

    for stock_idx, demand_idx, remaining_length in best_solution:
        stock_length = stock_lengths[stock_idx % len(stock_lengths)]
        demand_length = demand_lengths[demand_idx]
        total_demand_cut = stock_length - remaining_length
        wastage = remaining_length
        total_demand_cut_length += demand_length
        total_wastage += wastage
        results.append(
            (stock_length, demand_length, total_demand_cut, wastage, roll_counter)
        )
        used_rolls.add(stock_idx)
        roll_counter += 1
        unique_cut_patterns.add(tuple(sorted([demand_length])))
    
    results.sort(key=lambda x: (x[0], -len([x[1]])))

    logging.info("Printing the results.")
    for stock_length, demand_cuts, total_demand_cut, wastage, roll_counter in results:
        result_line = f"Roll number {roll_counter} ({stock_length})\n\t{' + '.join(map(str, [demand_cuts]))} = {total_demand_cut}\n\tWastage = {wastage}"
        print_func(result_line)
    
    num_unique_cut_patterns = len(unique_cut_patterns)
    total_cuts = sum(len([demand_cuts]) for _, demand_cuts, _, _, _ in results)
    average_cuts_per_roll = total_cuts / len(used_rolls) if used_rolls else 0
    max_cuts_roll = max(len([demand_cuts]) for _, demand_cuts, _, _, _ in results) if results else 0

    print_func("\nSummary:")
    print_func("---------")
    print_func(f"Number of rolls used: {len(used_rolls)}")
    print_func(f"Total length cut: {total_demand_cut_length}")
    print_func(f"Total wastage: {total_wastage}")

    print_func("\nDetailed Statistics:")
    print_func("---------------------")
    stock_count = {}
    for stock_length, _, _, _, _ in results:
        if stock_length not in stock_count:
            stock_count[stock_length] = 0
        stock_count[stock_length] += 1
    print_func("Stock Lengths Used:")
    for stock_length, count in stock_count.items():
        print_func(f"    - {stock_length} ({count} time{'s' if count > 1 else ''})")

    print_func("\nWastage Distribution:")
    min_wastage = min(w for _, _, _, w, _ in results) if results else 0
    max_wastage = max(w for _, _, _, w, _ in results) if results else 0
    avg_wastage = total_wastage / len(used_rolls) if used_rolls else 0
    print_func(f"    - Minimum wastage: {min_wastage}")
    print_func(f"    - Maximum wastage: {max_wastage}")
    print_func(f"    - Average wastage per roll: {avg_wastage:.2f}")

    print_func("\nDemand Lengths Statistics:")
    total_demand_lengths = len(demand_lengths)
    fulfilled_demand_lengths = total_demand_lengths
    unfulfilled_demand_lengths = 0
    print_func(f"    - Total demand lengths: {total_demand_lengths}")
    print_func(f"    - Fulfilled demand lengths: {fulfilled_demand_lengths}")
    print_func(f"    - Unfulfilled demand lengths: {unfulfilled_demand_lengths}")

    print_func("\nPerformance Metrics:")
    print_func(f"    - Total time taken: {time.time() - start_time:.2f} seconds")

    print_func("\nUtilization Analysis:")
    print_func("---------------------")
    for stock_length in stock_count:
        print_func(f"    - Roll Utilization (Stock Length: {stock_length}):")
        for stock, demand_cuts, total_demand_cut, _, roll_counter in results:
            if stock == stock_length:
                utilization = (total_demand_cut / stock) * 100
                print_func(f"        - Roll {roll_counter}: {utilization:.2f}%")

    print_func("\nBalance of Cuts:")
    print_func(f"    - Number of unique cut patterns: {num_unique_cut_patterns}")
    print_func(f"    - Average cuts per roll: {average_cuts_per_roll:.2f}")
    print_func(f"    - Rolls with most cuts: {max_cuts_roll}")

    return True
