/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.streaming.runtime.io;

import org.apache.flink.annotation.Internal;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.metrics.Counter;
import org.apache.flink.metrics.SimpleCounter;
import org.apache.flink.runtime.event.AbstractEvent;
import org.apache.flink.runtime.io.disk.iomanager.IOManager;
import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
import org.apache.flink.runtime.io.network.api.PartitionConsumptionFailedEvent;
import org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer;
import org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer.DeserializationResult;
import org.apache.flink.runtime.io.network.api.serialization.SerializerManagerUtility;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent;
import org.apache.flink.runtime.io.network.partition.consumer.InputGate;
import org.apache.flink.runtime.metrics.MetricNames;
import org.apache.flink.runtime.metrics.SumAndCount;
import org.apache.flink.runtime.metrics.groups.OperatorMetricGroup;
import org.apache.flink.runtime.metrics.groups.TaskIOMetricGroup;
import org.apache.flink.runtime.plugable.DeserializationDelegate;
import org.apache.flink.runtime.plugable.NonReusingDeserializationDelegate;
import org.apache.flink.streaming.api.CheckpointingMode;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.api.watermark.Watermark;
import org.apache.flink.streaming.runtime.metrics.WatermarkGauge;
import org.apache.flink.streaming.runtime.streamrecord.ReusingRecordValueDeserializationDelegate;
import org.apache.flink.streaming.runtime.streamrecord.StreamElement;
import org.apache.flink.streaming.runtime.streamrecord.StreamElementSerializer;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.streaming.runtime.streamstatus.StatusWatermarkValve;
import org.apache.flink.streaming.runtime.streamstatus.StreamStatus;
import org.apache.flink.streaming.runtime.streamstatus.StreamStatusMaintainer;
import org.apache.flink.streaming.runtime.tasks.StreamTask;
import org.apache.flink.util.checkpointlock.CheckpointLockDelegate;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.util.BitSet;

import static org.apache.flink.util.Preconditions.checkNotNull;

/**
 * Input reader for {@link org.apache.flink.streaming.runtime.tasks.OneInputStreamTask}.
 *
 * <p>This internally uses a {@link StatusWatermarkValve} to keep track of {@link Watermark} and
 * {@link StreamStatus} events, and forwards them to event subscribers once the
 * {@link StatusWatermarkValve} determines the {@link Watermark} from all inputs has advanced, or
 * that a {@link StreamStatus} needs to be propagated downstream to denote a status change.
 *
 * <p>Forwarding elements, watermarks, or status status elements must be protected by synchronizing
 * on the given lock object. This ensures that we don't call methods on a
 * {@link OneInputStreamOperator} concurrently with the timer callback or other things.
 *
 * @param <IN> The type of the record that can be read with this record reader.
 */
@Internal
public class StreamInputProcessor<IN> {

	private static final Logger LOG = LoggerFactory.getLogger(StreamInputProcessor.class);

	private final RecordDeserializer<DeserializationDelegate<StreamElement>>[] recordDeserializers;

	private RecordDeserializer<DeserializationDelegate<StreamElement>> currentRecordDeserializer;

	private final DeserializationDelegate<StreamElement> deserializationDelegate;

	private final SelectedReadingBarrierHandler barrierHandler;

	private final CheckpointLockDelegate lockDelegate;

	// ---------------- Status and Watermark Valve ------------------

	/** Valve that controls how watermarks and stream statuses are forwarded. */
	private StatusWatermarkValve statusWatermarkValve;

	/** Number of input channels the valve needs to handle. */
	private final int numInputChannels;

	private final BitSet channelsWithEndOfPartitionEvents;

	/**
	 * The channel from which a buffer came, tracked so that we can appropriately map
	 * the watermarks and watermark statuses to channel indexes of the valve.
	 */
	private int currentChannel = -1;

	private final StreamStatusMaintainer streamStatusMaintainer;

	private final OneInputStreamOperator<IN, ?> streamOperator;

	// ---------------- Metrics ------------------

	private final WatermarkGauge watermarkGauge;
	private Counter numRecordsIn;
	private Counter numRecordsReceived;

	private boolean enableTracingMetrics;
	private int tracingMetricsInterval;
	private long tracingInputCount;
	private SumAndCount taskLatency;
	private SumAndCount waitInput;
	private long lastProcessedTime = -1;

	private boolean isFinished;

	@SuppressWarnings("unchecked")
	public StreamInputProcessor(
			InputGate[] inputGates,
			TypeSerializer<IN> inputSerializer,
			boolean isCheckpointingEnabled,
			StreamTask<?, ?> checkpointedTask,
			CheckpointingMode checkpointMode,
			Object lock,
			IOManager ioManager,
			Configuration taskManagerConfig,
			StreamStatusMaintainer streamStatusMaintainer,
			OneInputStreamOperator<IN, ?> streamOperator,
			TaskIOMetricGroup metrics,
			WatermarkGauge watermarkGauge,
			boolean objectReuse,
			boolean enableTracingMetrics,
			int tracingMetricsInterval) throws IOException {

		this.barrierHandler = InputProcessorUtil.createCheckpointBarrierHandler(
			isCheckpointingEnabled, checkpointedTask, checkpointMode, ioManager, taskManagerConfig, inputGates);

		this.lockDelegate = new CheckpointLockDelegate(checkNotNull(lock));

		StreamElementSerializer<IN> ser = new StreamElementSerializer<>(inputSerializer);
		this.deserializationDelegate = objectReuse ?
			new ReusingRecordValueDeserializationDelegate<>(ser) :
			new NonReusingDeserializationDelegate<>(ser);

		this.numInputChannels = this.barrierHandler.getNumberOfInputChannels();

		// Initialize one deserializer per input channel
		SerializerManagerUtility<DeserializationDelegate<StreamElement>> serializerManagerUtility =
			new SerializerManagerUtility<>(taskManagerConfig);
		this.recordDeserializers = serializerManagerUtility.createRecordDeserializers(
			barrierHandler.getAllInputChannels(), ioManager.getSpillingDirectoriesPaths());

		this.channelsWithEndOfPartitionEvents = new BitSet(this.numInputChannels);

		this.streamStatusMaintainer = checkNotNull(streamStatusMaintainer);
		this.streamOperator = checkNotNull(streamOperator);

		this.statusWatermarkValve = new StatusWatermarkValve(
				numInputChannels,
				new ForwardingValveOutputHandler(streamOperator, lockDelegate.getLock()));

		this.watermarkGauge = watermarkGauge;
		metrics.gauge("checkpointAlignmentTime", barrierHandler::getAlignmentDurationNanos);

		this.enableTracingMetrics = enableTracingMetrics;
		this.tracingMetricsInterval = tracingMetricsInterval;
		this.tracingInputCount = 0;
	}

	public boolean processInput() throws Exception {
		if (isFinished) {
			return false;
		}
		if (numRecordsIn == null) {
			try {
				numRecordsIn = ((OperatorMetricGroup) streamOperator.getMetricGroup()).getIOMetricGroup().getNumRecordsInCounter();
			} catch (Exception e) {
				LOG.warn("An exception occurred during the metrics setup.", e);
				numRecordsIn = new SimpleCounter();
			}
		}

		if (numRecordsReceived == null) {
			try {
				numRecordsReceived = ((OperatorMetricGroup) streamOperator.getMetricGroup()).parent().getIOMetricGroup().getNumRecordsReceived();
			} catch (Exception e) {
				LOG.warn("An exception occurred during the metrics setup.", e);
				numRecordsReceived = new SimpleCounter();
			}
		}

		if (enableTracingMetrics) {
			if (taskLatency == null) {
				taskLatency = new SumAndCount(
						MetricNames.TASK_LATENCY, ((OperatorMetricGroup) streamOperator.getMetricGroup()).parent());
			}
			if (waitInput == null) {
				waitInput = new SumAndCount(
						MetricNames.IO_WAIT_INPUT, ((OperatorMetricGroup) streamOperator.getMetricGroup()).parent());
			}
		}

		while (true) {
			if (currentRecordDeserializer != null) {
				DeserializationResult result = currentRecordDeserializer.getNextRecord(deserializationDelegate);

				if (result.isBufferConsumed()) {
					currentRecordDeserializer.getCurrentBuffer().recycleBuffer();
					currentRecordDeserializer = null;
				}

				if (result.isFullRecord()) {
					StreamElement recordOrMark = deserializationDelegate.getInstance();

					boolean recordProcessed;

					numRecordsReceived.inc();
					tracingInputCount++;
					if (enableTracingMetrics && tracingInputCount % tracingMetricsInterval == 0) {
						long start = System.nanoTime();
						waitInput.update(start - lastProcessedTime);
						recordProcessed = processRecordOrMark(recordOrMark);
						lastProcessedTime = System.nanoTime();
						taskLatency.update(lastProcessedTime - start);
					} else {
						recordProcessed = processRecordOrMark(recordOrMark);
					}

					if (recordProcessed) {
						return true;
					} else {
						continue;
					}
				}
			}

			final BufferOrEvent bufferOrEvent = barrierHandler.getNextNonBlocked();
			if (bufferOrEvent != null) {
				if (bufferOrEvent.isBuffer()) {
					currentChannel = bufferOrEvent.getChannelIndex();
					currentRecordDeserializer = recordDeserializers[currentChannel];
					currentRecordDeserializer.setNextBuffer(bufferOrEvent.getBuffer());
				}
				else {
					// Event received
					final AbstractEvent event = bufferOrEvent.getEvent();
					if (event.getClass() == EndOfPartitionEvent.class) {
						channelsWithEndOfPartitionEvents.set(bufferOrEvent.getChannelIndex());
						if (channelsWithEndOfPartitionEvents.cardinality() == numInputChannels) {
							lockDelegate.lockAndRun(() -> {
								streamOperator.endInput();
							});
						}
					} else if (event.getClass() == PartitionConsumptionFailedEvent.class) {
						// Clear the incomplete record if exists
						currentRecordDeserializer = recordDeserializers[bufferOrEvent.getChannelIndex()];
						clearDeserializer(currentRecordDeserializer);
						currentRecordDeserializer = null;
					} else {
						throw new IOException("Unexpected event: " + event);
					}
				}
			}
			else {
				isFinished = true;
				if (!barrierHandler.isEmpty()) {
					throw new IllegalStateException("Trailing data in checkpoint barrier handler.");
				}
				return false;
			}
		}
	}

	private boolean processRecordOrMark(StreamElement recordOrMark) throws Exception {
		if (recordOrMark.isRecord()) {
			StreamRecord<IN> record = recordOrMark.asRecord();

			// now we can do the actual processing
			lockDelegate.lockAndRun(() -> {
				numRecordsIn.inc();
				streamOperator.setKeyContextElement1(record);
				streamOperator.processElement(record);
			});
			return true;
		} else if (recordOrMark.isWatermark()) {
			// handle watermark
			statusWatermarkValve.inputWatermark(recordOrMark.asWatermark(), currentChannel);
			return false;
		} else if (recordOrMark.isStreamStatus()) {
			// handle stream status
			statusWatermarkValve.inputStreamStatus(recordOrMark.asStreamStatus(), currentChannel);
			return false;
		} else if (recordOrMark.isLatencyMarker()) {
			// handle latency marker
			lockDelegate.lockAndRun(() -> {
				streamOperator.processLatencyMarker(recordOrMark.asLatencyMarker());
			});
			return false;
		} else {
			throw new RuntimeException("Unexpected stream element type " + recordOrMark);
		}
	}

	public void cleanup() throws IOException {
		// clear the buffers first. this part should not ever fail
		for (RecordDeserializer<?> deserializer : recordDeserializers) {
			clearDeserializer(deserializer);
		}

		// cleanup the barrier handler resources
		barrierHandler.cleanup();
	}

	private void clearDeserializer(RecordDeserializer<?> deserializer) {
		final Buffer buffer = deserializer.getCurrentBuffer();
		if (buffer != null && !buffer.isRecycled()) {
			buffer.recycleBuffer();
		}
		deserializer.clear();
	}

	private class ForwardingValveOutputHandler implements StatusWatermarkValve.ValveOutputHandler {
		private final OneInputStreamOperator<IN, ?> operator;
		private final CheckpointLockDelegate lockDelegate;

		private ForwardingValveOutputHandler(final OneInputStreamOperator<IN, ?> operator, final Object lock) {
			this.operator = checkNotNull(operator);
			this.lockDelegate = new CheckpointLockDelegate(checkNotNull(lock));
		}

		@Override
		public void handleWatermark(Watermark watermark) {
			try {
				lockDelegate.lockAndRun(() -> {
					watermarkGauge.setCurrentWatermark(watermark.getTimestamp());
					operator.processWatermark(watermark);
				});
			} catch (Exception e) {
				throw new RuntimeException("Exception occurred while processing valve output watermark: ", e);
			}
		}

		@SuppressWarnings("unchecked")
		@Override
		public void handleStreamStatus(StreamStatus streamStatus) {
			try {
				lockDelegate.lockAndRun(() -> {
					streamStatusMaintainer.toggleStreamStatus(streamStatus);
				});
			} catch (Exception e) {
				throw new RuntimeException("Exception occurred while processing valve output stream status: ", e);
			}
		}
	}

}
