/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.runtime.healthmanager.plugins.detectors;

import org.apache.flink.api.common.JobID;
import org.apache.flink.runtime.healthmanager.HealthMonitor;
import org.apache.flink.runtime.healthmanager.RestServerClient;
import org.apache.flink.runtime.healthmanager.metrics.MetricProvider;
import org.apache.flink.runtime.healthmanager.plugins.Detector;
import org.apache.flink.runtime.healthmanager.plugins.Symptom;
import org.apache.flink.runtime.healthmanager.plugins.symptoms.JobVertexOverParallelized;
import org.apache.flink.runtime.healthmanager.plugins.utils.HealthMonitorOptions;
import org.apache.flink.runtime.healthmanager.plugins.utils.JobTopologyAnalyzer;
import org.apache.flink.runtime.healthmanager.plugins.utils.TaskMetrics;
import org.apache.flink.runtime.healthmanager.plugins.utils.TaskMetricsSubscriber;
import org.apache.flink.runtime.jobgraph.JobVertexID;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;

/**
 * OverParallelizedDetector detects over parallelized of a job.
 * Detects {@link JobVertexOverParallelized} if:
 *   current_parallelism > (input_tps * tps_ratio) * latency_per_record
 */
public class OverParallelizedDetector implements Detector {

	private static final Logger LOGGER = LoggerFactory.getLogger(OverParallelizedDetector.class);

	private JobID jobID;
	private HealthMonitor monitor;
	private MetricProvider metricProvider;

	private double ratio;
	private long checkInterval;
	private int maxPartitionPerTask;
	private double multiOutputRatio;

	private double minDiffParallelismRatio;

	private TaskMetricsSubscriber subscriber;

	@Override
	public void open(HealthMonitor monitor) {
		this.jobID = monitor.getJobID();
		this.monitor = monitor;
		this.metricProvider = monitor.getMetricProvider();

		ratio = monitor.getConfig().getDouble(HealthMonitorOptions.PARALLELISM_MAX_RATIO);
		checkInterval = monitor.getConfig().getLong(HealthMonitorOptions.PARALLELISM_SCALE_INTERVAL);
		maxPartitionPerTask = monitor.getConfig().getInteger(HealthMonitorOptions.MAX_PARTITION_PER_TASK);
		multiOutputRatio = monitor.getConfig().getDouble(HealthMonitorOptions.PARALLELISM_SCALE_MULTI_OUTPUT_RATIO);
		minDiffParallelismRatio = monitor.getConfig().getDouble(HealthMonitorOptions.PARALLELISM_SCALE_MIN_DIFF_RATIO);

		subscriber = this.monitor.subscribeTaskMetrics(checkInterval);
	}

	@Override
	public void close() {
		if (metricProvider == null) {
			return;
		}

	}

	@Override
	public Symptom detect() throws Exception {
		LOGGER.debug("Start detecting.");

		RestServerClient.JobConfig jobConfig = monitor.getJobConfig();
		if (jobConfig == null) {
			return null;
		}

		Map<JobVertexID, TaskMetrics> allTaskMetrics = subscriber.getTaskMetrics();
		if (allTaskMetrics == null) {
			return null;
		}

		List<JobVertexID> jobVertexIDs = new ArrayList<>();
		for (JobVertexID vertexId : allTaskMetrics.keySet()) {

			TaskMetrics metrics = allTaskMetrics.get(vertexId);
			LOGGER.debug("vertex metrics: {}", metrics);

			int parallelism = jobConfig.getVertexConfigs().get(vertexId).getParallelism();

			double workload;
			double minParallelism = 1;
			if (metrics.isParallelSource()) {
				workload = metrics.getPartitionLatency() <= 0.0 ? 0.0 :
						metrics.getPartitionCount() * metrics.getTaskLatencyPerRecord() / metrics.getPartitionLatency();

				minParallelism = Math.ceil(metrics.getPartitionCount() / maxPartitionPerTask);
			} else {
				workload = metrics.getTaskLatencyPerRecord() * metrics.getMaxInputTpsPerMinute();
			}

			if (parallelism > workload * getRatio(vertexId) && parallelism / minParallelism > (1 + minDiffParallelismRatio)) {
				jobVertexIDs.add(vertexId);
				LOGGER.info("Over parallelized detected for vertex {}, parallelism:{} workload:{}.", vertexId, parallelism, workload);
			}
		}

		if (!jobVertexIDs.isEmpty()) {
			return new JobVertexOverParallelized(jobID, jobVertexIDs);
		}
		return null;
	}

	private double getRatio(JobVertexID vertexID) {
		JobTopologyAnalyzer jobTopologyAnalyzer = monitor.getJobTopologyAnalyzer();
		for (JobVertexID input : jobTopologyAnalyzer.getInputs(vertexID)) {
			if (jobTopologyAnalyzer.getOutputs(input).size() > 1) {
				return ratio * multiOutputRatio;
			}
		}
		return ratio;
	}
}
