/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.checkpoint;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.runtime.checkpoint.JobManagerTaskRestore;
import org.apache.flink.runtime.checkpoint.OperatorState;
import org.apache.flink.runtime.checkpoint.OperatorStateRepartitioner;
import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
import org.apache.flink.runtime.checkpoint.RoundRobinOperatorStateRepartitioner;
import org.apache.flink.runtime.checkpoint.StateObjectCollection;
import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
import org.apache.flink.runtime.executiongraph.Execution;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.executiongraph.ExecutionVertex;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.jobgraph.OperatorInstanceID;
import org.apache.flink.runtime.state.KeyGroupRange;
import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
import org.apache.flink.runtime.state.KeyedStateHandle;
import org.apache.flink.runtime.state.OperatorStateHandle;
import org.apache.flink.util.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class StateAssignmentOperation {
    private static final Logger LOG = LoggerFactory.getLogger(StateAssignmentOperation.class);
    private final Map<JobVertexID, ExecutionJobVertex> tasks;
    private final Map<OperatorID, OperatorState> operatorStates;
    private final long restoreCheckpointId;
    private final boolean allowNonRestoredState;

    public StateAssignmentOperation(long restoreCheckpointId, Map<JobVertexID, ExecutionJobVertex> tasks, Map<OperatorID, OperatorState> operatorStates, boolean allowNonRestoredState) {
        this.restoreCheckpointId = restoreCheckpointId;
        this.tasks = (Map)Preconditions.checkNotNull(tasks);
        this.operatorStates = (Map)Preconditions.checkNotNull(operatorStates);
        this.allowNonRestoredState = allowNonRestoredState;
    }

    public boolean assignStates() throws Exception {
        HashMap<OperatorID, OperatorState> localOperators = new HashMap<OperatorID, OperatorState>(this.operatorStates);
        Map<JobVertexID, ExecutionJobVertex> localTasks = this.tasks;
        StateAssignmentOperation.checkStateMappingCompleteness(this.allowNonRestoredState, this.operatorStates, this.tasks);
        for (Map.Entry<JobVertexID, ExecutionJobVertex> task : localTasks.entrySet()) {
            ExecutionJobVertex executionJobVertex = task.getValue();
            List<OperatorID> operatorIDs = executionJobVertex.getOperatorIDs();
            List<OperatorID> altOperatorIDs = executionJobVertex.getUserDefinedOperatorIDs();
            ArrayList<OperatorState> operatorStates = new ArrayList<OperatorState>();
            boolean statelessTask = true;
            for (int x = 0; x < operatorIDs.size(); ++x) {
                OperatorID operatorID = altOperatorIDs.get(x) == null ? operatorIDs.get(x) : altOperatorIDs.get(x);
                OperatorState operatorState = (OperatorState)localOperators.remove((Object)operatorID);
                if (operatorState == null) {
                    operatorState = new OperatorState(operatorID, executionJobVertex.getParallelism(), executionJobVertex.getMaxParallelism());
                } else {
                    statelessTask = false;
                }
                operatorStates.add(operatorState);
            }
            if (statelessTask) continue;
            HashSet<Integer> executionVertexIndices = new HashSet<Integer>();
            for (ExecutionVertex executionVertex : task.getValue().getTaskVertices()) {
                executionVertexIndices.add(executionVertex.getParallelSubtaskIndex());
            }
            this.assignAttemptState(task.getValue(), operatorStates, executionVertexIndices);
        }
        return true;
    }

    public boolean assignStates(List<ExecutionVertex> executionVertices) throws Exception {
        HashMap<OperatorID, OperatorState> localOperators = new HashMap<OperatorID, OperatorState>(this.operatorStates);
        Map<JobVertexID, ExecutionJobVertex> localTasks = this.tasks;
        if (!this.allowNonRestoredState) {
            StateAssignmentOperation.checkStateMappingCompleteness(false, this.operatorStates, this.tasks);
        }
        HashMap jobVertexIDSetMap = new HashMap();
        for (ExecutionVertex executionVertex : executionVertices) {
            JobVertexID jobvertexId = executionVertex.getJobvertexId();
            jobVertexIDSetMap.putIfAbsent(jobvertexId, new HashSet());
            ((Set)jobVertexIDSetMap.get((Object)jobvertexId)).add(executionVertex.getParallelSubtaskIndex());
        }
        for (Map.Entry entry : localTasks.entrySet()) {
            ExecutionJobVertex executionJobVertex = (ExecutionJobVertex)entry.getValue();
            List<OperatorID> operatorIDs = executionJobVertex.getOperatorIDs();
            List<OperatorID> altOperatorIDs = executionJobVertex.getUserDefinedOperatorIDs();
            ArrayList<OperatorState> operatorStates = new ArrayList<OperatorState>();
            boolean statelessTask = true;
            for (int x = 0; x < operatorIDs.size(); ++x) {
                OperatorID operatorID = altOperatorIDs.get(x) == null ? operatorIDs.get(x) : altOperatorIDs.get(x);
                OperatorState operatorState = (OperatorState)localOperators.remove((Object)operatorID);
                if (operatorState == null) {
                    operatorState = new OperatorState(operatorID, executionJobVertex.getParallelism(), executionJobVertex.getMaxParallelism());
                } else {
                    statelessTask = false;
                }
                operatorStates.add(operatorState);
            }
            if (statelessTask || !jobVertexIDSetMap.containsKey((Object)executionJobVertex.getJobVertexId())) continue;
            this.assignAttemptState(executionJobVertex, operatorStates, (Set)jobVertexIDSetMap.get((Object)executionJobVertex.getJobVertexId()));
        }
        return true;
    }

    private void assignAttemptState(ExecutionJobVertex executionJobVertex, List<OperatorState> operatorStates, Set<Integer> subTaskIndices) {
        List<OperatorID> operatorIDs = executionJobVertex.getOperatorIDs();
        this.checkParallelismPreconditions(operatorStates, executionJobVertex);
        int newParallelism = executionJobVertex.getParallelism();
        List<KeyGroupRange> keyGroupPartitions = StateAssignmentOperation.createKeyGroupPartitions(executionJobVertex.getMaxParallelism(), newParallelism);
        HashMap<OperatorInstanceID, List<OperatorStateHandle>> newManagedOperatorStates = new HashMap<OperatorInstanceID, List<OperatorStateHandle>>();
        HashMap<OperatorInstanceID, List<OperatorStateHandle>> newRawOperatorStates = new HashMap<OperatorInstanceID, List<OperatorStateHandle>>();
        this.reDistributePartitionableStates(operatorStates, newParallelism, operatorIDs, newManagedOperatorStates, newRawOperatorStates);
        HashMap<OperatorInstanceID, List<KeyedStateHandle>> newManagedKeyedState = new HashMap<OperatorInstanceID, List<KeyedStateHandle>>();
        HashMap<OperatorInstanceID, List<KeyedStateHandle>> newRawKeyedState = new HashMap<OperatorInstanceID, List<KeyedStateHandle>>();
        this.reDistributeKeyedStates(operatorStates, newParallelism, operatorIDs, keyGroupPartitions, newManagedKeyedState, newRawKeyedState);
        this.assignTaskStateToExecutionJobVertices(executionJobVertex, subTaskIndices, newManagedOperatorStates, newRawOperatorStates, newManagedKeyedState, newRawKeyedState, newParallelism);
    }

    private void assignTaskStateToExecutionJobVertices(ExecutionJobVertex executionJobVertex, Set<Integer> subTaskIndices, Map<OperatorInstanceID, List<OperatorStateHandle>> subManagedOperatorState, Map<OperatorInstanceID, List<OperatorStateHandle>> subRawOperatorState, Map<OperatorInstanceID, List<KeyedStateHandle>> subManagedKeyedState, Map<OperatorInstanceID, List<KeyedStateHandle>> subRawKeyedState, int newParallelism) {
        List<OperatorID> operatorIDs = executionJobVertex.getOperatorIDs();
        for (int subTaskIndex = 0; subTaskIndex < newParallelism; ++subTaskIndex) {
            if (!subTaskIndices.contains(subTaskIndex)) continue;
            Execution currentExecutionAttempt = executionJobVertex.getTaskVertices()[subTaskIndex].getCurrentExecutionAttempt();
            TaskStateSnapshot taskState = new TaskStateSnapshot();
            boolean statelessTask = true;
            for (OperatorID operatorID : operatorIDs) {
                OperatorInstanceID instanceID = OperatorInstanceID.of(subTaskIndex, operatorID);
                OperatorSubtaskState operatorSubtaskState = StateAssignmentOperation.operatorSubtaskStateFrom(instanceID, subManagedOperatorState, subRawOperatorState, subManagedKeyedState, subRawKeyedState);
                if (operatorSubtaskState.hasState()) {
                    statelessTask = false;
                }
                taskState.putSubtaskStateByOperatorID(operatorID, operatorSubtaskState);
            }
            if (statelessTask) continue;
            JobManagerTaskRestore taskRestore = new JobManagerTaskRestore(this.restoreCheckpointId, taskState);
            currentExecutionAttempt.setInitialState(taskRestore);
        }
    }

    public static OperatorSubtaskState operatorSubtaskStateFrom(OperatorInstanceID instanceID, Map<OperatorInstanceID, List<OperatorStateHandle>> subManagedOperatorState, Map<OperatorInstanceID, List<OperatorStateHandle>> subRawOperatorState, Map<OperatorInstanceID, List<KeyedStateHandle>> subManagedKeyedState, Map<OperatorInstanceID, List<KeyedStateHandle>> subRawKeyedState) {
        if (!(subManagedOperatorState.containsKey(instanceID) || subRawOperatorState.containsKey(instanceID) || subManagedKeyedState.containsKey(instanceID) || subRawKeyedState.containsKey(instanceID))) {
            return new OperatorSubtaskState();
        }
        if (!subManagedKeyedState.containsKey(instanceID)) {
            Preconditions.checkState((!subRawKeyedState.containsKey(instanceID) ? 1 : 0) != 0);
        }
        return new OperatorSubtaskState(new StateObjectCollection<OperatorStateHandle>(subManagedOperatorState.getOrDefault(instanceID, Collections.emptyList())), new StateObjectCollection<OperatorStateHandle>(subRawOperatorState.getOrDefault(instanceID, Collections.emptyList())), new StateObjectCollection<KeyedStateHandle>(subManagedKeyedState.getOrDefault(instanceID, Collections.emptyList())), new StateObjectCollection<KeyedStateHandle>(subRawKeyedState.getOrDefault(instanceID, Collections.emptyList())));
    }

    private static boolean isHeadOperator(int opIdx, List<OperatorID> operatorIDs) {
        return opIdx == operatorIDs.size() - 1;
    }

    public void checkParallelismPreconditions(List<OperatorState> operatorStates, ExecutionJobVertex executionJobVertex) {
        for (OperatorState operatorState : operatorStates) {
            StateAssignmentOperation.checkParallelismPreconditions(operatorState, executionJobVertex);
        }
    }

    private void reDistributeKeyedStates(List<OperatorState> oldOperatorStates, int newParallelism, List<OperatorID> newOperatorIDs, List<KeyGroupRange> newKeyGroupPartitions, Map<OperatorInstanceID, List<KeyedStateHandle>> newManagedKeyedState, Map<OperatorInstanceID, List<KeyedStateHandle>> newRawKeyedState) {
        Preconditions.checkState((newOperatorIDs.size() == oldOperatorStates.size() ? 1 : 0) != 0, (Object)"This method still depends on the order of the new and old operators");
        for (int operatorIndex = 0; operatorIndex < newOperatorIDs.size(); ++operatorIndex) {
            OperatorState operatorState = oldOperatorStates.get(operatorIndex);
            int oldParallelism = operatorState.getParallelism();
            for (int subTaskIndex = 0; subTaskIndex < newParallelism; ++subTaskIndex) {
                OperatorInstanceID instanceID = OperatorInstanceID.of(subTaskIndex, newOperatorIDs.get(operatorIndex));
                if (!StateAssignmentOperation.isHeadOperator(operatorIndex, newOperatorIDs)) continue;
                Tuple2<Collection<KeyedStateHandle>, Collection<KeyedStateHandle>> subKeyedStates = this.reAssignSubKeyedStates(operatorState, newKeyGroupPartitions, subTaskIndex, newParallelism, oldParallelism);
                newManagedKeyedState.computeIfAbsent(instanceID, key -> new ArrayList()).addAll((Collection)subKeyedStates.f0);
                newRawKeyedState.computeIfAbsent(instanceID, key -> new ArrayList()).addAll((Collection)subKeyedStates.f1);
            }
        }
    }

    private Tuple2<Collection<KeyedStateHandle>, Collection<KeyedStateHandle>> reAssignSubKeyedStates(OperatorState operatorState, List<KeyGroupRange> keyGroupPartitions, int subTaskIndex, int newParallelism, int oldParallelism) {
        Collection<Object> subRawKeyedState;
        Collection<Object> subManagedKeyedState;
        if (newParallelism == oldParallelism) {
            if (operatorState.getState(subTaskIndex) != null) {
                subManagedKeyedState = operatorState.getState(subTaskIndex).getManagedKeyedState();
                subRawKeyedState = operatorState.getState(subTaskIndex).getRawKeyedState();
            } else {
                subManagedKeyedState = Collections.emptyList();
                subRawKeyedState = Collections.emptyList();
            }
        } else {
            subManagedKeyedState = StateAssignmentOperation.getManagedKeyedStateHandles(operatorState, keyGroupPartitions.get(subTaskIndex));
            subRawKeyedState = StateAssignmentOperation.getRawKeyedStateHandles(operatorState, keyGroupPartitions.get(subTaskIndex));
        }
        if (subManagedKeyedState.isEmpty() && subRawKeyedState.isEmpty()) {
            return new Tuple2(Collections.emptyList(), Collections.emptyList());
        }
        return new Tuple2(subManagedKeyedState, subRawKeyedState);
    }

    private void reDistributePartitionableStates(List<OperatorState> oldOperatorStates, int newParallelism, List<OperatorID> newOperatorIDs, Map<OperatorInstanceID, List<OperatorStateHandle>> newManagedOperatorStates, Map<OperatorInstanceID, List<OperatorStateHandle>> newRawOperatorStates) {
        Preconditions.checkState((newOperatorIDs.size() == oldOperatorStates.size() ? 1 : 0) != 0, (Object)"This method still depends on the order of the new and old operators");
        ArrayList<List<OperatorStateHandle>> oldManagedOperatorStates = new ArrayList<List<OperatorStateHandle>>();
        ArrayList<List<OperatorStateHandle>> oldRawOperatorStates = new ArrayList<List<OperatorStateHandle>>();
        this.collectPartionableStates(oldOperatorStates, oldManagedOperatorStates, oldRawOperatorStates);
        OperatorStateRepartitioner opStateRepartitioner = RoundRobinOperatorStateRepartitioner.INSTANCE;
        for (int operatorIndex = 0; operatorIndex < oldOperatorStates.size(); ++operatorIndex) {
            OperatorID operatorID = newOperatorIDs.get(operatorIndex);
            int oldParallelism = oldOperatorStates.get(operatorIndex).getParallelism();
            newManagedOperatorStates.putAll(StateAssignmentOperation.applyRepartitioner(operatorID, opStateRepartitioner, (List)oldManagedOperatorStates.get(operatorIndex), oldParallelism, newParallelism));
            newRawOperatorStates.putAll(StateAssignmentOperation.applyRepartitioner(operatorID, opStateRepartitioner, (List)oldRawOperatorStates.get(operatorIndex), oldParallelism, newParallelism));
        }
    }

    private void collectPartionableStates(List<OperatorState> operatorStates, List<List<OperatorStateHandle>> managedOperatorStates, List<List<OperatorStateHandle>> rawOperatorStates) {
        for (OperatorState operatorState : operatorStates) {
            ArrayList<OperatorStateHandle> managedOperatorState = null;
            ArrayList<OperatorStateHandle> rawOperatorState = null;
            for (int i = 0; i < operatorState.getParallelism(); ++i) {
                OperatorSubtaskState operatorSubtaskState = operatorState.getState(i);
                if (operatorSubtaskState == null) continue;
                if (managedOperatorState == null) {
                    managedOperatorState = new ArrayList<OperatorStateHandle>();
                }
                managedOperatorState.addAll(operatorSubtaskState.getManagedOperatorState());
                if (rawOperatorState == null) {
                    rawOperatorState = new ArrayList<OperatorStateHandle>();
                }
                rawOperatorState.addAll(operatorSubtaskState.getRawOperatorState());
            }
            managedOperatorStates.add(managedOperatorState);
            rawOperatorStates.add(rawOperatorState);
        }
    }

    public static List<KeyedStateHandle> getManagedKeyedStateHandles(OperatorState operatorState, KeyGroupRange subtaskKeyGroupRange) {
        ArrayList<KeyedStateHandle> subtaskKeyedStateHandles = new ArrayList<KeyedStateHandle>();
        for (int i = 0; i < operatorState.getParallelism(); ++i) {
            if (operatorState.getState(i) == null) continue;
            StateObjectCollection<KeyedStateHandle> keyedStateHandles = operatorState.getState(i).getManagedKeyedState();
            StateAssignmentOperation.extractIntersectingState(keyedStateHandles, subtaskKeyGroupRange, subtaskKeyedStateHandles);
        }
        return subtaskKeyedStateHandles;
    }

    public static List<KeyedStateHandle> getRawKeyedStateHandles(OperatorState operatorState, KeyGroupRange subtaskKeyGroupRange) {
        ArrayList<KeyedStateHandle> extractedKeyedStateHandles = new ArrayList<KeyedStateHandle>();
        for (int i = 0; i < operatorState.getParallelism(); ++i) {
            if (operatorState.getState(i) == null) continue;
            StateObjectCollection<KeyedStateHandle> rawKeyedState = operatorState.getState(i).getRawKeyedState();
            StateAssignmentOperation.extractIntersectingState(rawKeyedState, subtaskKeyGroupRange, extractedKeyedStateHandles);
        }
        return extractedKeyedStateHandles;
    }

    private static void extractIntersectingState(Collection<KeyedStateHandle> originalSubtaskStateHandles, KeyGroupRange rangeToExtract, List<KeyedStateHandle> extractedStateCollector) {
        for (KeyedStateHandle keyedStateHandle : originalSubtaskStateHandles) {
            KeyedStateHandle intersectedKeyedStateHandle;
            if (keyedStateHandle == null || (intersectedKeyedStateHandle = keyedStateHandle.getIntersection(rangeToExtract)) == null) continue;
            extractedStateCollector.add(intersectedKeyedStateHandle);
        }
    }

    public static List<KeyGroupRange> createKeyGroupPartitions(int numberKeyGroups, int parallelism) {
        Preconditions.checkArgument((numberKeyGroups >= parallelism ? 1 : 0) != 0);
        ArrayList<KeyGroupRange> result = new ArrayList<KeyGroupRange>(parallelism);
        for (int i = 0; i < parallelism; ++i) {
            result.add(KeyGroupRangeAssignment.computeKeyGroupRangeForOperatorIndex(numberKeyGroups, parallelism, i));
        }
        return result;
    }

    private static void checkParallelismPreconditions(OperatorState operatorState, ExecutionJobVertex executionJobVertex) {
        if (operatorState.getMaxParallelism() < executionJobVertex.getParallelism()) {
            throw new IllegalStateException("The state for task " + (Object)((Object)executionJobVertex.getJobVertexId()) + " can not be restored. The maximum parallelism (" + operatorState.getMaxParallelism() + ") of the restored state is lower than the configured parallelism (" + executionJobVertex.getParallelism() + "). Please reduce the parallelism of the task to be lower or equal to the maximum parallelism.");
        }
        if (operatorState.getMaxParallelism() != executionJobVertex.getMaxParallelism()) {
            if (!executionJobVertex.isMaxParallelismConfigured()) {
                LOG.debug("Overriding maximum parallelism for JobVertex {} from {} to {}", new Object[]{executionJobVertex.getJobVertexId(), executionJobVertex.getMaxParallelism(), operatorState.getMaxParallelism()});
                executionJobVertex.setMaxParallelism(operatorState.getMaxParallelism());
            } else {
                throw new IllegalStateException("The maximum parallelism (" + operatorState.getMaxParallelism() + ") with which the latest checkpoint of the execution job vertex " + executionJobVertex + " has been taken and the current maximum parallelism (" + executionJobVertex.getMaxParallelism() + ") changed. This is currently not supported.");
            }
        }
    }

    private static void checkStateMappingCompleteness(boolean allowNonRestoredState, Map<OperatorID, OperatorState> operatorStates, Map<JobVertexID, ExecutionJobVertex> tasks) {
        HashSet<OperatorID> allOperatorIDs = new HashSet<OperatorID>();
        for (ExecutionJobVertex executionJobVertex : tasks.values()) {
            allOperatorIDs.addAll(executionJobVertex.getOperatorIDs());
        }
        for (Map.Entry entry : operatorStates.entrySet()) {
            OperatorState operatorState = (OperatorState)entry.getValue();
            if (allOperatorIDs.contains(entry.getKey())) continue;
            if (allowNonRestoredState) {
                LOG.info("Skipped checkpoint state for operator {}.", (Object)operatorState.getOperatorID());
                continue;
            }
            throw new IllegalStateException("There is no operator for the state " + (Object)((Object)operatorState.getOperatorID()));
        }
    }

    public static Map<OperatorInstanceID, List<OperatorStateHandle>> applyRepartitioner(OperatorID operatorID, OperatorStateRepartitioner opStateRepartitioner, List<OperatorStateHandle> chainOpParallelStates, int oldParallelism, int newParallelism) {
        HashMap<OperatorInstanceID, List<OperatorStateHandle>> result = new HashMap<OperatorInstanceID, List<OperatorStateHandle>>();
        List<Collection<OperatorStateHandle>> states = StateAssignmentOperation.applyRepartitioner(opStateRepartitioner, chainOpParallelStates, oldParallelism, newParallelism);
        for (int subtaskIndex = 0; subtaskIndex < states.size(); ++subtaskIndex) {
            Preconditions.checkNotNull((Object)(states.get(subtaskIndex) != null ? 1 : 0), (String)"states.get(subtaskIndex) is null");
            result.computeIfAbsent(OperatorInstanceID.of(subtaskIndex, operatorID), key -> new ArrayList()).addAll(states.get(subtaskIndex));
        }
        return result;
    }

    public static List<Collection<OperatorStateHandle>> applyRepartitioner(OperatorStateRepartitioner opStateRepartitioner, List<OperatorStateHandle> chainOpParallelStates, int oldParallelism, int newParallelism) {
        if (chainOpParallelStates == null) {
            return Collections.emptyList();
        }
        if (newParallelism != oldParallelism) {
            return opStateRepartitioner.repartitionState(chainOpParallelStates, newParallelism);
        }
        ArrayList<Collection<OperatorStateHandle>> repackStream = new ArrayList<Collection<OperatorStateHandle>>(newParallelism);
        for (OperatorStateHandle operatorStateHandle : chainOpParallelStates) {
            if (operatorStateHandle == null) continue;
            Map<String, OperatorStateHandle.StateMetaInfo> partitionOffsets = operatorStateHandle.getStateNameToPartitionOffsets();
            for (OperatorStateHandle.StateMetaInfo metaInfo : partitionOffsets.values()) {
                if (!OperatorStateHandle.Mode.UNION.equals((Object)metaInfo.getDistributionMode())) continue;
                return opStateRepartitioner.repartitionState(chainOpParallelStates, newParallelism);
            }
            repackStream.add(Collections.singletonList(operatorStateHandle));
        }
        return repackStream;
    }

    public static List<KeyedStateHandle> getKeyedStateHandles(Collection<? extends KeyedStateHandle> keyedStateHandles, KeyGroupRange subtaskKeyGroupRange) {
        ArrayList<KeyedStateHandle> subtaskKeyedStateHandles = new ArrayList<KeyedStateHandle>();
        for (KeyedStateHandle keyedStateHandle : keyedStateHandles) {
            KeyedStateHandle intersectedKeyedStateHandle = keyedStateHandle.getIntersection(subtaskKeyGroupRange);
            if (intersectedKeyedStateHandle == null) continue;
            subtaskKeyedStateHandles.add(intersectedKeyedStateHandle);
        }
        return subtaskKeyedStateHandles;
    }
}

