/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.runtime.healthmanager.plugins.utils;

import org.apache.flink.api.common.JobID;
import org.apache.flink.runtime.healthmanager.HealthMonitor;
import org.apache.flink.runtime.healthmanager.RestServerClient;
import org.apache.flink.runtime.healthmanager.metrics.MetricAggType;
import org.apache.flink.runtime.healthmanager.metrics.MetricProvider;
import org.apache.flink.runtime.healthmanager.metrics.TaskMetricSubscription;
import org.apache.flink.runtime.healthmanager.metrics.timeline.TimelineAggType;
import org.apache.flink.runtime.jobgraph.JobVertexID;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.HashMap;
import java.util.Map;

import static org.apache.flink.runtime.healthmanager.plugins.utils.MetricNames.SOURCE_DELAY;
import static org.apache.flink.runtime.healthmanager.plugins.utils.MetricNames.SOURCE_LATENCY_COUNT;
import static org.apache.flink.runtime.healthmanager.plugins.utils.MetricNames.SOURCE_LATENCY_SUM;
import static org.apache.flink.runtime.healthmanager.plugins.utils.MetricNames.SOURCE_PARTITION_COUNT;
import static org.apache.flink.runtime.healthmanager.plugins.utils.MetricNames.SOURCE_PARTITION_LATENCY_COUNT;
import static org.apache.flink.runtime.healthmanager.plugins.utils.MetricNames.SOURCE_PARTITION_LATENCY_SUM;
import static org.apache.flink.runtime.healthmanager.plugins.utils.MetricNames.SOURCE_PROCESS_LATENCY_COUNT;
import static org.apache.flink.runtime.healthmanager.plugins.utils.MetricNames.SOURCE_PROCESS_LATENCY_SUM;
import static org.apache.flink.runtime.healthmanager.plugins.utils.MetricNames.TASK_INPUT_COUNT;
import static org.apache.flink.runtime.healthmanager.plugins.utils.MetricNames.TASK_LATENCY_COUNT;
import static org.apache.flink.runtime.healthmanager.plugins.utils.MetricNames.TASK_LATENCY_SUM;
import static org.apache.flink.runtime.healthmanager.plugins.utils.MetricNames.TASK_OUTPUT_COUNT;
import static org.apache.flink.runtime.healthmanager.plugins.utils.MetricNames.TASK_TIMER_COUNT;
import static org.apache.flink.runtime.healthmanager.plugins.utils.MetricNames.WAIT_OUTPUT_COUNT;
import static org.apache.flink.runtime.healthmanager.plugins.utils.MetricNames.WAIT_OUTPUT_SUM;

/**
 * Utils to subscribe task performance metrics.
 */
public class TaskMetricsSubscriber {

	private static final Logger LOGGER = LoggerFactory.getLogger(TaskMetricsSubscriber.class);

	// metric subscriptions

	private Map<JobVertexID, TaskMetricSubscription> inputTpsSubs;
	private Map<JobVertexID, TaskMetricSubscription> maxTpsPerMinuteSubs;
	private Map<JobVertexID, TaskMetricSubscription> outputTpsSubs;
	private Map<JobVertexID, TaskMetricSubscription> timerCountSubs;
	private Map<JobVertexID, TaskMetricSubscription> taskLatencyCountRangeSubs;
	private Map<JobVertexID, TaskMetricSubscription> taskLatencySumRangeSubs;
	private Map<JobVertexID, TaskMetricSubscription> sourceLatencyCountRangeSubs;
	private Map<JobVertexID, TaskMetricSubscription> sourceLatencySumRangeSubs;
	private Map<JobVertexID, TaskMetricSubscription> sourceProcessLatencyCountRangeSubs;
	private Map<JobVertexID, TaskMetricSubscription> sourceProcessLatencySumRangeSubs;
	private Map<JobVertexID, TaskMetricSubscription> waitOutputCountRangeSubs;
	private Map<JobVertexID, TaskMetricSubscription> waitOutputSumRangeSubs;
	private Map<JobVertexID, TaskMetricSubscription> sourcePartitionCountSubs;
	private Map<JobVertexID, TaskMetricSubscription> sourcePartitionLatencyCountRangeSubs;
	private Map<JobVertexID, TaskMetricSubscription> sourcePartitionLatencySumRangeSubs;
	private Map<JobVertexID, TaskMetricSubscription> sourceDelayRateSubs;

	private HealthMonitor monitor;

	private long interval;

	public TaskMetricsSubscriber(HealthMonitor monitor, long interval) {
		this.monitor = monitor;
		this.interval = interval;
	}

	public void open() {

		inputTpsSubs = new HashMap<>();
		maxTpsPerMinuteSubs = new HashMap<>();
		outputTpsSubs = new HashMap<>();
		timerCountSubs = new HashMap<>();
		taskLatencyCountRangeSubs = new HashMap<>();
		taskLatencySumRangeSubs = new HashMap<>();
		sourceLatencyCountRangeSubs = new HashMap<>();
		sourceLatencySumRangeSubs = new HashMap<>();
		sourceProcessLatencyCountRangeSubs = new HashMap<>();
		sourceProcessLatencySumRangeSubs = new HashMap<>();
		waitOutputCountRangeSubs = new HashMap<>();
		waitOutputSumRangeSubs = new HashMap<>();
		sourcePartitionCountSubs = new HashMap<>();
		sourcePartitionLatencyCountRangeSubs = new HashMap<>();
		sourcePartitionLatencySumRangeSubs = new HashMap<>();
		sourceDelayRateSubs = new HashMap<>();

		RestServerClient.JobConfig jobConfig = monitor.getJobConfig();

		JobID jobID = monitor.getJobID();
		MetricProvider metricProvider = monitor.getMetricProvider();

		// subscribe metrics.
		for (JobVertexID vertexId : jobConfig.getVertexConfigs().keySet()) {
			// input tps
			TaskMetricSubscription inputTpsSub = metricProvider.subscribeTaskMetric(
					jobID, vertexId, TASK_INPUT_COUNT, MetricAggType.SUM, interval, TimelineAggType.RATE);
			inputTpsSubs.put(vertexId, inputTpsSub);

			// input tps per minute.
			TaskMetricSubscription tpsPerMinuteSub = metricProvider.subscribeTaskMetric(
					jobID, vertexId, TASK_INPUT_COUNT, MetricAggType.SUM, interval / 60_000L * 60_000L, TimelineAggType.MAX, 60_000L, TimelineAggType.RATE);
			maxTpsPerMinuteSubs.put(vertexId, tpsPerMinuteSub);

			// output tps
			TaskMetricSubscription outputTps = metricProvider.subscribeTaskMetric(
					jobID, vertexId, TASK_OUTPUT_COUNT, MetricAggType.SUM, interval, TimelineAggType.RATE);
			outputTpsSubs.put(vertexId, outputTps);

			// timer count
			TaskMetricSubscription timerCount = metricProvider.subscribeTaskMetric(
					jobID, vertexId, TASK_TIMER_COUNT, MetricAggType.SUM, 1, TimelineAggType.LATEST);
			timerCountSubs.put(vertexId, timerCount);

			// source latency
			if (monitor.getJobTopologyAnalyzer().isSource(vertexId)) {
				sourceLatencyCountRangeSubs.put(vertexId,
						metricProvider.subscribeTaskMetric(
								jobID, vertexId, SOURCE_LATENCY_COUNT, MetricAggType.SUM, interval, TimelineAggType.LATEST));
				sourceLatencySumRangeSubs.put(vertexId,
						metricProvider.subscribeTaskMetric(
								jobID, vertexId, SOURCE_LATENCY_SUM, MetricAggType.SUM, interval, TimelineAggType.LATEST));

				sourcePartitionCountSubs.put(vertexId,
						metricProvider.subscribeTaskMetric(
								jobID, vertexId, SOURCE_PARTITION_COUNT, MetricAggType.SUM, interval, TimelineAggType.LATEST));
				sourcePartitionLatencyCountRangeSubs.put(vertexId,
						metricProvider.subscribeTaskMetric(
								jobID, vertexId, SOURCE_PARTITION_LATENCY_COUNT, MetricAggType.SUM, interval, TimelineAggType.RANGE));
				sourcePartitionLatencySumRangeSubs.put(vertexId,
						metricProvider.subscribeTaskMetric(
								jobID, vertexId, SOURCE_PARTITION_LATENCY_SUM, MetricAggType.SUM, interval, TimelineAggType.RANGE));
				sourceProcessLatencyCountRangeSubs.put(vertexId,
						metricProvider.subscribeTaskMetric(
								jobID, vertexId, SOURCE_PROCESS_LATENCY_COUNT, MetricAggType.SUM, interval, TimelineAggType.LATEST));
				sourceProcessLatencySumRangeSubs.put(vertexId,
						metricProvider.subscribeTaskMetric(
								jobID, vertexId, SOURCE_PROCESS_LATENCY_SUM, MetricAggType.SUM, interval, TimelineAggType.LATEST));
				sourceDelayRateSubs.put(vertexId,
						metricProvider.subscribeTaskMetric(
								jobID, vertexId, SOURCE_DELAY, MetricAggType.AVG, interval, TimelineAggType.RATE));
			}

			// task latency
			TaskMetricSubscription latencyCountRangeSub = metricProvider.subscribeTaskMetric(
					jobID, vertexId, TASK_LATENCY_COUNT, MetricAggType.SUM, interval, TimelineAggType.LATEST);
			taskLatencyCountRangeSubs.put(vertexId, latencyCountRangeSub);
			TaskMetricSubscription latencySumRangeSub = metricProvider.subscribeTaskMetric(
					jobID, vertexId, TASK_LATENCY_SUM, MetricAggType.SUM, interval, TimelineAggType.LATEST);
			taskLatencySumRangeSubs.put(vertexId, latencySumRangeSub);

			// wait output
			TaskMetricSubscription waitOutputCountRangeSub = metricProvider.subscribeTaskMetric(
					jobID, vertexId, WAIT_OUTPUT_COUNT, MetricAggType.SUM, interval, TimelineAggType.LATEST);
			waitOutputCountRangeSubs.put(vertexId, waitOutputCountRangeSub);
			TaskMetricSubscription waitOutputSumRangeSub = metricProvider.subscribeTaskMetric(
					jobID, vertexId, WAIT_OUTPUT_SUM, MetricAggType.SUM, interval, TimelineAggType.LATEST);
			waitOutputSumRangeSubs.put(vertexId, waitOutputSumRangeSub);
		}
	}

	public void close() {
		if (monitor == null) {
			return;
		}

		MetricProvider metricProvider = monitor.getMetricProvider();

		if (metricProvider == null) {
			return;
		}

		if (inputTpsSubs != null) {
			for (TaskMetricSubscription sub : inputTpsSubs.values()) {
				metricProvider.unsubscribe(sub);
			}
		}

		if (outputTpsSubs != null) {
			for (TaskMetricSubscription sub : outputTpsSubs.values()) {
				metricProvider.unsubscribe(sub);
			}
		}

		if (timerCountSubs != null) {
			for (TaskMetricSubscription sub: timerCountSubs.values()) {
				metricProvider.unsubscribe(sub);
			}
		}

		if (sourceLatencyCountRangeSubs != null) {
			for (TaskMetricSubscription sub : sourceLatencyCountRangeSubs.values()) {
				metricProvider.unsubscribe(sub);
			}
		}

		if (sourceLatencySumRangeSubs != null) {
			for (TaskMetricSubscription sub : sourceLatencySumRangeSubs.values()) {
				metricProvider.unsubscribe(sub);
			}
		}

		if (sourcePartitionCountSubs != null) {
			for (TaskMetricSubscription sub : sourcePartitionCountSubs.values()) {
				metricProvider.unsubscribe(sub);
			}
		}

		if (sourcePartitionLatencyCountRangeSubs != null) {
			for (TaskMetricSubscription sub : sourcePartitionLatencyCountRangeSubs.values()) {
				metricProvider.unsubscribe(sub);
			}
		}

		if (sourcePartitionLatencySumRangeSubs != null) {
			for (TaskMetricSubscription sub : sourcePartitionLatencySumRangeSubs.values()) {
				metricProvider.unsubscribe(sub);
			}
		}

		if (taskLatencyCountRangeSubs != null) {
			for (TaskMetricSubscription sub : taskLatencyCountRangeSubs.values()) {
				metricProvider.unsubscribe(sub);
			}
		}

		if (taskLatencySumRangeSubs != null) {
			for (TaskMetricSubscription sub : taskLatencySumRangeSubs.values()) {
				metricProvider.unsubscribe(sub);
			}
		}

		if (waitOutputCountRangeSubs != null) {
			for (TaskMetricSubscription sub : waitOutputCountRangeSubs.values()) {
				metricProvider.unsubscribe(sub);
			}
		}

		if (waitOutputSumRangeSubs != null) {
			for (TaskMetricSubscription sub : waitOutputSumRangeSubs.values()) {
				metricProvider.unsubscribe(sub);
			}
		}
	}

	public Map<JobVertexID, TaskMetrics> getTaskMetrics() {

		long now = System.currentTimeMillis();

		RestServerClient.JobConfig jobConfig = monitor.getJobConfig();
		if (jobConfig == null) {
			return null;
		}

		Map<JobVertexID, TaskMetrics> metrics = new HashMap<>();
		for (JobVertexID vertexId : jobConfig.getVertexConfigs().keySet()) {
			TaskMetricSubscription inputTpsSub = inputTpsSubs.get(vertexId);
			TaskMetricSubscription outputTpsSub = outputTpsSubs.get(vertexId);
			TaskMetricSubscription timerCountSub = timerCountSubs.get(vertexId);
			TaskMetricSubscription sourceLatencyCountRangeSub = sourceLatencyCountRangeSubs.get(vertexId);
			TaskMetricSubscription sourceLatencySumRangeSub = sourceLatencySumRangeSubs.get(vertexId);
			TaskMetricSubscription sourceProcessLatencyCountRangeSub = sourceProcessLatencyCountRangeSubs.get(vertexId);
			TaskMetricSubscription sourceProcessLatencySumRangeSub = sourceProcessLatencySumRangeSubs.get(vertexId);
			TaskMetricSubscription taskLatencyCountRangeSub = taskLatencyCountRangeSubs.get(vertexId);
			TaskMetricSubscription taskLatencySumRangeSub = taskLatencySumRangeSubs.get(vertexId);
			TaskMetricSubscription waitOutputCountRangeSub = waitOutputCountRangeSubs.get(vertexId);
			TaskMetricSubscription waitOutputSumRangeSub = waitOutputSumRangeSubs.get(vertexId);
			TaskMetricSubscription sourcePartitionCountSub = sourcePartitionCountSubs.get(vertexId);
			TaskMetricSubscription sourcePartitionLatencyCountRangeSub = sourcePartitionLatencyCountRangeSubs.get(vertexId);
			TaskMetricSubscription sourcePartitionLatencySumRangeSub = sourcePartitionLatencySumRangeSubs.get(vertexId);
			TaskMetricSubscription sourceDelayIncreasingRateSub = sourceDelayRateSubs.get(vertexId);

			TaskMetricSubscription maxTpsPerMinuteSub = maxTpsPerMinuteSubs.get(vertexId);

			// check metrics complete
			if (!MetricUtils.validateTaskMetric(monitor, interval * 2, inputTpsSub,
					taskLatencyCountRangeSub, taskLatencySumRangeSub, timerCountSub)) {
				LOGGER.debug("input metric missing " + vertexId);
				LOGGER.debug("input tps " + inputTpsSub.getValue()
						+ ", task latency count range " + taskLatencyCountRangeSub.getValue()
						+ ", task latency sum range " + taskLatencySumRangeSub.getValue()
						+ ", timer count " + timerCountSub.getValue());
				return null;
			}

			if (!monitor.getJobTopologyAnalyzer().isSource(vertexId) &&
					!MetricUtils.validateTaskMetric(monitor, interval * 2,
							outputTpsSub, waitOutputCountRangeSub, waitOutputSumRangeSub)) {
				LOGGER.debug("output metric missing " + vertexId);
				LOGGER.debug("output tps " + outputTpsSub.getValue()
						+ "wait output count range " + waitOutputCountRangeSub.getValue()
						+ "wait output sum range " + waitOutputSumRangeSub.getValue());
				return null;
			}

			if (monitor.getJobTopologyAnalyzer().isSource(vertexId) &&
					!MetricUtils.validateTaskMetric(monitor, interval * 2,
							sourceLatencyCountRangeSub, sourceLatencySumRangeSub)) {
				LOGGER.debug("input metric missing for source " + vertexId);
				LOGGER.debug("source latency count range " + sourceLatencyCountRangeSub.getValue()
						+ "source latency sum range " + sourceLatencySumRangeSub.getValue());
				return null;
			}

			boolean isParallelReader = false;
			if (monitor.getJobTopologyAnalyzer().isSource(vertexId)) {
				isParallelReader = MetricUtils.validateTaskMetric(monitor, interval * 2,
						sourcePartitionCountSub, sourcePartitionLatencyCountRangeSub, sourcePartitionLatencySumRangeSub) && sourcePartitionCountSub.getValue().f1 > 0;
				LOGGER.debug("Treat vertex {} as {} reader.", vertexId, isParallelReader ? "parallel" : "non-parallel");
				LOGGER.debug("source partition count " + sourcePartitionCountSub.getValue()
						+ " source partition latency count range " + sourcePartitionLatencyCountRangeSub.getValue()
						+ " source partition latency sum range " + sourcePartitionLatencySumRangeSub.getValue());
			}

			double inputTps = inputTpsSub.getValue().f1;
			double outputTps = outputTpsSub.getValue().f1;

			double maxInputTpsPerMinute = inputTps;
			if (maxTpsPerMinuteSub != null && maxTpsPerMinuteSub.getValue() != null) {
				maxInputTpsPerMinute = maxTpsPerMinuteSub.getValue().f1;
			}

			double timerCount = timerCountSub.getValue().f1;

			double taskLatencyCount = taskLatencyCountRangeSub.getValue().f1;
			double taskLatencySum = taskLatencySumRangeSub.getValue().f1;
			double taskLatency = taskLatencyCount <= 0.0 ? 0.0 : taskLatencySum / taskLatencyCount / 1.0e9;

			double sourceLatency = 0.0;
			if (monitor.getJobTopologyAnalyzer().isSource(vertexId)) {
				double sourceLatencyCount = sourceLatencyCountRangeSub.getValue().f1;
				double sourceLatencySum = sourceLatencySumRangeSub.getValue().f1;
				sourceLatency = sourceLatencyCount <= 0.0 ? 0.0 : sourceLatencySum / sourceLatencyCount / 1.0e9;
			}

			double waitOutput = 0.0;
			if (!monitor.getJobTopologyAnalyzer().isSink(vertexId)) {
				double waitOutputCount = waitOutputCountRangeSub.getValue().f1;
				double waitOutputSum = waitOutputSumRangeSub.getValue().f1;
				waitOutput = waitOutputCount <= 0.0 ? 0.0 : waitOutputSum / waitOutputCount / 1.0e9;
			}
			double waitOutputPerInputRecord = inputTps <= 0.0 ? 0.0 : waitOutput * outputTps / inputTps;

			double workload;
			double partitionLatency = 0;
			double partitionCount = 0;
			if (isParallelReader) {
				// reset task latency.
				double processLatencyCount = sourceProcessLatencyCountRangeSub.getValue().f1;
				double processLatencySum = sourceProcessLatencySumRangeSub.getValue().f1;
				taskLatency = processLatencyCount <= 0.0 ? 0.0 : processLatencySum / processLatencyCount / 1.0e9;

				double partitionLatencyCount = sourcePartitionLatencyCountRangeSub.getValue().f1;
				double partitionLatencySum = sourcePartitionLatencySumRangeSub.getValue().f1;
				partitionLatency = partitionLatencyCount <= 0.0 ? 0.0 : partitionLatencySum / partitionLatencyCount / 1.0e9;
				partitionCount = sourcePartitionCountSub.getValue().f1;
			}

			workload = (taskLatency - waitOutputPerInputRecord) * maxInputTpsPerMinute;

			double delayIncreasingRate = 0;
			if (sourceDelayIncreasingRateSub != null && sourceDelayIncreasingRateSub.getValue() != null) {
				delayIncreasingRate = sourceDelayIncreasingRateSub.getValue().f1 / 1.0e3;
			}

			TaskMetrics taskMetrics = new TaskMetrics(
					vertexId,
					isParallelReader,
					inputTps,
					maxInputTpsPerMinute,
					outputTps,
					timerCount,
					taskLatency,
					sourceLatency,
					waitOutputPerInputRecord,
					workload,
					delayIncreasingRate,
					partitionLatency,
					partitionCount
			);

			metrics.put(vertexId, taskMetrics);
		}
		return metrics;
	}

}
