/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.table.runtime.functions.aggfunctions.percentile;

import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.table.api.dataview.MapView;
import org.apache.flink.table.api.dataview.SortedMapView;
import org.apache.flink.table.dataformat.GenericArray;
import org.apache.flink.table.runtime.functions.aggfunctions.ApproximatePercentile.ApproximatePercentileAggFunction;
import org.apache.flink.table.runtime.functions.aggfunctions.ApproximatePercentile.HistogramAcc;

import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;

/**
 * Util functions for {@link ApproximatePercentileAggFunction}.
 * This class is an implementation of the paper, Yael Ben-Haim and Elad Tom-Tov,
 * "A streaming parallel decision tree algorithm".
 */
@SuppressWarnings("unchecked")
public class HistogramUtils {

	// Left(right) sibling of the smallest(largest) bin is `NIL`.
	public static final double NIL = Double.NEGATIVE_INFINITY;

	/**
	 * insert a bin into the histogram.
	 */
	public static void insertBin(
		HistogramAcc acc,
		double value,
		long height,
		double[] p,
		int maxBins) throws Exception {

		acc.percentile = new GenericArray(p, p.length, true);
		acc.maxBins = maxBins;
		acc.totalCount += height;

		Long binHeight = acc.sketch.get(value);
		// there exist an exactly matching bins.
		if (binHeight != null) {
			long mergeHeight = binHeight + height;
			// height may be negative because of retraction.
			if (mergeHeight == 0) {
				removeBin(acc.sketch, acc.intervalMap, acc.siblingMap, value);
			} else {
				acc.sketch.put(value, binHeight + height);
			}
			return;
		}

		// update auxiliary data structures.
		HistogramUtils.updateAuxiliaryData(
			acc.sketch, acc.intervalMap, acc.siblingMap, value, height);

		acc.usedBins += 1;
		// do not exceeding the limit of max number of bins, over.
		if (acc.usedBins <= maxBins) {
			return;
		}

		// otherwise, compress some bins to decrease total number of bins.
		HistogramUtils.compressBins(acc.sketch, acc.intervalMap, acc.siblingMap);
		// update number of used bins
		acc.usedBins--;
	}

	/**
	 * Get the closest bin to the given `binVal`.
	 */
	public static Tuple2<Double, Long> getClosestBin(
		SortedMapView<Double, Long> sketch,
		MapView<Double, List<Double>> siblingMap,
		double binVal) throws Exception {

		double hitBin = binVal;
		Long binHeight = sketch.get(hitBin);
		if (binHeight != null) {
			binHeight--;
		} else {
			Iterator itr = sketch.tailEntries(hitBin).iterator();
			if (itr.hasNext()) {
				double nextEntry = ((Map.Entry<Double, Long>) itr.next()).getKey();
				List<Double> nextSiblings = siblingMap.get(nextEntry);
				Double prevEntry = nextSiblings.get(0);
				if (prevEntry == NIL) {
					hitBin = nextEntry;
				} else {
					hitBin = (hitBin - prevEntry) > (nextEntry - hitBin) ? nextEntry : prevEntry;
				}
				binHeight = sketch.get(hitBin) - 1;
			} else {
				Map.Entry<Double, Long> a = sketch.lastEntry();
				if (a == null) {
					return Tuple2.of(NIL, 0L);
				}
				hitBin = a.getKey();
				binHeight = a.getValue() - 1;
			}
		}
		return Tuple2.of(hitBin, binHeight);
	}

	/**
	 * remove a bin from the histogram.
	 */
	public static void removeBin(
			SortedMapView<Double, Long> sketch,
			SortedMapView<Double, List<Double>> intervalMap,
			MapView<Double, List<Double>> siblingMap,
			double bin) throws Exception {
		sketch.remove(bin);
		// update interval interval map.
		List<Double> siblings = siblingMap.get(bin);
		Double prev = siblings.get(0);
		Double next = siblings.get(1);
		if (prev != NIL) {
			// update sibling map.
			updateSibling(prev, next, 1, siblingMap);
			// update interval map
			removeIntervalEntry(bin - prev, prev, intervalMap);
		}
		if (next != NIL) {
			// update sibling map
			updateSibling(next, prev, 0, siblingMap);
			// update interval map
			removeIntervalEntry(next - bin, bin, intervalMap);
		}
		if (prev != NIL && next != NIL) {
			double interval = next - prev;
			addIntervalEntry(interval, prev, intervalMap);
		}

		// update sibling map.
		siblingMap.remove(bin);
	}

	/**
	 * update auxiliary data structures.
	 */
	private static void updateAuxiliaryData(
		SortedMapView<Double, Long> sketch,
		SortedMapView<Double, List<Double>> intervalMap,
		MapView<Double, List<Double>> siblingMap,
		double inputVal,
		long height) throws Exception {

		Iterator itr = sketch.tailEntries(inputVal).iterator();
		Double nextEntry = NIL, prevEntry = NIL;
		if (itr.hasNext()) {
			nextEntry = ((Map.Entry<Double, Long>) itr.next()).getKey();
		}
		if (nextEntry == NIL) {
			Map.Entry<Double, Long> last = sketch.lastEntry();
			Map.Entry<Double, Long> first = sketch.firstEntry();
			// TODO: fix bug of `lastEntry`: lastEntry is not null when firstEntry is null
			if (first != null && last != null) {
				prevEntry = last.getKey();
			}
		} else {
			prevEntry = siblingMap.get(nextEntry).get(0);
		}
		// put the new bin into the histogram.
		sketch.put(inputVal, height);

		siblingMap.put(inputVal, Arrays.asList(prevEntry, nextEntry));
		if (prevEntry != NIL) {
			// update sibling map.
			updateSibling(prevEntry, inputVal, 1, siblingMap);
			// update interval map.
			addIntervalEntry(inputVal - prevEntry, prevEntry, intervalMap);
		}

		if (nextEntry != NIL) {
			// update sibling map.
			updateSibling(nextEntry, inputVal, 0, siblingMap);
			// update interval map.
			addIntervalEntry(nextEntry - inputVal, inputVal, intervalMap);
		}

		if (prevEntry != NIL && nextEntry != NIL) {
			double interval = nextEntry - prevEntry;
			removeIntervalEntry(interval, prevEntry, intervalMap);
		}
	}


	/**
	 * merge two bins into one.
	 */
	private static void compressBins(
		SortedMapView<Double, Long> sketch,
		SortedMapView<Double, List<Double>> intervalMap,
		MapView<Double, List<Double>> siblingMap) throws Exception {

		// compress bins.
		Map.Entry<Double, List<Double>> firstEntry = intervalMap.firstEntry();
		List<Double> keysList = firstEntry.getValue();
		Double compressBin = keysList.remove(new Random().nextInt(keysList.size()));

		List<Double> siblings = siblingMap.get(compressBin);
		double prevEntry = siblings.get(0);
		double nextEntry = siblings.get(1);
		assert nextEntry != NIL;

		long h1 = sketch.get(compressBin);
		long h2 = sketch.get(nextEntry);
		List<Double> nextSiblings = siblingMap.get(nextEntry);
		Double nextNextEntry = nextSiblings.get(1);

		// rm compressed bins from sketch
		sketch.remove(compressBin);
		sketch.remove(nextEntry);

		// rm compressed bins from interval map.
		if (keysList.isEmpty()) {
			intervalMap.remove(firstEntry.getKey());
		} else {
			intervalMap.put(firstEntry.getKey(), keysList);
		}

		// rm compressed bins from sibling map
		siblingMap.remove(compressBin);
		siblingMap.remove(nextEntry);

		// add new bin to sketch.
		double mergedVal = (compressBin * h1 + nextEntry * h2) / (h1 + h2);
		sketch.put(mergedVal, h1 + h2);

		if (prevEntry != NIL) {
			// update interval map.
			removeIntervalEntry(compressBin - prevEntry, prevEntry, intervalMap);
			addIntervalEntry(mergedVal - prevEntry, prevEntry, intervalMap);
			// update sibling map.
			updateSibling(prevEntry, mergedVal, 1, siblingMap);
		}

		if (nextNextEntry != NIL) {
			// update interval map.
			removeIntervalEntry(nextNextEntry - nextEntry, nextEntry, intervalMap);
			addIntervalEntry(nextNextEntry - mergedVal, mergedVal, intervalMap);
			// update sibling map.
			updateSibling(nextNextEntry, mergedVal, 0, siblingMap);
		}

		// update sibling map for merged value.
		siblingMap.put(mergedVal, Arrays.asList(prevEntry, nextNextEntry));
	}

	/**
	 * calculate the given `percentile` of the input values.
	 */
	public static Double query(
		SortedMapView<Double, Long> sketch,
		MapView<Double, List<Double>> siblingMap,
		long totalCnt,
		double percent) {
		try {
			Iterator itr = sketch.iterator();
			double qSum = 0;
			while (itr.hasNext()) {
				Map.Entry<Double, Long> a = (Map.Entry<Double, Long>) itr.next();
				qSum += a.getValue();
				if (qSum / totalCnt >= percent) {
					double binValue = a.getKey();
					Double prev = siblingMap.get(binValue).get(0);
					if (prev == NIL) {
						return binValue;
					}
					qSum -= a.getValue();
					return prev + (percent * totalCnt - qSum) * (binValue - prev) / a.getValue();
				}
			}
		} catch (Exception e) {
			e.printStackTrace();
		}
		return null;
	}

	/**
	 * update siblings of a bin.
	 */
	private static void updateSibling(
		double binVal, double newSibling, int index,
		MapView<Double, List<Double>> siblingMap) throws Exception {

		List<Double> siblings = siblingMap.get(binVal);
		siblings.set(index, newSibling);
		siblingMap.put(binVal, siblings);
	}

	/**
	 * remove an interval -> binVal pair from the interval map.
	 */
	private static void removeIntervalEntry(
		double interval,
		double binVal,
		SortedMapView<Double, List<Double>> intervalMap) throws Exception {

		List<Double> entries = intervalMap.get(interval);
		entries.remove(binVal);
		if (entries.isEmpty()) {
			intervalMap.remove(interval);
		} else {
			intervalMap.put(interval, entries);
		}
	}

	/**
	 * remove an interval -> binVal pair from the intervalMap.
	 */
	private static void addIntervalEntry(
		double interval,
		double binVal,
		SortedMapView<Double, List<Double>> intervalMap) throws Exception {

		List<Double> entries = intervalMap.get(interval);
		if (entries == null) {
			entries = new ArrayList<>();
		}
		entries.add(binVal);
		intervalMap.put(interval, entries);
	}

	public static double cast2Double(Object value) {
		if (value instanceof Byte) {
			return ((Byte) value).doubleValue();
		}
		if (value instanceof Short) {
			return ((Short) value).doubleValue();
		}
		if (value instanceof Integer) {
			return ((Integer) value).doubleValue();
		}
		if (value instanceof Long) {
			return ((Long) value).doubleValue();
		}
		if (value instanceof Float) {
			return ((Float) value).doubleValue();
		}
		if (value instanceof Double) {
			return (double) value;
		}
		if (value instanceof BigDecimal) {
			return ((BigDecimal) value).doubleValue();
		}
		return 0;
	}
}
