package org.apache.flink.table.resource.batch.schedule;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.runtime.event.ExecutionVertexFailoverEvent;
import org.apache.flink.runtime.event.ExecutionVertexStateChangedEvent;
import org.apache.flink.runtime.event.ResultPartitionConsumableEvent;
import org.apache.flink.runtime.execution.ExecutionState;
import org.apache.flink.runtime.executiongraph.ExecutionGraph;
import org.apache.flink.runtime.jobgraph.ExecutionVertexID;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.jobgraph.JobVertex;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.jobmaster.ExecutionSlotAllocator;
import org.apache.flink.runtime.jobmaster.GraphManager;
import org.apache.flink.runtime.schedule.GraphManagerPlugin;
import org.apache.flink.runtime.schedule.SchedulingConfig;
import org.apache.flink.runtime.schedule.VertexScheduler;
import org.apache.flink.table.resource.batch.BatchExecNodeStage;
import org.apache.flink.table.resource.batch.NodeRunningUnit;
import org.apache.flink.util.InstantiationUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/flink/table/resource/batch/schedule/RunningUnitGraphManagerPlugin.class */
public class RunningUnitGraphManagerPlugin implements GraphManagerPlugin {
    private static final Logger LOG = LoggerFactory.getLogger(RunningUnitGraphManagerPlugin.class);
    public static final String RUNNING_UNIT_CONF_KEY = "runningUnit.key";
    private LogicalJobVertexRunningUnit leftJobVertexRunningUnit;
    private VertexScheduler scheduler;
    private JobGraph jobGraph;
    private Map<JobVertexID, LogicalJobVertex> logicalJobVertices = new LinkedHashMap();
    private Map<NodeRunningUnit, LogicalJobVertexRunningUnit> runningUnitMap = new LinkedHashMap();
    private LinkedList<LogicalJobVertexRunningUnit> scheduleQueue = new LinkedList<>();
    private LogicalJobVertexRunningUnit currentScheduledUnit = null;
    private Set<LogicalJobVertexRunningUnit> scheduledRunningUnitSet = Collections.newSetFromMap(new IdentityHashMap());
    private Map<JobVertexID, Set<LogicalJobVertexRunningUnit>> inputRunningUnitMap = new LinkedHashMap();
    private Map<LogicalJobVertexRunningUnit, Set<LogicalJobVertexRunningUnit>> joinDependRunningUnitMap = new LinkedHashMap();

    public void open(VertexScheduler vertexScheduler, JobGraph jobGraph, SchedulingConfig schedulingConfig, ExecutionGraph executionGraph, GraphManager graphManager, ExecutionSlotAllocator executionSlotAllocator) {
        try {
            open(vertexScheduler, jobGraph, (Map) InstantiationUtil.readObjectFromConfig(schedulingConfig.getConfiguration(), "jobVertexToStreamNodeMap", schedulingConfig.getUserClassLoader()), (List) InstantiationUtil.readObjectFromConfig(schedulingConfig.getConfiguration(), RUNNING_UNIT_CONF_KEY, schedulingConfig.getUserClassLoader()));
        } catch (IOException | ClassNotFoundException e) {
            throw new RuntimeException(e);
        }
    }

    @VisibleForTesting
    public void open(VertexScheduler vertexScheduler, JobGraph jobGraph, Map<JobVertexID, ArrayList<Integer>> map, List<NodeRunningUnit> list) {
        this.scheduler = vertexScheduler;
        this.jobGraph = jobGraph;
        Map<Integer, JobVertexID> reverseMap = reverseMap(map);
        buildJobVertices(Arrays.asList(jobGraph.getVerticesAsArray()));
        buildJobRunningUnits(list, reverseMap);
    }

    private void buildJobRunningUnits(List<NodeRunningUnit> list, Map<Integer, JobVertexID> map) {
        avoidDeadLockDepend(list);
        for (NodeRunningUnit nodeRunningUnit : list) {
            LogicalJobVertexRunningUnit transformRunningUnit = transformRunningUnit(nodeRunningUnit, map);
            this.joinDependRunningUnitMap.computeIfAbsent(transformRunningUnit, logicalJobVertexRunningUnit -> {
                return new LinkedHashSet();
            });
            Iterator<BatchExecNodeStage> it = nodeRunningUnit.getAllNodeStages().iterator();
            while (it.hasNext()) {
                Iterator<BatchExecNodeStage> it2 = it.next().getDependStageList(BatchExecNodeStage.DependType.DATA_TRIGGER).iterator();
                while (it2.hasNext()) {
                    Iterator<Integer> it3 = it2.next().getTransformationIDList().iterator();
                    while (it3.hasNext()) {
                        JobVertexID jobVertexID = map.get(Integer.valueOf(it3.next().intValue()));
                        if (jobVertexID != null) {
                            this.inputRunningUnitMap.computeIfAbsent(jobVertexID, jobVertexID2 -> {
                                return new LinkedHashSet();
                            }).add(transformRunningUnit);
                            transformRunningUnit.addInputDepend(jobVertexID);
                        }
                    }
                }
            }
            this.runningUnitMap.put(nodeRunningUnit, transformRunningUnit);
        }
        for (NodeRunningUnit nodeRunningUnit2 : list) {
            Iterator<BatchExecNodeStage> it4 = nodeRunningUnit2.getAllNodeStages().iterator();
            while (it4.hasNext()) {
                List<BatchExecNodeStage> dependStageList = it4.next().getDependStageList(BatchExecNodeStage.DependType.PRIORITY);
                if (!dependStageList.isEmpty()) {
                    LogicalJobVertexRunningUnit logicalJobVertexRunningUnit2 = this.runningUnitMap.get(nodeRunningUnit2);
                    Iterator<BatchExecNodeStage> it5 = dependStageList.iterator();
                    while (it5.hasNext()) {
                        Iterator<NodeRunningUnit> it6 = it5.next().getRunningUnitList().iterator();
                        while (it6.hasNext()) {
                            LogicalJobVertexRunningUnit logicalJobVertexRunningUnit3 = this.runningUnitMap.get(it6.next());
                            this.joinDependRunningUnitMap.get(logicalJobVertexRunningUnit3).add(logicalJobVertexRunningUnit2);
                            logicalJobVertexRunningUnit2.addJoinDepend(logicalJobVertexRunningUnit3);
                        }
                    }
                }
            }
        }
        this.leftJobVertexRunningUnit = new LogicalJobVertexRunningUnit((Set) this.logicalJobVertices.values().stream().filter(logicalJobVertex -> {
            return logicalJobVertex.getRunningUnitSet().isEmpty();
        }).collect(Collectors.toCollection(LinkedHashSet::new)));
    }

    private LogicalJobVertexRunningUnit transformRunningUnit(NodeRunningUnit nodeRunningUnit, Map<Integer, JobVertexID> map) {
        List<BatchExecNodeStage> allNodeStages = nodeRunningUnit.getAllNodeStages();
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        Iterator<BatchExecNodeStage> it = allNodeStages.iterator();
        while (it.hasNext()) {
            Iterator<Integer> it2 = it.next().getTransformationIDList().iterator();
            while (it2.hasNext()) {
                linkedHashSet.add(this.logicalJobVertices.get(map.get(it2.next())));
            }
        }
        LogicalJobVertexRunningUnit logicalJobVertexRunningUnit = new LogicalJobVertexRunningUnit(linkedHashSet);
        linkedHashSet.forEach(logicalJobVertex -> {
            logicalJobVertex.addRunningUnit(logicalJobVertexRunningUnit);
        });
        return logicalJobVertexRunningUnit;
    }

    private void buildJobVertices(List<JobVertex> list) {
        list.forEach(jobVertex -> {
            this.logicalJobVertices.put(jobVertex.getID(), new LogicalJobVertex(jobVertex));
        });
        list.forEach(jobVertex2 -> {
            this.inputRunningUnitMap.put(jobVertex2.getID(), new LinkedHashSet());
        });
    }

    private Map<Integer, JobVertexID> reverseMap(Map<JobVertexID, ArrayList<Integer>> map) {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (Map.Entry<JobVertexID, ArrayList<Integer>> entry : map.entrySet()) {
            Iterator<Integer> it = entry.getValue().iterator();
            while (it.hasNext()) {
                linkedHashMap.put(it.next(), entry.getKey());
            }
        }
        return linkedHashMap;
    }

    public void close() {
    }

    public void reset() {
    }

    public void onSchedulingStarted() {
        this.runningUnitMap.values().stream().filter((v0) -> {
            return v0.allDependReady();
        }).forEach(this::addToScheduleQueue);
        checkScheduleNewRunningUnit();
    }

    public void onResultPartitionConsumable(ResultPartitionConsumableEvent resultPartitionConsumableEvent) {
        produceResultPartition(this.jobGraph.getResultProducerID(resultPartitionConsumableEvent.getResultID()));
        checkScheduleNewRunningUnit();
    }

    private void produceResultPartition(JobVertexID jobVertexID) {
        for (LogicalJobVertexRunningUnit logicalJobVertexRunningUnit : this.inputRunningUnitMap.get(jobVertexID)) {
            logicalJobVertexRunningUnit.receiveInput(jobVertexID);
            if (logicalJobVertexRunningUnit.allDependReady()) {
                addToScheduleQueue(logicalJobVertexRunningUnit);
            }
        }
    }

    public void onExecutionVertexFailover(ExecutionVertexFailoverEvent executionVertexFailoverEvent) {
        this.scheduleQueue.clear();
        this.currentScheduledUnit = null;
        this.scheduledRunningUnitSet.clear();
        executionVertexFailoverEvent.getAffectedExecutionVertexIDs().forEach(executionVertexID -> {
            this.logicalJobVertices.get(executionVertexID.getJobVertexID()).failoverTask();
        });
        this.runningUnitMap.values().forEach((v0) -> {
            v0.reset();
        });
        this.logicalJobVertices.values().stream().filter((v0) -> {
            return v0.allTasksDeploying();
        }).forEach(logicalJobVertex -> {
            allTaskDeploying(logicalJobVertex);
            for (int i = 0; i < logicalJobVertex.getParallelism(); i++) {
                produceResultPartition(logicalJobVertex.getJobVertexID());
            }
        });
        onSchedulingStarted();
    }

    private synchronized void checkScheduleNewRunningUnit() {
        LogicalJobVertexRunningUnit pollFirst;
        if (this.currentScheduledUnit == null || this.currentScheduledUnit.allTasksDeploying()) {
            while (true) {
                pollFirst = this.scheduleQueue.pollFirst();
                if (pollFirst == null || !pollFirst.allTasksDeploying()) {
                    break;
                } else {
                    runningUnitAllDeploying(pollFirst);
                }
            }
            if (pollFirst != null) {
                scheduleRunningUnit(pollFirst);
            } else {
                if (this.scheduledRunningUnitSet.size() < this.runningUnitMap.size() || this.leftJobVertexRunningUnit.allTasksDeploying()) {
                    return;
                }
                scheduleRunningUnit(this.leftJobVertexRunningUnit);
            }
        }
    }

    private synchronized void addToScheduleQueue(LogicalJobVertexRunningUnit logicalJobVertexRunningUnit) {
        if (this.scheduledRunningUnitSet.add(logicalJobVertexRunningUnit)) {
            this.scheduleQueue.add(logicalJobVertexRunningUnit);
        }
    }

    private synchronized void scheduleRunningUnit(LogicalJobVertexRunningUnit logicalJobVertexRunningUnit) {
        this.currentScheduledUnit = logicalJobVertexRunningUnit;
        LOG.info("begin to schedule runningUnit: ");
        this.currentScheduledUnit.getJobVertexSet().forEach(logicalJobVertex -> {
            LOG.info(logicalJobVertex.toString());
        });
        logicalJobVertexRunningUnit.getToScheduleJobVertices().forEach(logicalJobVertex2 -> {
            for (int i = 0; i < logicalJobVertex2.getParallelism(); i++) {
                this.scheduler.scheduleExecutionVertices(Collections.singletonList(new ExecutionVertexID(logicalJobVertex2.getJobVertexID(), i)));
            }
        });
    }

    public synchronized void onExecutionVertexStateChanged(ExecutionVertexStateChangedEvent executionVertexStateChangedEvent) {
        if (executionVertexStateChangedEvent.getNewExecutionState() == ExecutionState.DEPLOYING) {
            LogicalJobVertex logicalJobVertex = this.logicalJobVertices.get(executionVertexStateChangedEvent.getExecutionVertexID().getJobVertexID());
            logicalJobVertex.deployingTask();
            if (logicalJobVertex.allTasksDeploying()) {
                allTaskDeploying(logicalJobVertex);
            }
            checkScheduleNewRunningUnit();
        }
    }

    private synchronized void allTaskDeploying(LogicalJobVertex logicalJobVertex) {
        logicalJobVertex.getRunningUnitSet().stream().filter((v0) -> {
            return v0.allTasksDeploying();
        }).forEach(this::runningUnitAllDeploying);
    }

    private synchronized void runningUnitAllDeploying(LogicalJobVertexRunningUnit logicalJobVertexRunningUnit) {
        this.scheduledRunningUnitSet.add(logicalJobVertexRunningUnit);
        for (LogicalJobVertexRunningUnit logicalJobVertexRunningUnit2 : this.joinDependRunningUnitMap.get(logicalJobVertexRunningUnit)) {
            logicalJobVertexRunningUnit2.joinDependDeploying(logicalJobVertexRunningUnit);
            if (logicalJobVertexRunningUnit2.allDependReady()) {
                addToScheduleQueue(logicalJobVertexRunningUnit2);
            }
        }
    }

    private void avoidDeadLockDepend(List<NodeRunningUnit> list) {
        for (NodeRunningUnit nodeRunningUnit : list) {
            for (BatchExecNodeStage batchExecNodeStage : nodeRunningUnit.getAllNodeStages()) {
                LinkedList linkedList = new LinkedList();
                for (BatchExecNodeStage batchExecNodeStage2 : batchExecNodeStage.getAllDependStageList()) {
                    Iterator<NodeRunningUnit> it = batchExecNodeStage2.getRunningUnitList().iterator();
                    while (it.hasNext()) {
                        if (loopDepend(it.next(), nodeRunningUnit, new HashSet())) {
                            linkedList.add(batchExecNodeStage2);
                        }
                    }
                }
                Iterator it2 = linkedList.iterator();
                while (it2.hasNext()) {
                    batchExecNodeStage.removeDependStage((BatchExecNodeStage) it2.next());
                }
            }
        }
    }

    private boolean loopDepend(NodeRunningUnit nodeRunningUnit, NodeRunningUnit nodeRunningUnit2, Set<NodeRunningUnit> set) {
        if (nodeRunningUnit == nodeRunningUnit2) {
            return true;
        }
        if (set.contains(nodeRunningUnit)) {
            return false;
        }
        set.add(nodeRunningUnit);
        Iterator<BatchExecNodeStage> it = nodeRunningUnit.getAllNodeStages().iterator();
        while (it.hasNext()) {
            Iterator<BatchExecNodeStage> it2 = it.next().getAllDependStageList().iterator();
            while (it2.hasNext()) {
                Iterator<NodeRunningUnit> it3 = it2.next().getRunningUnitList().iterator();
                while (it3.hasNext()) {
                    if (loopDepend(it3.next(), nodeRunningUnit2, set)) {
                        return true;
                    }
                }
            }
        }
        return false;
    }
}
