/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.runtime.join.batch.hashtable;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.core.memory.MemorySegment;
import org.apache.flink.runtime.io.disk.ChannelReaderInputViewIterator;
import org.apache.flink.runtime.io.disk.iomanager.ChannelReaderInputView;
import org.apache.flink.runtime.io.disk.iomanager.FileIOChannel;
import org.apache.flink.runtime.io.disk.iomanager.HeaderlessChannelReaderInputView;
import org.apache.flink.runtime.io.disk.iomanager.IOManager;
import org.apache.flink.runtime.memory.MemoryManager;
import org.apache.flink.runtime.operators.util.BitSet;
import org.apache.flink.table.codegen.JoinConditionFunction;
import org.apache.flink.table.codegen.Projection;
import org.apache.flink.table.dataformat.BaseRow;
import org.apache.flink.table.dataformat.BinaryRow;
import org.apache.flink.table.dataformat.util.BinaryRowUtil;
import org.apache.flink.table.runtime.join.batch.HashJoinType;
import org.apache.flink.table.runtime.join.batch.NullAwareJoinHelper;
import org.apache.flink.table.runtime.join.batch.hashtable.BaseHybridHashTable;
import org.apache.flink.table.runtime.join.batch.hashtable.BinaryHashBucketArea;
import org.apache.flink.table.runtime.join.batch.hashtable.BinaryHashPartition;
import org.apache.flink.table.runtime.join.batch.hashtable.BuildSideIterator;
import org.apache.flink.table.runtime.join.batch.hashtable.LookupBucketIterator;
import org.apache.flink.table.runtime.join.batch.hashtable.ProbeIterator;
import org.apache.flink.table.runtime.util.ChannelWithMeta;
import org.apache.flink.table.runtime.util.FileChannelUtil;
import org.apache.flink.table.runtime.util.PagedChannelReaderInputViewIterator;
import org.apache.flink.table.runtime.util.RowIterator;
import org.apache.flink.table.runtime.util.WrappedRowIterator;
import org.apache.flink.table.typeutils.AbstractRowSerializer;
import org.apache.flink.table.typeutils.BinaryRowSerializer;
import org.apache.flink.util.MathUtils;
import org.apache.flink.util.Preconditions;

public class BinaryHashTable
extends BaseHybridHashTable {
    final BinaryRowSerializer binaryBuildSideSerializer;
    private final AbstractRowSerializer originBuildSideSerializer;
    private final BinaryRowSerializer binaryProbeSideSerializer;
    private final AbstractRowSerializer originProbeSideSerializer;
    private final Projection<BaseRow, BinaryRow> buildSideProjection;
    private final Projection<BaseRow, BinaryRow> probeSideProjection;
    final int bucketsPerSegment;
    final int bucketsPerSegmentMask;
    final int bucketsPerSegmentBits;
    final boolean useBloomFilters;
    final ArrayList<BinaryHashPartition> partitionsBeingBuilt;
    final BitSet probedSet = new BitSet(2);
    private final ArrayList<BinaryHashPartition> partitionsPending;
    private final JoinConditionFunction condFunc;
    private final boolean reverseJoin;
    private final int[] nullFilterKeys;
    private final boolean nullSafe;
    private final boolean filterAllNulls;
    LookupBucketIterator bucketIterator;
    private ProbeIterator probeIterator;
    final HashJoinType type;
    private RowIterator<BinaryRow> buildIterator;
    private boolean probeMatchedPhase = true;
    private boolean buildIterVisited = false;
    private BinaryRow probeKey;
    private BaseRow probeRow;
    BinaryRow reuseBuildRow;

    public BinaryHashTable(Configuration conf, Object owner, AbstractRowSerializer buildSideSerializer, AbstractRowSerializer probeSideSerializer, Projection<BaseRow, BinaryRow> buildSideProjection, Projection<BaseRow, BinaryRow> probeSideProjection, MemoryManager memManager, long reservedMemorySize, IOManager ioManager, int avgRecordLen, int buildRowCount, boolean useBloomFilters, HashJoinType type, JoinConditionFunction condFunc, boolean reverseJoin, boolean[] filterNulls, boolean tryDistinctBuildRow) {
        this(conf, owner, buildSideSerializer, probeSideSerializer, buildSideProjection, probeSideProjection, memManager, reservedMemorySize, reservedMemorySize, 0L, ioManager, avgRecordLen, buildRowCount, useBloomFilters, type, condFunc, reverseJoin, filterNulls, tryDistinctBuildRow);
    }

    public BinaryHashTable(Configuration conf, Object owner, AbstractRowSerializer buildSideSerializer, AbstractRowSerializer probeSideSerializer, Projection<BaseRow, BinaryRow> buildSideProjection, Projection<BaseRow, BinaryRow> probeSideProjection, MemoryManager memManager, long reservedMemorySize, long preferredMemorySize, long perRequestMemorySize, IOManager ioManager, int avgRecordLen, long buildRowCount, boolean useBloomFilters, HashJoinType type, JoinConditionFunction condFunc, boolean reverseJoin, boolean[] filterNulls, boolean tryDistinctBuildRow) {
        super(conf, owner, memManager, reservedMemorySize, preferredMemorySize, perRequestMemorySize, ioManager, avgRecordLen, buildRowCount, !type.buildLeftSemiOrAnti() && tryDistinctBuildRow);
        this.originBuildSideSerializer = buildSideSerializer;
        this.binaryBuildSideSerializer = new BinaryRowSerializer(buildSideSerializer.getTypes());
        this.reuseBuildRow = this.binaryBuildSideSerializer.createInstance();
        this.originProbeSideSerializer = probeSideSerializer;
        this.binaryProbeSideSerializer = new BinaryRowSerializer(this.originProbeSideSerializer.getTypes());
        this.buildSideProjection = buildSideProjection;
        this.probeSideProjection = probeSideProjection;
        this.useBloomFilters = useBloomFilters;
        this.type = type;
        this.condFunc = condFunc;
        this.reverseJoin = reverseJoin;
        this.nullFilterKeys = NullAwareJoinHelper.getNullFilterKeys(filterNulls);
        this.nullSafe = this.nullFilterKeys.length == 0;
        this.filterAllNulls = this.nullFilterKeys.length == filterNulls.length;
        this.bucketsPerSegment = this.segmentSize >> 7;
        Preconditions.checkArgument(this.bucketsPerSegment != 0, "Hash Table requires buffers of at least 128 bytes.");
        this.bucketsPerSegmentMask = this.bucketsPerSegment - 1;
        this.bucketsPerSegmentBits = MathUtils.log2strict(this.bucketsPerSegment);
        this.partitionsBeingBuilt = new ArrayList();
        this.partitionsPending = new ArrayList();
        this.createPartitions(this.initPartitionFanOut, 0);
    }

    public void putBuildRow(BaseRow row2) throws IOException {
        int hashCode2 = BinaryHashTable.hash(this.buildSideProjection.apply(row2).hashCode(), 0);
        this.insertIntoTable(this.originBuildSideSerializer.baseRowToBinary(row2), hashCode2);
    }

    public void endBuild() throws IOException {
        int buildWriteBuffers = 0;
        for (BinaryHashPartition p : this.partitionsBeingBuilt) {
            buildWriteBuffers += p.finalizeBuildPhase(this.ioManager, this.currentEnumerator);
        }
        this.buildSpillRetBufferNumbers += buildWriteBuffers;
        this.probeIterator = new ProbeIterator(this.binaryProbeSideSerializer.createInstance());
        this.bucketIterator = new LookupBucketIterator(this);
    }

    public boolean tryProbe(BaseRow record) throws IOException {
        BinaryRow probeKey;
        int hash;
        BinaryHashPartition p;
        if (!this.probeIterator.hasSource()) {
            this.probeIterator.setInstance(record);
        }
        if ((p = this.partitionsBeingBuilt.get((hash = BinaryHashTable.hash((probeKey = this.probeSideProjection.apply(record)).hashCode(), this.currentRecursionDepth)) % this.partitionsBeingBuilt.size())).isInMemory()) {
            this.probeKey = probeKey;
            this.probeRow = record;
            p.bucketArea.startLookup(hash);
            return true;
        }
        if (p.testHashBloomFilter(hash)) {
            BinaryRow row2 = this.originProbeSideSerializer.baseRowToBinary(record);
            p.insertIntoProbeBuffer(row2);
        }
        return false;
    }

    public boolean nextMatching() throws IOException {
        if (this.type.needSetProbed()) {
            return this.processProbeIter() || this.processBuildIter() || this.prepareNextPartition();
        }
        return this.processProbeIter() || this.prepareNextPartition();
    }

    public BaseRow getCurrentProbeRow() {
        if (this.probeMatchedPhase) {
            return this.probeIterator.current();
        }
        return null;
    }

    public RowIterator<BinaryRow> getBuildSideIterator() {
        return this.probeMatchedPhase ? this.bucketIterator : this.buildIterator;
    }

    @VisibleForTesting
    static int getNumWriteBehindBuffers(int numBuffers) {
        int numIOBufs = (int)(Math.log(numBuffers) / Math.log(4.0) - 1.5);
        return numIOBufs > 6 ? 6 : numIOBufs;
    }

    private boolean processProbeIter() throws IOException {
        if (this.probeIterator.hasSource()) {
            BinaryRow next;
            ProbeIterator probeIter = this.probeIterator;
            if (!this.probeMatchedPhase) {
                return false;
            }
            while ((next = probeIter.next()) != null) {
                BinaryRow probeKey = this.probeSideProjection.apply(next);
                int hash = BinaryHashTable.hash(probeKey.hashCode(), this.currentRecursionDepth);
                BinaryHashPartition p = this.partitionsBeingBuilt.get(hash % this.partitionsBeingBuilt.size());
                if (p.isInMemory()) {
                    this.probeKey = probeKey;
                    this.probeRow = next;
                    p.bucketArea.startLookup(hash);
                    return true;
                }
                p.insertIntoProbeBuffer(next);
            }
            return false;
        }
        return false;
    }

    private boolean processBuildIter() throws IOException {
        if (this.buildIterVisited) {
            return false;
        }
        this.probeMatchedPhase = false;
        this.buildIterator = new BuildSideIterator(this.binaryBuildSideSerializer, this.reuseBuildRow, this.partitionsBeingBuilt, this.probedSet, this.type.equals((Object)HashJoinType.BUILD_LEFT_SEMI));
        this.buildIterVisited = true;
        return true;
    }

    private boolean prepareNextPartition() throws IOException {
        for (BinaryHashPartition p : this.partitionsBeingBuilt) {
            p.finalizeProbePhase(this.availableMemory, this.partitionsPending, this.type.needSetProbed());
        }
        this.partitionsBeingBuilt.clear();
        if (this.currentSpilledBuildSide != null) {
            this.currentSpilledBuildSide.closeAndDelete();
            this.currentSpilledBuildSide = null;
        }
        if (this.currentSpilledProbeSide != null) {
            this.currentSpilledProbeSide.closeAndDelete();
            this.currentSpilledProbeSide = null;
        }
        if (this.partitionsPending.isEmpty()) {
            return false;
        }
        BinaryHashPartition p = this.partitionsPending.get(0);
        LOG.info(String.format("Begin to process spilled partition [%d]", p.getPartitionNumber()));
        if (p.probeSideRecordCounter == 0L) {
            this.currentSpilledBuildSide = this.createInputView(p.getBuildSideChannel().getChannelID(), p.getBuildSideBlockCount(), p.getLastSegmentLimit());
            this.buildIterator = new WrappedRowIterator<BinaryRow>(new PagedChannelReaderInputViewIterator<BinaryRow>((ChannelReaderInputView)this.currentSpilledBuildSide, this.binaryBuildSideSerializer), this.binaryBuildSideSerializer.createInstance());
            this.partitionsPending.remove(0);
            return true;
        }
        this.probeMatchedPhase = true;
        this.buildIterVisited = false;
        this.buildTableFromSpilledPartition(p);
        ChannelWithMeta channelWithMeta = new ChannelWithMeta(p.probeSideBuffer.getChannelID(), p.probeSideBuffer.getBlockCount(), p.probeNumBytesInLastSeg);
        this.currentSpilledProbeSide = FileChannelUtil.createInputView(this.ioManager, channelWithMeta, new ArrayList<FileIOChannel>(), this.compressionEnable, this.compressionCodecFactory, this.compressionBlockSize, this.segmentSize);
        ChannelReaderInputViewIterator probeReader = new ChannelReaderInputViewIterator(this.currentSpilledProbeSide, (TypeSerializer)this.binaryProbeSideSerializer);
        this.probeIterator.set((ChannelReaderInputViewIterator<BinaryRow>)probeReader);
        this.probeIterator.setReuse(this.binaryProbeSideSerializer.createInstance());
        this.partitionsPending.remove(0);
        this.currentRecursionDepth = p.getRecursionLevel() + 1;
        return this.nextMatching();
    }

    private void buildTableFromSpilledPartition(BinaryHashPartition p) throws IOException {
        int totalBuffersAvailable;
        int nextRecursionLevel = p.getRecursionLevel() + 1;
        if (nextRecursionLevel == 2) {
            LOG.info("Recursive hash join: partition number is " + p.getPartitionNumber());
        } else if (nextRecursionLevel > 3) {
            throw new RuntimeException("Hash join exceeded maximum number of recursions, without reducing partitions enough to be memory resident. Probably cause: Too many duplicate keys.");
        }
        if (p.getBuildSideBlockCount() > p.getProbeSideBlockCount()) {
            LOG.info(String.format("Hash join: Partition(%d) build side block [%d] more than probe side block [%d]", p.getPartitionNumber(), p.getBuildSideBlockCount(), p.getProbeSideBlockCount()));
        }
        if ((totalBuffersAvailable = this.availableMemory.size() + this.buildSpillRetBufferNumbers) != this.reservedNumBuffers + this.allocatedFloatingNum) {
            throw new RuntimeException(String.format("Hash Join bug in memory management: Memory buffers leaked. availableMemory(%s), buildSpillRetBufferNumbers(%s), reservedNumBuffers(%s), allocatedFloatingNum(%s)", this.availableMemory.size(), this.buildSpillRetBufferNumbers, this.reservedNumBuffers, this.allocatedFloatingNum));
        }
        long numBuckets = p.getBuildSideRecordCount() / 15L + 1L;
        int maxBucketAreaBuffers = Math.max((int)(2L * (numBuckets / (long)(this.bucketsPerSegmentMask + 1))), 1);
        long totalBuffersNeeded = maxBucketAreaBuffers + p.getBuildSideBlockCount() + 2;
        if (totalBuffersNeeded < (long)totalBuffersAvailable) {
            LOG.info(String.format("Build in memory hash table from spilled partition [%d]", p.getPartitionNumber()));
            List<MemorySegment> partitionBuffers = this.readAllBuffers(p.getBuildSideChannel().getChannelID(), p.getBuildSideBlockCount());
            BinaryHashBucketArea area = new BinaryHashBucketArea(this, (int)p.getBuildSideRecordCount(), maxBucketAreaBuffers);
            BinaryHashPartition newPart = new BinaryHashPartition(area, this.binaryBuildSideSerializer, this.binaryProbeSideSerializer, 0, nextRecursionLevel, partitionBuffers, p.getBuildSideRecordCount(), this.segmentSize, p.getLastSegmentLimit());
            area.setPartition(newPart);
            this.partitionsBeingBuilt.add(newPart);
            BinaryHashPartition.PartitionIterator pIter = newPart.newPartitionIterator();
            while (pIter.advanceNext()) {
                int hashCode2 = BinaryHashTable.hash(this.buildSideProjection.apply(pIter.getRow()).hashCode(), nextRecursionLevel);
                int pointer = (int)pIter.getPointer();
                area.insertToBucket(hashCode2, pointer, false, true);
            }
        } else {
            int splits = (int)(totalBuffersNeeded / (long)totalBuffersAvailable) + 1;
            int partitionFanOut = Math.min(Math.min(10 * splits, 127), this.maxNumPartition());
            this.createPartitions(partitionFanOut, nextRecursionLevel);
            LOG.info(String.format("Build hybrid hash table from spilled partition [%d] with recursion level [%d]", p.getPartitionNumber(), nextRecursionLevel));
            HeaderlessChannelReaderInputView inView = this.createInputView(p.getBuildSideChannel().getChannelID(), p.getBuildSideBlockCount(), p.getLastSegmentLimit());
            PagedChannelReaderInputViewIterator<BinaryRow> inIter = new PagedChannelReaderInputViewIterator<BinaryRow>((ChannelReaderInputView)inView, this.binaryBuildSideSerializer);
            BinaryRow rec = this.binaryBuildSideSerializer.createInstance();
            while ((rec = inIter.next(rec)) != null) {
                int hashCode3 = BinaryHashTable.hash(this.buildSideProjection.apply(rec).hashCode(), nextRecursionLevel);
                this.insertIntoTable(rec, hashCode3);
            }
            inView.closeAndDelete();
            int buildWriteBuffers = 0;
            for (BinaryHashPartition part : this.partitionsBeingBuilt) {
                buildWriteBuffers += part.finalizeBuildPhase(this.ioManager, this.currentEnumerator);
            }
            this.buildSpillRetBufferNumbers += buildWriteBuffers;
        }
    }

    private void insertIntoTable(BinaryRow record, int hashCode2) throws IOException {
        BinaryHashPartition p = this.partitionsBeingBuilt.get(hashCode2 % this.partitionsBeingBuilt.size());
        if (p.isInMemory()) {
            if (!p.bucketArea.appendRecordAndInsert(record, hashCode2)) {
                p.addHashBloomFilter(hashCode2);
            }
        } else {
            p.insertIntoBuildBuffer(record);
            p.addHashBloomFilter(hashCode2);
        }
    }

    private void createPartitions(int numPartitions, int recursionLevel) {
        this.ensureNumBuffersReturned(numPartitions);
        this.currentEnumerator = this.ioManager.createChannelEnumerator();
        this.partitionsBeingBuilt.clear();
        double numRecordPerPartition = (double)this.buildRowCount / (double)numPartitions;
        int maxBuffer = this.maxInitBufferOfBucketArea(numPartitions);
        for (int i = 0; i < numPartitions; ++i) {
            BinaryHashBucketArea area = new BinaryHashBucketArea(this, numRecordPerPartition, maxBuffer);
            BinaryHashPartition p = new BinaryHashPartition(area, this.binaryBuildSideSerializer, this.binaryProbeSideSerializer, i, recursionLevel, this.getNotNullNextBuffer(), this, this.segmentSize, this.compressionEnable, this.compressionCodecFactory, this.compressionBlockSize);
            area.setPartition(p);
            this.partitionsBeingBuilt.add(p);
        }
    }

    @Override
    public void clearPartitions() {
        this.bucketIterator = null;
        this.probeIterator = null;
        for (int i = this.partitionsBeingBuilt.size() - 1; i >= 0; --i) {
            BinaryHashPartition p = this.partitionsBeingBuilt.get(i);
            try {
                p.clearAllMemory(this.availableMemory);
                continue;
            }
            catch (Exception e2) {
                LOG.error("Error during partition cleanup.", (Throwable)e2);
            }
        }
        this.partitionsBeingBuilt.clear();
        for (BinaryHashPartition p : this.partitionsPending) {
            p.clearAllMemory(this.availableMemory);
        }
    }

    @Override
    protected int spillPartition() throws IOException {
        MemorySegment currBuff;
        int largestNumBlocks = 0;
        int largestPartNum = -1;
        for (int i = 0; i < this.partitionsBeingBuilt.size(); ++i) {
            BinaryHashPartition p = this.partitionsBeingBuilt.get(i);
            if (!p.isInMemory() || p.getNumOccupiedMemorySegments() <= largestNumBlocks) continue;
            largestNumBlocks = p.getNumOccupiedMemorySegments();
            largestPartNum = i;
        }
        BinaryHashPartition p = this.partitionsBeingBuilt.get(largestPartNum);
        int numBuffersFreed = p.spillPartition(this.ioManager, this.currentEnumerator.next(), this.buildSpillReturnBuffers);
        this.buildSpillRetBufferNumbers += numBuffersFreed;
        LOG.info(String.format("Grace hash join: Ran out memory, choosing partition [%d] to spill, %d memory segments being freed", largestPartNum, numBuffersFreed));
        while (this.buildSpillRetBufferNumbers > 0 && (currBuff = (MemorySegment)this.buildSpillReturnBuffers.poll()) != null) {
            this.availableMemory.add(currBuff);
            --this.buildSpillRetBufferNumbers;
        }
        ++this.numSpillFiles;
        this.spillInBytes += (long)(numBuffersFreed * this.segmentSize);
        p.buildBloomFilterAndFreeBucket();
        return largestPartNum;
    }

    boolean applyCondition(BinaryRow candidate) throws Exception {
        boolean equal;
        BinaryRow buildKey = this.buildSideProjection.apply(candidate);
        boolean bl = equal = buildKey.getSizeInBytes() == this.probeKey.getSizeInBytes() && BinaryRowUtil.byteArrayEquals(buildKey.getMemorySegment().getHeapMemory(), this.probeKey.getMemorySegment().getHeapMemory(), buildKey.getSizeInBytes());
        if (!this.nullSafe) {
            boolean bl2 = equal && !(!this.filterAllNulls ? buildKey.anyNull(this.nullFilterKeys) : buildKey.anyNull()) ? true : (equal = false);
        }
        return this.condFunc == null ? equal : equal && (this.reverseJoin ? this.condFunc.apply(this.probeRow, candidate) : this.condFunc.apply(candidate, this.probeRow));
    }
}

