/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.table.plan.rules.physical.batch.runtimefilter

import org.apache.flink.table.api.TableConfigOptions
import org.apache.flink.table.functions.sql.internal.SqlRuntimeFilterFunction
import org.apache.flink.table.plan.nodes.physical.batch.{BatchExecCalc, BatchExecHashJoinBase}
import org.apache.flink.table.plan.rules.physical.batch.runtimefilter.BaseRuntimeFilterPushDownRule.findRuntimeFilters
import org.apache.flink.table.plan.rules.physical.batch.runtimefilter.InsertRuntimeFilterRule.{SQL_EXEC_RUNTIME_FILTER_JOIN_PUSH_DOWN_ENABLED, SQL_EXEC_RUNTIME_FILTER_JOIN_PUSH_DOWN_WHEN_WAIT_ENABLED}
import org.apache.flink.table.plan.util.FlinkRelOptUtil
import org.apache.flink.table.runtime.join.batch.HashJoinType._

import org.apache.calcite.plan.RelOptRuleCall
import org.apache.calcite.rel.RelNode
import org.apache.calcite.rex._
import org.apache.calcite.util.ImmutableBitSet

import scala.collection.JavaConversions._

/**
  * Planner rule that pushes a [[SqlRuntimeFilterFunction]] past a [[BatchExecHashJoinBase]].
  */
class RuntimeFilterJoinTransposeRule extends BaseRuntimeFilterPushDownRule(
  classOf[BatchExecHashJoinBase],
  "RuntimeFilterJoinTransposeRule"){

  override def matches(call: RelOptRuleCall): Boolean = {
    val calc: BatchExecCalc = call.rel(0)
    val join: BatchExecHashJoinBase = call.rel(1)

    val conf = FlinkRelOptUtil.getTableConfig(join)

    val enabled = conf.getConf.getBoolean(SQL_EXEC_RUNTIME_FILTER_JOIN_PUSH_DOWN_ENABLED)

    val enabledInWait = conf.getConf.getBoolean(
      SQL_EXEC_RUNTIME_FILTER_JOIN_PUSH_DOWN_WHEN_WAIT_ENABLED)
    val rfWait = conf.getConf.getBoolean(TableConfigOptions.SQL_EXEC_RUNTIME_FILTER_WAIT)

    enabled &&
        (!rfWait || enabledInWait) &&
        (join.hashJoinType == INNER ||
            join.hashJoinType == PROBE_OUTER ||
            join.hashJoinType == SEMI ||
            join.hashJoinType == ANTI) &&
        findRuntimeFilters(calc.getProgram).nonEmpty
  }

  override def canPush(
      rel: BatchExecHashJoinBase,
      rCols: ImmutableBitSet,
      cond: RexNode): Boolean = {
    cond match {
      case call: RexCall => call.getOperator match {
        case _: SqlRuntimeFilterFunction =>
          // get probe fields
          val leftFields = rel.getLeft.getRowType.getFieldList.indices
          val rightFields = rel.getRight.getRowType.getFieldList.indices.map(_ + leftFields.length)
          val fields = if (rel.leftIsBuild) rightFields else leftFields
          rCols.length() == 1 && fields.contains(rCols.head.toInt)
        case _ => false
      }
      case _ => false
    }
  }

  /**
    * Convert Probe key to build key.
    */
  override def getFieldAdjustments(rel: BatchExecHashJoinBase): Array[Int] = {
    val adjustments = new Array[Int](rel.getRowType.getFieldCount)
    val offset = rel.getLeft.getRowType.getFieldCount
    rel.buildRel.getRowType.getFieldList.indices.zip(
      rel.probeRel.getRowType.getFieldList.indices).foreach { case (buildField, probeField) =>
      if (rel.hashJoinType == SEMI || rel.hashJoinType == ANTI) {
        // semi join just output probe fields.
        adjustments(probeField) = buildField - probeField
      } else {
        if (rel.leftIsBuild) {
          adjustments(probeField + offset) = buildField - probeField - offset
        } else {
          adjustments(buildField + offset) = -offset
          adjustments(probeField) = buildField - probeField
        }
      }
    }
    adjustments
  }

  override def updateRfFunction(filterInput: RelNode, program: RexProgram): Unit =
    BaseRuntimeFilterPushDownRule.updateRuntimeFilterFunction(filterInput, program)

  override def getInputOfInput(input: BatchExecHashJoinBase): RelNode = input.probeRel

  override def replaceInput(input: BatchExecHashJoinBase, filter: BatchExecCalc): RelNode = {
    val inputs = if (input.leftIsBuild) {
      Seq(input.buildRel, filter)
    } else {
      Seq(filter, input.buildRel)
    }
    input.copy(input.getTraitSet, inputs)
  }
}

object RuntimeFilterJoinTransposeRule {
  val INSTANCE = new RuntimeFilterJoinTransposeRule
}
