/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.preaggregatedaccumulators;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.api.common.accumulators.Accumulator;
import org.apache.flink.runtime.executiongraph.ExecutionGraph;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.preaggregatedaccumulators.CommitAccumulator;
import org.apache.flink.util.Preconditions;

public class AccumulatorAggregationCoordinator {
    private final Map<String, GlobalAggregatedAccumulator> aggregatedAccumulators = new HashMap<String, GlobalAggregatedAccumulator>();
    private final Map<String, List<CompletableFuture<Accumulator>>> unfulfilledQueryFutures = new HashMap<String, List<CompletableFuture<Accumulator>>>();

    public void commitPreAggregatedAccumulator(ExecutionGraph executionGraph, CommitAccumulator commitAccumulator) {
        ExecutionJobVertex jobVertex = executionGraph.getJobVertex(commitAccumulator.getJobVertexId());
        if (jobVertex == null) {
            throw new IllegalArgumentException("Commit contains an invalid job vertex id: " + (Object)((Object)commitAccumulator.getJobVertexId()));
        }
        GlobalAggregatedAccumulator globalAggregatedAccumulator = this.aggregatedAccumulators.computeIfAbsent(commitAccumulator.getName(), k -> new GlobalAggregatedAccumulator(commitAccumulator.getJobVertexId()));
        if (globalAggregatedAccumulator.isAllCommitted()) {
            return;
        }
        globalAggregatedAccumulator.commitForTasks(commitAccumulator.getJobVertexId(), commitAccumulator.getAccumulator(), commitAccumulator.getCommittedTasks());
        Preconditions.checkState((globalAggregatedAccumulator.getCommittedTasks().size() <= jobVertex.getParallelism() ? 1 : 0) != 0, (Object)"More tasks committed than the total number of tasks.");
        if (globalAggregatedAccumulator.getCommittedTasks().size() == jobVertex.getParallelism()) {
            globalAggregatedAccumulator.markAllCommitted();
            List<CompletableFuture<Accumulator>> queryFutures = this.unfulfilledQueryFutures.get(commitAccumulator.getName());
            if (queryFutures != null) {
                queryFutures.forEach(query -> query.complete(globalAggregatedAccumulator.getAggregatedValue()));
                this.unfulfilledQueryFutures.remove(commitAccumulator.getName());
            }
        }
    }

    public <V, A extends Serializable> CompletableFuture<Accumulator<V, A>> queryPreAggregatedAccumulator(String name) {
        CompletableFuture<Accumulator<V, A>> queryFuture = new CompletableFuture<Accumulator<V, A>>();
        GlobalAggregatedAccumulator globalAggregatedAccumulator = this.aggregatedAccumulators.get(name);
        if (globalAggregatedAccumulator == null || !globalAggregatedAccumulator.isAllCommitted()) {
            this.unfulfilledQueryFutures.computeIfAbsent(name, k -> new ArrayList()).add(queryFuture);
        } else {
            queryFuture.complete(globalAggregatedAccumulator.getAggregatedValue());
        }
        return queryFuture;
    }

    public void clear() {
        this.aggregatedAccumulators.clear();
        this.unfulfilledQueryFutures.clear();
    }

    @VisibleForTesting
    public Map<String, GlobalAggregatedAccumulator> getAggregatedAccumulators() {
        return this.aggregatedAccumulators;
    }

    public Map<String, List<CompletableFuture<Accumulator>>> getUnfulfilledQueryFutures() {
        return this.unfulfilledQueryFutures;
    }

    static class GlobalAggregatedAccumulator {
        private final JobVertexID jobVertexId;
        private final Set<Integer> committedTasks = new HashSet<Integer>();
        private Accumulator aggregatedValue;
        private boolean allCommitted;

        GlobalAggregatedAccumulator(JobVertexID jobVertexId) {
            this.jobVertexId = jobVertexId;
        }

        void commitForTasks(JobVertexID jobVertexId, Accumulator value, Set<Integer> committedTasks) {
            Preconditions.checkArgument((boolean)this.jobVertexId.equals((Object)jobVertexId), (Object)"The registered task belongs to different JobVertex with previous registered ones");
            if (this.aggregatedValue == null) {
                this.aggregatedValue = value.clone();
            } else {
                this.aggregatedValue.merge(value);
            }
            this.committedTasks.addAll(committedTasks);
        }

        boolean isAllCommitted() {
            return this.allCommitted;
        }

        void markAllCommitted() {
            this.allCommitted = true;
        }

        Accumulator getAggregatedValue() {
            return this.aggregatedValue;
        }

        Set<Integer> getCommittedTasks() {
            return this.committedTasks;
        }
    }
}

