/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.table.plan.rules.logical

import org.apache.flink.streaming.api.datastream.DataStream
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment
import org.apache.flink.table.api.TableSchema
import org.apache.flink.table.calcite.{FlinkPlannerImpl, FlinkTypeFactory}
import org.apache.flink.table.catalog.{CatalogTable, CatalogView}
import org.apache.flink.table.dataformat.BaseRow
import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils
import org.apache.flink.table.plan.nodes.calcite.LogicalWatermarkAssigner
import org.apache.flink.table.plan.schema._
import org.apache.flink.table.plan.stats.TableStats
import org.apache.flink.table.sources.{BatchTableSource, StreamTableSource, TableSourceUtil}
import org.apache.flink.table.types.{DataType, DataTypes}
import org.apache.flink.table.util.TableEnvironmentUtil

import org.apache.flink.shaded.guava18.com.google.common.collect.ImmutableList

import com.google.common.base.Strings
import org.apache.calcite.plan.RelOptRule.{any, operand}
import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall}
import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.logical._
import org.apache.calcite.rex.RexNode
import org.apache.calcite.sql.parser.SqlParserPos
import org.apache.calcite.sql.{SemiJoinType, SqlLiteral}
import org.apache.calcite.tools.RelBuilder
import org.apache.calcite.util.ImmutableBitSet
import org.apache.calcite.util.{Pair => CPair}

import java.util.{List => JList}

import scala.collection.JavaConversions._

/**
 * Catalog Table to stream table source rule.
 */
class CatalogTableToStreamTableSourceRule
    extends RelOptRule(
      operand(classOf[LogicalTableScan], any), "CatalogTableToStreamTableSource") {

  override def matches(call: RelOptRuleCall): Boolean = {
    val rel = call.rel(0).asInstanceOf[LogicalTableScan]
    val table = rel.getTable.unwrap(classOf[CatalogCalciteTable])
    table != null && !table.table.isInstanceOf[CatalogView]
  }

  override def onMatch(call: RelOptRuleCall): Unit = {
    val oldRel = call.rel(0).asInstanceOf[LogicalTableScan]
    val catalogTable = oldRel.getTable.unwrap(classOf[CatalogCalciteTable])
    val tableSource = catalogTable.streamTableSource
    var table = oldRel.getTable.asInstanceOf[FlinkRelOptTable].copy(
      new StreamTableSourceTable(
        tableSource, catalogTable.getStatistic()),
      TableSourceUtil.getRelDataType(
        tableSource,
        None,
        streaming = true,
        oldRel.getCluster.getTypeFactory.asInstanceOf[FlinkTypeFactory]))

    table = if (tableSource.explainSource().isEmpty) {
      val builder = ImmutableList.builder[String]()
      builder.add("From")
      builder.addAll(oldRel.getTable.getQualifiedName)
      table.config(builder.build(), table.unwrap(classOf[TableSourceTable]))
    } else {
      table
    }

    // parser
    var newRel:RelNode = CatalogTableRules.appendParserNode(
      catalogTable,
      LogicalTableScan.create(oldRel.getCluster, table),
      call.builder())

    // computed column.
    val flinkPlanner = call.getPlanner.getContext.unwrap(classOf[FlinkPlannerImpl])
    if(!Strings.isNullOrEmpty(catalogTable.table.getComputedColumnsSql)) {
      newRel = CatalogTableRules.appendComputedColumns(flinkPlanner,
        call.builder(),
        oldRel,
        newRel,
        catalogTable.table,
        isStreaming = true)
    }

    // watermark.
    if (catalogTable.table.getRowTimeField != null) {
      newRel = new LogicalWatermarkAssigner(
        newRel.getCluster,
        newRel.getTraitSet,
        newRel,
        catalogTable.table.getRowTimeField,
        catalogTable.table.getWatermarkOffset)
    }

    call.transformTo(newRel)
  }
}

/**
 * Catalog Table to batch table source rule.
 */
class CatalogTableToBatchTableSourceRule
    extends RelOptRule(
      operand(classOf[LogicalTableScan], any), "CatalogTableToBatchTableSource") {

  override def matches(call: RelOptRuleCall): Boolean = {
    val rel = call.rel(0).asInstanceOf[LogicalTableScan]
    val table = rel.getTable.unwrap(classOf[CatalogCalciteTable])
    table != null && !table.table.isInstanceOf[CatalogView]
  }

  override def onMatch(call: RelOptRuleCall): Unit = {
    val oldRel = call.rel(0).asInstanceOf[LogicalTableScan]
    val catalogTable = oldRel.getTable.unwrap(classOf[CatalogCalciteTable])
    val tableSource = catalogTable.batchTableSource
    var table = oldRel.getTable.asInstanceOf[FlinkRelOptTable].copy(
      new BatchTableSourceTable(
        tableSource, catalogTable.getStatistic()),
      TableSourceUtil.getRelDataType(
        tableSource,
        None,
        streaming = false,
        oldRel.getCluster.getTypeFactory.asInstanceOf[FlinkTypeFactory]))

    table = if (tableSource.explainSource().isEmpty) {
      val builder = ImmutableList.builder[String]()
      builder.add("From")
      builder.addAll(oldRel.getTable.getQualifiedName)
      table.config(builder.build(), table.unwrap(classOf[TableSourceTable]))
    } else {
      table
    }
    // parser
    var newRel:RelNode = CatalogTableRules.appendParserNode(
      catalogTable,
      LogicalTableScan.create(oldRel.getCluster, table),
      call.builder())

    // computed columns
    val flinkPlanner = call.getPlanner.getContext.unwrap(classOf[FlinkPlannerImpl])
    if (!Strings.isNullOrEmpty(catalogTable.table.getComputedColumnsSql)) {
      newRel = CatalogTableRules.appendComputedColumns(flinkPlanner,
        call.builder(),
        oldRel,
        newRel,
        catalogTable.table,
        isStreaming = false)
    }
    call.transformTo(newRel)
  }
}

object CatalogTableRules {
  val STREAM_TABLE_SCAN_RULE = new CatalogTableToStreamTableSourceRule
  val BATCH_TABLE_SCAN_RULE = new CatalogTableToBatchTableSourceRule

  def appendParserNode(
    catalogTable: CatalogCalciteTable, inputNode: RelNode, relBuilder: RelBuilder):RelNode = {

    val parser = catalogTable.tableSourceParser

    if (parser != null) {
      val colId = inputNode.getCluster.createCorrel()
      relBuilder.push(inputNode)
      val params = parser.getParameters.map {
        name: String => relBuilder.field(name)
      }

      val tf = parser.getParser

      val typeFactory = inputNode.getCluster.getTypeFactory.asInstanceOf[FlinkTypeFactory]

      val parserSqlFunction = UserDefinedFunctionUtils.createTableSqlFunction(
        "parser",
        "parser",
        tf,
        typeFactory)

      val rexCall = relBuilder.call(parserSqlFunction, params)

      val tableFunctionScan =
        LogicalTableFunctionScan.create(
          inputNode.getCluster,
          ImmutableList.of(),
          rexCall,
          parserSqlFunction.getElementType(
            typeFactory,
            params.map(_ => SqlLiteral.createNull(new SqlParserPos(0, 0)))),
          parserSqlFunction.getRowType(
            typeFactory,
            params.map(_ => SqlLiteral.createNull(new SqlParserPos(0, 0))),
            parser.getParameters.map(
              name => inputNode.getRowType.getField(name, true, false).getType)
          ),
          null)
      val outputDataType = tableFunctionScan.getRowType

      val columnSetBuilder = ImmutableBitSet.builder()
      params.foreach(param => columnSetBuilder.set(param.getIndex))
      (0 until outputDataType.getFieldCount).foreach (
        idx => columnSetBuilder.set(inputNode.getRowType.getFieldCount + idx))
      val correlate =
        LogicalCorrelate.create(
          inputNode, tableFunctionScan, colId, columnSetBuilder.build(), SemiJoinType.INNER)
      relBuilder.push(correlate)

      val projects = outputDataType.getFieldList.map(
        field => relBuilder.field(inputNode.getRowType.getFieldCount + field.getIndex))
      LogicalProject.create(
        correlate,
        projects,
        outputDataType
      )
    } else {
      inputNode
    }
  }

  def appendComputedColumns(
      flinkPlanner: FlinkPlannerImpl,
      relBuilder: RelBuilder,
      oldRel: LogicalTableScan,
      node: RelNode,
      catalogTable: CatalogTable,
      isStreaming: Boolean): RelNode = {
    val logicalSchema = catalogTable.getTableSchema
    val fullTableName = oldRel.getTable.getQualifiedName.map(p => s"`$p`").mkString(".")
    val viewSql = "select " + catalogTable.getComputedColumnsSql + " from " + fullTableName
    val project = TableEnvironmentUtil.queryToRel(viewSql, flinkPlanner)
      .asInstanceOf[LogicalProject]
    project.getNamedProjects match {
      case expressions: JList[CPair[RexNode, String]] if ! expressions.isEmpty =>
        // validate all the projection fields.
        expressions.map(_.right).foreach { name =>
          if (!logicalSchema.getFieldNames.contains(name)) {
            throw new RuntimeException(s"Computed column name $name does not exist.")
          }
        }
        val projects = expressions.map { p =>
          val node: RexNode = p.left
          // For batch mode, we have replaced the source type TimeIndicatorRelDataType to
          // Timestamp, we should sync this logic for computed columns.
          if (!isStreaming && node.getType.isInstanceOf[TimeIndicatorRelDataType]) {
            relBuilder.getRexBuilder.makeAbstractCast(
              // Assumes that computed columns field info already exists in original node.
              // Or this should be a bug.
              oldRel.getRowType.getField(p.right, true, false).getType,
              node)
          } else {
            node
          }
        }
        relBuilder.push(node)
        relBuilder.project(projects, expressions.map(_.right)).build()
      case _ => node
    }
  }
}

private class MockTableSource(var name: String, var schema: TableSchema)
    extends BatchTableSource[BaseRow]
    with StreamTableSource[BaseRow] {
  override def getBoundedStream(streamEnv: StreamExecutionEnvironment): DataStream[BaseRow] = null

  override def getReturnType: DataType = DataTypes.createRowTypeV2(schema.getFieldTypes,
    schema.getFieldNames)

  override def getTableSchema: TableSchema = schema

  override def explainSource(): String = name

  override def getTableStats: TableStats = null

  override def getDataStream(execEnv: StreamExecutionEnvironment): DataStream[BaseRow] = null
}

