package org.apache.flink.table.resource.batch.parallelism;

import java.util.Comparator;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.calcite.rel.RelDistribution;
import org.apache.flink.table.plan.nodes.exec.ExecNode;
import org.apache.flink.table.plan.nodes.physical.batch.BatchExecExchange;
import org.apache.flink.table.plan.nodes.physical.batch.BatchExecSink;
import org.apache.flink.table.plan.nodes.physical.batch.BatchExecUnion;

/* loaded from: input_file:org/apache/flink/table/resource/batch/parallelism/ShuffleStageGenerator.class */
public class ShuffleStageGenerator {
    private final Map<ExecNode<?, ?>, ShuffleStage> nodeShuffleStageMap = new LinkedHashMap();
    private final Map<ExecNode<?, ?>, Integer> nodeToFinalParallelismMap;

    private ShuffleStageGenerator(Map<ExecNode<?, ?>, Integer> map) {
        this.nodeToFinalParallelismMap = map;
    }

    public static Map<ExecNode<?, ?>, ShuffleStage> generate(List<ExecNode<?, ?>> list, Map<ExecNode<?, ?>, Integer> map) {
        ShuffleStageGenerator shuffleStageGenerator = new ShuffleStageGenerator(map);
        shuffleStageGenerator.getClass();
        list.forEach(shuffleStageGenerator::buildShuffleStages);
        shuffleStageGenerator.getNodeShuffleStageMap().values().forEach(shuffleStage -> {
            List list2 = (List) shuffleStage.getExecNodeSet().stream().filter(ShuffleStageGenerator::isVirtualNode).collect(Collectors.toList());
            shuffleStage.getClass();
            list2.forEach(shuffleStage::removeNode);
        });
        return (Map) shuffleStageGenerator.getNodeShuffleStageMap().entrySet().stream().filter(entry -> {
            return !isVirtualNode((ExecNode) entry.getKey());
        }).collect(Collectors.toMap((v0) -> {
            return v0.getKey();
        }, (v0) -> {
            return v0.getValue();
        }, (shuffleStage2, shuffleStage3) -> {
            return shuffleStage2;
        }, LinkedHashMap::new));
    }

    private void buildShuffleStages(ExecNode<?, ?> execNode) {
        if (this.nodeShuffleStageMap.containsKey(execNode)) {
            return;
        }
        Iterator<ExecNode<?, ?>> it = execNode.getInputNodes().iterator();
        while (it.hasNext()) {
            buildShuffleStages(it.next());
        }
        if (execNode.getInputNodes().isEmpty()) {
            ShuffleStage shuffleStage = new ShuffleStage();
            shuffleStage.addNode(execNode);
            if (this.nodeToFinalParallelismMap.containsKey(execNode)) {
                shuffleStage.setParallelism(this.nodeToFinalParallelismMap.get(execNode).intValue(), true);
            }
            this.nodeShuffleStageMap.put(execNode, shuffleStage);
            return;
        }
        if ((!(execNode instanceof BatchExecExchange) || ((BatchExecExchange) execNode).getDistribution().getType() == RelDistribution.Type.RANGE_DISTRIBUTED) && !(execNode instanceof BatchExecSink)) {
            ShuffleStage mergeInputShuffleStages = mergeInputShuffleStages(getInputShuffleStages(execNode), this.nodeToFinalParallelismMap.get(execNode));
            mergeInputShuffleStages.addNode(execNode);
            this.nodeShuffleStageMap.put(execNode, mergeInputShuffleStages);
        }
    }

    private ShuffleStage mergeInputShuffleStages(Set<ShuffleStage> set, Integer num) {
        if (num == null) {
            ShuffleStage orElse = set.stream().filter((v0) -> {
                return v0.isFinalParallelism();
            }).max(Comparator.comparing((v0) -> {
                return v0.getParallelism();
            })).orElse(new ShuffleStage());
            for (ShuffleStage shuffleStage : set) {
                if (!shuffleStage.isFinalParallelism() || shuffleStage.getParallelism() == orElse.getParallelism()) {
                    mergeShuffleStage(orElse, shuffleStage);
                }
            }
            return orElse;
        }
        ShuffleStage shuffleStage2 = new ShuffleStage();
        shuffleStage2.setParallelism(num.intValue(), true);
        for (ShuffleStage shuffleStage3 : set) {
            if (!shuffleStage3.isFinalParallelism() || shuffleStage3.getParallelism() == num.intValue()) {
                mergeShuffleStage(shuffleStage2, shuffleStage3);
            }
        }
        return shuffleStage2;
    }

    private void mergeShuffleStage(ShuffleStage shuffleStage, ShuffleStage shuffleStage2) {
        Set<ExecNode<?, ?>> execNodeSet = shuffleStage2.getExecNodeSet();
        shuffleStage.addNodeSet(execNodeSet);
        Iterator<ExecNode<?, ?>> it = execNodeSet.iterator();
        while (it.hasNext()) {
            this.nodeShuffleStageMap.put(it.next(), shuffleStage);
        }
    }

    private Set<ShuffleStage> getInputShuffleStages(ExecNode<?, ?> execNode) {
        HashSet hashSet = new HashSet();
        Iterator<ExecNode<?, ?>> it = execNode.getInputNodes().iterator();
        while (it.hasNext()) {
            ShuffleStage shuffleStage = this.nodeShuffleStageMap.get(it.next());
            if (shuffleStage != null) {
                hashSet.add(shuffleStage);
            }
        }
        return hashSet;
    }

    private static boolean isVirtualNode(ExecNode<?, ?> execNode) {
        return execNode instanceof BatchExecUnion;
    }

    private Map<ExecNode<?, ?>, ShuffleStage> getNodeShuffleStageMap() {
        return this.nodeShuffleStageMap;
    }
}
