/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.table.plan.rules.physical.stream

import org.apache.flink.table.api.{AggPhaseEnforcer, TableConfig, TableConfigOptions}
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.plan.`trait`.{AccMode, AccModeTrait, FlinkRelDistribution}
import org.apache.flink.table.plan.metadata.FlinkRelMetadataQuery
import org.apache.flink.table.plan.nodes.FlinkConventions
import org.apache.flink.table.plan.nodes.physical.stream.{StreamExecExchange, StreamExecGlobalWindowAggregate, StreamExecGroupWindowAggregate, StreamExecLocalWindowAggregate}
import org.apache.flink.table.plan.rules.physical.FlinkExpandConversionRule._
import org.apache.flink.table.plan.schema.BaseRowSchema
import org.apache.flink.table.plan.util.{AggregateInfoList, AggregateUtil, EmitStrategy, FlinkRelOptUtil, WindowAggregateUtil}

import org.apache.calcite.plan.RelOptRule.{any, operand}
import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall}
import org.apache.calcite.rel.RelNode

import java.util.{ArrayList => JArrayList}

class TwoStageOptimizedWindowAggregateRule extends RelOptRule(
  operand(
    classOf[StreamExecGroupWindowAggregate],
    operand(
      classOf[StreamExecExchange],
      operand(classOf[RelNode], any))), "TwoStageOptimizedWindowAggregateRule") {

  override def matches(call: RelOptRuleCall): Boolean = {
    val windowAgg: StreamExecGroupWindowAggregate = call.rel(0)
    val realInput: RelNode = call.rel(2)

    val tableConfig = FlinkRelOptUtil.getTableConfig(windowAgg)
    val needRetraction = StreamExecRetractionRules.isAccRetract(realInput)
    val modifiedMono = call.getMetadataQuery.asInstanceOf[FlinkRelMetadataQuery]
      .getRelModifiedMonotonicity(windowAgg)
    val needRetractionArray = AggregateUtil.getNeedRetractions(
      windowAgg.getGroupings.length, needRetraction, modifiedMono, windowAgg.aggCalls)

    val aggInfoList = AggregateUtil.transformToStreamAggregateInfoList(
      windowAgg.aggCalls,
      windowAgg.getInput.getRowType,
      needRetractionArray,
      needRetraction,
      isStateBackendDataViews = true)

      !tableConfig.getConf.getString(TableConfigOptions.SQL_OPTIMIZER_WINDOW_AGG_PHASE_ENFOREER)
        .equalsIgnoreCase(AggPhaseEnforcer.ONE_PHASE.toString) &&
      WindowAggregateUtil.isWindowMiniBatchApplicable(
        tableConfig, windowAgg.window, aggInfoList.aggInfos)
  }

  override def onMatch(call: RelOptRuleCall): Unit = {
    val agg: StreamExecGroupWindowAggregate = call.rel(0)
    val realInput: RelNode = call.rel(2)
    val needRetraction = StreamExecRetractionRules.isAccRetract(realInput)

    val modifiedMono = call.getMetadataQuery.asInstanceOf[FlinkRelMetadataQuery]
      .getRelModifiedMonotonicity(agg)
    val needRetractionArray = AggregateUtil.getNeedRetractions(
      agg.getGroupings.length, needRetraction, modifiedMono, agg.aggCalls)

    val localAggInfoList = AggregateUtil.transformToStreamAggregateInfoList(
      agg.aggCalls,
      realInput.getRowType,
      needRetractionArray,
      needRetraction,
      isStateBackendDataViews = false)

    val globalAggInfoList = AggregateUtil.transformToStreamAggregateInfoList(
      agg.aggCalls,
      realInput.getRowType,
      needRetractionArray,
      needRetraction,
      isStateBackendDataViews = true)

    transformToTwoStageAgg(
      call,
      realInput,
      localAggInfoList,
      globalAggInfoList,
      agg)
  }

  // the difference between localAggInfos and globalAggInfos is local agg use heap dataview,
  // but global agg use state dataview
  private[flink] def transformToTwoStageAgg(
    call: RelOptRuleCall,
    input: RelNode,
    localAggInfoList: AggregateInfoList,
    globalAggInfoList: AggregateInfoList,
    agg: StreamExecGroupWindowAggregate): Unit = {

    // prepare local window agg.
    val localWindowAggType = WindowAggregateUtil.inferLocalWindowAggType(
      localAggInfoList,
      input.getRowType,
      agg.getGroupings,
      input.getCluster.getTypeFactory.asInstanceOf[FlinkTypeFactory])

    // do not support retraction on window agg now.
    val localWindowAggTraitSet = input.getTraitSet.plus(new AccModeTrait(AccMode.Acc))
    val localWindowAgg = new StreamExecLocalWindowAggregate(
      agg.window,
      agg.getCluster,
      localWindowAggTraitSet,
      input,
      localAggInfoList,
      agg.aggCalls,
      new BaseRowSchema(localWindowAggType),
      new BaseRowSchema(input.getRowType),
      agg.getGroupings,
      agg.inputTimestampIndex)

    // prepare global window agg.
    val globalDistribution = if (agg.getGroupings.nonEmpty) {
      val fields = new JArrayList[Integer]()
      // grouping keys is forwarded by local agg, use indices instead of groupings
      agg.getGroupings.indices.foreach(fields.add(_))
      FlinkRelDistribution.hash(fields)
    } else {
      FlinkRelDistribution.SINGLETON
    }

    val newInput = satisfyDistribution(
      FlinkConventions.STREAM_PHYSICAL, localWindowAgg, globalDistribution)
    val globalAggProvidedTraitSet = agg.getTraitSet

    val inputTimestampIndexFromLocal = agg.getGroupings.length
    val config = agg.getCluster.getPlanner.getContext.unwrap(classOf[TableConfig])
    val emitStrategy = EmitStrategy(config, agg.window)

    val globalWindowAgg = new StreamExecGlobalWindowAggregate(
      agg.window,
      agg.getWindowProperties,
      agg.getCluster,
      globalAggProvidedTraitSet,
      newInput,
      localAggInfoList,
      globalAggInfoList,
      input.getRowType,
      agg.aggCalls,
      new BaseRowSchema(agg.getRowType),
      new BaseRowSchema(newInput.getRowType),
      agg.getGroupings.indices.toArray,
      inputTimestampIndexFromLocal,
      emitStrategy)

    call.transformTo(globalWindowAgg)
  }
}

object TwoStageOptimizedWindowAggregateRule {
  val INSTANCE: RelOptRule = new TwoStageOptimizedWindowAggregateRule
}
