/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.	See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.	You may obtain a copy of the License at
 *
 *		http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.table.runtime.functions.aggfunctions.hyperloglog;

import org.apache.flink.table.dataformat.BinaryString;
import org.apache.flink.table.runtime.functions.aggfunctions.StreamApproximateCountDistinct.HllAcc;

import java.sql.Date;

import static com.yahoo.sketches.hash.MurmurHash3.hash;
import static java.nio.charset.StandardCharsets.UTF_8;

/**
 * The implementation of HyperLogLogSketchHelper is referred from "com.yahoo.datasketches",
 * (https://github.com/apache/incubator-datasketches-java).
 *
 * <p>Referred paper:
 * [1] HyperLogLog: the analysis of a near-optimal cardinality estimation algorithm.
 * [2] All-Distances Sketches, Revisited: HIP Estimators for Massive Graphs Analysis.
 */
@SuppressWarnings("unchecked")
public class HyperLogLogSketchHelper {
	// number of bits used as the address of HLL slots.
	public static final short LG_CONFIG_K = log2m(0.01);
	// minimum number of bits used as address of HLL slots.
	private static final int MIN_LOG_K = 4;
	// number of HLL slots.
	private static final int CONFIG_K = 1 << LG_CONFIG_K;
	// number of Hll slots in one map entry, default is 8, must be 2^x.
	private static final int SLOTS_PER_ENTRY = 8;

	static final long DEFAULT_UPDATE_SEED = 9001L;
	static final int KEY_BITS_26 = 26;
	static final int KEY_MASK_26 = (1 << KEY_BITS_26) - 1;

	/**
	 * accumulate `value` into the Hll sketch `buffer`.
	 */
	public static void insert(HllAcc buffer, Object value) throws Exception {
		int coupon = calcCoupon(value);
		if (coupon < 0) {
			return;
		}

		// get position of slot.
		final int configKmask = CONFIG_K - 1;
		final int slotNo = getLow26(coupon) & configKmask;
		final short entryIdx = (short) (slotNo / SLOTS_PER_ENTRY);
		final int slotIdx = slotNo % SLOTS_PER_ENTRY;

		// calculate new slot value for the given position.
		byte newVal = getValue(coupon);
		assert newVal > 0;

		byte[] bytesVal = buffer.sketch.get(entryIdx);
		// slots never hit before, initialize it.
		if (bytesVal == null) {
			bytesVal = new byte[SLOTS_PER_ENTRY];
		}
		byte curVal = bytesVal[slotIdx];

		// update hll statistics if necessary.
		if (newVal > curVal) {
			// update hll slots.
			bytesVal[slotIdx] = newVal;
			buffer.sketch.put(entryIdx, bytesVal);
			// update hll statistics.
			updateHllStatistic(buffer, curVal, newVal);
		}
	}

	/**
	 * merging two Hll sketch, `other` into `buffer`.
	 */
	public static void merge(HllAcc buffer, HllAcc other) throws Exception {
		// Hip estimator can not be used after merging two Hll sketch.
		buffer.oooFlag = true;

		for (Short slot: other.sketch.keys()) {
			byte[] curVals = buffer.sketch.get(slot);
			byte[] newVals = other.sketch.get(slot);

			if (curVals == null) {
				curVals = new byte[SLOTS_PER_ENTRY];
			}

			boolean updated = false;
			for (int i = 0; i < SLOTS_PER_ENTRY; i++) {
				// value with longer leading zeros have been found, update.
				if (newVals[i] > curVals[i]) {
					// update hll statistics.
					updateHllStatistic(buffer, curVals[i], newVals[i]);
					// update hll slots.
					updated = true;
					curVals[i] = newVals[i];
				}
			}
			if (updated) {
				buffer.sketch.put(slot, curVals);
			}
		}
		other.sketch.clear();
	}

	/**
	 * estimate the count distinct result from the given Hll sketch `buffer`.
	 */
	public static long getEstimate(HllAcc buffer) {
		double estimate;
		if (buffer.oooFlag) {
			estimate = hllCompositeEstimate(
				buffer.kxq0,
				buffer.kxq1,
				buffer.zeroSlotNum,
				LG_CONFIG_K);
		} else {
			// return hipAccum directly.
			estimate = buffer.hipAccum;
		}
		return Math.round(estimate);
	}

	/**
	 * coupon contains number of leading zeros and slotNo of the value.
	 */
	private static int calcCoupon(Object value) {
		int coupon;
		if (value instanceof Boolean) {
			final byte[] data = {((Byte) value).byteValue()};
			coupon = coupon(hash(data, DEFAULT_UPDATE_SEED));
		} else if (value instanceof Byte) {
			final byte[] data = {(byte) value};
			coupon = coupon(hash(data, DEFAULT_UPDATE_SEED));
		} else if (value instanceof Short) {
			final int[] data = {(short) value};
			coupon = coupon(hash(data, DEFAULT_UPDATE_SEED));
		} else if (value instanceof Integer) {
			final int[] data = {(int) value};
			coupon = coupon(hash(data, DEFAULT_UPDATE_SEED));
		} else if (value instanceof Long) {
			final long[] data = {(long) value};
			coupon = coupon(hash(data, DEFAULT_UPDATE_SEED));
		} else if (value instanceof Double) {
			// canonicalize -0.0, 0.0
			value = ((Double) value == 0.0) ? 0.0 : value;
			// canonicalize all NaN forms
			final long[] data = { Double.doubleToLongBits((Double) value) };
			coupon = coupon(hash(data, DEFAULT_UPDATE_SEED));
		} else if (value instanceof Float) {
			// canonicalize -0.0, 0.0
			value = ((Float) value == 0.0) ? 0.0f : value;
			// canonicalize all NaN forms
			final long[] data = { Double.doubleToLongBits((Float) value) };
			coupon = coupon(hash(data, DEFAULT_UPDATE_SEED));
		} else if (value instanceof BinaryString) {
			final byte[] data = ((BinaryString) value).getBytes();
			coupon = coupon(hash(data, DEFAULT_UPDATE_SEED));
		} else if (value instanceof Date) {
			final long[] data = {((Date) value).getTime()};
			coupon = coupon(hash(data, DEFAULT_UPDATE_SEED));
		} else {
			String data = value.toString();
			if (data.length() == 0) {
				return -1;
			}
			final byte[] datas = data.getBytes(UTF_8);
			coupon = coupon(hash(datas, DEFAULT_UPDATE_SEED));
		}

		return coupon;
	}

	private static int coupon(final long[] hash) {
		final int addr26 = (int) ((hash[0] & KEY_MASK_26));
		final int lz = Long.numberOfLeadingZeros(hash[1]);
		final int value = ((lz > 62 ? 62 : lz) + 1);
		return (value << KEY_BITS_26) | addr26;
	}

	/**
	 * Computes the inverse integer power of 2: 1/(2^e) = 2^(-e).
	 * @param e a positive value between 0 and 1023 inclusive
	 * @return  the inverse integer power of 2: 1/(2^e) = 2^(-e)
	 */
	private static double invPow2(final int e) {
		assert (e | (1024 - e - 1)) >= 0 : "e cannot be negative or greater than 1023: " + e;
		return Double.longBitsToDouble((1023L - e) << 52);
	}

	/**
	 * update Hll statistics.
	 * f0: zeroSlotNum
	 * f1: kxq0
	 * f2: kxq1
	 * f3: hipAccum
	 */
	private static void updateHllStatistic(
		HllAcc buffer, byte curVal, byte newVal) {
		// update zeroSlotNum if the slot is hit first time.
		if (curVal == 0) {
			buffer.zeroSlotNum--;
			assert buffer.zeroSlotNum >= 0;
		}
		//update hipAccum BEFORE updating kxq0 and kxq1
		buffer.hipAccum += CONFIG_K / (buffer.kxq0 + buffer.kxq1);
		//update kxq0 and kxq1; subtract first, then add.
		if (curVal < 32) {
			buffer.kxq0 -= invPow2(curVal);
		} else {
			buffer.kxq1 -= invPow2(curVal);
		}
		if (newVal < 32) {
			buffer.kxq0 += invPow2(newVal);
		} else {
			buffer.kxq1 += invPow2(newVal);
		}
	}

	/**
	 * @param rsd relative standard error.
	 * @return the number of Hll slots based on the HLL RSE formula: rse = 1.106 / √m
	 */
	private static short log2m(double rsd) {
		return (short) (2 * Math.log(1.106 / rsd) / Math.log(2));
	}

	private static int getLow26(final int coupon) {
		return coupon & KEY_MASK_26;
	}

	public static byte getValue(final int coupon) {
		return (byte) (coupon >>> KEY_BITS_26);
	}

	/**
	 * This is the (non-HIP) estimator.
	 * It is called "composite" because multiple estimators are pasted together.
	 */
	private static double hllCompositeEstimate(
		double kxq0, double kxq1, int zeroBucketNum, int lgConfigK) {
		final double rawEst = getHllRawEstimate(lgConfigK, kxq0 + kxq1);

		final double[] xArr = CompositeInterpolationXTable.X_ARRS[lgConfigK - MIN_LOG_K];
		final double yStride = CompositeInterpolationXTable.Y_STRIDES[lgConfigK - MIN_LOG_K];
		final int xArrLen = xArr.length;

		if (rawEst < xArr[0]) {
			return 0;
		}

		final int xArrLenM1 = xArrLen - 1;

		if (rawEst > xArr[xArrLenM1]) {
			final double finalY = yStride * (xArrLenM1);
			final double factor = finalY / xArr[xArrLenM1];
			return rawEst * factor;
		}

		final double adjEst =
			CubicInterpolation.usingXArrAndYStride(xArr, yStride, rawEst);

		// We need to completely avoid the linear_counting estimator if it might have a crazy value.
		// Empirical evidence suggests that the threshold 3*k will keep us safe if 2^4 <= k <= 2^21.

		if (adjEst > (3 << lgConfigK)) {
			return adjEst;
		}
		//Alternate call
		//if ((adjEst > (3 << lgConfigK)) || ((curMin != 0) || (numAtCurMin == 0)) ) { return adjEst; }

		final double linEst = getHllBitMapEstimate(lgConfigK, 0, zeroBucketNum);

		// Bias is created when the value of an estimator is compared with a threshold to decide whether
		// to use that estimator or a different one.
		// We conjecture that less bias is created when the average of the two estimators
		// is compared with the threshold. Empirical measurements support this conjecture.

		final double avgEst = (adjEst + linEst) / 2.0;

		// The following constants comes from empirical measurements of the crossover point
		// between the average error of the linear estimator and the adjusted hll estimator
		double crossOver = 0.64;
		if (lgConfigK == 4)      {
			crossOver = 0.718;
		} else if (lgConfigK == 5) {
			crossOver = 0.672;
		}

		return (avgEst > (crossOver * (1 << lgConfigK))) ? adjEst : linEst;
	}

	//In C: again-two-registers.c hhb_get_raw_estimate L1167
	private static final double getHllRawEstimate(final int lgConfigK, final double kxqSum) {
		final int configK = 1 << lgConfigK;
		final double correctionFactor;
		if (lgConfigK == 4) {
			correctionFactor = 0.673;
		} else if (lgConfigK == 5) {
			correctionFactor = 0.697;
		} else if (lgConfigK == 6) {
			correctionFactor = 0.709;
		} else {
			correctionFactor = 0.7213 / (1.0 + (1.079 / configK));
		}
		final double hyperEst = (correctionFactor * configK * configK) / kxqSum;
		return hyperEst;
	}

	/**
	 * Estimator when N is small, roughly less than k log(k).
	 * Refer to Wikipedia: Coupon Collector Problem
	 * @param lgConfigK the current configured lgK of the sketch
	 * @param curMin the current minimum value of the HLL window
	 * @param numAtCurMin the current number of rows with the value curMin
	 * @return the very low range estimate
	 */
	//In C: again-two-registers.c hhb_get_improved_linear_counting_estimate L1274
	private static final double getHllBitMapEstimate(
		final int lgConfigK, final int curMin, final int numAtCurMin) {
		final int configK = 1 << lgConfigK;
		final int numUnhitBuckets =  (curMin == 0) ? numAtCurMin : 0;

		//This will eventually go away.
		if (numUnhitBuckets == 0) {
			return configK * Math.log(configK / 0.5);
		}

		final int numHitBuckets = configK - numUnhitBuckets;
		return HarmonicNumbers.getBitMapEstimate(configK, numHitBuckets);
	}
}
