/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.resource.batch.parallelism;

import java.util.Comparator;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.calcite.rel.RelDistribution;
import org.apache.flink.table.plan.nodes.exec.ExecNode;
import org.apache.flink.table.plan.nodes.physical.batch.BatchExecExchange;
import org.apache.flink.table.plan.nodes.physical.batch.BatchExecSink;
import org.apache.flink.table.plan.nodes.physical.batch.BatchExecUnion;
import org.apache.flink.table.resource.batch.parallelism.ShuffleStage;

public class ShuffleStageGenerator {
    private final Map<ExecNode<?, ?>, ShuffleStage> nodeShuffleStageMap = new LinkedHashMap();
    private final Map<ExecNode<?, ?>, Integer> nodeToFinalParallelismMap;

    private ShuffleStageGenerator(Map<ExecNode<?, ?>, Integer> nodeToFinalParallelismMap) {
        this.nodeToFinalParallelismMap = nodeToFinalParallelismMap;
    }

    public static Map<ExecNode<?, ?>, ShuffleStage> generate(List<ExecNode<?, ?>> sinkNodes, Map<ExecNode<?, ?>, Integer> finalParallelismNodeMap) {
        ShuffleStageGenerator generator = new ShuffleStageGenerator(finalParallelismNodeMap);
        sinkNodes.forEach(generator::buildShuffleStages);
        Map<ExecNode<?, ?>, ShuffleStage> result = generator.getNodeShuffleStageMap();
        result.values().forEach(s -> {
            List<ExecNode> virtualNodeList = s.getExecNodeSet().stream().filter(ShuffleStageGenerator::isVirtualNode).collect(Collectors.toList());
            virtualNodeList.forEach(s::removeNode);
        });
        return generator.getNodeShuffleStageMap().entrySet().stream().filter(x -> !ShuffleStageGenerator.isVirtualNode((ExecNode)x.getKey())).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (e1, e2) -> e1, LinkedHashMap::new));
    }

    private void buildShuffleStages(ExecNode<?, ?> execNode) {
        if (this.nodeShuffleStageMap.containsKey(execNode)) {
            return;
        }
        for (ExecNode<?, ?> input : execNode.getInputNodes()) {
            this.buildShuffleStages(input);
        }
        if (execNode.getInputNodes().isEmpty()) {
            ShuffleStage shuffleStage = new ShuffleStage();
            shuffleStage.addNode(execNode);
            if (this.nodeToFinalParallelismMap.containsKey(execNode)) {
                shuffleStage.setParallelism(this.nodeToFinalParallelismMap.get(execNode), true);
            }
            this.nodeShuffleStageMap.put(execNode, shuffleStage);
        } else if (!(execNode instanceof BatchExecExchange && ((BatchExecExchange)execNode).getDistribution().getType() != RelDistribution.Type.RANGE_DISTRIBUTED || execNode instanceof BatchExecSink)) {
            Set<ShuffleStage> inputShuffleStages = this.getInputShuffleStages(execNode);
            Integer parallelism = this.nodeToFinalParallelismMap.get(execNode);
            ShuffleStage inputShuffleStage = this.mergeInputShuffleStages(inputShuffleStages, parallelism);
            inputShuffleStage.addNode(execNode);
            this.nodeShuffleStageMap.put(execNode, inputShuffleStage);
        }
    }

    private ShuffleStage mergeInputShuffleStages(Set<ShuffleStage> shuffleStageSet, Integer parallelism) {
        if (parallelism != null) {
            ShuffleStage resultShuffleStage = new ShuffleStage();
            resultShuffleStage.setParallelism(parallelism, true);
            for (ShuffleStage shuffleStage : shuffleStageSet) {
                if (shuffleStage.isFinalParallelism() && shuffleStage.getParallelism() != parallelism.intValue()) continue;
                this.mergeShuffleStage(resultShuffleStage, shuffleStage);
            }
            return resultShuffleStage;
        }
        ShuffleStage resultShuffleStage = shuffleStageSet.stream().filter(ShuffleStage::isFinalParallelism).max(Comparator.comparing(ShuffleStage::getParallelism)).orElse(new ShuffleStage());
        for (ShuffleStage shuffleStage : shuffleStageSet) {
            if (shuffleStage.isFinalParallelism() && shuffleStage.getParallelism() != resultShuffleStage.getParallelism()) continue;
            this.mergeShuffleStage(resultShuffleStage, shuffleStage);
        }
        return resultShuffleStage;
    }

    private void mergeShuffleStage(ShuffleStage shuffleStage, ShuffleStage other) {
        Set<ExecNode<?, ?>> nodeSet = other.getExecNodeSet();
        shuffleStage.addNodeSet(nodeSet);
        for (ExecNode<?, ?> r : nodeSet) {
            this.nodeShuffleStageMap.put(r, shuffleStage);
        }
    }

    private Set<ShuffleStage> getInputShuffleStages(ExecNode<?, ?> node) {
        HashSet<ShuffleStage> shuffleStageList = new HashSet<ShuffleStage>();
        for (ExecNode<?, ?> input : node.getInputNodes()) {
            ShuffleStage oneInputShuffleStage = this.nodeShuffleStageMap.get(input);
            if (oneInputShuffleStage == null) continue;
            shuffleStageList.add(oneInputShuffleStage);
        }
        return shuffleStageList;
    }

    private static boolean isVirtualNode(ExecNode<?, ?> node) {
        return node instanceof BatchExecUnion;
    }

    private Map<ExecNode<?, ?>, ShuffleStage> getNodeShuffleStageMap() {
        return this.nodeShuffleStageMap;
    }
}

