package cn.com.duiba.nezha.alg.alg.dpa;

import cn.com.duiba.nezha.alg.common.vo.GenericPair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;

public class RecallUtil {
    private static final Logger logger= LoggerFactory.getLogger(RecallUtil.class);

    protected static List<int[]> combination(int m, int n, int cNum)
    {
        List<int[]> result = new ArrayList<>(cNum);
        int[] value = new int[n];
        for (int i = 0; i < n; i++) {
            value[i] = i;
        }
        int count = 0;
        do
        {
            result.add(value.clone());
        } while ( (++count < cNum) && move(value, m));
        return result;
    }

    private static boolean move(int[] value, int max)
    {
        int currentIndex = value.length - 1;
        while (currentIndex > -1)
        {
            int current = value[currentIndex];
            current++;
            if (current == max) {
                currentIndex--;
            }else {
                value[currentIndex] = current;
                for (int i = currentIndex + 1; i < value.length; i++) {
                    value[i] = value[i - 1] + 1;
                }
                if (Arrays.stream(value).noneMatch(x -> x >= max)){
                    return true;
                }
            }
        }
        return false;
    }

    protected static List<Integer> randKWithoutReplacement(List<Double> rates, int k) {
        if (null == rates || rates.isEmpty()) {
            logger.error("the rates list is null or empty");
            throw new RuntimeException("the rates list is null or empty");
        }
        if (k >= rates.size()) {
            logger.error("k is bigger than rates' size");
            throw new RuntimeException("k is bigger than rates' size");
        }

        List<HeapNode> nodes = new ArrayList<>(rates.size());

        for (int index = 0; index < rates.size(); index++) {
            nodes.add(new HeapNode(rates.get(index), index));
        }

        List<Integer> result = new ArrayList<>(k);
        List<HeapNode> heap = buildHeap(nodes);

        for (int index = 0; index < k; index++) {
            result.add(heapPop(heap));
        }
        return result;
    }

    private static List<HeapNode> buildHeap(List<HeapNode> nodes) {
        List<HeapNode> heap = new ArrayList<>(nodes.size() + 1);

        heap.add(null);
        heap.addAll(nodes);

        for (int index = heap.size() - 1; index > 1; index--) {
            double curTW = heap.get(index >> 1).totalWeight;
            heap.get(index >> 1).totalWeight = curTW + heap.get(index).totalWeight;
        }
        return heap;

    }

    private static int heapPop(List<HeapNode> heap) {
        double gas = heap.get(1).totalWeight * Math.random();
        int i = 1;
        while (gas > heap.get(i).weight) {
            gas = gas - heap.get(i).weight;
            i <<= 1;
            if (gas > heap.get(i).totalWeight) {
                gas = gas - heap.get(i).totalWeight;
                i++;
            }
        }

        double weight = heap.get(i).weight;
        int value = heap.get(i).value;
        heap.get(i).weight = 0;
        while (i > 0) {
            heap.get(i).totalWeight = heap.get(i).totalWeight - weight;
            i >>= 1;
        }

        return value;
    }

    private static class HeapNode {
        double weight;
        int value;
        double totalWeight;

        public HeapNode(double weight, int value) {
            this.weight = weight;
            this.value = value;
            this.totalWeight = weight;
        }
    }

    protected static List<Integer> argSort(List<Double> rates){
        List<GenericPair<Integer, Double>> idxWithRates = new ArrayList<>(rates.size());
        for (int i = 0; i < rates.size(); i++) {
            idxWithRates.add(new GenericPair<>(i, rates.get(i)));
        }
        idxWithRates = idxWithRates.stream().sorted().collect(Collectors.toList());
        return idxWithRates.stream().mapToInt(item -> item.first).boxed().collect(Collectors.toList());
    }


    protected static List<Integer> randKWithReplacement(List<Double> rates, int k){
        if (null == rates || rates.isEmpty()) {
            logger.error("the rates list is null or empty");
            throw new RuntimeException("the rates list is null or empty");
        }

        List<Integer> result = new ArrayList<>(k);
        double sum = rates.stream().mapToDouble(v -> v).sum();
        for (int i = 0; i < k; i++) {
            result.add(weightedRandPick(rates, sum));
        }

        return result;
    }

    protected static int weightedRandPick(List<Double> rates, double sum){
        double randomNum = Math.random()*sum;
        int findIdx = 0;
        for (int j = 0; j < rates.size(); j++) {
            findIdx = j;
            randomNum -= rates.get(j);
            if(randomNum <= 0.0){
                break;
            }
        }
        return findIdx;
    }

    protected static int weightedRandPick(List<Double> rates){
        double sum = rates.stream().mapToDouble(v -> v).sum();
        return weightedRandPick(rates, sum);
    }
}
