/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.table.runtime.functions.aggfunctions;

import org.apache.flink.api.java.typeutils.GenericTypeInfo;
import org.apache.flink.core.memory.MemorySegment;
import org.apache.flink.core.memory.MemorySegmentFactory;
import org.apache.flink.table.dataformat.BinaryString;
import org.apache.flink.table.dataformat.Decimal;
import org.apache.flink.table.functions.AggregateFunction;
import org.apache.flink.table.runtime.functions.aggfunctions.hyperloglog.HyperLogLogPlusPlus;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.DataTypes;
import org.apache.flink.table.types.DecimalType;
import org.apache.flink.table.types.GenericType;
import org.apache.flink.table.types.TypeInfoWrappedDataType;
import org.apache.flink.table.typeutils.BinaryStringTypeInfo;
import org.apache.flink.table.typeutils.DecimalTypeInfo;

import java.sql.Date;
import java.sql.Time;
import java.sql.Timestamp;

import static org.apache.flink.table.runtime.functions.aggfunctions.hyperloglog.XxHash64Function.DEFAULT_SEED;
import static org.apache.flink.table.runtime.functions.aggfunctions.hyperloglog.XxHash64Function.hashInt;
import static org.apache.flink.table.runtime.functions.aggfunctions.hyperloglog.XxHash64Function.hashLong;
import static org.apache.flink.table.runtime.functions.aggfunctions.hyperloglog.XxHash64Function.hashUnsafeBytes;

/**
 * Built-in approximate count distinct aggregate function for Batch.
 */
public class BatchApproximateCountDistinct {

	/**
	 * Base function for approximate count distinct aggregate.
	 */
	public abstract static class ApproximateCountDistinctAggFunction<T> extends AggregateFunction<Long, HllBuffer> {

		private static final Double RELATIVE_SD = 0.01;
		private HyperLogLogPlusPlus hyperLogLogPlusPlus;

		public abstract DataType getValueTypeInfo();

		@Override
		public HllBuffer createAccumulator() {
			hyperLogLogPlusPlus = new HyperLogLogPlusPlus(RELATIVE_SD);
			HllBuffer buffer = new HllBuffer();
			buffer.array = new long[hyperLogLogPlusPlus.numWords()];
			resetAccumulator(buffer);
			return buffer;
		}

		public void accumulate(HllBuffer buffer, Object input) throws Exception {
			if (input != null) {
				hyperLogLogPlusPlus.updateByHashcode(buffer, getHashcode((T) input));
			}
		}

		abstract long getHashcode(T t);

		public void merge(HllBuffer buffer, Iterable<HllBuffer> it) throws Exception {
			for (HllBuffer tmpBuffer : it) {
				hyperLogLogPlusPlus.merge(buffer, tmpBuffer);
			}
		}

		public void resetAccumulator(HllBuffer buffer) {
			int word = 0;
			while (word < hyperLogLogPlusPlus.numWords()) {
				buffer.array[word] = 0;
				word++;
			}
		}

		@Override
		public Long getValue(HllBuffer buffer) {
			return hyperLogLogPlusPlus.query(buffer);
		}

		@Override
		public DataType[] getUserDefinedInputTypes(Class[] signature) {
			if (signature.length == 1) {
				return new DataType[]{getValueTypeInfo()};
			} else if (signature.length == 0) {
				return new DataType[0];
			} else {
				throw new UnsupportedOperationException();
			}
		}

		@Override
		public DataType getAccumulatorType() {
			return new GenericType<>(new GenericTypeInfo<>(HllBuffer.class));
		}
	}

	/**
	 * Buffer used by HLL++.
	 */
	public static class HllBuffer {
		public long[] array;
	}

	/**
	 * Built-in byte approximate count distinct aggregate function.
	 */
	public static class ByteApproximateCountDistinctAggFunction extends ApproximateCountDistinctAggFunction<Byte> {

		@Override
		public DataType getValueTypeInfo() {
			return DataTypes.BYTE;
		}

		@Override
		long getHashcode(Byte value) {
			return hashInt(value, DEFAULT_SEED());
		}
	}

	/**
	 * Built-in decimal approximate count distinct aggregate function.
	 */
	public static class DecimalApproximateCountDistinctAggFunction extends ApproximateCountDistinctAggFunction<Decimal> {

		private final int precision;
		private final int scale;

		public DecimalApproximateCountDistinctAggFunction(DecimalType decimalType) {
			this.precision = decimalType.precision();
			this.scale = decimalType.scale();
		}

		@Override
		public DataType getValueTypeInfo() {
			return new TypeInfoWrappedDataType(new DecimalTypeInfo(precision, scale));
		}

		@Override
		long getHashcode(Decimal d) {
			if (precision <= Decimal.MAX_LONG_DIGITS) {
				return hashLong(d.toUnscaledLong(), DEFAULT_SEED());
			} else {
				byte[] bytes = d.toBigDecimal().unscaledValue().toByteArray();
				MemorySegment segment = MemorySegmentFactory.wrap(bytes);
				return hashUnsafeBytes(segment, 0, segment.size(), DEFAULT_SEED());
			}
		}
	}

	/**
	 * Built-in double approximate count distinct aggregate function.
	 */
	public static class DoubleApproximateCountDistinctAggFunction extends ApproximateCountDistinctAggFunction<Double> {

		@Override
		public DataType getValueTypeInfo() {
			return DataTypes.DOUBLE;
		}

		@Override
		long getHashcode(Double value) {
			return hashLong(Double.doubleToLongBits(value), DEFAULT_SEED());
		}
	}

	/**
	 * Built-in float approximate count distinct aggregate function.
	 */
	public static class FloatApproximateCountDistinctAggFunction extends ApproximateCountDistinctAggFunction<Float> {

		@Override
		public DataType getValueTypeInfo() {
			return DataTypes.FLOAT;
		}

		@Override
		long getHashcode(Float value) {
			return hashInt(Float.floatToIntBits(value), DEFAULT_SEED());
		}
	}

	/**
	 * Built-in int approximate count distinct aggregate function.
	 */
	public static class IntApproximateCountDistinctAggFunction extends ApproximateCountDistinctAggFunction<Integer> {

		@Override
		public DataType getValueTypeInfo() {
			return DataTypes.INT;
		}

		@Override
		long getHashcode(Integer value) {
			return hashInt(value, DEFAULT_SEED());
		}
	}

	/**
	 * Built-in long approximate count distinct aggregate function.
	 */
	public static class LongApproximateCountDistinctAggFunction extends ApproximateCountDistinctAggFunction<Long> {

		@Override
		public DataType getValueTypeInfo() {
			return DataTypes.LONG;
		}

		@Override
		long getHashcode(Long value) {
			return hashLong(value, DEFAULT_SEED());
		}
	}

	/**
	 * Built-in short approximate count distinct aggregate function.
	 */
	public static class ShortApproximateCountDistinctAggFunction extends ApproximateCountDistinctAggFunction<Short> {

		@Override
		public DataType getValueTypeInfo() {
			return DataTypes.SHORT;
		}

		@Override
		long getHashcode(Short value) {
			return hashInt(value, DEFAULT_SEED());
		}
	}

	/**
	 * Built-in boolean approximate count distinct aggregate function.
	 */
	public static class BooleanApproximateCountDistinctAggFunction extends ApproximateCountDistinctAggFunction<Boolean> {

		@Override
		public DataType getValueTypeInfo() {
			return DataTypes.BOOLEAN;
		}

		@Override
		long getHashcode(Boolean value) {
			return hashInt(value ? 1 : 0, DEFAULT_SEED());
		}
	}

	/**
	 * Built-in date approximate count distinct aggregate function.
	 */
	public static class DateApproximateCountDistinctAggFunction extends ApproximateCountDistinctAggFunction<Date> {

		@Override
		public DataType getValueTypeInfo() {
			return DataTypes.DATE;
		}

		@Override
		long getHashcode(Date value) {
			return hashLong(value.getTime(), DEFAULT_SEED());
		}
	}

	/**
	 * Built-in time approximate count distinct aggregate function.
	 */
	public static class TimeApproximateCountDistinctAggFunction extends ApproximateCountDistinctAggFunction<Time> {

		@Override
		public DataType getValueTypeInfo() {
			return DataTypes.TIME;
		}

		@Override
		long getHashcode(Time value) {
			return hashLong(value.getTime(), DEFAULT_SEED());
		}
	}

	/**
	 * Built-in timestamp approximate count distinct aggregate function.
	 */
	public static class TimestampApproximateCountDistinctAggFunction extends ApproximateCountDistinctAggFunction<Timestamp> {

		@Override
		public DataType getValueTypeInfo() {
			return DataTypes.TIMESTAMP;
		}

		@Override
		long getHashcode(Timestamp value) {
			return hashLong(value.getTime(), DEFAULT_SEED());
		}
	}

	/**
	 * Built-in string approximate count distinct aggregate function.
	 */
	public static class StringApproximateCountDistinctAggFunction extends ApproximateCountDistinctAggFunction<BinaryString> {

		@Override
		public DataType getValueTypeInfo() {
			return new TypeInfoWrappedDataType(BinaryStringTypeInfo.INSTANCE);
		}

		@Override
		long getHashcode(BinaryString s) {
			MemorySegment[] segments = s.getSegments();
			if (segments.length == 1) {
				return hashUnsafeBytes(
						segments[0], s.getOffset(), s.numBytes(), DEFAULT_SEED());
			} else {
				return hashUnsafeBytes(
						MemorySegmentFactory.wrap(s.getBytes()), 0, s.numBytes(), DEFAULT_SEED());
			}
		}
	}
}
