/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.runtime.join.stream.state.match;

import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import org.apache.flink.runtime.state.keyed.KeyedMapState;
import org.apache.flink.table.codegen.Projection;
import org.apache.flink.table.dataformat.BaseRow;
import org.apache.flink.table.runtime.join.stream.state.match.JoinMatchStateHandler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class JoinKeyNotContainPrimaryKeyMatchStateHandler
implements JoinMatchStateHandler {
    private static final Logger LOG = LoggerFactory.getLogger(JoinKeyNotContainPrimaryKeyMatchStateHandler.class);
    private final KeyedMapState<BaseRow, BaseRow, Long> keyedMapState;
    private final Projection<BaseRow, BaseRow> pkProjection;
    private transient BaseRow currentJoinKey;
    private transient BaseRow pk;
    private transient long currentRowMatchJoinCont;

    public JoinKeyNotContainPrimaryKeyMatchStateHandler(KeyedMapState<BaseRow, BaseRow, Long> keyedMapState, Projection<BaseRow, BaseRow> pkProjection) {
        this.keyedMapState = keyedMapState;
        this.pkProjection = pkProjection;
    }

    @Override
    public void extractCurrentRowMatchJoinCount(BaseRow joinKey, BaseRow row2, long possibleJoinCnt) {
        this.currentJoinKey = joinKey;
        this.pk = this.pkProjection.apply(row2);
        Long count = (Long)this.keyedMapState.get((Object)joinKey, (Object)this.pk);
        this.currentRowMatchJoinCont = count == null ? 0L : count;
    }

    @Override
    public long getCurrentRowMatchJoinCnt() {
        return this.currentRowMatchJoinCont;
    }

    @Override
    public void resetCurrentRowMatchJoinCnt(long joinCnt) {
        this.keyedMapState.add((Object)this.currentJoinKey, (Object)this.pk, (Object)joinCnt);
        this.currentRowMatchJoinCont = joinCnt;
    }

    @Override
    public void updateRowMatchJoinCnt(BaseRow joinKey, BaseRow baseRow, long joinCnt) {
        this.keyedMapState.add((Object)joinKey, (Object)this.pkProjection.apply(baseRow), (Object)joinCnt);
    }

    @Override
    public void addRowMatchJoinCnt(BaseRow joinKey, BaseRow baseRow, long joinCnt) {
        BaseRow mapKey = this.pkProjection.apply(baseRow);
        Long count = (Long)this.keyedMapState.get((Object)joinKey, (Object)mapKey);
        if (count != null) {
            this.keyedMapState.add((Object)joinKey, (Object)mapKey, (Object)(joinCnt + count));
        } else {
            LOG.warn("The state is cleared because of state ttl. This will result in incorrect result. You can increase the state ttl to avoid this.");
            this.keyedMapState.add((Object)joinKey, (Object)mapKey, (Object)joinCnt);
        }
    }

    @Override
    public void remove(BaseRow joinKey, BaseRow baseRow) {
        this.keyedMapState.remove((Object)joinKey, (Object)this.pkProjection.apply(baseRow));
    }

    @Override
    public void remove(BaseRow joinKey) {
        this.keyedMapState.remove((Object)joinKey);
    }

    @Override
    public void removeAll(BaseRow joinKey, Set<BaseRow> keys) {
        HashSet<BaseRow> pks = new HashSet<BaseRow>();
        for (BaseRow baseRow : keys) {
            pks.add(this.pkProjection.apply(baseRow));
        }
        this.keyedMapState.removeAll((Object)joinKey, pks);
    }

    @Override
    public void addAll(BaseRow joinKey, Map<BaseRow, Long> kvs) {
        HashMap<BaseRow, Long> pkMap = new HashMap<BaseRow, Long>();
        for (Map.Entry<BaseRow, Long> entry : kvs.entrySet()) {
            pkMap.put(this.pkProjection.apply(entry.getKey()), entry.getValue());
        }
        this.keyedMapState.addAll((Object)joinKey, pkMap);
    }
}

