/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.flink.table.plan.rules.logical

import org.apache.flink.table.plan.nodes.logical._
import org.apache.flink.table.plan.util.{FlinkRelMdUtil, InputRefVisitor, InputRewriter, VariableRankRange}

import org.apache.calcite.plan.RelOptRule.{any, operand}
import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall}
import org.apache.calcite.rel.core.Calc
import org.apache.calcite.rel.RelCollations
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rex.{RexBuilder, RexInputRef, RexProgram}
import org.apache.calcite.util.ImmutableBitSet

import scala.collection.JavaConversions._

/**
  * This rule is used to transpose Calc past FlinkLogicalRank to reduce state size.
  */
class CalcRankTransposeRule
  extends RelOptRule(
    operand(
      classOf[Calc],
      operand(classOf[FlinkLogicalRank], any())),
    "CalcRankTransposeRule") {

  override def matches(call: RelOptRuleCall): Boolean = {
    val calc: Calc = call.rel(0)
    val rank: FlinkLogicalRank = call.rel(1)

    val totalColumnCount = rank.getInput.getRowType.getFieldCount
    // apply the rule only when calc could prune some columns
    val pushableColumns = getPushableColumns(calc, rank)
    pushableColumns.length < totalColumnCount
  }

  override def onMatch(call: RelOptRuleCall): Unit = {
    val calc: FlinkLogicalCalc = call.rel(0)
    val rank: FlinkLogicalRank = call.rel(1)

    val pushableColumns = getPushableColumns(calc, rank)

    val rexBuilder = calc.getCluster.getRexBuilder
    // create a new Calc to project columns of Rank's input
    val innerProgram = createNewInnerCalcProgram(
      pushableColumns,
      rank.getInput.getRowType,
      rexBuilder)
    val newInnerCalc = calc.copy(calc.getTraitSet, rank.getInput, innerProgram)

    // create a new Rank on top of new Calc
    var fieldMapping = pushableColumns.zipWithIndex.toMap
    val newRank = createNewRankOnCalc(fieldMapping, newInnerCalc, rank)

    // create a new Calc on top of newRank if needed
    if (rank.outputRankFunColumn) {
      // append RankFunc field mapping
      val oldRankFunFieldIdx = FlinkRelMdUtil.getRankFunColumnIndex(rank)
      val newRankFunFieldIdx = FlinkRelMdUtil.getRankFunColumnIndex(newRank)
      fieldMapping += (oldRankFunFieldIdx -> newRankFunFieldIdx)
    }
    val topProgram = createNewTopCalcProgram(
      calc.getProgram,
      fieldMapping,
      newRank.getRowType,
      rexBuilder)

    val equiv = if (topProgram.isTrivial) {
      // Ignore newTopCac if it's program is trivial
      newRank
    } else {
      calc.copy(calc.getTraitSet, newRank, topProgram)
    }
    call.transformTo(equiv)
  }

  private def getPushableColumns(calc: Calc, rank: FlinkLogicalRank): Array[Int] = {
    val usedFields = getUsedFields(calc)
    val rankFunFieldIndex = FlinkRelMdUtil.getRankFunColumnIndex(rank)
    val usedFieldsExcludeRankFun = usedFields.filter(_ != rankFunFieldIndex)

    val requiredFields = getKeyFields(rank)
    usedFieldsExcludeRankFun.union(requiredFields).toSet[Int].toArray.sorted
  }

  private def getUsedFields(calc: Calc): Array[Int] = {
    val projectsAndConditions = calc.getProgram.split()
    val (projects, conditions) = (projectsAndConditions.left, projectsAndConditions.right)

    val visitor = new InputRefVisitor
    projects.foreach(_.accept(visitor))
    conditions.foreach(_.accept(visitor))
    visitor.getFields
  }

  private def getKeyFields(rank: FlinkLogicalRank): Array[Int] = {
    val partitionKey = rank.partitionKey.toArray
    val orderKey = rank.sortCollation.getFieldCollations.map(_.getFieldIndex).toArray
    val uniqueKeys = rank.getCluster.getMetadataQuery.getUniqueKeys(rank.getInput)
    val keysInUniqueKeys = if (uniqueKeys == null || uniqueKeys.isEmpty) {
      Array[Int]()
    } else {
      uniqueKeys.flatMap(_.toArray).toArray
    }
    val rankRangeKey = rank.rankRange match {
      case v: VariableRankRange => Array(v.rankEndIndex)
      case _ => Array[Int]()
    }
    // All key including partition key, order key, unique keys, VariableRankRange rankEndIndex
    Set(partitionKey, orderKey, keysInUniqueKeys, rankRangeKey).flatten.toArray
  }

  private def createNewInnerCalcProgram(
    projectedFields: Array[Int],
    inputRowType: RelDataType,
    rexBuilder: RexBuilder): RexProgram = {
    val projects = projectedFields.map(RexInputRef.of(_, inputRowType))
    val inputColNames = inputRowType.getFieldNames
    val colNames = projectedFields.map(inputColNames.get)
    RexProgram.create(
      inputRowType,
      projects.toList,
      null,
      colNames.toList,
      rexBuilder)
  }

  private def createNewTopCalcProgram(
    oldTopProgram: RexProgram,
    fieldMapping: Map[Int, Int],
    inputRowType: RelDataType,
    rexBuilder: RexBuilder): RexProgram = {
    val inputRewriter = new InputRewriter(fieldMapping)
    val oldProjects = oldTopProgram.getProjectList
    val projects = oldProjects.map(oldTopProgram.expandLocalRef).map(_.accept(inputRewriter))
    val oldCondition = oldTopProgram.getCondition
    val condition = if (oldCondition != null) {
      oldTopProgram.expandLocalRef(oldCondition).accept(inputRewriter)
    } else {
      null
    }
    val colNames = oldTopProgram.getOutputRowType.getFieldNames
    RexProgram.create(
      inputRowType,
      projects,
      condition,
      colNames,
      rexBuilder)
  }

  private def createNewRankOnCalc(
    fieldMapping: Map[Int, Int],
    input: Calc,
    rank: FlinkLogicalRank): FlinkLogicalRank = {
    val newPartitionKey = rank.partitionKey.toArray.map(fieldMapping(_))
    val oldSortCollation = rank.sortCollation
    val oldFieldCollations = oldSortCollation.getFieldCollations
    val newFieldCollations = oldFieldCollations.map {
      fc => fc.copy(fieldMapping(fc.getFieldIndex))
    }
    val newSortCollation = if (newFieldCollations.eq(oldFieldCollations)) {
      oldSortCollation
    } else {
      RelCollations.of(newFieldCollations)
    }
    new FlinkLogicalRank(
      rank.getCluster,
      rank.getTraitSet,
      input,
      rank.rankFunction,
      ImmutableBitSet.of(newPartitionKey: _*),
      newSortCollation,
      rank.rankRange,
      rank.outputRankFunColumn)
  }
}

object CalcRankTransposeRule {
  val INSTANCE = new CalcRankTransposeRule
}
