package cn.com.duiba.nezha.alg.alg.dpa;

import cn.com.duiba.nezha.alg.alg.constant.DPAConstant;
import cn.com.duiba.nezha.alg.alg.vo.dpa.PackageIdDo;
import cn.com.duiba.nezha.alg.alg.vo.dpa.PackageInfoDo;
import cn.com.duiba.nezha.alg.alg.vo.dpa.PackageRecallDo;
import cn.com.duiba.nezha.alg.alg.vo.dpa.PrizeDo;
import cn.com.duiba.nezha.alg.common.util.AssertUtil;
import cn.com.duiba.nezha.alg.common.vo.GenericPair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.*;
import java.util.stream.Collectors;

import static cn.com.duiba.nezha.alg.alg.dpa.PackageRecall.ucbPackage;

/**
 * @author: tanyan
 * @date: 2020/6/19
 */
public class PrizeRecall {
    private static final Logger logger= LoggerFactory.getLogger(PrizeRecall.class);

    private static final Integer RECALL_GROUP_MAX_NUM = 10;
    private static final Integer TOTAL_TOPK_PICK_RANGE = 100;
    private static final Integer PRIZE_NUM_AT_LEAST = 1;
    private static final Integer AD_PRIZE_NUM_AT_LEAST = 1;
    private static final Integer AD_PRIZE_NUM_AT_MOST = 2;
    private static final Integer TRY_NUM_FOR_ONE_GROUP = 3;

    /**
     * 对外海选接口，不需要传广告奖品列表
     * @param pks 可投奖品列表, pk的ID对象的prizeId、prizeTagId属性不能为空，adPrizeId属性为null
     * @param slotNum 活动槽位数量
     * @return 海选召回的奖品组合列表
     */
    public static List<PackageRecallDo> matchPrize(List<PackageInfoDo> pks, int slotNum){
        List<PackageInfoDo> adPks = new ArrayList<>();
        return matchPrizeAndAd(pks,adPks,slotNum);
    }

//    /**
//     * 对外实时接口，只传广告奖品
//     * @param adPks
//     * @return
//     */
//    public static List<PackageRecallDo> matchAdPrizeOnline(List<PackageInfoDo> adPks){
//        List<PackageInfoDo> pks = new ArrayList<PackageInfoDo>();
//        int slotNum = AD_PRIZE_NUM_AT_MOST;
//        return matchPrize(pks,adPks,slotNum);
//    }

    /**
     * 召回接口
     * @param pks 可投奖品列表，pk的ID对象的prizeId、prizeTagId属性不能为空，adPrizeId属性为null
     * @param adPks 可投券列表，pk的ID对象的adPrizeId不能为空，prizeId属性为null, adPrizeTagId和adPrizeType至少一个不为空
     * @param slotNum 活动槽位数，约束组合的长度
     * @return 匹配的奖品和券的组合列表
     */
    public static List<PackageRecallDo> matchPrizeAndAd(List<PackageInfoDo> pks,
                                                   List<PackageInfoDo> adPks,
                                                   int slotNum){
        if(AssertUtil.isAnyEmpty(pks) && AssertUtil.isAnyEmpty(adPks)){
            logger.error("PackageRecall matchPrize input params prizes and ad is null");
            return null;
        }

        //将券的string类型“券类型”映射到Long类型字段券tagID上
        Map<String,Long> adTag2Idx = preprocessAdList(adPks);

        List<PackageRecallDo> recallPrizeList = recallPrize(pks);
        List<PackageRecallDo> recallAdList = recallAdPrize(adPks);

        //        List<PackageRecallDo> groupList = generatePrizeGroup(recallList,prizeTypeGraph,groupNum,slotNum);
        return genPrizeGroup(recallPrizeList, recallAdList, slotNum, RECALL_GROUP_MAX_NUM);
    }

    public static Map<String,Long> preprocessAdList(List<PackageInfoDo> adPks){
        Map<String,Long> adTag2Idx = new HashMap<>();
        long offset = 10000;
        for(PackageInfoDo pk: adPks){
            String typeString = pk.getPackageIdDo().getAdPrizeType();
            Long idx;
            if(adTag2Idx.containsKey(typeString)){
                idx = adTag2Idx.get(typeString);
            }else{
                idx = adTag2Idx.size() + offset;
                adTag2Idx.put(typeString,idx);
            }
            pk.getPackageIdDo().setPrizeTagId(idx);
        }
        return adTag2Idx;
    }

    private static TagInfo getTagBucket(List<PackageRecallDo> prds){
        Map<Long, List<PackageRecallDo>> tagBuckets = new HashMap<>();
        Map<Long, Double> tagWeight = new HashMap<>();
        for (PackageRecallDo pkrd : prds) {
            Long prizeTagId = pkrd.getPackageInfoDo().getPackageIdDo().getPrizeTagId();
            if (tagBuckets.containsKey(prizeTagId)) {
                tagBuckets.get(prizeTagId).add(pkrd);
            } else {
                List<PackageRecallDo> bucket = new ArrayList<>();
                bucket.add(pkrd);
                tagBuckets.put(prizeTagId, bucket);
            }
            tagWeight.put(prizeTagId, tagWeight.getOrDefault(prizeTagId, 0.0d) + pkrd.getMatchScore());
        }
        return new TagInfo(tagBuckets,tagWeight);
    }

    private static class TagInfo{
        public Map<Long, List<PackageRecallDo>> tagBucket;
        public Map<Long, Double> tagWeight;
        public TagInfo(Map<Long, List<PackageRecallDo>> tagBucket, Map<Long, Double> tagWeight){
            this.tagBucket = tagBucket;
            this.tagWeight = tagWeight;
        }
    }

    private static List<PackageRecallDo> topKFromMergedPrizeAndAdList(List<PackageRecallDo> pks,
                                                             List<PackageRecallDo> adPks,
                                                             int topk){
        if( pks.size() == 0 && adPks.size() == 0){
            logger.warn("No input need to be merged!");
        }

        int curIdxPK = 0;
        int curIdxAd = 0;
        int totalSize = Math.min(topk, pks.size() + adPks.size());

        List<PackageRecallDo> result = new ArrayList<>(totalSize);
        for (int i = 0; i < totalSize; i++) {
            if (curIdxPK == pks.size()) {
                result.addAll(adPks.subList(curIdxAd, curIdxAd + totalSize-i));
                break;
            }
            if (curIdxAd == adPks.size()) {
                result.addAll(pks.subList(curIdxPK, curIdxPK + totalSize-i));
                break;
            }

            if (pks.get(curIdxPK).getMatchScore() > adPks.get(curIdxAd).getMatchScore()) {
                result.add(pks.get(curIdxPK));
                curIdxPK++;
            } else {
                result.add(adPks.get(curIdxAd));
                curIdxAd++;
            }

        }
        return result;
    }

    private static List<PackageRecallDo> genPrizeGroup(List<PackageRecallDo> pks,
                                                       List<PackageRecallDo> adPks,
                                                       Integer slotNum,
                                                       Integer groupNum){
        Set<String>  existingGroupSet = new HashSet<>();
        List<PackageRecallDo> result = new ArrayList<>();

        List<PackageRecallDo> sortedPks = pks.stream().sorted(Comparator.comparing(PackageRecallDo::getMatchScore)
                .reversed()).collect(Collectors.toList());
        List<PackageRecallDo> sortedAdPks = adPks.stream().sorted(Comparator.comparing(PackageRecallDo::getMatchScore)
                .reversed()).collect(Collectors.toList());

        //生成排序分最好的组合
        PackageRecallDo theBestGroup = genBestScoreGroup(sortedPks, sortedAdPks,slotNum);
        result.add(theBestGroup);
        existingGroupSet.add(groupSignature(theBestGroup));

        //不考虑类型的topk random pk扩充召回集，swap策略
        Integer swapGenNum = (int)Math.ceil(groupNum * 0.4);
        List<PackageRecallDo> swapGroups = genByRandomPick(sortedPks, sortedAdPks, existingGroupSet, slotNum, swapGenNum);
        result.addAll(swapGroups);

        //考虑类型桶的random pk 补充
        Integer tagGenNum = (int)Math.ceil(groupNum * 0.5);
        List<PackageRecallDo> tagGroups = genByPickByTag(sortedPks, sortedAdPks, existingGroupSet, slotNum, tagGenNum);
        result.addAll(tagGroups);

        return result;
    }

    private static List<PackageRecallDo> genByPickByTag(List<PackageRecallDo> pks,
                                                        List<PackageRecallDo> adPks,
                                                        Set<String>  existingGroupSet,
                                                        int slotNum,
                                                        Integer groupNum){
        List<PackageRecallDo> result = new ArrayList<>(groupNum);
        TagInfo prizeTagInfo = getTagBucket(pks);
        TagInfo adTagInfo = getTagBucket(adPks);

        for (int i = 0; i < groupNum; i++) {
            int tryNum = 0;
            boolean isNewOne = false;
            while (tryNum < TRY_NUM_FOR_ONE_GROUP && !isNewOne) {
                PackageRecallDo pkrd = genOneGroupByTagInfo(pks, adPks, prizeTagInfo, adTagInfo, slotNum);
                String groupSign = groupSignature(pkrd);
                if(!existingGroupSet.contains(groupSign)){
                    existingGroupSet.add(groupSign);
                    result.add(pkrd);
                    isNewOne = true;
                }
                tryNum++;
            }
        }
        return result;
    }

    private static PackageRecallDo genOneGroupByTagInfo(List<PackageRecallDo> pks, List<PackageRecallDo> adPks, TagInfo prizeTagInfo, TagInfo adTagInfo, int slotNum) {
        PackageRecallDo result = new PackageRecallDo();
        int adNum = Math.min(adTagInfo.tagBucket.size(), AD_PRIZE_NUM_AT_MOST);
        int prizeNum = slotNum - adNum;
        List<PackageRecallDo> prizes;
        List<PackageRecallDo> ads;
        if(adNum <= adTagInfo.tagBucket.size()){
            ads = fetchPrizeFromRandomPickedTag(adTagInfo, adNum);
        }else{
            ads = fetchPkrdFromWeightedList(adPks,adNum);
        }
        if(prizeNum <= prizeTagInfo.tagBucket.size()){
            prizes = fetchPrizeFromRandomPickedTag(prizeTagInfo, prizeNum);
        }else{
            prizes = fetchPkrdFromWeightedList(pks,prizeNum);
        }

        List<PrizeDo> group = new ArrayList<>(slotNum);
        for(PackageRecallDo pkrd : ads){
            PrizeDo prizeDo = prizeDoFromPackageRecallDo(pkrd);
            group.add(prizeDo);
        }
        for(PackageRecallDo pkrd : prizes){
            PrizeDo prizeDo = prizeDoFromPackageRecallDo(pkrd);
            group.add(prizeDo);
        }
        result.getPackageInfoDo().getPackageIdDo().setPrizeGroup(group);
        return result;
    }

    private static PrizeDo prizeDoFromPackageRecallDo(PackageRecallDo pkrd){
        PackageIdDo curPid = pkrd.getPackageInfoDo().getPackageIdDo();
        boolean isAd = curPid.getAdPrizeId() != null;
        Long prizeTagId = curPid.getPrizeTagId();
        Long id = isAd ? curPid.getAdPrizeId() : curPid.getPrizeId();
        return new PrizeDo(isAd, id,prizeTagId, pkrd.getMatchScore());
    }

    private static List<PackageRecallDo> fetchPrizeFromRandomPickedTag(TagInfo tagInfo, int k){
        List<PackageRecallDo> result = new ArrayList<>();
        List<Long> choosenAdTags = weightedPickTags(tagInfo.tagWeight, k);
        for(Long tagId: choosenAdTags){
//            PackageRecallDo pkrd = weightedRandomPickByMatchscore(tagInfo.tagBucket.get(tagId));
            PackageRecallDo pkrd = tagInfo.tagBucket.get(tagId).get(0);
            result.add(pkrd);
        }
        return result;
    }

    protected static List<Long> weightedPickTags(Map<Long, Double> tagWeight, int topk){
        List<Long> result = new ArrayList<>(topk);
        if(topk <= 0){
            return result;
        }

        List<Long> tags = new ArrayList<>(tagWeight.size());
        List<Double> weights = new ArrayList<>(tagWeight.size());
        for (Long tagId : tagWeight.keySet()) {
            tags.add(tagId);
            weights.add(tagWeight.get(tagId));
        }
        List<Integer> resultIdx;
        if(topk > tagWeight.size()){
            logger.warn("k大于类目数,截断成类目数");
            topk = tagWeight.size();
        }

        if(topk == tagWeight.size()){
            resultIdx = argSort(weights);
        }else{
            resultIdx = randKWithoutReplacement(weights,topk);
        }

        for(Integer idx : resultIdx){
            result.add(tags.get(idx));
        }

        return result;
    }

    private static List<PackageRecallDo> genByRandomPick(List<PackageRecallDo> pks,
                                                           List<PackageRecallDo> adPks,
                                                           Set<String>  existingGroupSet,
                                                           Integer slotNum,
                                                           Integer groupNum){
        List<PackageRecallDo> result = new ArrayList<>(groupNum);
        int mergeNum = Math.min(TOTAL_TOPK_PICK_RANGE, pks.size() + adPks.size());
        List<PackageRecallDo> sortedPrizes = topKFromMergedPrizeAndAdList(pks, adPks, mergeNum);
        for (int i = 0; i < groupNum; i++) {
            int tryNum = 0;
            boolean isNewOne = false;
            while (tryNum < TRY_NUM_FOR_ONE_GROUP && !isNewOne) {
                PackageRecallDo pkrd = randomPickOneGroup(sortedPrizes, slotNum);
                String groupSign = groupSignature(pkrd);
                if(!existingGroupSet.contains(groupSign)){
                    existingGroupSet.add(groupSign);
                    result.add(pkrd);
                    isNewOne = true;
                }
                tryNum++;
            }
        }
        return result;
    }

    private static PackageRecallDo randomPickOneGroup(List<PackageRecallDo> sortedPrizes, Integer slotNum) {
        PackageRecallDo result = new PackageRecallDo();
        List<PrizeDo> group = new ArrayList<>(slotNum);
        List<PackageRecallDo> pkrds = fetchPkrdFromWeightedList(sortedPrizes, slotNum);
        for(PackageRecallDo pkrd : pkrds){
            PrizeDo prizeDo = prizeDoFromPackageRecallDo(pkrd);
            group.add(prizeDo);
        }
        result.getPackageInfoDo().getPackageIdDo().setPrizeGroup(group);
        return result;
    }

    protected static List<PackageRecallDo> fetchPkrdFromWeightedList(List<PackageRecallDo> sortedPrizes, int k){
        List<PackageRecallDo> result = new ArrayList<>(k);
        if( k > sortedPrizes.size()){
            logger.warn("无放回抽样数大于样本数,返回原List");
            return sortedPrizes;
        }
        List<Double> weights = sortedPrizes.stream().map(PackageRecallDo::getMatchScore).collect(Collectors.toList());
        List<Integer> idxs = randKWithoutReplacement(weights, k);
        for(Integer idx : idxs){
            result.add(sortedPrizes.get(idx));
        }
        return result;
    }

    private static PackageRecallDo genBestScoreGroup(List<PackageRecallDo> pks,
                                                     List<PackageRecallDo> adPks,
                                                     int slotNum){
        PackageRecallDo pkrd = new PackageRecallDo();
        List<PrizeDo> prizeGroup = new ArrayList<>();

        int adPreSelectNum = Math.min(AD_PRIZE_NUM_AT_LEAST, adPks.size());
        int prizePreSelectNum = Math.min(PRIZE_NUM_AT_LEAST, pks.size());

        for (int i = 0; i < adPreSelectNum; i++) {
            PackageRecallDo ad = adPks.get(i);
            prizeGroup.add(new PrizeDo(true,ad.getPackageInfoDo().getPackageIdDo().getAdPrizeId(),ad.getPackageInfoDo().getPackageIdDo().getPrizeTagId(),ad.getMatchScore()));
        }

        for (int i = 0; i < prizePreSelectNum; i++) {
            PackageRecallDo prize = pks.get(i);
            prizeGroup.add(new PrizeDo(false,prize.getPackageInfoDo().getPackageIdDo().getPrizeId(),prize.getPackageInfoDo().getPackageIdDo().getPrizeTagId(),prize.getMatchScore()));
        }

        int leftNum = slotNum - adPreSelectNum - prizePreSelectNum;
        int pkIdx = prizePreSelectNum;
        int adPkIdx = adPreSelectNum;

        for (int i = 0; i < leftNum; i++) {
            if(pkIdx == pks.size() && adPkIdx == adPks.size()){
                logger.error("No sufficient prize or ad candidate");
                break;
            }
            if(pkIdx == pks.size()){
                PackageRecallDo ad = adPks.get(adPkIdx);
                prizeGroup.add(new PrizeDo(true,
                        ad.getPackageInfoDo().getPackageIdDo().getAdPrizeId(),ad.getPackageInfoDo().getPackageIdDo().getPrizeTagId(), ad.getMatchScore()));
                adPkIdx++;
                continue;
            }

            if(adPkIdx == adPks.size()){
                PackageRecallDo prize = pks.get(pkIdx);
                prizeGroup.add(new PrizeDo(false,
                        prize.getPackageInfoDo().getPackageIdDo().getPrizeId(),prize.getPackageInfoDo().getPackageIdDo().getPrizeTagId(),prize.getMatchScore()));
                pkIdx++;
                continue;
            }

            if(pks.get(pkIdx).getMatchScore() > adPks.get(adPkIdx).getMatchScore()){
                PackageRecallDo prize = pks.get(pkIdx);
                prizeGroup.add(new PrizeDo(false,
                        prize.getPackageInfoDo().getPackageIdDo().getPrizeId(),prize.getPackageInfoDo().getPackageIdDo().getPrizeTagId(), prize.getMatchScore()));
                pkIdx++;
            }else{
                PackageRecallDo ad = adPks.get(adPkIdx);
                prizeGroup.add(new PrizeDo(true,
                        ad.getPackageInfoDo().getPackageIdDo().getAdPrizeId(), ad.getPackageInfoDo().getPackageIdDo().getPrizeTagId(), ad.getMatchScore()));
                adPkIdx++;
            }
        }

        pkrd.getPackageInfoDo().getPackageIdDo().setPrizeGroup(prizeGroup);
        return pkrd;
    }

    /**
     * 辅助方法，根据MatchScore加权随机挑选组件
     * @param pkrds 带权重的组件列表
     * @return 选中的组件
     */
    static PackageRecallDo weightedRandomPickByMatchscore(List<PackageRecallDo> pkrds){
        if(pkrds.size() == 0){
            logger.error("Weighted random pick from empty list");
            return null;
        }
        double scoreSum = pkrds.stream().mapToDouble(PackageRecallDo::getMatchScore).sum();
        double randomNum = Math.random()*scoreSum;
        int findIdx = 0;
        for (int i = 0; i < pkrds.size(); i++) {
            findIdx = i;
            randomNum -= pkrds.get(i).getMatchScore();
            if(randomNum <= 0.0){
                break;
            }
        }
        return pkrds.get(findIdx);
    }

    protected static PackageRecallDo weightedRandomPickByMatchscore(List<PackageRecallDo> pks, List<PackageRecallDo> choosen){
        List<PackageRecallDo> candidates = new ArrayList<>(pks.size() - choosen.size() );
        for(PackageRecallDo pk : pks){
            if(!choosen.contains(pk)){
                candidates.add(pk);
            }
        }
        return weightedRandomPickByMatchscore(candidates);
    }

    /**
     * 筛选出recallNum个候选奖品，默认数量为PRIZE_NUM_AT_MOST
     * @param pks 可投奖品列表
     * @return EE筛选的奖品列表
     */
    private static List<PackageRecallDo> recallPrize(List<PackageInfoDo> pks){
        if(pks.size() == 0){
            return new ArrayList<>();
        }
        int topN = Math.min(DPAConstant.PRIZETOPN, pks.size());
        return ucbPackage(pks, topN, DPAConstant.RECALLPRIZE, DPAConstant.GLOBALRECALLPRIZE, 2);
    }

    /**
     * 筛选出recallAdPrizeNum个候选奖品，默认数量为AD_PRIZE_NUM_AT_MOST
     * @param pks 可投券列表
     * @return EE筛选的券列表
     */
    protected static List<PackageRecallDo> recallAdPrize(List<PackageInfoDo> pks){
        if(pks.size() == 0){
            return new ArrayList<>();
        }
        int topN = Math.min(DPAConstant.AD_PRIZETOPN, pks.size());
        return ucbPackage(pks, topN, DPAConstant.RECALLPRIZE, DPAConstant.GLOBALRECALLPRIZE, 2);
    }

    private static List<PackageRecallDo> mergePrize(List<PackageRecallDo> recallList,
                                                    List<PackageRecallDo> recallAdList,
                                                    int slotNum){
        List<PackageRecallDo> result = new ArrayList<>(slotNum);
        for (int i = 0; i < PRIZE_NUM_AT_LEAST; i++) {
            result.add(recallList.remove(0));
        }
        for (int i = 0; i < AD_PRIZE_NUM_AT_LEAST; i++) {
            result.add(recallAdList.remove(0));
        }
        int leftNum = slotNum - AD_PRIZE_NUM_AT_LEAST - PRIZE_NUM_AT_LEAST;
        for (int i = 0; i < leftNum; i++) {
            if( recallList.size() == 0 && recallAdList.size() == 0){
                logger.error("insufficient recalled prizes!");
                break;
            }

            if(recallList.size() == 0){
                result.addAll(recallAdList.subList(0,leftNum-i));
                break;
            }

            if(recallAdList.size() == 0){
                result.addAll(recallList.subList(0,leftNum-i));
                break;
            }

            if(recallList.get(0).getMatchScore() > recallAdList.get(0).getMatchScore()){
                result.add(recallList.remove(0));
            }else{
                result.add(recallAdList.remove(0));
            }
        }

        result.sort((o1, o2) -> {
            if(o1.getMatchScore().equals(o2.getMatchScore())){
                return 0;
            }else if(o1.getMatchScore() > o2.getMatchScore()){
                return 1;
            }else{
                return -1;
            }
        });
        return result;
    }

    private static String groupSignature(PackageInfoDo packageInfo){
        List<PrizeDo> group = packageInfo.getPackageIdDo().getPrizeGroup();
        return groupSignature(group);
    }

    private static String groupSignature(PackageRecallDo pkrd){
        return groupSignature(pkrd.getPackageInfoDo());
    }

    protected static String groupSignature(List<PrizeDo> group){
        return group.stream().sorted().map(PrizeDo::toString).collect(Collectors.joining("#"));
    }

    protected static List<Integer> randKWithReplacement(List<Double> rates, int k){
        if (null == rates || rates.isEmpty()) {
            logger.error("the rates list is null or empty");
            throw new RuntimeException("the rates list is null or empty");
        }
        return null;
    }

    protected static List<Integer> argSort(List<Double> rates){
        List<GenericPair<Integer, Double>> idxWithRates = new ArrayList<>(rates.size());
        for (int i = 0; i < rates.size(); i++) {
            idxWithRates.add(new GenericPair<>(i, rates.get(i)));
        }
        idxWithRates = idxWithRates.stream().sorted().collect(Collectors.toList());
        return idxWithRates.stream().mapToInt(item -> item.first).boxed().collect(Collectors.toList());
    }

    protected static List<Integer> randKWithoutReplacement(List<Double> rates, int k) {
        if (null == rates || rates.isEmpty()) {
            logger.error("the rates list is null or empty");
            throw new RuntimeException("the rates list is null or empty");
        }
        if (k >= rates.size()) {
            logger.error("k is bigger than rates' size");
            throw new RuntimeException("k is bigger than rates' size");
        }

        List<HeapNode> nodes = new ArrayList<>(rates.size());

        for (int index = 0; index < rates.size(); index++) {
            nodes.add(new HeapNode(rates.get(index), index));
        }

        List<Integer> result = new ArrayList<>(k);
        List<HeapNode> heap = buildHeap(nodes);

        for (int index = 0; index < k; index++) {
            result.add(heapPop(heap));
        }
        return result;
    }

    private static List<HeapNode> buildHeap(List<HeapNode> nodes) {
        List<HeapNode> heap = new ArrayList<>(nodes.size() + 1);

        heap.add(null);
        heap.addAll(nodes);

        for (int index = heap.size() - 1; index > 1; index--) {
            double curTW = heap.get(index >> 1).totalWeight;
            heap.get(index >> 1).totalWeight = curTW + heap.get(index).totalWeight;
        }
        return heap;

    }

    private static int heapPop(List<HeapNode> heap) {
        double gas = heap.get(1).totalWeight * Math.random();
        int i = 1;
        while (gas > heap.get(i).weight) {
            gas = gas - heap.get(i).weight;
            i <<= 1;
            if (gas > heap.get(i).totalWeight) {
                gas = gas - heap.get(i).totalWeight;
                i++;
            }
        }

        double weight = heap.get(i).weight;
        int value = heap.get(i).value;
        heap.get(i).weight = 0;
        while (i > 0) {
            heap.get(i).totalWeight = heap.get(i).totalWeight - weight;
            i >>= 1;
        }

        return value;
    }

    private static class HeapNode {
        double weight;
        int value;
        double totalWeight;

        public HeapNode(double weight, int value) {
            this.weight = weight;
            this.value = value;
            this.totalWeight = weight;
        }
    }
}
