/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.runtime.join.batch;

import java.io.Serializable;
import org.apache.flink.runtime.io.disk.iomanager.IOManager;
import org.apache.flink.streaming.api.operators.TwoInputSelection;
import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.table.codegen.CodeGenUtils;
import org.apache.flink.table.codegen.GeneratedJoinConditionFunction;
import org.apache.flink.table.codegen.GeneratedProjection;
import org.apache.flink.table.codegen.JoinConditionFunction;
import org.apache.flink.table.codegen.Projection;
import org.apache.flink.table.dataformat.BaseRow;
import org.apache.flink.table.dataformat.BinaryRow;
import org.apache.flink.table.dataformat.GenericRow;
import org.apache.flink.table.dataformat.JoinedRow;
import org.apache.flink.table.runtime.AbstractStreamOperatorWithMetrics;
import org.apache.flink.table.runtime.join.batch.HashJoinType;
import org.apache.flink.table.runtime.join.batch.hashtable.BinaryHashTable;
import org.apache.flink.table.runtime.util.RowIterator;
import org.apache.flink.table.runtime.util.StreamRecordCollector;
import org.apache.flink.table.types.RowType;
import org.apache.flink.table.typeutils.AbstractRowSerializer;
import org.apache.flink.util.Collector;
import org.apache.flink.util.Preconditions;
import org.codehaus.commons.compiler.CompileException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class HashJoinOperator
extends AbstractStreamOperatorWithMetrics<BaseRow>
implements TwoInputStreamOperator<BaseRow, BaseRow, BaseRow> {
    private static final Logger LOG = LoggerFactory.getLogger(HashJoinOperator.class);
    private final HashJoinParameter parameter;
    private final boolean reverseJoinFunction;
    final HashJoinType type;
    transient Class<JoinConditionFunction> condFuncClass;
    transient Class<Projection<BaseRow, BinaryRow>> buildProjectionClass;
    transient Class<Projection<BaseRow, BinaryRow>> probeProjectionClass;
    private transient BinaryHashTable table;
    transient Collector<BaseRow> collector;
    transient BaseRow buildSideNullRow;
    transient BaseRow probeSideNullRow;
    private transient JoinedRow joinedRow;

    HashJoinOperator(HashJoinParameter parameter) {
        this.parameter = parameter;
        this.type = parameter.type;
        this.reverseJoinFunction = parameter.reverseJoinFunction;
    }

    @Override
    public void open() throws Exception {
        super.open();
        this.cookGeneratedClasses(this.getContainingTask().getUserCodeClassLoader());
        IOManager ioManager = this.getContainingTask().getEnvironment().getIOManager();
        AbstractRowSerializer buildSerializer = (AbstractRowSerializer)this.getOperatorConfig().getTypeSerializerIn1(this.getUserCodeClassloader());
        AbstractRowSerializer probeSerializer = (AbstractRowSerializer)this.getOperatorConfig().getTypeSerializerIn2(this.getUserCodeClassloader());
        boolean hashJoinUseBitMaps = this.getContainingTask().getEnvironment().getTaskConfiguration().getBoolean("taskmanager.runtime.hashjoin-bloom-filters", false);
        int parallel = this.getRuntimeContext().getNumberOfParallelSubtasks();
        this.table = new BinaryHashTable(this.getSqlConf(), this.getContainingTask(), buildSerializer, probeSerializer, this.buildProjectionClass.newInstance(), this.probeProjectionClass.newInstance(), this.getContainingTask().getEnvironment().getMemoryManager(), this.parameter.reservedMemorySize, this.parameter.maxMemorySize, this.parameter.perRequestMemorySize, ioManager, this.parameter.buildRowSize, this.parameter.buildRowCount / (long)parallel, hashJoinUseBitMaps, this.type, this.condFuncClass.newInstance(), this.reverseJoinFunction, this.parameter.filterNullKeys, this.parameter.tryDistinctBuildRow);
        this.collector = new StreamRecordCollector<BaseRow>(this.output);
        this.buildSideNullRow = new GenericRow(buildSerializer.getNumFields());
        this.probeSideNullRow = new GenericRow(probeSerializer.getNumFields());
        this.joinedRow = new JoinedRow();
        this.getMetricGroup().gauge("memoryUsedSizeInBytes", this.table::getUsedMemoryInBytes);
        this.getMetricGroup().gauge("numSpillFiles", this.table::getNumSpillFiles);
        this.getMetricGroup().gauge("spillInBytes", this.table::getSpillInBytes);
    }

    protected void cookGeneratedClasses(ClassLoader cl) throws CompileException {
        long startTime = System.currentTimeMillis();
        this.condFuncClass = CodeGenUtils.compile(cl, this.parameter.condFuncCode.name(), this.parameter.condFuncCode.code());
        this.buildProjectionClass = CodeGenUtils.compile(cl, this.parameter.buildProjectionCode.name(), this.parameter.buildProjectionCode.code());
        this.probeProjectionClass = CodeGenUtils.compile(cl, this.parameter.probeProjectionCode.name(), this.parameter.probeProjectionCode.code());
        this.parameter.condFuncCode = null;
        this.parameter.buildProjectionCode = null;
        this.parameter.probeProjectionCode = null;
        long endTime = System.currentTimeMillis();
        LOG.info("Compiling generated codes, used time: " + (endTime - startTime) + "ms.");
    }

    public TwoInputSelection firstInputSelection() {
        return TwoInputSelection.FIRST;
    }

    public TwoInputSelection processElement1(StreamRecord<BaseRow> element) throws Exception {
        this.table.putBuildRow((BaseRow)element.getValue());
        return TwoInputSelection.FIRST;
    }

    public TwoInputSelection processElement2(StreamRecord<BaseRow> element) throws Exception {
        if (this.table.tryProbe((BaseRow)element.getValue())) {
            this.joinWithNextKey();
        }
        return TwoInputSelection.SECOND;
    }

    public void endInput1() throws Exception {
        LOG.info("Finish build phase.");
        this.table.endBuild();
    }

    public void endInput2() throws Exception {
        LOG.info("Finish probe phase.");
        while (this.table.nextMatching()) {
            this.joinWithNextKey();
        }
        LOG.info("Finish rebuild phase.");
    }

    private void joinWithNextKey() throws Exception {
        this.join(this.table.getBuildSideIterator(), this.table.getCurrentProbeRow());
    }

    public abstract void join(RowIterator<BinaryRow> var1, BaseRow var2) throws Exception;

    void innerJoin(RowIterator<BinaryRow> buildIter, BaseRow probeRow) throws Exception {
        this.collect(buildIter.getRow(), probeRow);
        while (buildIter.advanceNext()) {
            this.collect(buildIter.getRow(), probeRow);
        }
    }

    void buildOuterJoin(RowIterator<BinaryRow> buildIter) throws Exception {
        this.collect(buildIter.getRow(), this.probeSideNullRow);
        while (buildIter.advanceNext()) {
            this.collect(buildIter.getRow(), this.probeSideNullRow);
        }
    }

    void collect(BaseRow row1, BaseRow row2) throws Exception {
        if (this.reverseJoinFunction) {
            this.collector.collect(this.joinedRow.replace(row2, row1));
        } else {
            this.collector.collect(this.joinedRow.replace(row1, row2));
        }
    }

    @Override
    public void close() throws Exception {
        super.close();
        if (this.table != null) {
            this.table.close();
            this.table.free();
            this.table = null;
        }
    }

    public static HashJoinOperator newHashJoinOperator(long minMemorySize, long maxMemorySize, long eachRequestMemorySize, HashJoinType type, GeneratedJoinConditionFunction condFuncCode, boolean reverseJoinFunction, boolean[] filterNullKeys, GeneratedProjection buildProjectionCode, GeneratedProjection probeProjectionCode, boolean tryDistinctBuildRow, int buildRowSize, long buildRowCount, long probeRowCount, RowType keyType) {
        HashJoinParameter parameter = new HashJoinParameter(minMemorySize, maxMemorySize, eachRequestMemorySize, type, condFuncCode, reverseJoinFunction, filterNullKeys, buildProjectionCode, probeProjectionCode, tryDistinctBuildRow, buildRowSize, buildRowCount, probeRowCount, keyType);
        switch (type) {
            case INNER: {
                return new InnerHashJoinOperator(parameter);
            }
            case BUILD_OUTER: {
                return new BuildOuterHashJoinOperator(parameter);
            }
            case PROBE_OUTER: {
                return new ProbeOuterHashJoinOperator(parameter);
            }
            case FULL_OUTER: {
                return new FullOuterHashJoinOperator(parameter);
            }
            case SEMI: {
                return new SemiHashJoinOperator(parameter);
            }
            case ANTI: {
                return new AntiHashJoinOperator(parameter);
            }
            case BUILD_LEFT_SEMI: 
            case BUILD_LEFT_ANTI: {
                return new BuildLeftSemiOrAntiHashJoinOperator(parameter);
            }
        }
        throw new IllegalArgumentException("invalid: " + (Object)((Object)type));
    }

    private static class BuildLeftSemiOrAntiHashJoinOperator
    extends HashJoinOperator {
        BuildLeftSemiOrAntiHashJoinOperator(HashJoinParameter parameter) {
            super(parameter);
        }

        @Override
        public void join(RowIterator<BinaryRow> buildIter, BaseRow probeRow) throws Exception {
            block4: {
                if (!buildIter.advanceNext()) break block4;
                if (probeRow != null) {
                    while (buildIter.advanceNext()) {
                    }
                } else {
                    this.collector.collect(buildIter.getRow());
                    while (buildIter.advanceNext()) {
                        this.collector.collect(buildIter.getRow());
                    }
                }
            }
        }
    }

    private static class AntiHashJoinOperator
    extends HashJoinOperator {
        AntiHashJoinOperator(HashJoinParameter parameter) {
            super(parameter);
        }

        @Override
        public void join(RowIterator<BinaryRow> buildIter, BaseRow probeRow) throws Exception {
            Preconditions.checkNotNull(probeRow);
            if (!buildIter.advanceNext()) {
                this.collector.collect(probeRow);
            }
        }
    }

    private static class SemiHashJoinOperator
    extends HashJoinOperator {
        SemiHashJoinOperator(HashJoinParameter parameter) {
            super(parameter);
        }

        @Override
        public void join(RowIterator<BinaryRow> buildIter, BaseRow probeRow) throws Exception {
            Preconditions.checkNotNull(probeRow);
            if (buildIter.advanceNext()) {
                this.collector.collect(probeRow);
            }
        }
    }

    private static class FullOuterHashJoinOperator
    extends HashJoinOperator {
        FullOuterHashJoinOperator(HashJoinParameter parameter) {
            super(parameter);
        }

        @Override
        public void join(RowIterator<BinaryRow> buildIter, BaseRow probeRow) throws Exception {
            if (buildIter.advanceNext()) {
                if (probeRow != null) {
                    this.innerJoin(buildIter, probeRow);
                } else {
                    this.buildOuterJoin(buildIter);
                }
            } else if (probeRow != null) {
                this.collect(this.buildSideNullRow, probeRow);
            }
        }
    }

    private static class ProbeOuterHashJoinOperator
    extends HashJoinOperator {
        ProbeOuterHashJoinOperator(HashJoinParameter parameter) {
            super(parameter);
        }

        @Override
        public void join(RowIterator<BinaryRow> buildIter, BaseRow probeRow) throws Exception {
            if (buildIter.advanceNext()) {
                if (probeRow != null) {
                    this.innerJoin(buildIter, probeRow);
                }
            } else if (probeRow != null) {
                this.collect(this.buildSideNullRow, probeRow);
            }
        }
    }

    private static class BuildOuterHashJoinOperator
    extends HashJoinOperator {
        BuildOuterHashJoinOperator(HashJoinParameter parameter) {
            super(parameter);
        }

        @Override
        public void join(RowIterator<BinaryRow> buildIter, BaseRow probeRow) throws Exception {
            if (buildIter.advanceNext()) {
                if (probeRow != null) {
                    this.innerJoin(buildIter, probeRow);
                } else {
                    this.buildOuterJoin(buildIter);
                }
            }
        }
    }

    private static class InnerHashJoinOperator
    extends HashJoinOperator {
        InnerHashJoinOperator(HashJoinParameter parameter) {
            super(parameter);
        }

        @Override
        public void join(RowIterator<BinaryRow> buildIter, BaseRow probeRow) throws Exception {
            if (buildIter.advanceNext() && probeRow != null) {
                this.innerJoin(buildIter, probeRow);
            }
        }
    }

    static class HashJoinParameter
    implements Serializable {
        long reservedMemorySize;
        long maxMemorySize;
        long perRequestMemorySize;
        HashJoinType type;
        GeneratedJoinConditionFunction condFuncCode;
        boolean reverseJoinFunction;
        boolean[] filterNullKeys;
        GeneratedProjection buildProjectionCode;
        GeneratedProjection probeProjectionCode;
        boolean tryDistinctBuildRow;
        int buildRowSize;
        long buildRowCount;
        long probeRowCount;
        RowType keyType;

        HashJoinParameter(long reservedMemorySize, long maxMemorySize, long perRequestMemorySize, HashJoinType type, GeneratedJoinConditionFunction condFuncCode, boolean reverseJoinFunction, boolean[] filterNullKeys, GeneratedProjection buildProjectionCode, GeneratedProjection probeProjectionCode, boolean tryDistinctBuildRow, int buildRowSize, long buildRowCount, long probeRowCount, RowType keyType) {
            this.reservedMemorySize = reservedMemorySize;
            this.maxMemorySize = maxMemorySize;
            this.perRequestMemorySize = perRequestMemorySize;
            this.type = type;
            this.condFuncCode = condFuncCode;
            this.reverseJoinFunction = reverseJoinFunction;
            this.filterNullKeys = filterNullKeys;
            this.buildProjectionCode = buildProjectionCode;
            this.probeProjectionCode = probeProjectionCode;
            this.tryDistinctBuildRow = tryDistinctBuildRow;
            this.buildRowSize = buildRowSize;
            this.buildRowCount = buildRowCount;
            this.probeRowCount = probeRowCount;
            this.keyType = keyType;
        }
    }
}

