/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.table.plan.rules.logical

import org.apache.flink.table.api.{RichTableSchema, TableConfigOptions, TableSchema}
import org.apache.flink.table.calcite.{FlinkPlannerImpl, FlinkTypeFactory}
import org.apache.flink.table.catalog.{CatalogTable, CatalogView}
import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils
import org.apache.flink.table.plan.nodes.calcite.LogicalWatermarkAssigner
import org.apache.flink.table.plan.schema._
import org.apache.flink.table.plan.util.FlinkRelOptUtil
import org.apache.flink.table.sources.TableSourceUtil
import org.apache.flink.table.util.TableEnvironmentUtil

import org.apache.flink.shaded.guava18.com.google.common.collect.ImmutableList

import com.google.common.base.Strings
import org.apache.calcite.config.{CalciteConnectionConfigImpl, CalciteConnectionProperty}
import org.apache.calcite.jdbc.CalciteSchemaBuilder
import org.apache.calcite.plan.RelOptRule.{any, operand}
import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall}
import org.apache.calcite.prepare.CalciteCatalogReader
import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeFactory}
import org.apache.calcite.rel.logical._
import org.apache.calcite.rex.RexNode
import org.apache.calcite.schema.Table
import org.apache.calcite.schema.impl.{AbstractSchema, AbstractTable}
import org.apache.calcite.sql.parser.SqlParserPos
import org.apache.calcite.sql.{SemiJoinType, SqlLiteral}
import org.apache.calcite.tools.RelBuilder
import org.apache.calcite.util.{ImmutableBitSet, Pair => CPair}

import java.util
import java.util.{Properties, List => JList}

import scala.collection.JavaConversions._

/**
 * Catalog Table to stream table source rule.
 */
class CatalogTableToStreamTableSourceRule
    extends RelOptRule(
      operand(classOf[LogicalTableScan], any), "CatalogTableToStreamTableSource") {

  override def matches(call: RelOptRuleCall): Boolean = {
    val rel = call.rel(0).asInstanceOf[LogicalTableScan]
    val table = rel.getTable.unwrap(classOf[CatalogCalciteTable])
    table != null && !table.table.isInstanceOf[CatalogView]
  }

  override def onMatch(call: RelOptRuleCall): Unit = {
    val tableScan = call.rel(0).asInstanceOf[LogicalTableScan]
    val catalogTable = tableScan.getTable.unwrap(classOf[CatalogCalciteTable])
    val tableSource = catalogTable.streamTableSource
    var table = tableScan.getTable.asInstanceOf[FlinkRelOptTable].copy(
      new StreamTableSourceTable(
        tableSource, catalogTable.getStatistic(),
        FlinkRelOptUtil.getTableConfig(tableScan).getConf.getBoolean(
          TableConfigOptions.SQL_OPTIMIZER_SOURCE_COLLECT_STATS_ENABLED)),
      TableSourceUtil.getRelDataType(
        tableSource,
        None,
        streaming = true,
        tableScan.getCluster.getTypeFactory.asInstanceOf[FlinkTypeFactory]))

    table = if (tableSource.explainSource().isEmpty) {
      val builder = ImmutableList.builder[String]()
      builder.add("From")
      builder.addAll(tableScan.getTable.getQualifiedName)
      table.config(builder.build(), table.unwrap(classOf[TableSourceTable]))
    } else {
      table
    }

    // parser
    var newRel:RelNode = CatalogTableRules.appendParserNode(
      catalogTable,
      LogicalTableScan.create(tableScan.getCluster, table),
      call.builder())

    // computed column.
    val flinkPlanner = call.getPlanner.getContext.unwrap(classOf[FlinkPlannerImpl])
    if(!Strings.isNullOrEmpty(catalogTable.table.getComputedColumnsSql)) {
      newRel = CatalogTableRules.appendComputedColumns(flinkPlanner,
        call.builder(),
        tableScan,
        newRel,
        catalogTable.table,
        isStreaming = true)
    }

    // watermark.
    if (catalogTable.table.getRowTimeField != null) {
      newRel = new LogicalWatermarkAssigner(
        newRel.getCluster,
        newRel.getTraitSet,
        newRel,
        catalogTable.table.getRowTimeField,
        catalogTable.table.getWatermarkOffset)
    }

    call.transformTo(newRel)
  }
}

/**
 * Catalog Table to batch table source rule.
 */
class CatalogTableToBatchTableSourceRule
    extends RelOptRule(
      operand(classOf[LogicalTableScan], any), "CatalogTableToBatchTableSource") {

  override def matches(call: RelOptRuleCall): Boolean = {
    val rel = call.rel(0).asInstanceOf[LogicalTableScan]
    val table = rel.getTable.unwrap(classOf[CatalogCalciteTable])
    table != null && !table.table.isInstanceOf[CatalogView]
  }

  override def onMatch(call: RelOptRuleCall): Unit = {
    val tableScan = call.rel(0).asInstanceOf[LogicalTableScan]
    val catalogTable = tableScan.getTable.unwrap(classOf[CatalogCalciteTable])
    val tableSource = catalogTable.batchTableSource
    var table = tableScan.getTable.asInstanceOf[FlinkRelOptTable].copy(
      new BatchTableSourceTable(
        tableSource, catalogTable.getStatistic(),
        FlinkRelOptUtil.getTableConfig(tableScan).getConf.getBoolean(
          TableConfigOptions.SQL_OPTIMIZER_SOURCE_COLLECT_STATS_ENABLED)),
      TableSourceUtil.getRelDataType(
        tableSource,
        None,
        streaming = false,
        tableScan.getCluster.getTypeFactory.asInstanceOf[FlinkTypeFactory]))

    table = if (tableSource.explainSource().isEmpty) {
      val builder = ImmutableList.builder[String]()
      builder.add("From")
      builder.addAll(tableScan.getTable.getQualifiedName)
      table.config(builder.build(), table.unwrap(classOf[TableSourceTable]))
    } else {
      table
    }
    // parser
    var newRel:RelNode = CatalogTableRules.appendParserNode(
      catalogTable,
      LogicalTableScan.create(tableScan.getCluster, table),
      call.builder())

    // computed columns
    val flinkPlanner = call.getPlanner.getContext.unwrap(classOf[FlinkPlannerImpl])
    if (!Strings.isNullOrEmpty(catalogTable.table.getComputedColumnsSql)) {
      newRel = CatalogTableRules.appendComputedColumns(flinkPlanner,
        call.builder(),
        tableScan,
        newRel,
        catalogTable.table,
        isStreaming = false)
    }
    call.transformTo(newRel)
  }
}

object CatalogTableRules {
  val STREAM_TABLE_SCAN_RULE = new CatalogTableToStreamTableSourceRule
  val BATCH_TABLE_SCAN_RULE = new CatalogTableToBatchTableSourceRule

  def appendParserNode(
    catalogTable: CatalogCalciteTable, inputNode: RelNode, relBuilder: RelBuilder):RelNode = {

    val parser = catalogTable.tableSourceParser

    if (parser != null) {
      val colId = inputNode.getCluster.createCorrel()
      relBuilder.push(inputNode)
      val params = parser.getParameters.map {
        name: String => relBuilder.field(name)
      }

      val tf = parser.getParser

      val typeFactory = inputNode.getCluster.getTypeFactory.asInstanceOf[FlinkTypeFactory]

      val parserSqlFunction = UserDefinedFunctionUtils.createTableSqlFunction(
        "parser",
        "parser",
        tf,
        typeFactory)

      val rexCall = relBuilder.call(parserSqlFunction, params)

      val tableFunctionScan =
        LogicalTableFunctionScan.create(
          inputNode.getCluster,
          ImmutableList.of(),
          rexCall,
          parserSqlFunction.getElementType(
            typeFactory,
            params.map(_ => SqlLiteral.createNull(new SqlParserPos(0, 0)))),
          parserSqlFunction.getRowType(
            typeFactory,
            params.map(_ => SqlLiteral.createNull(new SqlParserPos(0, 0))),
            parser.getParameters.map(
              name => inputNode.getRowType.getField(name, true, false).getType)
          ),
          null)
      val outputDataType = tableFunctionScan.getRowType

      val columnSetBuilder = ImmutableBitSet.builder()
      params.foreach(param => columnSetBuilder.set(param.getIndex))
      (0 until outputDataType.getFieldCount).foreach (
        idx => columnSetBuilder.set(inputNode.getRowType.getFieldCount + idx))
      val correlate =
        LogicalCorrelate.create(
          inputNode, tableFunctionScan, colId, columnSetBuilder.build(), SemiJoinType.INNER)
      relBuilder.push(correlate)

      val projects = outputDataType.getFieldList.map(
        field => relBuilder.field(inputNode.getRowType.getFieldCount + field.getIndex))
      LogicalProject.create(
        correlate,
        projects,
        outputDataType
      )
    } else {
      inputNode
    }
  }

  /** Replace the original logical schema table scan with
    * (physical table scan + projection with computed columns). */
  def appendComputedColumns(
      flinkPlanner: FlinkPlannerImpl,
      relBuilder: RelBuilder,
      tableScan: LogicalTableScan,
      currentRel: RelNode,
      catalogTable: CatalogTable,
      isStreaming: Boolean): RelNode = {
    val physicalSchema = tableSchemaFromRichSchema(catalogTable.getRichTableSchema)
    val logicalSchema = catalogTable.getTableSchema
    val tableName = "__MockTable__"
    val viewSql = s"select ${catalogTable.getComputedColumnsSql} from $tableName"
    val catalogReader = createSingleTableCatalogReader(tableName,
      flinkPlanner.typeFactory,
      flinkPlanner.typeFactory.buildLogicalRowType(physicalSchema,
        Option.apply(isStreaming)))
    val project = TableEnvironmentUtil.queryToRel(viewSql,
      flinkPlanner, Option.apply(catalogReader)).asInstanceOf[LogicalProject]
    project.getNamedProjects match {
      case expressions: JList[CPair[RexNode, String]] if ! expressions.isEmpty =>
        // validate all the projection fields.
        expressions.map(_.right).foreach { name =>
          if (!logicalSchema.getColumnNames.contains(name)) {
            throw new RuntimeException(s"Column name $name does not exist.")
          }
        }
        val projects = expressions.map { p =>
          val node: RexNode = p.left
          // For batch mode, we have replaced the source type TimeIndicatorRelDataType to
          // Timestamp, we should sync this logic for computed columns.
          if (!isStreaming && node.getType.isInstanceOf[TimeIndicatorRelDataType]) {
            relBuilder.getRexBuilder.makeAbstractCast(
              // Assumes that computed columns field info already exists in original node.
              // Or this should be a bug.
              tableScan.getRowType.getField(p.right, true, false).getType,
              node)
          } else {
            node
          }
        }
        relBuilder.push(currentRel)
        relBuilder.project(projects, expressions.map(_.right)).build()
      case _ => currentRel
    }
  }

  private def tableSchemaFromRichSchema(richSchema: RichTableSchema): TableSchema = {
    val builder = TableSchema.builder()
    val physicalNames = richSchema.getColumnNames
    val physicalDataTypes = richSchema.getColumnTypes
    val physicalNullables = richSchema.getNullables
    (physicalNames zip physicalDataTypes) zip physicalNullables foreach {
      case ((name, dataType), nullable) =>
        builder.column(name, dataType, nullable)
    }
    builder.build()
  }

  /**
    * Create a catalog reader with single table.
    * @param name        table name
    * @param typeFactory flink type factory
    * @param rowType     row type of the table
    * @return a catalog reader instance with only one table
    */
  private def createSingleTableCatalogReader(name: String,
      typeFactory: FlinkTypeFactory,
      rowType: RelDataType): CalciteCatalogReader = {
    val schema: SimpleSchema = new SimpleSchema(name, rowType)
    val props: Properties = new Properties
    props.setProperty(CalciteConnectionProperty.CASE_SENSITIVE.camelName, String.valueOf(false))
    new CalciteCatalogReader(CalciteSchemaBuilder.asRootSchema(schema),
      new util.ArrayList[String](),
      typeFactory,
      new CalciteConnectionConfigImpl(props))
  }

  /** Simple schema with only one table name. **/
  private class SimpleSchema(var name: String, var rowType: RelDataType) extends AbstractSchema {
    override protected def getTableMap: util.Map[String, Table] = {
      val table: Table = new AbstractTable() {
        override def getRowType(relDataTypeFactory: RelDataTypeFactory): RelDataType = rowType
      }
      val tableMap: util.Map[String, Table] = new util.HashMap[String, Table]
      tableMap.put(name, table)
      tableMap
    }
  }
}

