/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.resource.batch.schedule;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.runtime.event.ExecutionVertexFailoverEvent;
import org.apache.flink.runtime.event.ExecutionVertexStateChangedEvent;
import org.apache.flink.runtime.event.ResultPartitionConsumableEvent;
import org.apache.flink.runtime.execution.ExecutionState;
import org.apache.flink.runtime.jobgraph.ExecutionVertexID;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.jobgraph.JobVertex;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.schedule.GraphManagerPlugin;
import org.apache.flink.runtime.schedule.SchedulingConfig;
import org.apache.flink.runtime.schedule.VertexScheduler;
import org.apache.flink.table.resource.batch.BatchExecNodeStage;
import org.apache.flink.table.resource.batch.NodeRunningUnit;
import org.apache.flink.table.resource.batch.schedule.LogicalJobVertex;
import org.apache.flink.table.resource.batch.schedule.LogicalJobVertexRunningUnit;
import org.apache.flink.util.InstantiationUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class RunningUnitGraphManagerPlugin
implements GraphManagerPlugin {
    private static final Logger LOG = LoggerFactory.getLogger(RunningUnitGraphManagerPlugin.class);
    public static final String RUNNING_UNIT_CONF_KEY = "runningUnit.key";
    private Map<JobVertexID, LogicalJobVertex> logicalJobVertices = new LinkedHashMap<JobVertexID, LogicalJobVertex>();
    private Map<NodeRunningUnit, LogicalJobVertexRunningUnit> runningUnitMap = new LinkedHashMap<NodeRunningUnit, LogicalJobVertexRunningUnit>();
    private LogicalJobVertexRunningUnit leftJobVertexRunningUnit;
    private LinkedList<LogicalJobVertexRunningUnit> scheduleQueue = new LinkedList();
    private LogicalJobVertexRunningUnit currentScheduledUnit = null;
    private Set<LogicalJobVertexRunningUnit> scheduledRunningUnitSet = Collections.newSetFromMap(new IdentityHashMap());
    private Map<JobVertexID, Set<LogicalJobVertexRunningUnit>> inputRunningUnitMap = new LinkedHashMap<JobVertexID, Set<LogicalJobVertexRunningUnit>>();
    private Map<LogicalJobVertexRunningUnit, Set<LogicalJobVertexRunningUnit>> joinDependRunningUnitMap = new LinkedHashMap<LogicalJobVertexRunningUnit, Set<LogicalJobVertexRunningUnit>>();
    private VertexScheduler scheduler;
    private JobGraph jobGraph;

    public void open(VertexScheduler scheduler, JobGraph jobGraph, SchedulingConfig schedulingConfig) {
        try {
            Map vertexToStreamNodeIds = (Map)InstantiationUtil.readObjectFromConfig(schedulingConfig.getConfiguration(), "jobVertexToStreamNodeMap", schedulingConfig.getUserClassLoader());
            List nodeRunningUnits = (List)InstantiationUtil.readObjectFromConfig(schedulingConfig.getConfiguration(), RUNNING_UNIT_CONF_KEY, schedulingConfig.getUserClassLoader());
            this.open(scheduler, jobGraph, vertexToStreamNodeIds, nodeRunningUnits);
        }
        catch (IOException | ClassNotFoundException e2) {
            throw new RuntimeException(e2);
        }
    }

    @VisibleForTesting
    public void open(VertexScheduler scheduler, JobGraph jobGraph, Map<JobVertexID, ArrayList<Integer>> vertexToStreamNodeIds, List<NodeRunningUnit> nodeRunningUnits) {
        this.scheduler = scheduler;
        this.jobGraph = jobGraph;
        Map<Integer, JobVertexID> streamNodeIdToVertex = this.reverseMap(vertexToStreamNodeIds);
        this.buildJobVertices(Arrays.asList(jobGraph.getVerticesAsArray()));
        this.buildJobRunningUnits(nodeRunningUnits, streamNodeIdToVertex);
    }

    private void buildJobRunningUnits(List<NodeRunningUnit> nodeRunningUnits, Map<Integer, JobVertexID> streamNodeIdToVertex) {
        this.avoidDeadLockDepend(nodeRunningUnits);
        for (NodeRunningUnit nodeRunningUnit : nodeRunningUnits) {
            LogicalJobVertexRunningUnit jobVertexRunningUnit = this.transformRunningUnit(nodeRunningUnit, streamNodeIdToVertex);
            this.joinDependRunningUnitMap.computeIfAbsent(jobVertexRunningUnit, k -> new LinkedHashSet());
            for (BatchExecNodeStage stage : nodeRunningUnit.getAllNodeStages()) {
                for (BatchExecNodeStage dependStage : stage.getDependStageList(BatchExecNodeStage.DependType.DATA_TRIGGER)) {
                    for (int id : dependStage.getTransformationIDList()) {
                        JobVertexID inputJobVertexID = streamNodeIdToVertex.get(id);
                        if (inputJobVertexID == null) continue;
                        this.inputRunningUnitMap.computeIfAbsent(inputJobVertexID, k -> new LinkedHashSet()).add(jobVertexRunningUnit);
                        jobVertexRunningUnit.addInputDepend(inputJobVertexID);
                    }
                }
            }
            this.runningUnitMap.put(nodeRunningUnit, jobVertexRunningUnit);
        }
        for (NodeRunningUnit unit : nodeRunningUnits) {
            for (BatchExecNodeStage stage : unit.getAllNodeStages()) {
                List<BatchExecNodeStage> joinDependStageList = stage.getDependStageList(BatchExecNodeStage.DependType.PRIORITY);
                if (joinDependStageList.isEmpty()) continue;
                LogicalJobVertexRunningUnit probeRunningUnit = this.runningUnitMap.get(unit);
                for (BatchExecNodeStage dependStage : joinDependStageList) {
                    for (NodeRunningUnit dependUnit : dependStage.getRunningUnitList()) {
                        LogicalJobVertexRunningUnit buildRunningUnit = this.runningUnitMap.get(dependUnit);
                        this.joinDependRunningUnitMap.get(buildRunningUnit).add(probeRunningUnit);
                        probeRunningUnit.addJoinDepend(buildRunningUnit);
                    }
                }
            }
        }
        Set jobVertexSet = this.logicalJobVertices.values().stream().filter(j2 -> j2.getRunningUnitSet().isEmpty()).collect(Collectors.toCollection(LinkedHashSet::new));
        this.leftJobVertexRunningUnit = new LogicalJobVertexRunningUnit(jobVertexSet);
    }

    private LogicalJobVertexRunningUnit transformRunningUnit(NodeRunningUnit nodeRunningUnit, Map<Integer, JobVertexID> streamNodeIdToVertex) {
        List<BatchExecNodeStage> allStages = nodeRunningUnit.getAllNodeStages();
        LinkedHashSet<LogicalJobVertex> jobVertexSet = new LinkedHashSet<LogicalJobVertex>();
        for (BatchExecNodeStage stage : allStages) {
            for (Integer transformationID : stage.getTransformationIDList()) {
                jobVertexSet.add(this.logicalJobVertices.get(streamNodeIdToVertex.get(transformationID)));
            }
        }
        LogicalJobVertexRunningUnit jobVertexRunningUnit = new LogicalJobVertexRunningUnit(jobVertexSet);
        jobVertexSet.forEach(j2 -> j2.addRunningUnit(jobVertexRunningUnit));
        return jobVertexRunningUnit;
    }

    private void buildJobVertices(List<JobVertex> jobVertices) {
        jobVertices.forEach(j2 -> this.logicalJobVertices.put(j2.getID(), new LogicalJobVertex((JobVertex)j2)));
        jobVertices.forEach(j2 -> {
            Set cfr_ignored_0 = this.inputRunningUnitMap.put(j2.getID(), new LinkedHashSet());
        });
    }

    private Map<Integer, JobVertexID> reverseMap(Map<JobVertexID, ArrayList<Integer>> vertexToStreamNodeIds) {
        LinkedHashMap<Integer, JobVertexID> streamNodeIdToVertex = new LinkedHashMap<Integer, JobVertexID>();
        for (Map.Entry<JobVertexID, ArrayList<Integer>> entry : vertexToStreamNodeIds.entrySet()) {
            for (Integer transId : entry.getValue()) {
                streamNodeIdToVertex.put(transId, entry.getKey());
            }
        }
        return streamNodeIdToVertex;
    }

    public void close() {
    }

    public void reset() {
    }

    public void onSchedulingStarted() {
        this.runningUnitMap.values().stream().filter(LogicalJobVertexRunningUnit::allDependReady).forEach(this::addToScheduleQueue);
        this.checkScheduleNewRunningUnit();
    }

    public void onResultPartitionConsumable(ResultPartitionConsumableEvent event) {
        JobVertexID producerJobVertexID = this.jobGraph.getResultProducerID(event.getResultID());
        this.produceResultPartition(producerJobVertexID);
        this.checkScheduleNewRunningUnit();
    }

    private void produceResultPartition(JobVertexID producerJobVertexID) {
        Set<LogicalJobVertexRunningUnit> runningUnitList = this.inputRunningUnitMap.get(producerJobVertexID);
        for (LogicalJobVertexRunningUnit jobVertexRunningUnit : runningUnitList) {
            jobVertexRunningUnit.receiveInput(producerJobVertexID);
            if (!jobVertexRunningUnit.allDependReady()) continue;
            this.addToScheduleQueue(jobVertexRunningUnit);
        }
    }

    public void onExecutionVertexFailover(ExecutionVertexFailoverEvent event) {
        this.scheduleQueue.clear();
        this.currentScheduledUnit = null;
        this.scheduledRunningUnitSet.clear();
        event.getAffectedExecutionVertexIDs().forEach(t -> this.logicalJobVertices.get(t.getJobVertexID()).failoverTask());
        this.runningUnitMap.values().forEach(LogicalJobVertexRunningUnit::reset);
        this.logicalJobVertices.values().stream().filter(LogicalJobVertex::allTasksDeploying).forEach(j2 -> {
            this.allTaskDeploying((LogicalJobVertex)j2);
            for (int i = 0; i < j2.getParallelism(); ++i) {
                this.produceResultPartition(j2.getJobVertexID());
            }
        });
        this.onSchedulingStarted();
    }

    private synchronized void checkScheduleNewRunningUnit() {
        if (this.currentScheduledUnit == null || this.currentScheduledUnit.allTasksDeploying()) {
            LogicalJobVertexRunningUnit jobVertexRunningUnit;
            while ((jobVertexRunningUnit = this.scheduleQueue.pollFirst()) != null && jobVertexRunningUnit.allTasksDeploying()) {
                this.runningUnitAllDeploying(jobVertexRunningUnit);
            }
            if (jobVertexRunningUnit != null) {
                this.scheduleRunningUnit(jobVertexRunningUnit);
            } else if (this.scheduledRunningUnitSet.size() >= this.runningUnitMap.size() && !this.leftJobVertexRunningUnit.allTasksDeploying()) {
                this.scheduleRunningUnit(this.leftJobVertexRunningUnit);
            }
        }
    }

    private synchronized void addToScheduleQueue(LogicalJobVertexRunningUnit jobVertexRunningUnit) {
        if (this.scheduledRunningUnitSet.add(jobVertexRunningUnit)) {
            this.scheduleQueue.add(jobVertexRunningUnit);
        }
    }

    private synchronized void scheduleRunningUnit(LogicalJobVertexRunningUnit jobVertexRunningUnit) {
        this.currentScheduledUnit = jobVertexRunningUnit;
        LOG.info("begin to schedule runningUnit: ");
        this.currentScheduledUnit.getJobVertexSet().forEach(x -> LOG.info(x.toString()));
        jobVertexRunningUnit.getToScheduleJobVertices().forEach(j2 -> {
            for (int i = 0; i < j2.getParallelism(); ++i) {
                this.scheduler.scheduleExecutionVertices(Collections.singletonList(new ExecutionVertexID(j2.getJobVertexID(), i)));
            }
        });
    }

    public synchronized void onExecutionVertexStateChanged(ExecutionVertexStateChangedEvent event) {
        if (event.getNewExecutionState() == ExecutionState.DEPLOYING) {
            LogicalJobVertex jobVertex = this.logicalJobVertices.get(event.getExecutionVertexID().getJobVertexID());
            jobVertex.deployingTask();
            if (jobVertex.allTasksDeploying()) {
                this.allTaskDeploying(jobVertex);
            }
            this.checkScheduleNewRunningUnit();
        }
    }

    private synchronized void allTaskDeploying(LogicalJobVertex jobVertex) {
        jobVertex.getRunningUnitSet().stream().filter(LogicalJobVertexRunningUnit::allTasksDeploying).forEach(this::runningUnitAllDeploying);
    }

    private synchronized void runningUnitAllDeploying(LogicalJobVertexRunningUnit jobVertexRunningUnit) {
        this.scheduledRunningUnitSet.add(jobVertexRunningUnit);
        for (LogicalJobVertexRunningUnit runningUnit : this.joinDependRunningUnitMap.get(jobVertexRunningUnit)) {
            runningUnit.joinDependDeploying(jobVertexRunningUnit);
            if (!runningUnit.allDependReady()) continue;
            this.addToScheduleQueue(runningUnit);
        }
    }

    private void avoidDeadLockDepend(List<NodeRunningUnit> nodeRunningUnits) {
        for (NodeRunningUnit unit : nodeRunningUnits) {
            for (BatchExecNodeStage stage : unit.getAllNodeStages()) {
                LinkedList<BatchExecNodeStage> toRemoveStages = new LinkedList<BatchExecNodeStage>();
                for (BatchExecNodeStage dependStage : stage.getAllDependStageList()) {
                    for (NodeRunningUnit dependUnit : dependStage.getRunningUnitList()) {
                        HashSet<NodeRunningUnit> visitedRunningUnit;
                        if (!this.loopDepend(dependUnit, unit, visitedRunningUnit = new HashSet<NodeRunningUnit>())) continue;
                        toRemoveStages.add(dependStage);
                    }
                }
                for (BatchExecNodeStage toRemove : toRemoveStages) {
                    stage.removeDependStage(toRemove);
                }
            }
        }
    }

    private boolean loopDepend(NodeRunningUnit dependUnit, NodeRunningUnit preUnit, Set<NodeRunningUnit> visitedRunningUnit) {
        if (dependUnit == preUnit) {
            return true;
        }
        if (visitedRunningUnit.contains(dependUnit)) {
            return false;
        }
        visitedRunningUnit.add(dependUnit);
        for (BatchExecNodeStage containStage : dependUnit.getAllNodeStages()) {
            for (BatchExecNodeStage dependStage : containStage.getAllDependStageList()) {
                for (NodeRunningUnit unit : dependStage.getRunningUnitList()) {
                    if (!this.loopDepend(unit, preUnit, visitedRunningUnit)) continue;
                    return true;
                }
            }
        }
        return false;
    }
}

