/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.runtime.functions.aggfunctions;

import java.util.Map;
import org.apache.flink.table.api.dataview.MapView;
import org.apache.flink.table.dataformat.GenericRow;
import org.apache.flink.table.functions.AggregateFunction;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.DataTypes;
import org.apache.flink.table.types.DecimalType;
import org.apache.flink.table.types.InternalType;
import org.apache.flink.table.types.RowType;
import org.apache.flink.table.types.TypeConverters;
import org.apache.flink.table.types.TypeInfoWrappedDataType;
import org.apache.flink.table.typeutils.BinaryStringTypeInfo;
import org.apache.flink.table.typeutils.MapViewTypeInfo;

public class CountDistinct {

    public static class StringCountDistinctAggFunction
    extends CountDistinctAggFunction {
        @Override
        public DataType getValueTypeInfo() {
            return new TypeInfoWrappedDataType(BinaryStringTypeInfo.INSTANCE);
        }
    }

    public static class TimestampCountDistinctAggFunction
    extends CountDistinctAggFunction {
        @Override
        public DataType getValueTypeInfo() {
            return DataTypes.TIMESTAMP;
        }
    }

    public static class TimeCountDistinctAggFunction
    extends CountDistinctAggFunction {
        @Override
        public DataType getValueTypeInfo() {
            return DataTypes.TIME;
        }
    }

    public static class DateCountDistinctAggFunction
    extends CountDistinctAggFunction {
        @Override
        public DataType getValueTypeInfo() {
            return DataTypes.DATE;
        }
    }

    public static class BooleanCountDistinctAggFunction
    extends CountDistinctAggFunction {
        @Override
        public DataType getValueTypeInfo() {
            return DataTypes.BOOLEAN;
        }
    }

    public static class ShortCountDistinctAggFunction
    extends CountDistinctAggFunction {
        @Override
        public DataType getValueTypeInfo() {
            return DataTypes.SHORT;
        }
    }

    public static class LongCountDistinctAggFunction
    extends CountDistinctAggFunction {
        @Override
        public DataType getValueTypeInfo() {
            return DataTypes.LONG;
        }
    }

    public static class IntCountDistinctAggFunction
    extends CountDistinctAggFunction {
        @Override
        public DataType getValueTypeInfo() {
            return DataTypes.INT;
        }
    }

    public static class FloatCountDistinctAggFunction
    extends CountDistinctAggFunction {
        @Override
        public DataType getValueTypeInfo() {
            return DataTypes.FLOAT;
        }
    }

    public static class DoubleCountDistinctAggFunction
    extends CountDistinctAggFunction {
        @Override
        public DataType getValueTypeInfo() {
            return DataTypes.DOUBLE;
        }
    }

    public static class DecimalCountDistinctAggFunction
    extends CountDistinctAggFunction {
        public final DecimalType decimalType;

        public DecimalCountDistinctAggFunction(DecimalType decimalType) {
            this.decimalType = decimalType;
        }

        @Override
        public DataType getValueTypeInfo() {
            return this.decimalType;
        }
    }

    public static class ByteCountDistinctAggFunction
    extends CountDistinctAggFunction {
        @Override
        public DataType getValueTypeInfo() {
            return DataTypes.BYTE;
        }
    }

    public static abstract class CountDistinctAggFunction
    extends AggregateFunction<Long, GenericRow> {
        public abstract DataType getValueTypeInfo();

        @Override
        public GenericRow createAccumulator() {
            GenericRow acc = new GenericRow(3);
            acc.setLong(0, 0L);
            acc.update(1, new MapView(this.getValueTypeInfo(), DataTypes.LONG));
            return acc;
        }

        public void accumulate(GenericRow acc, Object input) throws Exception {
            if (input != null) {
                long count = acc.getLong(0);
                MapView map2 = (MapView)acc.getField(1);
                Long valueCnt = (Long)map2.get(input);
                if (valueCnt != null) {
                    if ((valueCnt = Long.valueOf(valueCnt + 1L)) == 0L) {
                        map2.remove(input);
                    } else {
                        map2.put(input, valueCnt);
                    }
                } else {
                    map2.put(input, 1L);
                    acc.update(0, count + 1L);
                }
            }
        }

        public void retract(GenericRow acc, Object input) throws Exception {
            if (input != null) {
                long count = acc.getLong(0);
                MapView map2 = (MapView)acc.getField(1);
                Long valueCnt = (Long)map2.get(input);
                if (valueCnt != null) {
                    if ((valueCnt = Long.valueOf(valueCnt - 1L)) == 0L) {
                        map2.remove(input);
                        acc.update(0, count - 1L);
                    } else {
                        map2.put(input, valueCnt);
                    }
                } else {
                    map2.put(input, -1L);
                }
            }
        }

        public void merge(GenericRow acc, Iterable<GenericRow> it) throws Exception {
            MapView map2 = (MapView)acc.getField(1);
            long count = acc.getLong(0);
            for (GenericRow mergeAcc : it) {
                MapView mergeMap = (MapView)mergeAcc.getField(1);
                Iterable entries = mergeMap.entries();
                if (entries == null) continue;
                for (Map.Entry entry : entries) {
                    Object key = entry.getKey();
                    Long mergeCnt = (Long)entry.getValue();
                    Long valueCnt = (Long)map2.get(key);
                    if (valueCnt != null) {
                        Long mergedCnt = valueCnt + mergeCnt;
                        if (mergedCnt == 0L) {
                            map2.remove(key);
                            if (valueCnt <= 0L) continue;
                            --count;
                            continue;
                        }
                        if (mergedCnt < 0L) {
                            map2.put(key, mergedCnt);
                            if (valueCnt <= 0L) continue;
                            --count;
                            continue;
                        }
                        if (valueCnt < 0L) {
                            ++count;
                        }
                        map2.put(key, mergedCnt);
                        continue;
                    }
                    if (mergeCnt > 0L) {
                        map2.put(key, mergeCnt);
                        ++count;
                        continue;
                    }
                    if (mergeCnt >= 0L) continue;
                    map2.put(key, mergeCnt);
                }
            }
            acc.update(0, count);
        }

        public void resetAccumulator(GenericRow acc) {
            acc.setLong(0, 0L);
            MapView map2 = (MapView)acc.getField(1);
            map2.clear();
        }

        @Override
        public Long getValue(GenericRow accumulator) {
            return accumulator.getLong(0);
        }

        @Override
        public DataType[] getUserDefinedInputTypes(Class[] signature) {
            if (signature.length == 1) {
                return new DataType[]{this.getValueTypeInfo()};
            }
            if (signature.length == 0) {
                return new DataType[0];
            }
            throw new UnsupportedOperationException();
        }

        @Override
        public DataType getAccumulatorType() {
            DataType[] fieldTypes2 = new InternalType[]{DataTypes.LONG, DataTypes.createGenericType(new MapViewTypeInfo(TypeConverters.createExternalTypeInfoFromDataType(this.getValueTypeInfo()), TypeConverters.createExternalTypeInfoFromDataType(DataTypes.LONG), false, true))};
            String[] fieldNames = new String[]{"count", "map"};
            return new RowType(fieldTypes2, fieldNames);
        }
    }
}

