/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.table.plan.optimize.program

import org.apache.flink.table.api.scala._
import org.apache.flink.table.api.{RichTableSchema, TableSchema, VirtualColumn}
import org.apache.flink.table.catalog.CatalogTable
import org.apache.flink.table.catalog.config.CatalogTableConfig
import org.apache.flink.table.plan.nodes.calcite.LogicalSink
import org.apache.flink.table.plan.schema.{CatalogCalciteTable, FlinkRelOptTable, TableSourceSinkTable}
import org.apache.flink.table.plan.util.DefaultRelShuttle
import org.apache.flink.table.sinks.{OperationType, UpdateDeleteTableSink}
import org.apache.flink.table.sources.DeletableTableSource

import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.core.TableScan
import org.apache.calcite.rel.logical.{LogicalProject, LogicalTableScan}
import org.apache.calcite.rex.RexInputRef

import java.util

import _root_.scala.collection.JavaConversions._

/**
  * Supports delete dml. Appends row_id and file_name to tableSource fields and configs tableSink
  * fields to these columns.
  */
class FlinkDeleteDMLSupportProgram extends FlinkOptimizeProgram[BatchOptimizeContext] {

  override def optimize(input: RelNode, context: BatchOptimizeContext): RelNode = {
    input match {
      case logicalSink: LogicalSink =>
        logicalSink.sink match {
          case t: UpdateDeleteTableSink
            // delete dml
            if t.operationType == OperationType.DELETE =>
            input.accept(new UpdateVirtualColumnRelShuttle())
          case _ => input
        }
      case _ => input
    }
  }

  /**
    * Appends row_id and file_name to tableSource fields and configs tableSink
    * fields to these columns.
    */
  private class UpdateVirtualColumnRelShuttle extends DefaultRelShuttle {

    // configs table sink
    override def visit(rel: RelNode): RelNode = {
      val newRel = super.visit(rel)
      newRel match {
        case oldLogicalSink: LogicalSink => {
          val oldSink = oldLogicalSink.sink
          val newSink = oldSink.configure(VirtualColumn.getNames,
            VirtualColumn.getTypes)
          if (!newSink.getFieldNames.sameElements(VirtualColumn.getNames)) {
            throw new UnsupportedOperationException("Configs tableSink fields " +
                "to only row_id and file_name failed.")
          }
          LogicalSink.create(oldLogicalSink.getInput, newSink, oldLogicalSink.sinkName)
        }
        case _ => newRel
      }
    }

    // adds row_id and file_name to project
    override def visit(project: LogicalProject): RelNode = {
      val newNode = super.visit(project)

      val projectionKeyList: util.List[String] = VirtualColumn.getNames.toList
      val inputRowType = newNode.asInstanceOf[LogicalProject].getInput.getRowType
      val projects = projectionKeyList.map { field: String =>
        val index = inputRowType.getFieldNames.indexOf(field)
        assert(index >= 0, "It should not happen.")
        val fieldType = inputRowType.getFieldList.get(index).getType
        new RexInputRef(index, fieldType)
      }

      val logicalProject = LogicalProject.create(
        newNode.asInstanceOf[LogicalProject].getInput,
        projects,
        projectionKeyList)
      logicalProject
    }

    // appends row_id and file_name to source fields.
    override def visit(scan: TableScan): RelNode = {
      scan match {
        case oldLogicalTableScan: LogicalTableScan =>
          val oldFlinkTable = oldLogicalTableScan.getTable.asInstanceOf[FlinkRelOptTable]
          val newFlinkTable = oldFlinkTable.table match {
              // table source registered by tableEnv
            case oldTableSourceSinkTable: TableSourceSinkTable[_] =>
              val oldTableSourceTable = oldTableSourceSinkTable.tableSourceTable.get
              oldTableSourceTable.tableSource  match {
                case _: DeletableTableSource => {
                  val oldTableSource =
                    oldTableSourceTable.tableSource.asInstanceOf[DeletableTableSource]
                  val newTableSource = oldTableSource.applyDelete
                  val newTableSourceTable = oldTableSourceTable.replaceTableSource(newTableSource)
                  val newTableSourceSinkTable = new TableSourceSinkTable(
                    Option.apply(newTableSourceTable), oldTableSourceSinkTable.tableSinkTable)
                  oldFlinkTable.copy(newTableSourceSinkTable,
                    newTableSourceTable.getRowType(scan.getCluster.getTypeFactory))
                }
                case _ => throw new UnsupportedOperationException("table source" +
                    " should implements DeletableTableSource interface.")
              }
              // catalogTable
            case oldCatalogCalciteTable: CatalogCalciteTable =>
              val oldCatalogTable = oldCatalogCalciteTable.table
              val newCatalogTable = appendVirtualColumns(oldCatalogTable)
              if (oldCatalogCalciteTable.isStreaming.isDefined) {
                newCatalogTable.getProperties.put(
                  CatalogTableConfig.IS_STREAMING, oldCatalogCalciteTable.isStreaming.get.toString)
              }
              val newCatalogCalciteTable = new CatalogCalciteTable(
                oldCatalogCalciteTable.name, newCatalogTable)

              oldFlinkTable.copy(newCatalogCalciteTable,
                newCatalogCalciteTable.getRowType(scan.getCluster.getTypeFactory))
            case t =>
              throw new UnsupportedOperationException("Un considered table source type: " +
                  t.getClass)
          }
          LogicalTableScan.create(oldLogicalTableScan.getCluster, newFlinkTable)
        case _ => scan
      }
    }

    def appendVirtualColumns(oldCatalogTable: CatalogTable): CatalogTable = {
      val oldTableSchema = oldCatalogTable.getTableSchema
      val newTableSchemaBuilder = TableSchema.builder()
      oldTableSchema.getColumns.foreach(c =>
        newTableSchemaBuilder.column(c.name(), c.internalType(), c.isNullable))
      // append virtual columns
      VirtualColumn.getVirtualColumns.foreach(v =>
        newTableSchemaBuilder.column(v.getName, v.getInternalType, v.isNullable))
      newTableSchemaBuilder.primaryKey(oldTableSchema.getPrimaryKeys: _*)
      oldTableSchema.getUniqueKeys.foreach(l => newTableSchemaBuilder.uniqueKey(l: _*))
      oldTableSchema.getNormalIndexes.foreach(l => newTableSchemaBuilder.normalIndex(l: _*))
      oldTableSchema.getComputedColumns.foreach(c =>
        newTableSchemaBuilder.computedColumn(c.name(), c.expression()))
      oldTableSchema.getWatermarks.foreach(w =>
        newTableSchemaBuilder.watermark(w.name(), w.eventTime(), w.offset()))
      val newTableSchema = newTableSchemaBuilder.build()
      val oldRichTableSchema = oldCatalogTable.getRichTableSchema
      val newRichTableSchema = new RichTableSchema(
        oldRichTableSchema.getColumnNames ++ VirtualColumn.getNames,
        oldRichTableSchema.getColumnTypes ++ VirtualColumn.getTypes,
        oldRichTableSchema.getNullables ++ VirtualColumn.getNullables,
        oldRichTableSchema.getPrimaryKeys.toList.toArray,
        oldRichTableSchema.getUniqueKeys,
        oldRichTableSchema.getPartitionColumns.toList.toArray,
        oldRichTableSchema.getIndexes,
        oldRichTableSchema.getHeaderFields)
      new CatalogTable(oldCatalogTable.getTableType,
        newTableSchema,
        oldCatalogTable.getProperties,
        newRichTableSchema,
        oldCatalogTable.getTableStats,
        oldCatalogTable.getComment,
        oldCatalogTable.getPartitionColumnNames,
        oldCatalogTable.isPartitioned,
        oldCatalogTable.getComputedColumnsSql,
        oldCatalogTable.getRowTimeField,
        oldCatalogTable.getWatermarkOffset,
        oldCatalogTable.getCreateTime,
        oldCatalogTable.getLastAccessTime)
    }
  }
}
