package org.apache.flink.table.plan.rules.logical;

import java.util.Set;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.logical.LogicalFilter;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.rules.LoptMultiJoin;
import org.apache.calcite.rel.rules.MultiJoin;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.flink.table.plan.rules.logical.SegmentTopTransformRule;
import scala.Predef$;
import scala.collection.JavaConversions$;
import scala.collection.immutable.IndexedSeq;
import scala.collection.immutable.IndexedSeq$;
import scala.collection.immutable.Range;
import scala.collection.mutable.Buffer;
import scala.reflect.ScalaSignature;
import scala.runtime.IntRef;
import scala.runtime.NonLocalReturnControl;
import scala.runtime.RichInt$;

/* compiled from: SegmentTopTransformRule.scala */
@ScalaSignature(bytes = "\u0006\u000113A!\u0001\u0002\u0001#\t)R*\u001e7uS*{\u0017N\u001c+p'\u0016<W.\u001a8u)>\u0004(BA\u0002\u0005\u0003\u001dawnZ5dC2T!!\u0002\u0004\u0002\u000bI,H.Z:\u000b\u0005\u001dA\u0011\u0001\u00029mC:T!!\u0003\u0006\u0002\u000bQ\f'\r\\3\u000b\u0005-a\u0011!\u00024mS:\\'BA\u0007\u000f\u0003\u0019\t\u0007/Y2iK*\tq\"A\u0002pe\u001e\u001c\u0001a\u0005\u0002\u0001%A\u00111\u0003F\u0007\u0002\u0005%\u0011QC\u0001\u0002\u0018'\u0016<W.\u001a8u)>\u0004HK]1og\u001a|'/\u001c*vY\u0016DQa\u0006\u0001\u0005\u0002a\ta\u0001P5oSRtD#A\r\u0011\u0005M\u0001\u0001\"B\u000e\u0001\t\u0013a\u0012!D7bi\u000eDW\r\u001a$bGR|'\u000f\u0006\u0003\u001eG!\u001a\u0004C\u0001\u0010\"\u001b\u0005y\"\"\u0001\u0011\u0002\u000bM\u001c\u0017\r\\1\n\u0005\tz\"a\u0002\"p_2,\u0017M\u001c\u0005\u0006Ii\u0001\r!J\u0001\bM\u0006\u001cG/\u00133y!\tqb%\u0003\u0002(?\t\u0019\u0011J\u001c;\t\u000b%R\u0002\u0019\u0001\u0016\u0002\u00135,H\u000e^5K_&t\u0007CA\u00162\u001b\u0005a#BA\u0003.\u0015\tqs&A\u0002sK2T!\u0001\r\u0007\u0002\u000f\r\fGnY5uK&\u0011!\u0007\f\u0002\u000e\u0019>\u0004H/T;mi&Tu.\u001b8\t\u000bQR\u0002\u0019A\u001b\u0002\u00055\f\bC\u0001\u001c:\u001b\u00059$B\u0001\u001d.\u0003!iW\r^1eCR\f\u0017B\u0001\u001e8\u0005A\u0011V\r\\'fi\u0006$\u0017\r^1Rk\u0016\u0014\u0018\u0010C\u0003=\u0001\u0011\u0005S(A\u0004nCR\u001c\u0007.Z:\u0015\u0005uq\u0004\"B <\u0001\u0004\u0001\u0015\u0001B2bY2\u0004\"!Q\"\u000e\u0003\tS!aB\u0018\n\u0005\u0011\u0013%A\u0004*fY>\u0003HOU;mK\u000e\u000bG\u000e\u001c\u0005\u0006\r\u0002!\teR\u0001\b_:l\u0015\r^2i)\tA5\n\u0005\u0002\u001f\u0013&\u0011!j\b\u0002\u0005+:LG\u000fC\u0003@\u000b\u0002\u0007\u0001\t")
/* loaded from: input_file:org/apache/flink/table/plan/rules/logical/MultiJoinToSegmentTop.class */
public class MultiJoinToSegmentTop extends SegmentTopTransformRule {
    public boolean org$apache$flink$table$plan$rules$logical$MultiJoinToSegmentTop$$matchedFactor(int i, LoptMultiJoin loptMultiJoin, RelMetadataQuery relMetadataQuery) {
        if (isNonInnerJoinOrSelfJoin(loptMultiJoin, i)) {
            return false;
        }
        RelNode joinFactor = loptMultiJoin.getJoinFactor(i);
        return isMaxMinAggMatched(joinFactor) && isSimpleFactorMatched(joinFactor, relMetadataQuery) && isMultiJoinMatched(joinFactor);
    }

    @Override // org.apache.calcite.plan.RelOptRule
    public boolean matches(RelOptRuleCall relOptRuleCall) {
        LoptMultiJoin loptMultiJoin = new LoptMultiJoin((MultiJoin) relOptRuleCall.rel(0));
        RelMetadataQuery metadataQuery = relOptRuleCall.getMetadataQuery();
        if (loptMultiJoin.getMultiJoinRel().isFullOuterJoin()) {
            return false;
        }
        Range until$extension0 = RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), loptMultiJoin.getNumJoinFactors());
        return ((IndexedSeq) until$extension0.map(new MultiJoinToSegmentTop$$anonfun$7(this, loptMultiJoin), IndexedSeq$.MODULE$.canBuildFrom())).count(new MultiJoinToSegmentTop$$anonfun$matches$3(this)) == 1 && until$extension0.forall(new MultiJoinToSegmentTop$$anonfun$matches$1(this, loptMultiJoin, metadataQuery));
    }

    @Override // org.apache.calcite.plan.RelOptRule
    public void onMatch(RelOptRuleCall relOptRuleCall) {
        Set<ImmutableBitSet> uniqueKeys;
        Object obj = new Object();
        try {
            MultiJoin multiJoin = (MultiJoin) relOptRuleCall.rel(0);
            LoptMultiJoin loptMultiJoin = new LoptMultiJoin(multiJoin);
            RelMetadataQuery metadataQuery = relOptRuleCall.getMetadataQuery();
            RelBuilder builder = relOptRuleCall.builder();
            int numJoinFactors = loptMultiJoin.getNumJoinFactors();
            IntRef create = IntRef.create(-1);
            RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), numJoinFactors).foreach$mVc$sp(new MultiJoinToSegmentTop$$anonfun$onMatch$1(this, loptMultiJoin, create));
            Predef$.MODULE$.require(create.elem != -1);
            LogicalAggregate logicalAggregate = (LogicalAggregate) getRealFactor(loptMultiJoin.getJoinFactor(create.elem));
            if (logicalAggregate.getGroupCount() == 1 && logicalAggregate.getGroupSets().size() <= 1 && logicalAggregate.getAggCallList().size() == 1) {
                SqlKind kind = ((AggregateCall) JavaConversions$.MODULE$.asScalaBuffer(logicalAggregate.getAggCallList()).head()).getAggregation().getKind();
                SegmentTopTransformRule.CorrelateFactorFinder correlateFactorFinder = new SegmentTopTransformRule.CorrelateFactorFinder(this, RexInputRef.of(Predef$.MODULE$.Integer2int((Integer) JavaConversions$.MODULE$.iterableAsScalaIterable(logicalAggregate.getGroupSet()).head()), logicalAggregate.getRowType()));
                correlateFactorFinder.visit(logicalAggregate);
                if (correlateFactorFinder.notFound()) {
                    return;
                }
                int joinStart = loptMultiJoin.getJoinStart(create.elem);
                RexNode rexNode = (RexNode) getFactorSingleFilter(loptMultiJoin, joinStart).getOrElse(new MultiJoinToSegmentTop$$anonfun$8(this, obj));
                Integer num = (Integer) JavaConversions$.MODULE$.iterableAsScalaIterable(loptMultiJoin.getFactorsRefByJoinFilter(rexNode).clear(create.elem)).head();
                int relationFieldOffset = relationFieldOffset(loptMultiJoin, rexNode, Predef$.MODULE$.Integer2int(num), joinStart);
                int i = joinStart + 1;
                RexNode rexNode2 = (RexNode) getFactorSingleFilter(loptMultiJoin, i).getOrElse(new MultiJoinToSegmentTop$$anonfun$9(this, obj));
                Integer num2 = (Integer) JavaConversions$.MODULE$.iterableAsScalaIterable(loptMultiJoin.getFactorsRefByJoinFilter(rexNode2).clear(create.elem)).head();
                int relationFieldOffset2 = relationFieldOffset(loptMultiJoin, rexNode2, Predef$.MODULE$.Integer2int(num2), i);
                RelNode realFactor = getRealFactor(loptMultiJoin.getJoinFactor(Predef$.MODULE$.Integer2int(num)));
                if (!(realFactor instanceof LogicalFilter) || ((uniqueKeys = metadataQuery.getUniqueKeys(realFactor.getInput(0))) != null && JavaConversions$.MODULE$.asScalaSet(uniqueKeys).exists(new MultiJoinToSegmentTop$$anonfun$onMatch$3(this, relationFieldOffset)))) {
                    SegmentTopTransformRule.MultiJoinFinder multiJoinFinder = new SegmentTopTransformRule.MultiJoinFinder(this);
                    logicalAggregate.accept(multiJoinFinder);
                    MultiJoin multiJoin2 = (MultiJoin) multiJoinFinder.getMultiJoin().get();
                    Predef$.MODULE$.require(multiJoin2 != null);
                    LoptMultiJoin loptMultiJoin2 = new LoptMultiJoin(multiJoin2);
                    if (JavaConversions$.MODULE$.asScalaBuffer(loptMultiJoin.getJoinFilters()).exists(new MultiJoinToSegmentTop$$anonfun$10(this, loptMultiJoin, correlateFactorFinder, num)) && JavaConversions$.MODULE$.asScalaBuffer(loptMultiJoin.getJoinFilters()).count(new MultiJoinToSegmentTop$$anonfun$11(this, loptMultiJoin, create, num)) <= 1 && multiJoinEq(loptMultiJoin, loptMultiJoin2, ((Buffer) JavaConversions$.MODULE$.asScalaBuffer(loptMultiJoin.getJoinFilters()).filter(new MultiJoinToSegmentTop$$anonfun$12(this, loptMultiJoin, create, num))).toList())) {
                        constructFinalPlan(builder, multiJoin, loptMultiJoin, create.elem, metadataQuery, Predef$.MODULE$.Integer2int(num), relationFieldOffset, Predef$.MODULE$.Integer2int(num2), relationFieldOffset2, kind, relOptRuleCall);
                    }
                }
            }
        } catch (NonLocalReturnControl e) {
            if (e.key() != obj) {
                throw e;
            }
            e.value$mcV$sp();
        }
    }
}
