/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.table.runtime.functions.aggfunctions;

import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.ListTypeInfo;
import org.apache.flink.table.api.Types;
import org.apache.flink.table.api.dataview.MapView;
import org.apache.flink.table.api.dataview.Order;
import org.apache.flink.table.api.dataview.SortedMapView;
import org.apache.flink.table.dataformat.GenericArray;
import org.apache.flink.table.functions.AggregateFunction;
import org.apache.flink.table.runtime.functions.aggfunctions.percentile.HistogramUtils;
import org.apache.flink.table.types.DataTypes;
import org.apache.flink.table.types.TypeInfoWrappedDataType;

import org.apache.commons.lang3.ArrayUtils;

import java.math.BigDecimal;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

import static org.apache.flink.table.runtime.functions.aggfunctions.percentile.HistogramUtils.NIL;
import static org.apache.flink.table.runtime.functions.aggfunctions.percentile.HistogramUtils.cast2Double;
import static org.apache.flink.table.runtime.functions.aggfunctions.percentile.HistogramUtils.getClosestBin;
import static org.apache.flink.table.runtime.functions.aggfunctions.percentile.HistogramUtils.insertBin;
import static org.apache.flink.table.runtime.functions.aggfunctions.percentile.HistogramUtils.removeBin;

/**
 * This class is referred from Hive `NumericHistogram`, implementation of the paper
 * Yael Ben-Haim and Elad Tom-Tov, "A streaming parallel decision tree algorithm".
 */
public class ApproximatePercentile {

	/**
	 * Returns an approximate pth percentile of a numeric column in the group.
	 */
	@SuppressWarnings("unchecked")
	public abstract static class ApproximatePercentileAggFunction<T>
		extends AggregateFunction<T, HistogramAcc> {

		// structure of Approx_Percentile sketch.
		@Override
		public HistogramAcc createAccumulator() {
			HistogramAcc acc = new HistogramAcc();
			acc.percentile = new GenericArray(new double[0], 0, true);
			acc.maxBins = 0;
			acc.totalCount = 0;
			acc.usedBins = 0;
			acc.sketch = new SortedMapView(Order.ASCENDING, DataTypes.DOUBLE, DataTypes.LONG);
			acc.intervalMap = new SortedMapView(Order.ASCENDING, DataTypes.DOUBLE,
				new TypeInfoWrappedDataType(new ListTypeInfo(Types.DOUBLE())));
			acc.siblingMap = new MapView(DataTypes.DOUBLE,
				new TypeInfoWrappedDataType(new ListTypeInfo(Types.DOUBLE())));
			return acc;
		}

		public void accumulate(
			HistogramAcc acc, Object value, BigDecimal p, Integer maxBins) throws Exception {
			accumulate(acc, value, new Double[]{p.doubleValue()}, maxBins);
		}

		public void accumulate(
			HistogramAcc acc, Object value, BigDecimal[] p, Integer maxBins) throws Exception {
			Double[] percentiles = new Double[p.length];
			for (int i = 0; i < p.length; i++) {
				percentiles[i] = p[i].doubleValue();
			}
			accumulate(acc, value, percentiles, maxBins);
		}

		public void accumulate(
			HistogramAcc acc, Object value, Double[] p, Integer maxBins) throws Exception {
			if (value == null) {
				return;
			}
			for (double percent: p) {
				if (percent < 0.0 || percent > 1.0) {
					throw new IllegalArgumentException(
						"Percentile of APPROX_PERCENTILE should be >= 0.0 and <= 1.0," +
							" but get [" + percent + "].");
				}
			}
			insertBin(acc, cast2Double(value), 1, ArrayUtils.toPrimitive(p), maxBins);
		}

		public void retract(
			HistogramAcc acc, Object value, BigDecimal[] p, Integer bin) throws Exception {
			Double[] percentiles = new Double[p.length];
			for (int i = 0; i < p.length; i++) {
				percentiles[i] = p[i].doubleValue();
			}
			retract(acc, value, percentiles, bin);
		}

		public void retract(
			HistogramAcc acc, Object value, BigDecimal p, Integer bin) throws Exception {
			retract(acc, value, new Double[]{p.doubleValue()}, bin);
		}

		public void retract(
			HistogramAcc acc, Object value, Double[] p, Integer bin) throws Exception {
			if (value == null) {
				return;
			}
			for (double percent: p) {
				if (percent < 0.0 || percent > 1.0) {
					throw new IllegalArgumentException(
						"Percentile of APPROX_PERCENTILE should be >= 0.0 and <= 1.0," +
							" but get [" + percent + "].");
				}
			}

			// decrease total count.
			acc.totalCount--;

			Tuple2<Double, Long> hitBin = getClosestBin(
				acc.sketch, acc.siblingMap, cast2Double(value));
			// no bins in the histogram.
			if (hitBin.f0 == NIL) {
				insertBin(acc, cast2Double(value), -1, ArrayUtils.toPrimitive(p), bin);
				return;
			}

			// the hit bin isn't empty after deleting a value.
			if (hitBin.f1 > 0) {
				acc.sketch.put(hitBin.f0, hitBin.f1);
				return;
			}
			// otherwise, remove the hit bin.
			acc.usedBins--;
			removeBin(acc.sketch, acc.intervalMap, acc.siblingMap, hitBin.f0);
		}

		public void merge(HistogramAcc acc, Iterable<HistogramAcc> itr) throws Exception {
			for (HistogramAcc mergedAcc: itr) {
				Iterator binItr = mergedAcc.sketch.iterator();
				if (binItr == null) {
					continue;
				}
				while (binItr.hasNext()) {
					Map.Entry<Double, Long> bin = (Map.Entry<Double, Long>) binItr.next();
					insertBin(
						acc,
						bin.getKey(),
						bin.getValue(),
						mergedAcc.percentile.toDoubleArray(),
						mergedAcc.maxBins);
				}
			}
		}
	}

	/**
	 * Accumulator for Histogram algorithm sketch.
	 */
	public static class HistogramAcc {
		// percentiles to be queried.
		public GenericArray percentile;
		// maximum number of bins used in Histogram algorithm.
		public int maxBins;
		// total number of input values.
		public long totalCount;
		// number of used bins.
		public int usedBins;
		// maintain value and height of histogram bins.
		public SortedMapView<Double, Long> sketch;
		// maintain the interval between bins (interval -> prev_bin)
		public SortedMapView<Double, List<Double>> intervalMap;
		// maintain the sbilings of every bin in the histogram.
		public MapView<Double, List<Double>> siblingMap;
	}

	/**
	 * Approximate percentile agg function with only one percentile parameter.
	 */
	@SuppressWarnings("unchecked")
	public static class ApproxSinglePercentileAggFunction
		extends ApproximatePercentileAggFunction<Double> {
		@Override
		public Double getValue(HistogramAcc acc) {
			double[] percents = acc.percentile.toDoubleArray();
			if (percents.length < 1) {
				return null;
			}
			return HistogramUtils.query(acc.sketch, acc.siblingMap, acc.totalCount, percents[0]);
		}
	}

	/**
	 * Approximate percentile agg function with multiple percentile parameters.
	 */
	@SuppressWarnings("unchecked")
	public static class ApproxMultiPercentileAggFunction
		extends ApproximatePercentileAggFunction<Double[]> {
		@Override
		public Double[] getValue(HistogramAcc acc) {
			double[] paras = acc.percentile.toDoubleArray();
			if (paras.length < 1) {
				return null;
			}
			Double[] percentiles = new Double[paras.length];
			for (int i = 0; i < percentiles.length; i++) {
				percentiles[i] = HistogramUtils.query(
					acc.sketch, acc.siblingMap, acc.totalCount, paras[i]);
			}
			return percentiles;
		}
	}
}
