/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.table.plan.nodes.calcite

import org.apache.flink.table.calcite.FlinkRelBuilder.NamedWindowProperty
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.plan.logical.LogicalWindow

import com.google.common.collect.ImmutableList
import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.{Aggregate, AggregateCall}
import org.apache.calcite.rel.{AbstractRelNode, RelNode, RelWriter, SingleRel}
import org.apache.calcite.util.{ImmutableBitSet, Util}

import java.util

import scala.collection.JavaConverters._

abstract class WindowAggregate(
    window: LogicalWindow,
    namedProperties: Seq[NamedWindowProperty],
    cluster: RelOptCluster,
    traitSet: RelTraitSet,
    child: RelNode,
    groupSet: ImmutableBitSet,
    aggCalls: Seq[AggregateCall])
  extends SingleRel(cluster, traitSet, child) {

  def getWindow: LogicalWindow = window

  def getNamedProperties: Seq[NamedWindowProperty] = namedProperties

  def getGroupSet: ImmutableBitSet = groupSet

  def getAggCallList: Seq[AggregateCall] = aggCalls

  override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): RelNode = {
    copy(traitSet, AbstractRelNode.sole(inputs), aggCalls)
  }

  def copy(traitSet: RelTraitSet, input: RelNode, aggCalls: Seq[AggregateCall]): WindowAggregate

  override def deriveRowType(): RelDataType = {
    val aggregateRowType = Aggregate.deriveRowType(
      getCluster.getTypeFactory,
      getInput.getRowType,
      false,
      groupSet,
      ImmutableList.of(groupSet),
      aggCalls.asJava)
    val typeFactory = getCluster.getTypeFactory.asInstanceOf[FlinkTypeFactory]
    val builder = typeFactory.builder
    builder.addAll(aggregateRowType.getFieldList)
    namedProperties.foreach { namedProp =>
      builder.add(
        namedProp.name,
        typeFactory.createTypeFromInternalType(namedProp.property.resultType, isNullable = true)
      )
    }
    builder.build()
  }

  /**
    * Returns the number of grouping fields.
    * These grouping fields are the leading fields in both the input and output
    * records.
    *
    * <p>NOTE: The {@link #getGroupSet()} data structure allows for the
    * grouping fields to not be on the leading edge. New code should, if
    * possible, assume that grouping fields are in arbitrary positions in the
    * input relational expression.
    *
    * @return number of grouping fields
    */
  def getGroupCount: Int = {
    groupSet.cardinality()
  }

  /**
    * Returns a list of calls to aggregate functions together with their output
    * field names.
    *
    * @return list of calls to aggregate functions and their output field names
    */
  def getNamedAggCalls: Seq[(AggregateCall, String)] = {
    val offset = getGroupCount
    val nameList = Util.skip(getRowType.getFieldNames, offset)
    (aggCalls, nameList.asScala).zipped.toList
  }

  /**
    * The [[getDigest]] should be uniquely identifies the node; another node
    * is equivalent if and only if it has the same value. The [[getDigest]] is
    * computed by [[explainTerms(pw)]], so it should contains window information
    * to identify different WindowAggregate nodes, otherwise WindowAggregate node
    * can be replaced by any other WindowAggregate node.
    */
  override def explainTerms(pw: RelWriter): RelWriter = {
    super.explainTerms(pw)
      .item("group", groupSet)
      .itemIf("aggs", aggCalls, pw.nest())

    if (!pw.nest()) {
      aggCalls.zipWithIndex.foreach(
        call => pw.item(Util.first(call._1.name, "agg#" + call._2), call._1))
    }

    pw.item("window", window)
      .item("properties", namedProperties.map(_.name).mkString(", "))
  }

  /**
    * Returns whether any of the aggregates are DISTINCT.
    *
    * @return Whether any of the aggregates are DISTINCT
    */
  def containsDistinctCall: Boolean = {
    aggCalls.exists(call => call.isDistinct)
  }
}
