/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.table.plan.rules.physical.batch

import org.apache.flink.table.plan.nodes.FlinkConventions
import org.apache.flink.table.plan.nodes.logical.FlinkLogicalSink
import org.apache.flink.table.plan.nodes.physical.batch.BatchExecSink

import org.apache.calcite.plan.RelOptRule
import org.apache.calcite.rel.{RelCollations, RelNode}
import org.apache.calcite.rel.convert.ConverterRule
import org.apache.flink.table.api.{TableException, VirtualColumn}
import org.apache.flink.table.connector.DefinedDistribution
import org.apache.flink.table.plan.`trait`.FlinkRelDistribution
import org.apache.flink.table.runtime.aggregate.RelFieldCollations
import org.apache.flink.table.sinks.{OperationType, PartitionableTableSink, UpdateDeleteTableSink}

import collection.JavaConversions._

class BatchExecSinkRule extends ConverterRule(
  classOf[FlinkLogicalSink],
  FlinkConventions.LOGICAL,
  FlinkConventions.BATCH_PHYSICAL,
  "BatchExecSinkRule") {

  def convert(rel: RelNode): RelNode = {
    val sinkNode = rel.asInstanceOf[FlinkLogicalSink]
    val newTrait = rel.getTraitSet.replace(FlinkConventions.BATCH_PHYSICAL)
    var requiredTraitSet = sinkNode.getInput.getTraitSet.replace(FlinkConventions.BATCH_PHYSICAL)
    sinkNode.sink match {
      case partitionSink: DefinedDistribution
        if partitionSink.getPartitionFields != null &&
          partitionSink.getPartitionFields.nonEmpty &&
          !partitionSink.getPartitionFields.contains(null) =>
        val partitionIndices = partitionSink
          .getPartitionFields
          .map(partitionSink.getFieldNames.indexOf(_))
        // validate partition columns must exist.
        partitionIndices.foreach { idx =>
          if (idx < 0) {
            throw new TableException("Partition fields must be in the schema.")
          }
        }

        var needHashByPartitionFields = true
        var partitionsToHash = partitionIndices
        partitionSink match {
          case sink: PartitionableTableSink =>
            val staticPartitions = sink.getStaticPartitions
            if (staticPartitions!= null && staticPartitions.nonEmpty) {
              // validate static partition columns must exist. Calcite already validate this,
              // we did this again for user defined column names.
              staticPartitions.foreach { p =>
                if (!partitionSink.getFieldNames.contains(p._1)){
                  throw new TableException(s"Partition column ${p._1} not exists in the schema.")
                }
              }
              // validate static partitions must be in front of dynamic partitions.
              val partitionFields = partitionSink.getPartitionFields
              staticPartitions.map(_._1) zip
                partitionFields.slice(0, staticPartitions.size()) foreach {
                case (p1, p2) =>
                  if (p1 != p2) {
                    throw new TableException(s"Static partition column $p1 " +
                      s"appears after dynamic partition $p2.")
                  }
              }
              if (staticPartitions.size == partitionFields.length) {
                // all the partitions are static, no need to hash
                needHashByPartitionFields = false
              } else {
                // only hash by dynamic partitions.
                partitionsToHash = partitionsToHash.slice(staticPartitions.size,
                  partitionFields.length)
              }
            }
          case _ =>
        }

        if (needHashByPartitionFields) {
          requiredTraitSet = requiredTraitSet.plus(
            FlinkRelDistribution.hash(partitionsToHash.toSeq
              .map(Integer.valueOf), requireStrict = false))

          if (partitionSink.sortLocalPartition()) {
            val fieldCollations = partitionsToHash.map(RelFieldCollations.of) // default to asc.
            requiredTraitSet = requiredTraitSet.plus(RelCollations.of(fieldCollations: _*))
          }
        }
      // hash by input_file_name and sort by it.
      case deleteSink: UpdateDeleteTableSink
        if deleteSink.operationType == OperationType.DELETE =>
        val index = deleteSink.getFieldNames.indexOf(VirtualColumn.FILENAME.getName)
        val partitionIndices = Array[Int](index)
        requiredTraitSet = requiredTraitSet.plus(
          FlinkRelDistribution.hash(partitionIndices.toSeq
              .map(Integer.valueOf), requireStrict = false)).
            plus(RelCollations.of(RelFieldCollations.of(index)))
      case _ =>
    }
    val newInput = RelOptRule.convert(sinkNode.getInput, requiredTraitSet)

    new BatchExecSink(
      rel.getCluster,
      newTrait,
      newInput,
      sinkNode.sink,
      sinkNode.sinkName)
  }
}

object BatchExecSinkRule {

  val INSTANCE: RelOptRule = new BatchExecSinkRule

}
