/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.table.plan.nodes.physical.stream

import org.apache.flink.streaming.api.transformations.{OneInputTransformation, StreamTransformation}
import org.apache.flink.table.api.window.TimeWindow
import org.apache.flink.table.api.{StreamTableEnvironment, TableConfig, TableConfigOptions, TableException}
import org.apache.flink.table.calcite.FlinkRelBuilder.NamedWindowProperty
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.codegen._
import org.apache.flink.table.dataformat.BaseRow
import org.apache.flink.table.errorcode.TableErrors
import org.apache.flink.table.expressions.ExpressionUtils.{isTimeIntervalLiteral, _}
import org.apache.flink.table.plan.logical._
import org.apache.flink.table.plan.nodes.exec.{ExecNodeWriter, RowStreamExecNode}
import org.apache.flink.table.plan.nodes.physical.FlinkPhysicalRel
import org.apache.flink.table.plan.rules.physical.stream.StreamExecRetractionRules
import org.apache.flink.table.plan.schema.BaseRowSchema
import org.apache.flink.table.plan.util._
import org.apache.flink.table.runtime.fault.tolerant.FaultTolerantUtil
import org.apache.flink.table.runtime.window.aligned.{GlobalAlignedWindowAggregator, InternalAlignedWindowTriggers}
import org.apache.flink.table.runtime.window.assigners.{SlidingWindowAssigner, TumblingWindowAssigner}
import org.apache.flink.table.runtime.window.{AbstractAlignedWindowOperator, KeyedAlignedWindowOperator}
import org.apache.flink.table.types.{DataTypes, InternalType, TypeConverters}
import org.apache.flink.table.typeutils.BaseRowTypeInfo
import org.apache.flink.util.Preconditions

import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.AggregateCall
import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel}

import scala.collection.JavaConversions._

/**
  * Flink RelNode for stream global group window aggregate.
  */
class StreamExecGlobalWindowAggregate(
    val window: LogicalWindow,
    namedProperties: Seq[NamedWindowProperty],
    cluster: RelOptCluster,
    traitSet: RelTraitSet,
    inputNode: RelNode,
    localAggInfoList: AggregateInfoList,
    globalAggInfoList: AggregateInfoList,
    val windowAggInputType: RelDataType,
    val aggCalls: Seq[AggregateCall],
    outputSchema: BaseRowSchema,
    inputSchema: BaseRowSchema,
    grouping: Array[Int],
    val inputTimestampIndex: Int,
    val emitStrategy: EmitStrategy)
  extends SingleRel(cluster, traitSet, inputNode)
  with StreamPhysicalRel
  with RowStreamExecNode {

  override def deriveRowType(): RelDataType = outputSchema.relDataType

  override def producesUpdates: Boolean = emitStrategy.produceUpdates

  override def consumesRetractions = true

  override def needsUpdatesAsRetraction(input: RelNode) = true

  override def requireWatermark: Boolean = window match {
    case TumblingGroupWindow(_, timeField, size)
      if isRowtimeAttribute(timeField) && isTimeIntervalLiteral(size) => true
    case SlidingGroupWindow(_, timeField, size, _)
      if isRowtimeAttribute(timeField) && isTimeIntervalLiteral(size) => true
    case SessionGroupWindow(_, timeField, _)
      if isRowtimeAttribute(timeField) => true
    case _ => false
  }

  def getGroupings: Array[Int] = grouping

  def getWindowProperties: Seq[NamedWindowProperty] = namedProperties

  override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = {
    new StreamExecGlobalWindowAggregate(
      window,
      namedProperties,
      cluster,
      traitSet,
      inputs.get(0),
      localAggInfoList,
      globalAggInfoList,
      windowAggInputType,
      aggCalls,
      outputSchema,
      inputSchema,
      grouping,
      inputTimestampIndex,
      emitStrategy)
  }

  override def explainTerms(pw: RelWriter): RelWriter = {
    super.explainTerms(pw)
      .itemIf("groupBy",
        AggregateNameUtil.groupingToString(inputSchema.relDataType, grouping), grouping.nonEmpty)
      .item("window", window)
      .itemIf("properties", namedProperties.map(_.name).mkString(", "), namedProperties.nonEmpty)
      .item(
        "select", AggregateNameUtil.aggregationToString(
          inputSchema.relDataType,
          grouping,
          outputSchema.relDataType,
          aggCalls,
          namedProperties,
          withOutputFieldNames = true,
          isLocal = false))
      .itemIf("emit", emitStrategy, !emitStrategy.toString.isEmpty)
  }

  override def isDeterministic: Boolean = AggregateUtil.isDeterministic(aggCalls)

  //~ ExecNode methods -----------------------------------------------------------

  override def getFlinkPhysicalRel: FlinkPhysicalRel = this

  override def getStateDigest(pw: ExecNodeWriter): ExecNodeWriter = {
    pw.item("inputType", input.getRowType)
      .itemIf("groupBy",
        AggregateNameUtil.groupingToString(inputSchema.relDataType, grouping), grouping.nonEmpty)
      .item("window", window)
      .itemIf("properties", namedProperties.map(_.name).mkString(", "), namedProperties.nonEmpty)
      .item("select",
        AggregateNameUtil.aggregationToString(
          inputSchema.relDataType,
          grouping,
          outputSchema.relDataType,
          aggCalls,
          namedProperties,
          withOutputFieldNames = false,
          isLocal = false))
      .itemIf("emit", emitStrategy, !emitStrategy.toString.isEmpty)
  }

  override def translateToPlanInternal(
    tableEnv: StreamTableEnvironment): StreamTransformation[BaseRow] = {

    val config = tableEnv.getConfig

    // only support rowtime window now.
    Preconditions.checkArgument(isRowtimeAttribute(window.timeAttribute))
    // validation
    emitStrategy.checkValidation()

    val inputTransform = getInputNodes.get(0).translateToPlan(tableEnv)
      .asInstanceOf[StreamTransformation[BaseRow]]
    val inputRowType = inputTransform.getOutputType.asInstanceOf[BaseRowTypeInfo]
    val selector = StreamExecUtil.getKeySelector(grouping, inputRowType)

    val inputIsAccRetract = StreamExecRetractionRules.isAccRetract(input)
    if (inputIsAccRetract) {
      throw new TableException(
        TableErrors.INST.sqlGroupWindowAggTranslateRetractNotSupported())
    }

    val aggString = AggregateNameUtil.aggregationToString(
      inputSchema.relDataType,
      grouping,
      outputSchema.relDataType,
      aggCalls,
      namedProperties,
      withOutputFieldNames = true,
      isLocal = false)

    val needRetraction = StreamExecRetractionRules.isAccRetract(getInput)

    val localAggsHandler = WindowAggregateUtil.createAggsHandler(
      "LocalGroupWindowAggsHandler",
      window,
      namedProperties,
      localAggInfoList,
      config,
      tableEnv.getRelBuilder,
      FlinkTypeFactory.toInternalFieldTypes(windowAggInputType),
      needRetraction,
      true,
      // merged acc is from local window agg, offset is keyLength + inputTimestampLength
      grouping.length + 1,
      true)

    val globalAggsHandler = WindowAggregateUtil.createAggsHandler(
      "GlobalGroupWindowAggsHandler",
      window,
      namedProperties,
      globalAggInfoList,
      config,
      tableEnv.getRelBuilder,
      FlinkTypeFactory.toInternalFieldTypes(windowAggInputType),
      needRetraction,
      true,
      0,
      true)

    val accTypes = globalAggInfoList.getAccTypes.map(_.toInternalType)
    val windowPropertyTypes = namedProperties
      .map(_.property.resultType)
      .toArray

    val aggValueTypes = globalAggInfoList.getActualValueTypes.map(_.toInternalType)

    val operator = createGlobalWindowOperator(
        config,
        localAggsHandler,
        globalAggsHandler,
        accTypes,
        windowPropertyTypes,
        aggValueTypes,
        inputTimestampIndex)

    val operatorName = if (grouping.nonEmpty) {
      s"global-window: ($window), " +
        s"groupBy: (${AggregateNameUtil.groupingToString(inputSchema.relDataType, grouping)}), " +
        s"select: ($aggString)"
    } else {
      s"global-window: ($window), select: ($aggString)"
    }

    val transformation = new OneInputTransformation(
      inputTransform,
      operatorName,
      FaultTolerantUtil.addFaultTolerantProxyIfNeed(
        operator,
        operatorName,
        config),
      outputSchema.typeInfo(),
      inputTransform.getParallelism)

    if (grouping.isEmpty) {
      transformation.setParallelism(1)
      transformation.setMaxParallelism(1)
    }

    transformation.setResources(getResource.getReservedResourceSpec,
      getResource.getPreferResourceSpec)
    // set KeyType and Selector for state
    transformation.setStateKeySelector(selector)
    transformation.setStateKeyType(selector.getProducedType)

    transformation
  }

  private def createGlobalWindowOperator(config: TableConfig,
    localAggsHandler: GeneratedSubKeyedAggsHandleFunction[_],
    globalAggsHandler: GeneratedSubKeyedAggsHandleFunction[_],
    accTypes: Array[InternalType],
    windowPropertyTypes: Array[InternalType],
    aggValueTypes: Array[InternalType],
    timeIdx: Int): AbstractAlignedWindowOperator = {

    val accTypeInfo = TypeConverters.createInternalTypeInfoFromDataType(
      DataTypes.createRowType(accTypes: _*))
    val aggResultType = DataTypes.createRowType(aggValueTypes ++ windowPropertyTypes: _*)
    val aggResultTypeInfo = TypeConverters.createInternalTypeInfoFromDataType(aggResultType)
    val minibatchSize = config.getConf.getLong(TableConfigOptions.SQL_EXEC_MINIBATCH_SIZE)

    val windowRunner = new GlobalAlignedWindowAggregator(
      accTypeInfo.asInstanceOf[BaseRowTypeInfo],
      aggResultTypeInfo.asInstanceOf[BaseRowTypeInfo],
      localAggsHandler.asInstanceOf[GeneratedSubKeyedAggsHandleFunction[TimeWindow]],
      globalAggsHandler.asInstanceOf[GeneratedSubKeyedAggsHandleFunction[TimeWindow]],
      minibatchSize,
      false)

    // we should reverse the offset because assigner needed
    val (windowAssigner, windowTrigger) = window match {
      case TumblingGroupWindow(_, timeField, size)
        if isRowtimeAttribute(timeField) && isTimeIntervalLiteral(size) =>
        val sizeDuration = toDuration(size)
        val assigner = TumblingWindowAssigner.of(sizeDuration).withTimeZone(config.getTimeZone)
        val trigger = InternalAlignedWindowTriggers.tumbling(sizeDuration, config.getTimeZone)
        (assigner, trigger)
      case SlidingGroupWindow(_, timeField, size, slide)
        if isRowtimeAttribute(timeField) && isTimeIntervalLiteral(size) =>
        val sizeDuration = toDuration(size)
        val slideDuration = toDuration(slide)
        val assigner = SlidingWindowAssigner
          .of(sizeDuration, slideDuration)
          .withTimeZone(config.getTimeZone)
        val trigger = InternalAlignedWindowTriggers.sliding(
          sizeDuration, slideDuration, config.getTimeZone)
        (assigner, trigger)
    }

    new KeyedAlignedWindowOperator(
      windowRunner,
      windowAssigner,
      windowTrigger,
      timeIdx)
  }
}
