/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.table.runtime.window.aligned;

import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.typeutils.GenericTypeInfo;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.metrics.Gauge;
import org.apache.flink.table.api.window.TimeWindow;
import org.apache.flink.table.codegen.GeneratedSubKeyedAggsHandleFunction;
import org.apache.flink.table.dataformat.BaseRow;
import org.apache.flink.table.dataformat.GenericRow;
import org.apache.flink.table.dataformat.JoinedRow;
import org.apache.flink.table.dataformat.util.BaseRowUtil;
import org.apache.flink.table.runtime.functions.ExecutionContext;
import org.apache.flink.table.runtime.functions.SubKeyedAggsHandleFunction;
import org.apache.flink.util.Collector;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * The emitting of Local window result is not triggered by watermark exactly as the functionality of
 * local window agg is local-accumulating elements under same key, and the exact triggering of window is
 * completed on global window agg.
 */
public class LocalAlignedWindowAggregator
	implements AlignedWindowAggregator<BaseRow, TimeWindow, BaseRow> {

	private static final long serialVersionUID = 1L;

	private final TypeInformation<BaseRow> keyTypeInfo;
	private final TypeInformation<BaseRow> accTypeInfo;
	private final GeneratedSubKeyedAggsHandleFunction<TimeWindow> generatedWindowAggregator;

	private final long minibatchSize;

	private transient SubKeyedAggsHandleFunction<TimeWindow> windowAggregator;

	// aligned window buffer
	private transient Map<TimeWindow, Map<BaseRow, BaseRow>> buffer;

	// stores buffer status, schema: <Window, Key, ACC>
	private transient ListState<Tuple3<TimeWindow, BaseRow, BaseRow>> bufferState;

	private transient int numOfElements;

	private transient JoinedRow outRow;

	private transient Collector<BaseRow> out;

	//----------------------------------------------------------------------------------

	public LocalAlignedWindowAggregator(
		TypeInformation<BaseRow> keyTypeInfo,
		TypeInformation<BaseRow> accTypeInfo,
		GeneratedSubKeyedAggsHandleFunction<TimeWindow> generatedWindowAggregator,
		long minibatchSize) {
		this.keyTypeInfo = keyTypeInfo;
		this.accTypeInfo = accTypeInfo;
		this.generatedWindowAggregator = generatedWindowAggregator;
		this.minibatchSize = minibatchSize;
	}

	@Override
	public void open(Context context) throws Exception {
		ExecutionContext ctx = context.getExecutionContext();
		this.buffer = new HashMap<>();
		this.numOfElements = 0;
		this.outRow = new JoinedRow();
		this.out = context.getCollector();
		BaseRowUtil.setAccumulate(outRow);

		this.windowAggregator = (SubKeyedAggsHandleFunction<TimeWindow>) generatedWindowAggregator.newInstance(
			ctx.getRuntimeContext().getUserCodeClassLoader());
		windowAggregator.open(ctx);

		//noinspection unchecked
		TypeInformation<Tuple3<TimeWindow, BaseRow, BaseRow>> tupleType = new TupleTypeInfo<>(
			new GenericTypeInfo(TimeWindow.class), this.keyTypeInfo, this.accTypeInfo);

		this.bufferState = context.getOpStateStore().getListState(
			new ListStateDescriptor<>("_local_window_buffer_", tupleType));

		// recover buffer from partition state.
		if (bufferState != null) {
			for (Tuple3<TimeWindow, BaseRow, BaseRow> tuple: bufferState.get()) {
				TimeWindow window = tuple.f0;
				BaseRow key = tuple.f1;
				BaseRow acc = tuple.f2;
				Map<BaseRow, BaseRow> keyAccs = buffer.computeIfAbsent(
					window,
					k->new HashMap<>());
				keyAccs.put(key, acc);
				numOfElements++;
			}
			bufferState.clear();
		}

		// counter metric to get the size of bundle.
		ctx.getRuntimeContext().getMetricGroup().gauge(
			"bundleSize", (Gauge<Integer>) () -> numOfElements);
		ctx.getRuntimeContext().getMetricGroup().gauge("bundleRatio", (Gauge<Double>) () -> {
			int numOfKeys = buffer.size();
			if (numOfKeys == 0) {
				return 0.0;
			} else {
				return 1.0 * numOfElements / numOfKeys;
			}
		});
	}

	@Override
	public void addElement(BaseRow key, TimeWindow window, BaseRow input) throws Exception {
		numOfElements++;
		// aggregate current input row to buffer.
		Map<BaseRow, BaseRow> keyAccs = buffer.computeIfAbsent(window, k -> new HashMap<>());
		BaseRow acc = keyAccs.get(key);
		if (acc == null) {
			acc = windowAggregator.createAccumulators();
		}
		// null namespace means using heap data view.
		windowAggregator.setAccumulators(null, acc);
		if (BaseRowUtil.isAccumulateMsg(input)) {
			windowAggregator.accumulate(input);
		} else {
			windowAggregator.retract(input);
		}
		acc = windowAggregator.getAccumulators();
		keyAccs.put(key, acc);

		if (minibatchSize > 0 && numOfElements > minibatchSize) {
			// flush buffer acc to downstream directly.
			for (TimeWindow w: windows()) {
				fireWindow(w);
			}
			expireAllWindows();
		}
	}

	@Override
	public void fireWindow(TimeWindow window) throws Exception {
		Map<BaseRow, BaseRow> fired = buffer.get(window);
		if (fired == null) {
			return;
		}
		for (Map.Entry<BaseRow, BaseRow> entry: fired.entrySet()) {
			BaseRow key = entry.getKey();
			BaseRow acc = entry.getValue();
			// add window.start as the input timestamp for global window.
			BaseRow windowTime = new GenericRow(1);
			windowTime.setLong(0, window.getStart());
			outRow.replace(new JoinedRow(key, windowTime), acc);
			out.collect(outRow);
		}
	}

	@Override
	public void snapshot() throws Exception {
		bufferState.clear();
		List<Tuple3<TimeWindow, BaseRow, BaseRow>> stateToPut = new ArrayList<>();
		for (TimeWindow window: buffer.keySet()) {
			for (Map.Entry<BaseRow, BaseRow> entry: buffer.get(window).entrySet()) {
				BaseRow key = entry.getKey();
				BaseRow acc = entry.getValue();
				stateToPut.add(Tuple3.of(window, key, acc));
			}
		}
		bufferState.addAll(stateToPut);
	}

	@Override
	public Iterable<TimeWindow> windows() throws Exception {
		return this.buffer.keySet();
	}

	@Override
	public void close() throws Exception {
		// do nothing.
	}

	@Override
	public void expireWindow(TimeWindow window) throws Exception {
		buffer.remove(window);
	}

	@Override
	public void expireAllWindows() throws Exception {
		this.numOfElements = 0;
		buffer.clear();
	}

	@Override
	public TimeWindow lowestWindow() throws Exception {
		throw new UnsupportedOperationException(
			"This method should not be called for LocalAlignedWindowAggregator.");
	}

	@Override
	public Iterable<TimeWindow> ascendingWindows() throws Exception {
		throw new UnsupportedOperationException(
			"This method should not be called for LocalAlignedWindowAggregator.");
	}

	@Override
	public Iterable<TimeWindow> ascendingWindows(TimeWindow fromWindow) throws Exception {
		throw new UnsupportedOperationException(
			"This method should not be called for LocalAlignedWindowAggregator.");
	}

}
