package cn.com.duiba.nezha.alg.alg.recall;

import cn.com.duiba.nezha.alg.alg.recall.data.AdInputDto;
import cn.com.duiba.nezha.alg.alg.recall.data.CandidateDto;
import cn.com.duiba.nezha.alg.alg.recall.data.RecallResultDto;
import cn.com.duiba.nezha.alg.alg.recall.enums.RecallerEnum;
import cn.com.duiba.nezha.alg.alg.recall.params.CombineParam;
import com.google.common.collect.ImmutableMap;
import lombok.extern.slf4j.Slf4j;

import java.util.*;

//这个类用来把各个通道的召回结果拼在毅起
//现在支持的方法是每个通道指定召回个数-20210125
@Slf4j
public class RecallCombination {

    //对某个通道取前k名，有得分的通道按得分排序，没得分的通道按随机数排序（随机数会覆盖score属性，是否合理？）
    public static List<CandidateDto> topK(List<CandidateDto> candidateDtos,RecallerEnum recallerEnum,CombineParam combineParam) {
        if (recallerEnum.isScored()) {
            candidateDtos.sort(Comparator.comparing(CandidateDto::getScore).reversed());
            List<CandidateDto> candidateDtosTopk = candidateDtos.subList(0,
                    Math.min(combineParam.getRecallNumMap().getOrDefault(recallerEnum.getId(),0),
                    candidateDtos.size()));
            return candidateDtosTopk;
        }
        else {
            for (CandidateDto candidateDto : candidateDtos) {
                candidateDto.setScore(Math.random());
            }
            candidateDtos.sort(Comparator.comparing(CandidateDto::getScore).reversed());
            List<CandidateDto> candidateDtosTopk = candidateDtos.subList(0,
                    Math.min(combineParam.getRecallNumMap().getOrDefault(recallerEnum.getId(),0),
                            candidateDtos.size()));
            return candidateDtosTopk;
        }
    }

    //选出各个通道的topk
    public static Map<RecallerEnum, List<CandidateDto>> selectSome(Map<RecallerEnum, List<CandidateDto>> rawRecallerListMap, CombineParam combineParam) {
        Map<RecallerEnum, List<CandidateDto>> selectedListMap=new HashMap<RecallerEnum, List<CandidateDto>>();
        for (Map.Entry<RecallerEnum, List<CandidateDto>> recallerEnumListEntry : rawRecallerListMap.entrySet()) {
            List<CandidateDto> midCandidateDtoList=new ArrayList<CandidateDto>();
            try {
                midCandidateDtoList = topK(recallerEnumListEntry.getValue(),
                        recallerEnumListEntry.getKey(), combineParam);
            } catch (Exception e) {
                log.error("combine ",e);
            }
            selectedListMap.put(recallerEnumListEntry.getKey(),midCandidateDtoList);
        }
        return selectedListMap;
    }

    //把各个通道的topk拼起来，融合各个得分进scoreMap，实现方式感觉不甚巧妙，是否有更好办法？
    public static List<RecallResultDto> combineRecallResult(Map<RecallerEnum, List<CandidateDto>> rawRecallerListMap,
                                                            CombineParam combineParam) {
        Map<RecallerEnum, List<CandidateDto>> recallerListMap=selectSome(rawRecallerListMap,combineParam);
        Map<String,Map<Integer,Double>> midResultMap=new HashMap<String,Map<Integer,Double>>();
        for (Map.Entry<RecallerEnum, List<CandidateDto>> recallerEnumListEntry : recallerListMap.entrySet()) {
            for (CandidateDto candidateDto : recallerEnumListEntry.getValue()) {
                if(midResultMap.getOrDefault(candidateDto.getKey(),null)==null) {
                    Map<Integer,Double> tempMap=new HashMap<Integer, Double>();
                    tempMap.put(recallerEnumListEntry.getKey().getId(),candidateDto.getScore());
                    midResultMap.put(candidateDto.getKey(),tempMap);
                }
                else {
                    midResultMap.get(candidateDto.getKey()).put(recallerEnumListEntry.getKey().getId(),
                            candidateDto.getScore());
                }
            }
        }
        List<RecallResultDto> recallResultDtos=new ArrayList<RecallResultDto>();
        for (String adpk : midResultMap.keySet()) {
            RecallResultDto recallResultDto=new RecallResultDto();
            recallResultDto.setAdPkOnce(adpk);
            recallResultDto.setScoreMap(midResultMap.get(adpk));
            recallResultDtos.add(recallResultDto);
        }
        return recallResultDtos;
    }

    public static void main(String[] args) {
        String pkStr="73672_178993,71121_173168,65521_176061,68810_178437,49829_178811,73041_177565,"+
                "67320_177930,67320_178039,68837_178357,62556_177938,67320_179069,72917_177183,73524_0,72764_177427,68810_177405," +
                "68810_178808,67320_178976,72260_178946,73639_0,73526_0,73363_0,72917_177492,68810_173191,68810_175292,72093_177812";
        List<String> strList= Arrays.asList(pkStr.split(","));
        List<CandidateDto> candidateDtos=new ArrayList<>();
        for (String s : strList) {
            CandidateDto candidateDto=new CandidateDto();
            candidateDto.setAdPkOnce(s);
            candidateDto.setScore(Math.random());
            candidateDtos.add(candidateDto);
        }

        CombineParam combineParam=new CombineParam();
        combineParam.setRecallNumMap(ImmutableMap.of(2,10,3,11));
        Map<RecallerEnum, List<CandidateDto>> recallerEnumListMap=new HashMap<>();
        recallerEnumListMap.put(RecallerEnum.HUMAN,candidateDtos);

        List<CandidateDto> candidateDtos1=new ArrayList<>();
        for (String s : strList) {
            CandidateDto candidateDto=new CandidateDto();
            candidateDto.setAdPkOnce(s);
            candidateDto.setScore(Math.random());
            candidateDtos1.add(candidateDto);
        }

        recallerEnumListMap.put(RecallerEnum.DSSM,candidateDtos1);

        List<RecallResultDto> a = combineRecallResult(recallerEnumListMap, combineParam);
        System.out.println(a);

    }

}

