/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.flink.table.plan.util

import java.io.{PrintWriter, StringWriter}
import java.util

import org.apache.flink.streaming.api.transformations.{OneInputTransformation, StreamTransformation}
import org.apache.flink.table.plan.nodes.exec.StreamExecNode
import org.apache.flink.table.plan.nodes.physical.stream.StreamExecDataStreamScan
import org.apache.flink.table.plan.schema.IntermediateDataStreamTable
import org.apache.flink.table.runtime.functions.utils.Md5Utils

import _root_.scala.collection.JavaConversions._
import _root_.scala.collection.mutable

object ExecNodeUidCalculator {

  private val cachedUids = new util.IdentityHashMap[StreamExecNode[_], Option[String]]()

  private val cachedStateDigests = new util.IdentityHashMap[StreamExecNode[_], Option[String]]()

  private val transformation2ExecNodeMapping =
    new util.IdentityHashMap[StreamTransformation[_], StreamExecNode[_]]

  /**
    * Returns the optional uid of the specified ExecNode, None if the specified
    * exec node has not state.
    */
  def getUid(node: StreamExecNode[_]): Option[String] = {
    cachedUids.getOrElseUpdate(node, {
      getStateDigest(node).map {plainUid =>
        Md5Utils.md5sum(plainUid + getInputStateDigests(node).mkString(", "))
      }
    })
  }

  /**
    * Returns the state digest of the specified ExecNode itself NOT including its inputs.
    */
  def getStateDigest(node: StreamExecNode[_]): Option[String] = {
    cachedStateDigests.getOrElseUpdate(node, {
      val sw = new StringWriter
      val pw = new PrintWriter(sw)
      val execNodeInfoWriter = new ExecNodeInfoWriter(pw, ExecNodeInfoWriter.STREAM_EXEC, true)
      node.getStateDigest(execNodeInfoWriter).done(node)
      val plainUid = sw.toString
      if (plainUid.isEmpty) {
        None
      } else {
        Some(plainUid)
      }
    })
  }

  /**
    * Records the mapping between StreamTransformation and StreamExecRel. It is needed
    * in case of subsection optimization is triggered.
    */
  def addTransformationExecNodeMapping(
      transformation: StreamTransformation[_], node: StreamExecNode[_]): Unit = {

    // finds the real node corresponding to the specified transformation.
    def findRealNode(transformation: StreamTransformation[_]): StreamExecNode[_] = {
      transformation2ExecNodeMapping.getOrElse(
        transformation,
        findRealNode(transformation.asInstanceOf[OneInputTransformation[_, _]].getInput))
    }

    node match {
      case scan: StreamExecDataStreamScan
        if scan.dataStreamTable.isInstanceOf[IntermediateDataStreamTable[_]] =>
        // needs to find the real node for the scan node generated by subsection optimization
        transformation2ExecNodeMapping.put(transformation, findRealNode(transformation))
      case _ => transformation2ExecNodeMapping.put(transformation, node)
    }
  }

  /**
    * Returns the state digests of the inputs of the specified ExecNode.
    */
  private def getInputStateDigests(node: StreamExecNode[_]): Array[String] = {
    val inputStateDigests = mutable.ArrayBuffer[String]()
    node.getInputNodes.foreach {
      case scan: StreamExecDataStreamScan
        if scan.dataStreamTable.isInstanceOf[IntermediateDataStreamTable[_]] =>
        // finds the real input node for scan node generated by subsection optimization
        val inputNode = transformation2ExecNodeMapping.get(
          scan.dataStreamTable.asInstanceOf[IntermediateDataStreamTable[_]]
            .dataStream.getTransformation)
        inputStateDigests.appendAll(getInputStateDigests(inputNode))

      case input: StreamExecNode[_] =>
        inputStateDigests.appendAll(getInputStateDigests(input))
        getStateDigest(input).foreach(inputStateDigests.append(_))
    }
    inputStateDigests.toArray
  }
}
