/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.table.codegen.calls

import org.apache.flink.table.codegen.CodeGenUtils._
import org.apache.flink.table.codegen.CodeGeneratorContext.BINARY_STRING
import org.apache.flink.table.codegen.{CodeGeneratorContext, GeneratedExpression}
import org.apache.flink.table.types.{DataTypes, InternalType}

trait CallGenerator {

  def generate(
      ctx: CodeGeneratorContext,
      callName: String,
      operands: Seq[GeneratedExpression],
      returnType: InternalType,
      nullCheck: Boolean): GeneratedExpression
}

object CallGenerator {

  val loggerTerm = "logger$"
  val loggerName = "_builtInFunc$_"
  val loggerMember =
    s"""org.slf4j.Logger $loggerTerm = org.slf4j.LoggerFactory.getLogger("$loggerName");"""

  def generateUnaryOperatorIfNotNull(
      ctx: CodeGeneratorContext,
      callName: String,
      nullCheck: Boolean,
      returnType: InternalType,
      operand: GeneratedExpression,
      primitiveNullable: Boolean = false)
  (expr: String => String)
    : GeneratedExpression = {
    generateCallIfArgsNotNull(
      ctx, callName, nullCheck, returnType, Seq(operand), primitiveNullable) {
      args => expr(args.head)
    }
  }

  def generateOperatorIfNotNull(
      ctx: CodeGeneratorContext,
      callName: String,
      nullCheck: Boolean,
      returnType: InternalType,
      left: GeneratedExpression,
      right: GeneratedExpression)
      (expr: (String, String) => String)
    : GeneratedExpression = {
    generateCallIfArgsNotNull(ctx, callName, nullCheck, returnType, Seq(left, right)) {
      args => expr(args.head, args(1))
    }
  }

  def generateReturnStringCallIfArgsNotNull(
      ctx: CodeGeneratorContext,
      callName: String,
      operands: Seq[GeneratedExpression])
      (call: Seq[String] => String): GeneratedExpression = {
    generateCallIfArgsNotNull(ctx, callName, nullCheck = true, DataTypes.STRING, operands) {
      args => s"$BINARY_STRING.fromString(${call(args)})"
    }
  }

  def generateReturnStringCallWithStmtIfArgsNotNull(
      ctx: CodeGeneratorContext,
      callName: String,
      operands: Seq[GeneratedExpression])
      (call: Seq[String] => (String, String)): GeneratedExpression = {
    generateCallWithStmtIfArgsNotNull(
      ctx, callName, nullCheck = true, DataTypes.STRING, operands) {
      args =>
        val (stmt, result) = call(args)
        (stmt, s"$BINARY_STRING.fromString($result)")
    }
  }

  def generateCallIfArgsNotNull(
      ctx: CodeGeneratorContext,
      callName: String,
      nullCheck: Boolean,
      returnType: InternalType,
      operands: Seq[GeneratedExpression],
      primitiveNullable: Boolean = false)
      (call: Seq[String] => String): GeneratedExpression = {
    generateCallWithStmtIfArgsNotNull(
      ctx, callName, nullCheck, returnType, operands, primitiveNullable) {
      args => ("", call(args))
    }
  }

  def generateCallWithStmtIfArgsNotNull(
      ctx: CodeGeneratorContext,
      callName: String,
      nullCheck: Boolean,
      returnType: InternalType,
      operands: Seq[GeneratedExpression],
      primitiveNullable: Boolean = false)
      (call: Seq[String] => (String, String)): GeneratedExpression = {
    val (resultTypeTerm, defaultValue) = if (primitiveNullable) {
        (boxedTypeTermForType(returnType), "null")
      } else {
        (primitiveTypeTermForType(returnType), primitiveDefaultValue(returnType))
      }
    val nullTerm = ctx.newReusableField("isNull", "boolean")
    val resultTerm = ctx.newReusableField("result", resultTypeTerm)
    val isResultNullable = (isReference(returnType) && !isInternalPrimitive(returnType)) ||
      primitiveNullable

    val nullResultCode = if (nullCheck && isResultNullable) {
      s"$nullTerm = ($resultTerm == null);"
    } else {
      ""
    }

    val (stmt, result) = call(operands.map(_.resultTerm))

    ctx.addReusableMember(loggerMember)
    val loggerCode = generateLoggerCode(callName, operands)

    val resultCode = if (nullCheck && operands.nonEmpty) {
      s"""
         |${operands.map(_.code).mkString("\n")}
         |$nullTerm = ${operands.map(_.nullTerm).mkString(" || ")};
         |$resultTerm = $defaultValue;
         |if (!$nullTerm) {
         |  try {
         |    $stmt
         |    $resultTerm = $result;
         |    $nullResultCode
         |  } catch (Throwable e) {
         |     $loggerCode
         |     $nullTerm = true;
         |  }
         |}
         |""".stripMargin
    } else if (nullCheck && operands.isEmpty) {
      s"""
         |${operands.map(_.code).mkString("\n")}
         |$nullTerm = false;
         |$resultTerm = $defaultValue;
         |try {
         |  $stmt
         |  $resultTerm = $result;
         |  $nullResultCode
         |} catch (Throwable e) {
         |    $loggerCode
         |    $nullTerm = true;
         |}
         |""".stripMargin
    } else {
      // Not catch built-in func exception if nullCheck is false
      s"""
         |$nullTerm = false;
         |${operands.map(_.code).mkString("\n")}
         |$stmt
         |$resultTerm = $result;
         |""".stripMargin
    }

    GeneratedExpression(resultTerm, nullTerm, resultCode, returnType)
  }

  def generateCallIfArgsNullable(
      ctx: CodeGeneratorContext,
      callName: String,
      nullCheck: Boolean,
      returnType: InternalType,
      operands: Seq[GeneratedExpression],
      primitiveNullable: Boolean = false)
      (call: Seq[String] => String): GeneratedExpression = {

    val (resultTypeTerm, defaultValue) = if (primitiveNullable) {
      (boxedTypeTermForType(returnType), "null")
    } else {
      (primitiveTypeTermForType(returnType), primitiveDefaultValue(returnType))
    }
    val nullTerm = ctx.newReusableField("isNull", "boolean")
    val resultTerm = ctx.newReusableField("result", resultTypeTerm)
    val isResultNullable = (isReference(returnType) && !isInternalPrimitive(returnType)) ||
      primitiveNullable
    val nullCode = if (nullCheck && isResultNullable) {
      s"$nullTerm = $resultTerm == null;"
    } else {
      ""
    }

    val parameters = operands.map(x =>
      if (x.resultType.equals(DataTypes.STRING)){
        "( " + x.nullTerm + " ) ? null : (" + x.resultTerm + ")"
      } else {
        x.resultTerm
      })

    ctx.addReusableMember(loggerMember)
    val loggerCode = generateLoggerCode(callName, operands)

    val resultCode = if (nullCheck) {
      s"""
         |${operands.map(_.code).mkString("\n")}
         |$nullTerm = false;
         |$resultTerm = $defaultValue;
         |try {
         |  $resultTerm = ${call(parameters)};
         |  $nullCode
         |} catch (Throwable e) {
         |    $loggerCode
         |    $nullTerm = true;
         |}
       """.stripMargin
    } else {
      // Not catch built-in func exception if nullCheck is false
      s"""
         |${operands.map(_.code).mkString("\n")}
         |$nullTerm = false;
         |$resultTerm = ${call(parameters)};
         |$nullCode
       """.stripMargin
    }

    GeneratedExpression(resultTerm, nullTerm, resultCode, returnType)
  }

  private def generateLoggerCode(
      callName: String,
      operands: Seq[GeneratedExpression]): String = {
    val inputArgs = if (operands.isEmpty) "" else operands.map(_.resultTerm).mkString(",") + ","
      val loggerCode =
        s"""
         |$loggerTerm.error("Result of call [{}] is null because an exception happened.\\n" +
         |"There are {} args: \\n${List.fill(operands.size)("[{}]").mkString("\\n")}",
         |"$callName", ${operands.size}, $inputArgs e);
       """.stripMargin
    loggerCode
  }
}
