import { FlowAssignment } from "./flow-assignment";
import Graph, { Edge } from "./graph";

export const FLOW_SOLVER_MAX_ITERATIONS = 1000;
export const FLOW_SOLVER_MAX_SOLUTION_DELTA = 0.0001;

export interface CycleResult {
  imbalance: number;
  suggestedAdjustment?: number;
}

let evaluations = 0;
// provide the initial flow assignment and graph, calculate the balanced flowrates
export function solveFlow<N, E>(options: {
  assignment: FlowAssignment;
  computeCycleImbalance: (
    flowAssignment: FlowAssignment,
    cycle: Edge<N, E>[],
    iteration: number,
    learningRate?: number,
  ) => CycleResult;
  graph: Graph<N, E>;
  adjustFlow?: (
    flowAssignment: FlowAssignment,
    cycle: Edge<N, E>[],
    adjustment: number,
  ) => FlowAssignment;
  escapeDelta?: number;
  onCycles?: (cycles: Edge<N, E>[][]) => Edge<N, E>[][];
  shortCycles?: boolean;
}): FlowAssignment {
  let {
    assignment,
    computeCycleImbalance,
    graph,
    escapeDelta,
    onCycles,
    adjustFlow,
    shortCycles,
  } = options;

  if (escapeDelta === undefined) {
    escapeDelta = FLOW_SOLVER_MAX_SOLUTION_DELTA;
  }

  const generateCycles = (shortest: boolean = false) => {
    let cycles: Edge<N, E>[][] = graph.edgeCycleCover({
      shortest: shortest,
      directed: false,
      random: true,
    });
    // add another set of coverings to diversify the iterations
    // cycles.push(...graph.edgeCycleCover(false, true));

    if (onCycles) {
      cycles = onCycles(cycles);
    }

    return cycles;
  };

  let cycles = generateCycles(true);

  let delta = 0;
  let lastImbalance = 1e9,
    currentImbalance = 0;
  let learningRate = 1;
  for (let iteration = 0; iteration < FLOW_SOLVER_MAX_ITERATIONS; iteration++) {
    delta = 0;
    currentImbalance = 0;
    if ((!shortCycles || iteration >= 100) && iteration % 4 === 0) {
      cycles = generateCycles();
    }

    for (const cycle of cycles) {
      evaluations++;
      const result = computeCycleImbalance(
        assignment,
        cycle,
        iteration,
        learningRate,
      );

      if (result.suggestedAdjustment === undefined) {
        throw new Error(
          "computeCycleImbalance must return a suggested adjustment",
        );
      }

      if (adjustFlow) {
        adjustFlow(assignment, cycle, result.suggestedAdjustment);
      } else {
        for (const edge of cycle) {
          assignment.addFlow(
            edge.uid,
            graph.sn(edge.from),
            result.suggestedAdjustment,
          );
        }
      }

      delta += Math.abs(result.suggestedAdjustment);
      currentImbalance += Math.abs(result.imbalance);
    }
    if (currentImbalance < 100 && currentImbalance > lastImbalance) {
      learningRate *= 0.7;
    }
    lastImbalance = currentImbalance;
    console.log("iteration", iteration, "delta", delta, currentImbalance);

    if (delta < escapeDelta) {
      return assignment;
    }
  }

  throw new Error(
    `Flow solver failed to converge after ${FLOW_SOLVER_MAX_ITERATIONS} iterations, delta: ${delta}`,
  );
}
