/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.runtime.join.batch;

import java.io.IOException;
import java.util.BitSet;
import java.util.List;
import org.apache.flink.core.memory.MemorySegment;
import org.apache.flink.runtime.io.disk.iomanager.IOManager;
import org.apache.flink.runtime.memory.MemoryManager;
import org.apache.flink.streaming.api.operators.TwoInputSelection;
import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.table.codegen.CodeGenUtils;
import org.apache.flink.table.codegen.GeneratedJoinConditionFunction;
import org.apache.flink.table.codegen.GeneratedProjection;
import org.apache.flink.table.codegen.GeneratedSorter;
import org.apache.flink.table.codegen.JoinConditionFunction;
import org.apache.flink.table.codegen.Projection;
import org.apache.flink.table.dataformat.BaseRow;
import org.apache.flink.table.dataformat.BinaryRow;
import org.apache.flink.table.dataformat.GenericRow;
import org.apache.flink.table.dataformat.JoinedRow;
import org.apache.flink.table.dataformat.util.BinaryRowUtil;
import org.apache.flink.table.plan.FlinkJoinRelType;
import org.apache.flink.table.runtime.AbstractStreamOperatorWithMetrics;
import org.apache.flink.table.runtime.join.batch.NullAwareJoinHelper;
import org.apache.flink.table.runtime.sort.RecordComparator;
import org.apache.flink.table.runtime.util.ResettableExternalBuffer;
import org.apache.flink.table.runtime.util.StreamRecordCollector;
import org.apache.flink.table.typeutils.AbstractRowSerializer;
import org.apache.flink.util.Collector;
import org.apache.flink.util.Preconditions;
import org.codehaus.commons.compiler.CompileException;

public class MergeJoinOperator
extends AbstractStreamOperatorWithMetrics<BaseRow>
implements TwoInputStreamOperator<BaseRow, BaseRow, BaseRow> {
    private final long leftBufferMemory;
    private final long rightBufferMemory;
    private final FlinkJoinRelType type;
    private final GeneratedJoinConditionFunction condFuncCode;
    private final GeneratedProjection projectionCode1;
    private final GeneratedProjection projectionCode2;
    private final GeneratedSorter keyGSorter;
    private transient JoinConditionFunction condFunc;
    private transient RecordComparator keyComparator;
    private transient Collector<BaseRow> collector;
    private transient boolean isFinished1;
    private transient boolean isFinished2;
    private transient AbstractRowSerializer<BaseRow> serializer1;
    private transient AbstractRowSerializer<BaseRow> serializer2;
    private transient WrappedBuffer buffer1;
    private transient WrappedBuffer buffer2;
    private transient BaseRow leftNullRow;
    private transient BaseRow rightNullRow;
    private transient JoinedRow joinedRow;
    private transient boolean isFindStage;
    private transient boolean advance1;
    private transient boolean advance2;
    private transient BinaryRow mergeKey;
    private transient int mergeCount1;
    private transient int mergeCount2;
    private transient BitSet mergeBs1;
    private transient BitSet mergeBs2;
    private transient boolean leftIsBuild;
    private transient MemoryManager memManager;
    private transient IOManager ioManager;
    private final int[] nullFilterKeys;
    private final boolean nullSafe;
    private final boolean filterAllNulls;

    public MergeJoinOperator(long leftBufferMemory, long rightBufferMemory, FlinkJoinRelType type, GeneratedJoinConditionFunction condFuncCode, GeneratedProjection projectionCode1, GeneratedProjection projectionCode2, GeneratedSorter keyGSorter, boolean[] filterNulls) {
        if (type != FlinkJoinRelType.INNER && type != FlinkJoinRelType.LEFT && type != FlinkJoinRelType.RIGHT && type != FlinkJoinRelType.FULL) {
            throw new RuntimeException("Merge join operator only supports inner/left outer/right outer/full outer join currently.");
        }
        LOG.info("Initializing merge join operator...\nleftBufferMemory = " + leftBufferMemory + ", rightBufferMemory = " + rightBufferMemory);
        this.leftBufferMemory = leftBufferMemory;
        this.rightBufferMemory = rightBufferMemory;
        this.type = type;
        this.condFuncCode = condFuncCode;
        this.projectionCode1 = projectionCode1;
        this.projectionCode2 = projectionCode2;
        this.keyGSorter = keyGSorter;
        this.nullFilterKeys = NullAwareJoinHelper.getNullFilterKeys(filterNulls);
        this.nullSafe = this.nullFilterKeys.length == 0;
        this.filterAllNulls = this.nullFilterKeys.length == filterNulls.length;
    }

    @Override
    public void open() throws Exception {
        super.open();
        this.isFinished1 = false;
        this.isFinished2 = false;
        this.collector = new StreamRecordCollector<BaseRow>(this.output);
        this.serializer1 = (AbstractRowSerializer)this.getOperatorConfig().getTypeSerializerIn1(this.getUserCodeClassloader());
        this.serializer2 = (AbstractRowSerializer)this.getOperatorConfig().getTypeSerializerIn2(this.getUserCodeClassloader());
        this.leftNullRow = new GenericRow(this.serializer1.getNumFields());
        this.rightNullRow = new GenericRow(this.serializer2.getNumFields());
        this.joinedRow = new JoinedRow();
        CookedClasses classes = this.cookGeneratedClasses(this.getContainingTask().getUserCodeClassLoader());
        this.condFunc = classes.condFuncClass.newInstance();
        this.keyComparator = classes.keyComparatorClass.newInstance();
        this.keyComparator.init(this.keyGSorter.serializers(), this.keyGSorter.comparators());
        this.memManager = this.getContainingTask().getEnvironment().getMemoryManager();
        this.ioManager = this.getContainingTask().getEnvironment().getIOManager();
        Projection projection1 = classes.projectionClass1.newInstance();
        Projection projection2 = classes.projectionClass2.newInstance();
        int pageNum1 = (int)(this.leftBufferMemory / (long)this.memManager.getPageSize());
        List mem1 = this.memManager.allocatePages((Object)this.getContainingTask(), pageNum1);
        this.buffer1 = new WrappedBuffer(mem1, this.serializer1, projection1);
        int pageNum2 = (int)(this.rightBufferMemory / (long)this.memManager.getPageSize());
        List mem2 = this.memManager.allocatePages((Object)this.getContainingTask(), pageNum2);
        this.buffer2 = new WrappedBuffer(mem2, this.serializer2, projection2);
        this.initGauge();
        this.initJoin();
    }

    protected CookedClasses cookGeneratedClasses(ClassLoader cl) throws CompileException {
        return new CookedClasses(CodeGenUtils.compile(cl, this.condFuncCode.name(), this.condFuncCode.code()), CodeGenUtils.compile(cl, this.keyGSorter.comparator().name(), this.keyGSorter.comparator().code()), CodeGenUtils.compile(cl, this.projectionCode1.name(), this.projectionCode1.code()), CodeGenUtils.compile(cl, this.projectionCode2.name(), this.projectionCode2.code()));
    }

    private void initJoin() {
        this.advance1 = true;
        this.advance2 = true;
        this.isFindStage = true;
        if (this.type.isLeftOuter()) {
            this.mergeBs1 = new BitSet();
        }
        if (this.type.isRightOuter()) {
            this.mergeBs2 = new BitSet();
        }
    }

    private void initGauge() {
        this.getMetricGroup().gauge("memoryUsedSizeInBytes", () -> this.buffer1.getUsedMemoryInBytes() + this.buffer2.getUsedMemoryInBytes());
        this.getMetricGroup().gauge("numSpillFiles", () -> this.buffer1.getNumSpillFiles() + this.buffer2.getNumSpillFiles());
        this.getMetricGroup().gauge("spillInBytes", () -> this.buffer1.getSpillInBytes() + this.buffer2.getSpillInBytes());
    }

    public TwoInputSelection firstInputSelection() {
        return TwoInputSelection.ANY;
    }

    public TwoInputSelection processElement1(StreamRecord<BaseRow> record) throws Exception {
        this.buffer1.add((BaseRow)record.getValue());
        this.runJoin();
        this.buffer1.materializeCache();
        return TwoInputSelection.ANY;
    }

    public TwoInputSelection processElement2(StreamRecord<BaseRow> record) throws Exception {
        this.buffer2.add((BaseRow)record.getValue());
        this.runJoin();
        this.buffer2.materializeCache();
        return TwoInputSelection.ANY;
    }

    public void endInput1() throws Exception {
        this.buffer1.add(BinaryRowUtil.EMPTY_ROW);
        this.isFinished1 = true;
        if (this.isAllFinished()) {
            this.runJoin();
        }
    }

    public void endInput2() throws Exception {
        this.buffer2.add(BinaryRowUtil.EMPTY_ROW);
        this.isFinished2 = true;
        if (this.isAllFinished()) {
            this.runJoin();
        }
    }

    @Override
    public void close() throws Exception {
        super.close();
        if (this.buffer1 != null) {
            this.buffer1.close();
        }
        if (this.buffer2 != null) {
            this.buffer2.close();
        }
    }

    private boolean isAllFinished() {
        return this.isFinished1 && this.isFinished2;
    }

    private void runJoin() throws Exception {
        boolean stepResult = true;
        while (stepResult) {
            if (this.isFindStage) {
                stepResult = this.runFindStep();
                continue;
            }
            stepResult = this.runMergeStep();
        }
    }

    private boolean runFindStep() {
        int cmp;
        if (this.advance1 && !this.buffer1.hasNext() || this.advance2 && !this.buffer2.hasNext()) {
            return false;
        }
        BaseRow row1 = this.advance1 ? this.buffer1.nextRow() : this.buffer1.current;
        BaseRow row2 = this.advance2 ? this.buffer2.nextRow() : this.buffer2.current;
        boolean end1 = MergeJoinOperator.isEndRow(row1);
        boolean end2 = MergeJoinOperator.isEndRow(row2);
        if (end1 && end2) {
            return false;
        }
        if (end1) {
            cmp = 1;
        } else if (end2) {
            cmp = -1;
        } else if (this.buffer1.currentKeyShouldFilter) {
            cmp = -1;
        } else if (this.buffer2.currentKeyShouldFilter) {
            cmp = 1;
        } else {
            BinaryRow key1 = this.buffer1.currentKey;
            BinaryRow key2 = this.buffer2.currentKey;
            cmp = this.keyComparator.compare(key1, key2);
        }
        if (cmp < 0) {
            this.advance1 = true;
            this.advance2 = false;
            if (this.type.isLeftOuter()) {
                this.collect(row1, this.rightNullRow);
            }
            this.buffer1.discard();
        } else if (cmp > 0) {
            this.advance1 = false;
            this.advance2 = true;
            if (this.type.isRightOuter()) {
                this.collect(this.leftNullRow, row2);
            }
            this.buffer2.discard();
        } else {
            this.initMergeStage();
        }
        return true;
    }

    private void initMergeStage() {
        this.isFindStage = false;
        this.leftIsBuild = this.buffer1.isBuild();
        if (this.leftIsBuild) {
            this.advance1 = true;
            this.advance2 = false;
            this.buffer1.enterBuildMode();
        } else {
            this.advance1 = false;
            this.advance2 = true;
            this.buffer2.enterBuildMode();
        }
        this.mergeKey = this.buffer1.currentKey.copy();
        this.mergeCount1 = 0;
        this.mergeCount2 = 0;
        if (this.mergeBs1 != null) {
            this.mergeBs1.clear();
        }
        if (this.mergeBs2 != null) {
            this.mergeBs2.clear();
        }
    }

    private boolean runMergeStep() throws Exception {
        if (this.advance1 && this.buffer1.hasNext()) {
            ++this.mergeCount1;
            if (this.leftIsBuild) {
                this.advance1 = this.checkNextRowIsSameKey(this.buffer1);
                if (!this.advance1 && this.leftIsBuild) {
                    this.advance2 = true;
                }
            } else {
                this.advance1 = this.joinCurrentProbeRow(this.buffer1, this.buffer2, this.mergeCount2);
            }
            return true;
        }
        if (this.advance2 && this.buffer2.hasNext()) {
            ++this.mergeCount2;
            if (this.leftIsBuild) {
                this.advance2 = this.joinCurrentProbeRow(this.buffer2, this.buffer1, this.mergeCount1);
            } else {
                this.advance2 = this.checkNextRowIsSameKey(this.buffer2);
                if (!this.advance2 && !this.leftIsBuild) {
                    this.advance1 = true;
                }
            }
            return true;
        }
        if (!this.advance1 && !this.advance2) {
            this.endMergeStage();
            return true;
        }
        return false;
    }

    private boolean checkNextRowIsSameKey(WrappedBuffer buffer) {
        BaseRow row2 = buffer.nextRow();
        if (MergeJoinOperator.isEndRow(row2)) {
            return false;
        }
        BinaryRow key = buffer.currentKey;
        return !buffer.currentKeyShouldFilter && this.mergeKey.equals(key);
    }

    private boolean joinCurrentProbeRow(WrappedBuffer probeBuffer, WrappedBuffer buildBuffer, int buildCount) throws Exception {
        BaseRow probeRow = probeBuffer.current;
        boolean matched = false;
        ResettableExternalBuffer.BufferIterator buildIter = buildBuffer.externalIterator;
        buildIter.reset();
        for (int index = 0; index < buildCount; ++index) {
            boolean result = buildIter.advanceNext();
            Preconditions.checkState(result, "There is no next row in build buffer. This is a bug.");
            BinaryRow buildRow = buildIter.getRow();
            if (this.leftIsBuild) {
                if (!this.condFunc.apply(buildRow, probeRow)) continue;
                matched = true;
                this.collect(buildRow, probeRow);
                if (this.mergeBs1 == null) continue;
                this.mergeBs1.set(index);
                continue;
            }
            if (!this.condFunc.apply(probeRow, buildRow)) continue;
            matched = true;
            this.collect(probeRow, buildRow);
            if (this.mergeBs2 == null) continue;
            this.mergeBs2.set(index);
        }
        if (!matched) {
            if (this.mergeBs1 != null && !this.leftIsBuild) {
                this.collect(probeRow, this.rightNullRow);
            } else if (this.mergeBs2 != null && this.leftIsBuild) {
                this.collect(this.leftNullRow, probeRow);
            }
        }
        return this.checkNextRowIsSameKey(probeBuffer);
    }

    private void endMergeStage() throws IOException {
        if (this.leftIsBuild && this.buffer1.externalIterator.rowInSpill(this.buffer1.externalIterator.getBeginRow())) {
            LOG.warn("(In merge join operator) Build side iterator is in spilled file, this may decrease performance.");
        } else if (!this.leftIsBuild && this.buffer2.externalIterator.rowInSpill(this.buffer2.externalIterator.getBeginRow())) {
            LOG.warn("(In merge join operator) Build side iterator is in spilled file, this may decrease performance.");
        }
        if (this.leftIsBuild) {
            if (this.mergeBs1 != null) {
                this.buffer1.externalIterator.reset();
                for (int i = 0; i < this.mergeCount1; ++i) {
                    this.buffer1.externalIterator.advanceNext();
                    if (this.mergeBs1.get(i)) continue;
                    this.collect(this.buffer1.externalIterator.getRow(), this.rightNullRow);
                }
            }
        } else if (this.mergeBs2 != null) {
            this.buffer2.externalIterator.reset();
            for (int i = 0; i < this.mergeCount2; ++i) {
                this.buffer2.externalIterator.advanceNext();
                if (this.mergeBs2.get(i)) continue;
                this.collect(this.leftNullRow, this.buffer2.externalIterator.getRow());
            }
        }
        this.isFindStage = true;
        this.advance1 = false;
        this.advance2 = false;
        if (this.leftIsBuild) {
            this.buffer1.leaveBuildMode();
        } else {
            this.buffer2.leaveBuildMode();
        }
    }

    private void collect(BaseRow row1, BaseRow row2) {
        this.collector.collect(this.joinedRow.replace(row1, row2));
    }

    private static boolean isEndRow(BaseRow row2) {
        return row2.getArity() == 0;
    }

    private boolean shouldFilterNull(BinaryRow key) {
        return NullAwareJoinHelper.shouldFilter(this.nullSafe, this.filterAllNulls, this.nullFilterKeys, key);
    }

    private class WrappedBuffer {
        private BaseRow current = null;
        private BinaryRow currentKey = null;
        private Projection<BaseRow, BinaryRow> projection;
        private boolean currentKeyShouldFilter = true;
        private BaseRow cache = null;
        private ResettableExternalBuffer externalBuffer;
        private ResettableExternalBuffer.BufferIterator externalIterator;

        private WrappedBuffer(List<MemorySegment> memory, AbstractRowSerializer serializer, Projection<BaseRow, BinaryRow> projection) {
            this.projection = projection;
            this.externalBuffer = new ResettableExternalBuffer(MergeJoinOperator.this.memManager, MergeJoinOperator.this.ioManager, memory, serializer);
            this.externalIterator = this.externalBuffer.newIterator();
        }

        private void add(BaseRow row2) {
            Preconditions.checkState(this.cache == null, "Old cache must be materialized to buffer. This is a bug.");
            this.cache = row2;
        }

        private void materializeCache() throws IOException {
            if (this.cache != null) {
                this.externalBuffer.add(this.cache);
                if (this.isCacheJustRead()) {
                    this.externalIterator.advanceNext();
                    this.current = this.externalIterator.getRow();
                }
                this.cache = null;
            }
        }

        private void discard() {
            if (this.isCacheJustRead()) {
                this.reset();
            }
        }

        private boolean isCacheJustRead() {
            return this.current == this.cache;
        }

        private boolean hasNext() {
            return this.cache != null && this.current != this.cache || this.externalIterator.hasNext();
        }

        private BaseRow nextRow() {
            if (this.externalIterator.hasNext()) {
                this.externalIterator.advanceNext();
                this.current = this.externalIterator.getRow();
            } else {
                this.current = this.cache;
            }
            if (this.current == null || MergeJoinOperator.isEndRow(this.current)) {
                this.currentKey = null;
                this.currentKeyShouldFilter = true;
            } else {
                this.currentKey = this.projection.apply(this.current);
                this.currentKeyShouldFilter = MergeJoinOperator.this.shouldFilterNull(this.currentKey);
            }
            return this.current;
        }

        private void reset() {
            this.current = null;
            this.currentKey = null;
            this.cache = null;
            this.clearBuffer();
        }

        private void close() {
            this.externalIterator.close();
            this.externalBuffer.close();
        }

        private void clearBuffer() {
            if (this.externalBuffer.size() > 0) {
                this.externalIterator.close();
                this.externalBuffer.reset();
                this.externalIterator = this.externalBuffer.newIterator();
            }
        }

        private boolean isBuild() {
            return this.current == this.cache;
        }

        private void enterBuildMode() {
            this.clearBuffer();
        }

        private void leaveBuildMode() {
            if (this.externalIterator.hasNext()) {
                this.externalIterator.advanceNext();
            } else {
                this.clearBuffer();
            }
        }

        private long getUsedMemoryInBytes() {
            return this.externalBuffer.getUsedMemoryInBytes();
        }

        private int getNumSpillFiles() {
            return this.externalBuffer.getNumSpillFiles();
        }

        private long getSpillInBytes() {
            return this.externalBuffer.getSpillInBytes();
        }
    }

    protected static class CookedClasses {
        protected final Class<JoinConditionFunction> condFuncClass;
        protected final Class<RecordComparator> keyComparatorClass;
        protected final Class<Projection> projectionClass1;
        protected final Class<Projection> projectionClass2;

        protected CookedClasses(Class<JoinConditionFunction> condFuncClass, Class<RecordComparator> keyComparatorClass, Class<Projection> projectionClass1, Class<Projection> projectionClass2) {
            this.condFuncClass = condFuncClass;
            this.keyComparatorClass = keyComparatorClass;
            this.projectionClass1 = projectionClass1;
            this.projectionClass2 = projectionClass2;
        }
    }
}

