package org.apache.flink.runtime.preaggregatedaccumulators;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.api.common.accumulators.Accumulator;
import org.apache.flink.runtime.executiongraph.ExecutionGraph;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.util.Preconditions;

/* loaded from: input_file:org/apache/flink/runtime/preaggregatedaccumulators/AccumulatorAggregationCoordinator.class */
public class AccumulatorAggregationCoordinator {
    private final Map<String, GlobalAggregatedAccumulator> aggregatedAccumulators = new HashMap();
    private final Map<String, List<CompletableFuture<Accumulator>>> unfulfilledQueryFutures = new HashMap();

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/flink/runtime/preaggregatedaccumulators/AccumulatorAggregationCoordinator$GlobalAggregatedAccumulator.class */
    public static class GlobalAggregatedAccumulator {
        private final JobVertexID jobVertexId;
        private final Set<Integer> committedTasks = new HashSet();
        private Accumulator aggregatedValue;
        private boolean allCommitted;

        GlobalAggregatedAccumulator(JobVertexID jobVertexID) {
            this.jobVertexId = jobVertexID;
        }

        void commitForTasks(JobVertexID jobVertexID, Accumulator accumulator, Set<Integer> set) {
            Preconditions.checkArgument(this.jobVertexId.equals(jobVertexID), "The registered task belongs to different JobVertex with previous registered ones");
            if (this.aggregatedValue == null) {
                this.aggregatedValue = accumulator.clone();
            } else {
                this.aggregatedValue.merge(accumulator);
            }
            this.committedTasks.addAll(set);
        }

        boolean isAllCommitted() {
            return this.allCommitted;
        }

        void markAllCommitted() {
            this.allCommitted = true;
        }

        Accumulator getAggregatedValue() {
            return this.aggregatedValue;
        }

        Set<Integer> getCommittedTasks() {
            return this.committedTasks;
        }
    }

    public void commitPreAggregatedAccumulator(ExecutionGraph executionGraph, CommitAccumulator commitAccumulator) {
        ExecutionJobVertex jobVertex = executionGraph.getJobVertex(commitAccumulator.getJobVertexId());
        if (jobVertex == null) {
            throw new IllegalArgumentException("Commit contains an invalid job vertex id: " + commitAccumulator.getJobVertexId());
        }
        GlobalAggregatedAccumulator computeIfAbsent = this.aggregatedAccumulators.computeIfAbsent(commitAccumulator.getName(), str -> {
            return new GlobalAggregatedAccumulator(commitAccumulator.getJobVertexId());
        });
        if (computeIfAbsent.isAllCommitted()) {
            return;
        }
        computeIfAbsent.commitForTasks(commitAccumulator.getJobVertexId(), commitAccumulator.getAccumulator(), commitAccumulator.getCommittedTasks());
        Preconditions.checkState(computeIfAbsent.getCommittedTasks().size() <= jobVertex.getParallelism(), "More tasks committed than the total number of tasks.");
        if (computeIfAbsent.getCommittedTasks().size() == jobVertex.getParallelism()) {
            computeIfAbsent.markAllCommitted();
            List<CompletableFuture<Accumulator>> list = this.unfulfilledQueryFutures.get(commitAccumulator.getName());
            if (list != null) {
                list.forEach(completableFuture -> {
                    completableFuture.complete(computeIfAbsent.getAggregatedValue());
                });
                this.unfulfilledQueryFutures.remove(commitAccumulator.getName());
            }
        }
    }

    public <V, A extends Serializable> CompletableFuture<Accumulator<V, A>> queryPreAggregatedAccumulator(String str) {
        CompletableFuture<Accumulator<V, A>> completableFuture = new CompletableFuture<>();
        GlobalAggregatedAccumulator globalAggregatedAccumulator = this.aggregatedAccumulators.get(str);
        if (globalAggregatedAccumulator == null || !globalAggregatedAccumulator.isAllCommitted()) {
            this.unfulfilledQueryFutures.computeIfAbsent(str, str2 -> {
                return new ArrayList();
            }).add(completableFuture);
        } else {
            completableFuture.complete(globalAggregatedAccumulator.getAggregatedValue());
        }
        return completableFuture;
    }

    public void clear() {
        this.aggregatedAccumulators.clear();
        this.unfulfilledQueryFutures.clear();
    }

    @VisibleForTesting
    public Map<String, GlobalAggregatedAccumulator> getAggregatedAccumulators() {
        return this.aggregatedAccumulators;
    }

    public Map<String, List<CompletableFuture<Accumulator>>> getUnfulfilledQueryFutures() {
        return this.unfulfilledQueryFutures;
    }
}
