package org.apache.flink.table.resource.batch.parallelism.autoconf;

import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
import org.apache.flink.table.plan.nodes.exec.BatchExecNode;
import org.apache.flink.table.plan.nodes.exec.ExecNode;
import org.apache.flink.table.resource.batch.NodeRunningUnit;
import org.apache.flink.table.resource.batch.parallelism.ShuffleStage;

/* loaded from: input_file:org/apache/flink/table/resource/batch/parallelism/autoconf/BatchParallelismAdjuster.class */
public class BatchParallelismAdjuster {
    private final double totalCpu;
    private Map<ShuffleStage, Set<NodeRunningUnit>> overlapRunningUnits = new LinkedHashMap();
    private Map<NodeRunningUnit, Set<ShuffleStage>> overlapShuffleStages = new LinkedHashMap();

    public static void adjustParallelism(double d, Map<BatchExecNode<?>, Set<NodeRunningUnit>> map, Map<ExecNode<?, ?>, ShuffleStage> map2) {
        new BatchParallelismAdjuster(d).adjust(map, map2);
    }

    private BatchParallelismAdjuster(double d) {
        this.totalCpu = d;
    }

    private void adjust(Map<BatchExecNode<?>, Set<NodeRunningUnit>> map, Map<ExecNode<?, ?>, ShuffleStage> map2) {
        buildOverlap(map, map2);
        for (ShuffleStage shuffleStage : map2.values()) {
            if (!shuffleStage.isFinalParallelism()) {
                int parallelism = shuffleStage.getParallelism();
                Iterator<NodeRunningUnit> it = this.overlapRunningUnits.get(shuffleStage).iterator();
                while (it.hasNext()) {
                    int calculateParallelism = calculateParallelism(this.overlapShuffleStages.get(it.next()), shuffleStage.getParallelism());
                    if (calculateParallelism < parallelism) {
                        parallelism = calculateParallelism;
                    }
                }
                shuffleStage.setParallelism(parallelism, true);
            }
        }
    }

    private void buildOverlap(Map<BatchExecNode<?>, Set<NodeRunningUnit>> map, Map<ExecNode<?, ?>, ShuffleStage> map2) {
        for (ShuffleStage shuffleStage : map2.values()) {
            LinkedHashSet linkedHashSet = new LinkedHashSet();
            Iterator<ExecNode<?, ?>> it = shuffleStage.getExecNodeSet().iterator();
            while (it.hasNext()) {
                linkedHashSet.addAll(map.get(it.next()));
            }
            this.overlapRunningUnits.put(shuffleStage, linkedHashSet);
        }
        Iterator<Set<NodeRunningUnit>> it2 = map.values().iterator();
        while (it2.hasNext()) {
            for (NodeRunningUnit nodeRunningUnit : it2.next()) {
                if (!this.overlapShuffleStages.containsKey(nodeRunningUnit)) {
                    LinkedHashSet linkedHashSet2 = new LinkedHashSet();
                    Iterator<BatchExecNode<?>> it3 = nodeRunningUnit.getNodeSet().iterator();
                    while (it3.hasNext()) {
                        linkedHashSet2.add(map2.get(it3.next()));
                    }
                    this.overlapShuffleStages.put(nodeRunningUnit, linkedHashSet2);
                }
            }
        }
    }

    private int calculateParallelism(Set<ShuffleStage> set, int i) {
        double d = this.totalCpu;
        double d2 = 0.0d;
        for (ShuffleStage shuffleStage : set) {
            if (shuffleStage.isFinalParallelism()) {
                d -= getCpu(shuffleStage, shuffleStage.getParallelism());
            } else {
                d -= getCpu(shuffleStage, 1);
                d2 += getCpu(shuffleStage, shuffleStage.getParallelism() - 1);
            }
        }
        if (d < 0.0d) {
            throw new IllegalArgumentException("adjust parallelism error, fixed resource > remain resource.");
        }
        if (d > d2) {
            return i;
        }
        return ((int) ((i - 1) * (d / d2))) + 1;
    }

    private double getCpu(ShuffleStage shuffleStage, int i) {
        double d = 0.0d;
        Iterator<ExecNode<?, ?>> it = shuffleStage.getExecNodeSet().iterator();
        while (it.hasNext()) {
            d = Math.max(d, it.next().getResource().getCpu());
        }
        return d * i;
    }
}
