/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.table.validate

import org.apache.flink.table.api.ValidationException

import org.apache.calcite.sql._
import org.apache.flink.table.calcite.{FlinkTypeFactory, LazySqlOperatorTable}
import org.apache.flink.table.catalog._
import org.apache.flink.table.expressions._
import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils.getResultTypeOfCTDFunction
import org.apache.flink.table.functions.utils.{AggSqlFunction, ScalarSqlFunction, TableSqlFunction}
import org.apache.flink.table.types.InternalType
import org.apache.flink.table.util.Logging

import scala.collection.JavaConversions._

/**
  * A catalog for looking up UDFs in catalogs via CatalogManager, used during validation phases
  * of both Table API and SQL API.
  */
class ExternalFunctionCatalog(catalogManager: CatalogManager, typeFactory: FlinkTypeFactory)
  extends FunctionCatalog with Logging {

  override def registerFunction(name: String, catalogFunction: CatalogFunction): Unit = {
    val funcName = name.toLowerCase
    val functionPath = new ObjectPath(catalogManager.getDefaultDatabaseName, funcName)
    val catalog = catalogManager.getDefaultCatalog.asInstanceOf[ReadableWritableCatalog]

    catalog.createFunction(functionPath, catalogFunction, false)
  }

  override def registerOrReplaceFunction(name: String, catalogFunction: CatalogFunction): Unit = {
    val funcName = name.toLowerCase
    catalogManager.getDefaultCatalog.asInstanceOf[ReadableWritableCatalog]
      .dropFunction(new ObjectPath(catalogManager.getDefaultDatabaseName, funcName), false)
    registerFunction(name, catalogFunction)
  }

  override def listFunctions(): List[String] = {
    val db = catalogManager.getDefaultDatabaseName

    catalogManager.getDefaultCatalog.listFunctions(catalogManager.getDefaultDatabaseName)
      .map(p => db + "." + p)
      .toList
  }

  override def getSqlOperatorTable: SqlOperatorTable = {
    LOG.info("Getting sql operator tables")

    new LazySqlOperatorTable(catalogManager, typeFactory)
  }

  override def lookupFunction(name: String, children: Seq[Expression]): Expression = {
    val funcName = name.toLowerCase
    val catalog = catalogManager.getDefaultCatalog
    val functionPath =  new ObjectPath(catalogManager.getDefaultDatabaseName, funcName)

    val sqlFunction = FunctionCatalogUtils.toSqlFunction(
      catalog,
      functionPath.getObjectName,
      catalog.getFunction(functionPath),
      typeFactory
    )

    def getChildrenLiterals: Array[AnyRef] = {
      children.map {
        case literal: Literal => literal.value.asInstanceOf[AnyRef]
        case _ => null
      }.toArray
    }

    def getChildrenTypes: Array[InternalType] = {
      children.map(expr => if (expr.valid) expr.resultType else null).toArray
    }

    sqlFunction match {
      case _: ScalarSqlFunction =>
        val scalarSqlFunction = sqlFunction.asInstanceOf[ScalarSqlFunction]
        ScalarFunctionCall(
          scalarSqlFunction.makeFunction(getChildrenLiterals, getChildrenTypes),
          children)
      case _: TableSqlFunction =>
        val tableSqlFunction = sqlFunction.asInstanceOf[TableSqlFunction]
        val tf = tableSqlFunction.makeFunction(getChildrenLiterals, getChildrenTypes)
        TableFunctionCall(
          name,
          tf,
          children,
          getResultTypeOfCTDFunction(
            tf,
            children.toArray, () => tableSqlFunction.getImplicitResultType))
      case _: AggSqlFunction =>
        val aggSqlFunction = sqlFunction.asInstanceOf[AggSqlFunction]
        AggFunctionCall(
          aggSqlFunction.makeFunction(getChildrenLiterals, getChildrenTypes),
          aggSqlFunction.externalResultType,
          aggSqlFunction.externalAccType,
          children)
      case _ =>
        throw new ValidationException(s"Cannot match sql function $name with any existing types")
    }
  }

  override def dropFunction(name: String): Unit = {
    val funcName = name.toLowerCase
    catalogManager.getDefaultCatalog.asInstanceOf[ReadableWritableCatalog]
      .dropFunction(new ObjectPath(catalogManager.getDefaultDatabaseName, funcName), false)
  }
}
