/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.resource.batch.parallelism.autoconf;

import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
import org.apache.flink.table.plan.nodes.exec.BatchExecNode;
import org.apache.flink.table.plan.nodes.exec.ExecNode;
import org.apache.flink.table.resource.batch.NodeRunningUnit;
import org.apache.flink.table.resource.batch.parallelism.ShuffleStage;

public class BatchParallelismAdjuster {
    private final double totalCpu;
    private Map<ShuffleStage, Set<NodeRunningUnit>> overlapRunningUnits = new LinkedHashMap<ShuffleStage, Set<NodeRunningUnit>>();
    private Map<NodeRunningUnit, Set<ShuffleStage>> overlapShuffleStages = new LinkedHashMap<NodeRunningUnit, Set<ShuffleStage>>();

    public static void adjustParallelism(double totalCpu, Map<BatchExecNode<?>, Set<NodeRunningUnit>> nodeRunningUnitMap, Map<ExecNode<?, ?>, ShuffleStage> nodeShuffleStageMap) {
        new BatchParallelismAdjuster(totalCpu).adjust(nodeRunningUnitMap, nodeShuffleStageMap);
    }

    private BatchParallelismAdjuster(double totalCpu) {
        this.totalCpu = totalCpu;
    }

    private void adjust(Map<BatchExecNode<?>, Set<NodeRunningUnit>> nodeRunningUnitMap, Map<ExecNode<?, ?>, ShuffleStage> nodeShuffleStageMap) {
        this.buildOverlap(nodeRunningUnitMap, nodeShuffleStageMap);
        for (ShuffleStage shuffleStage : nodeShuffleStageMap.values()) {
            if (shuffleStage.isFinalParallelism()) continue;
            int parallelism = shuffleStage.getParallelism();
            for (NodeRunningUnit runningUnit : this.overlapRunningUnits.get(shuffleStage)) {
                int result = this.calculateParallelism(this.overlapShuffleStages.get(runningUnit), shuffleStage.getParallelism());
                if (result >= parallelism) continue;
                parallelism = result;
            }
            shuffleStage.setParallelism(parallelism, true);
        }
    }

    private void buildOverlap(Map<BatchExecNode<?>, Set<NodeRunningUnit>> nodeRunningUnitMap, Map<ExecNode<?, ?>, ShuffleStage> nodeShuffleStageMap) {
        for (ShuffleStage shuffleStage : nodeShuffleStageMap.values()) {
            LinkedHashSet runningUnitSet = new LinkedHashSet();
            for (ExecNode<?, ?> node : shuffleStage.getExecNodeSet()) {
                runningUnitSet.addAll(nodeRunningUnitMap.get(node));
            }
            this.overlapRunningUnits.put(shuffleStage, runningUnitSet);
        }
        for (Set set : nodeRunningUnitMap.values()) {
            for (NodeRunningUnit runningUnit : set) {
                if (this.overlapShuffleStages.containsKey(runningUnit)) continue;
                LinkedHashSet<ShuffleStage> shuffleStageSet = new LinkedHashSet<ShuffleStage>();
                for (ExecNode execNode : runningUnit.getNodeSet()) {
                    shuffleStageSet.add(nodeShuffleStageMap.get(execNode));
                }
                this.overlapShuffleStages.put(runningUnit, shuffleStageSet);
            }
        }
    }

    private int calculateParallelism(Set<ShuffleStage> shuffleStages, int parallelism) {
        double remain = this.totalCpu;
        double need = 0.0;
        for (ShuffleStage shuffleStage : shuffleStages) {
            if (shuffleStage.isFinalParallelism()) {
                remain -= this.getCpu(shuffleStage, shuffleStage.getParallelism());
                continue;
            }
            remain -= this.getCpu(shuffleStage, 1);
            need += this.getCpu(shuffleStage, shuffleStage.getParallelism() - 1);
        }
        if (remain < 0.0) {
            throw new IllegalArgumentException("adjust parallelism error, fixed resource > remain resource.");
        }
        if (remain > need) {
            return parallelism;
        }
        double ratio = remain / need;
        return (int)((double)(parallelism - 1) * ratio) + 1;
    }

    private double getCpu(ShuffleStage shuffleStage, int parallelism) {
        double totalCpu = 0.0;
        for (ExecNode<?, ?> node : shuffleStage.getExecNodeSet()) {
            totalCpu = Math.max(totalCpu, node.getResource().getCpu());
        }
        return totalCpu *= (double)parallelism;
    }
}

