/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.runtime.join.batch.hashtable;

import java.io.IOException;
import java.nio.ByteOrder;
import java.util.Arrays;
import java.util.List;
import org.apache.flink.core.memory.MemorySegment;
import org.apache.flink.runtime.io.disk.RandomAccessInputView;
import org.apache.flink.runtime.memory.AbstractPagedInputView;
import org.apache.flink.table.dataformat.BinaryRow;
import org.apache.flink.table.runtime.join.batch.hashtable.BinaryHashPartition;
import org.apache.flink.table.runtime.join.batch.hashtable.BinaryHashTable;
import org.apache.flink.table.runtime.join.batch.hashtable.HashTableBloomFilter;
import org.apache.flink.util.MathUtils;
import org.apache.flink.util.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class BinaryHashBucketArea {
    private static final Logger LOG = LoggerFactory.getLogger(BinaryHashBucketArea.class);
    static final int BUCKET_SIZE_BITS = 7;
    static final int BUCKET_SIZE = 128;
    static final int HASH_CODE_LEN = 4;
    static final int POINTER_LEN = 4;
    public static final int RECORD_BYTES = 8;
    static final int HEADER_COUNT_OFFSET = 0;
    static final int PROBED_FLAG_OFFSET = 2;
    static final int HEADER_FORWARD_OFFSET = 4;
    static final int BUCKET_HEADER_LENGTH = 8;
    static final int NUM_ENTRIES_PER_BUCKET = 15;
    static final int BUCKET_POINTER_START_OFFSET = 68;
    static final int BUCKET_FORWARD_POINTER_NOT_SET = -1;
    private static final long BUCKET_HEADER_INIT = ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN ? -4294967296L : 0xFFFFFFFFL;
    private static final double DEFAULT_LOAD_FACTOR = 0.75;
    final BinaryHashTable table;
    private final double estimatedRowCount;
    private final double loadFactor;
    BinaryHashPartition partition;
    private int size;
    MemorySegment[] buckets;
    int numBuckets;
    private int numBucketsMask;
    MemorySegment[] overflowSegments;
    int numOverflowSegments;
    private int nextOverflowBucket;
    private int threshold;
    private boolean inReHash = false;

    BinaryHashBucketArea(BinaryHashTable table, double estimatedRowCount, int maxSegs) {
        this(table, estimatedRowCount, maxSegs, 0.75);
    }

    private BinaryHashBucketArea(BinaryHashTable table, double estimatedRowCount, int maxSegs, double loadFactor) {
        this.table = table;
        this.estimatedRowCount = estimatedRowCount;
        this.loadFactor = loadFactor;
        this.size = 0;
        int minNumBuckets = (int)Math.ceil(estimatedRowCount / loadFactor / 15.0);
        int bucketNumSegs = Math.max(1, Math.min(maxSegs, (minNumBuckets >>> table.bucketsPerSegmentBits) + ((minNumBuckets & table.bucketsPerSegmentMask) == 0 ? 0 : 1)));
        int numBuckets = MathUtils.roundDownToPowerOf2(bucketNumSegs << table.bucketsPerSegmentBits);
        int threshold = (int)((double)(numBuckets * 15) * loadFactor);
        MemorySegment[] buckets = new MemorySegment[bucketNumSegs];
        table.ensureNumBuffersReturned(bucketNumSegs);
        for (int i = 0; i < bucketNumSegs; ++i) {
            MemorySegment seg = table.getNextBuffer();
            this.initMemorySegment(seg);
            buckets[i] = seg;
        }
        this.setNewBuckets(buckets, numBuckets, threshold);
    }

    private void setNewBuckets(MemorySegment[] buckets, int numBuckets, int threshold) {
        this.buckets = buckets;
        Preconditions.checkArgument(MathUtils.isPowerOf2(numBuckets));
        this.numBuckets = numBuckets;
        this.numBucketsMask = numBuckets - 1;
        this.overflowSegments = new MemorySegment[2];
        this.numOverflowSegments = 0;
        this.nextOverflowBucket = 0;
        this.threshold = threshold;
    }

    public void setPartition(BinaryHashPartition partition2) {
        this.partition = partition2;
    }

    private void resize(boolean spillingAllowed) throws IOException {
        MemorySegment[] oldBuckets = this.buckets;
        int oldNumBuckets = this.numBuckets;
        MemorySegment[] oldOverflowSegments = this.overflowSegments;
        int newNumSegs = oldBuckets.length * 2;
        int newNumBuckets = MathUtils.roundDownToPowerOf2(newNumSegs << this.table.bucketsPerSegmentBits);
        int newThreshold = (int)((double)(newNumBuckets * 15) * this.loadFactor);
        if (!spillingAllowed && newNumSegs > this.table.remainBuffers()) {
            return;
        }
        MemorySegment[] newBuckets = new MemorySegment[newNumSegs];
        for (int i = 0; i < newNumSegs; ++i) {
            MemorySegment seg = this.table.getNextBuffer();
            if (seg == null) {
                int spilledPart = this.table.spillPartition();
                if (spilledPart == this.partition.partitionNumber) {
                    for (int j2 = 0; j2 < i; ++j2) {
                        this.table.free(newBuckets[j2]);
                    }
                    return;
                }
                seg = this.table.getNextBuffer();
                if (seg == null) {
                    throw new RuntimeException("Bug in HybridHashJoin: No memory became available after spilling a partition.");
                }
            }
            this.initMemorySegment(seg);
            newBuckets[i] = seg;
        }
        this.setNewBuckets(newBuckets, newNumBuckets, newThreshold);
        this.reHash(oldBuckets, oldNumBuckets, oldOverflowSegments);
    }

    private void reHash(MemorySegment[] oldBuckets, int oldNumBuckets, MemorySegment[] oldOverflowSegments) throws IOException {
        long reHashStartTime = System.currentTimeMillis();
        this.inReHash = true;
        int scanCount = -1;
        block0: while (++scanCount < oldNumBuckets) {
            int bucketArrayPos = scanCount >> this.table.bucketsPerSegmentBits;
            int bucketInSegOffset = (scanCount & this.table.bucketsPerSegmentMask) << 7;
            MemorySegment bucketSeg = oldBuckets[bucketArrayPos];
            int countInBucket = bucketSeg.getShort(bucketInSegOffset + 0);
            int numInBucket = 0;
            while (countInBucket != 0) {
                int hashCodeOffset = bucketInSegOffset + 8;
                int pointerOffset = bucketInSegOffset + 68;
                while (numInBucket < countInBucket) {
                    int pointer;
                    int hashCode2 = bucketSeg.getInt(hashCodeOffset);
                    if (!this.insertToBucket(hashCode2, pointer = bucketSeg.getInt(pointerOffset), true, false)) {
                        this.buildBloomFilterAndFree(oldBuckets, oldNumBuckets, oldOverflowSegments);
                        return;
                    }
                    ++numInBucket;
                    hashCodeOffset += 4;
                    pointerOffset += 4;
                }
                int forwardPointer = bucketSeg.getInt(bucketInSegOffset + 4);
                if (forwardPointer == -1) continue block0;
                int overflowSegIndex = forwardPointer >>> this.table.segmentSizeBits;
                bucketSeg = oldOverflowSegments[overflowSegIndex];
                bucketInSegOffset = forwardPointer & this.table.segmentSizeMask;
                countInBucket = bucketSeg.getShort(bucketInSegOffset + 0);
                numInBucket = 0;
            }
        }
        this.freeMemory(oldBuckets, oldOverflowSegments);
        this.inReHash = false;
        LOG.info("The rehash take {} ms for {} segments", (Object)(System.currentTimeMillis() - reHashStartTime), (Object)this.numBuckets);
    }

    private void freeMemory(MemorySegment[] buckets, MemorySegment[] overflowSegments) {
        for (MemorySegment segment : buckets) {
            this.table.free(segment);
        }
        for (MemorySegment segment : overflowSegments) {
            if (segment == null) continue;
            this.table.free(segment);
        }
    }

    private void initMemorySegment(MemorySegment seg) {
        for (int k = 0; k < this.table.bucketsPerSegment; ++k) {
            int bucketOffset = k * 128;
            seg.putLong(bucketOffset + 0, BUCKET_HEADER_INIT);
        }
    }

    private boolean insertToBucket(MemorySegment bucket, int bucketInSegmentPos, int hashCode2, int pointer, boolean spillingAllowed, boolean sizeAddAndCheckResize) throws IOException {
        short count = bucket.getShort(bucketInSegmentPos + 0);
        if (count < 15) {
            bucket.putShort(bucketInSegmentPos + 0, (short)(count + 1));
            bucket.putInt(bucketInSegmentPos + 8 + count * 4, hashCode2);
            bucket.putInt(bucketInSegmentPos + 68 + count * 4, pointer);
        } else {
            int overflowBucketNum;
            int overflowBucketOffset;
            MemorySegment overflowSeg;
            int forwardForNewBucket;
            int originalForwardPointer = bucket.getInt(bucketInSegmentPos + 4);
            if (originalForwardPointer != -1) {
                int overflowSegIndex = originalForwardPointer >>> this.table.segmentSizeBits;
                MemorySegment seg = this.overflowSegments[overflowSegIndex];
                int segOffset = originalForwardPointer & this.table.segmentSizeMask;
                short obCount = seg.getShort(segOffset + 0);
                if (obCount < 15) {
                    seg.putShort(segOffset + 0, (short)(obCount + 1));
                    seg.putInt(segOffset + 8 + obCount * 4, hashCode2);
                    seg.putInt(segOffset + 68 + obCount * 4, pointer);
                    return true;
                }
                forwardForNewBucket = originalForwardPointer;
            } else {
                forwardForNewBucket = -1;
            }
            if (this.nextOverflowBucket == 0) {
                overflowSeg = this.table.getNextBuffer();
                if (overflowSeg == null) {
                    if (!spillingAllowed) {
                        throw new IOException("Hashtable memory ran out in a non-spillable situation. This is probably related to wrong size calculations.");
                    }
                    int spilledPart = this.table.spillPartition();
                    if (spilledPart == this.partition.partitionNumber) {
                        return false;
                    }
                    overflowSeg = this.table.getNextBuffer();
                    if (overflowSeg == null) {
                        throw new RuntimeException("Bug in HybridHashJoin: No memory became available after spilling a partition.");
                    }
                }
                overflowBucketOffset = 0;
                overflowBucketNum = this.numOverflowSegments;
                if (this.overflowSegments.length <= this.numOverflowSegments) {
                    MemorySegment[] newSegsArray = new MemorySegment[this.overflowSegments.length * 2];
                    System.arraycopy(this.overflowSegments, 0, newSegsArray, 0, this.overflowSegments.length);
                    this.overflowSegments = newSegsArray;
                }
                this.overflowSegments[this.numOverflowSegments] = overflowSeg;
                ++this.numOverflowSegments;
            } else {
                overflowBucketNum = this.numOverflowSegments - 1;
                overflowSeg = this.overflowSegments[overflowBucketNum];
                overflowBucketOffset = this.nextOverflowBucket << 7;
            }
            this.nextOverflowBucket = this.nextOverflowBucket == this.table.bucketsPerSegmentMask ? 0 : this.nextOverflowBucket + 1;
            overflowSeg.putInt(overflowBucketOffset + 4, forwardForNewBucket);
            int pointerToNewBucket = (overflowBucketNum << this.table.segmentSizeBits) + overflowBucketOffset;
            bucket.putInt(bucketInSegmentPos + 4, pointerToNewBucket);
            overflowSeg.putInt(overflowBucketOffset + 8, hashCode2);
            overflowSeg.putInt(overflowBucketOffset + 68, pointer);
            overflowSeg.putShort(overflowBucketOffset + 0, (short)1);
            overflowSeg.putShort(overflowBucketOffset + 2, (short)0);
        }
        if (sizeAddAndCheckResize && ++this.size > this.threshold) {
            this.resize(spillingAllowed);
        }
        return true;
    }

    private int findBucket(int hashCode2) {
        return hashCode2 & this.numBucketsMask;
    }

    boolean insertToBucket(int hashCode2, int pointer, boolean spillingAllowed, boolean sizeAddAndCheckResize) throws IOException {
        int posHashCode = this.findBucket(hashCode2);
        int bucketArrayPos = posHashCode >> this.table.bucketsPerSegmentBits;
        int bucketInSegmentPos = (posHashCode & this.table.bucketsPerSegmentMask) << 7;
        MemorySegment bucket = this.buckets[bucketArrayPos];
        return this.insertToBucket(bucket, bucketInSegmentPos, hashCode2, pointer, spillingAllowed, sizeAddAndCheckResize);
    }

    boolean appendRecordAndInsert(BinaryRow record, int hashCode2) throws IOException {
        int posHashCode = this.findBucket(hashCode2);
        int bucketArrayPos = posHashCode >> this.table.bucketsPerSegmentBits;
        int bucketInSegmentPos = (posHashCode & this.table.bucketsPerSegmentMask) << 7;
        MemorySegment bucket = this.buckets[bucketArrayPos];
        if (!(this.table.tryDistinctBuildRow && this.partition.isInMemory() && this.findFirstSameBuildRow(bucket, hashCode2, bucketInSegmentPos, record))) {
            int pointer = this.partition.insertIntoBuildBuffer(record);
            if (pointer != -1) {
                this.insertToBucket(bucket, bucketInSegmentPos, hashCode2, pointer, true, true);
                return true;
            }
            return false;
        }
        return true;
    }

    private boolean findFirstSameBuildRow(MemorySegment bucket, int searchHashCode, int bucketInSegmentOffset, BinaryRow buildRowToInsert) {
        int posInSegment = bucketInSegmentOffset + 8;
        int countInBucket = bucket.getShort(bucketInSegmentOffset + 0);
        int numInBucket = 0;
        RandomAccessInputView view = this.partition.getBuildStateInputView();
        while (countInBucket != 0) {
            while (numInBucket < countInBucket) {
                int thisCode = bucket.getInt(posInSegment);
                posInSegment += 4;
                if (thisCode == searchHashCode) {
                    int pointer = bucket.getInt(bucketInSegmentOffset + 68 + numInBucket * 4);
                    ++numInBucket;
                    try {
                        view.setReadPosition((long)pointer);
                        BinaryRow row2 = this.table.binaryBuildSideSerializer.mapFromPages(this.table.reuseBuildRow, (AbstractPagedInputView)view);
                        if (!buildRowToInsert.equals(row2)) continue;
                        return true;
                    }
                    catch (IOException e2) {
                        throw new RuntimeException("Error deserializing key or value from the hashtable: " + e2.getMessage(), e2);
                    }
                }
                ++numInBucket;
            }
            int forwardPointer = bucket.getInt(bucketInSegmentOffset + 4);
            if (forwardPointer == -1) {
                return false;
            }
            int overflowSegIndex = forwardPointer >>> this.table.segmentSizeBits;
            bucket = this.overflowSegments[overflowSegIndex];
            bucketInSegmentOffset = forwardPointer & this.table.segmentSizeMask;
            countInBucket = bucket.getShort(bucketInSegmentOffset + 0);
            posInSegment = bucketInSegmentOffset + 8;
            numInBucket = 0;
        }
        return false;
    }

    void startLookup(int hashCode2) {
        int posHashCode = this.findBucket(hashCode2);
        int bucketArrayPos = posHashCode >> this.table.bucketsPerSegmentBits;
        int bucketInSegmentOffset = (posHashCode & this.table.bucketsPerSegmentMask) << 7;
        MemorySegment bucket = this.buckets[bucketArrayPos];
        this.table.bucketIterator.set(bucket, this.overflowSegments, this.partition, hashCode2, bucketInSegmentOffset);
    }

    void returnMemory(List<MemorySegment> target) {
        target.addAll(Arrays.asList(this.overflowSegments).subList(0, this.numOverflowSegments));
        target.addAll(Arrays.asList(this.buckets));
    }

    private void freeMemory() {
        this.table.availableMemory.addAll(Arrays.asList(this.overflowSegments).subList(0, this.numOverflowSegments));
        this.table.availableMemory.addAll(Arrays.asList(this.buckets));
    }

    void buildBloomFilterAndFree() {
        if (this.inReHash || !this.table.useBloomFilters) {
            this.freeMemory();
        } else {
            this.buildBloomFilterAndFree(this.buckets, this.numBuckets, this.overflowSegments);
        }
    }

    private void buildBloomFilterAndFree(MemorySegment[] buckets, int numBuckets, MemorySegment[] overflowSegments) {
        if (this.table.useBloomFilters) {
            long numRecords = (long)Math.max((double)this.partition.getBuildSideRecordCount() * 1.5, this.estimatedRowCount);
            int segSize = Math.min(Math.min(this.table.remainBuffers(), HashTableBloomFilter.optimalSegmentNumber(numRecords, this.table.pageSize(), 0.05)), this.table.maxInitBufferOfBucketArea(this.table.partitionsBeingBuilt.size()));
            if (segSize > 0) {
                HashTableBloomFilter filter = new HashTableBloomFilter(this.table.getNextBuffers(MathUtils.roundDownToPowerOf2(segSize)), numRecords);
                int scanCount = -1;
                block0: while (++scanCount < numBuckets) {
                    int bucketArrayPos = scanCount >> this.table.bucketsPerSegmentBits;
                    int bucketInSegOffset = (scanCount & this.table.bucketsPerSegmentMask) << 7;
                    MemorySegment bucketSeg = buckets[bucketArrayPos];
                    int countInBucket = bucketSeg.getShort(bucketInSegOffset + 0);
                    int numInBucket = 0;
                    while (countInBucket != 0) {
                        int hashCodeOffset = bucketInSegOffset + 8;
                        while (numInBucket < countInBucket) {
                            filter.addHash(bucketSeg.getInt(hashCodeOffset));
                            ++numInBucket;
                            hashCodeOffset += 4;
                        }
                        int forwardPointer = bucketSeg.getInt(bucketInSegOffset + 4);
                        if (forwardPointer == -1) continue block0;
                        int overflowSegIndex = forwardPointer >>> this.table.segmentSizeBits;
                        bucketSeg = overflowSegments[overflowSegIndex];
                        bucketInSegOffset = forwardPointer & this.table.segmentSizeMask;
                        countInBucket = bucketSeg.getShort(bucketInSegOffset + 0);
                        numInBucket = 0;
                    }
                }
                this.partition.bloomFilter = filter;
            }
        }
        this.freeMemory(buckets, overflowSegments);
    }
}

