package cn.com.duiba.nezha.alg.alg.recall.recaller;

import cn.com.duiba.nezha.alg.alg.recall.data.AdInputDto;
import cn.com.duiba.nezha.alg.alg.recall.data.CandidateDto;
import cn.com.duiba.nezha.alg.alg.recall.model.DSSMLocalTFModel;
import cn.com.duiba.nezha.alg.alg.recall.params.RecallerParam;

import java.io.UnsupportedEncodingException;
import java.time.Duration;
import java.time.Instant;
import java.util.*;

//向量化召回器
//有最大长度限制
public class VectorRecaller {


    public static List<CandidateDto> recall(List<AdInputDto> adInputDtos, RecallerParam recallerParam,
                                            DSSMLocalTFModel model, Map<String, String> featureMap
                                            ) throws UnsupportedEncodingException {
        Map<String,List> predResult=model.predict(featureMap);
        List<String> adpkUnionList=new ArrayList();
        List<Float> scoreList=new ArrayList();
        for (String resultKey : predResult.keySet()) {
            if(predResult.get(resultKey).get(0).getClass()==String.class){
                adpkUnionList=predResult.get(resultKey);
            }
            else {
                scoreList=predResult.get(resultKey);
            }
        }
        Map<String, Float> scoreMap=new HashMap<>();
        for (int i = 0; i < adpkUnionList.size(); i++) {
            scoreMap.put(adpkUnionList.get(i),scoreList.get(i));
        }
        List<CandidateDto> candidateDtos=new ArrayList<>();
        for (AdInputDto adpk : adInputDtos) {
            CandidateDto candidateDto=new CandidateDto();
            candidateDto.setAdPkOnce(adpk.getKey());
            candidateDto.setScore(Double.valueOf(scoreMap.getOrDefault(adpk.getKey(),0.0f)));
            if (candidateDto.getScore()>=recallerParam.getMinDssmScore()){
                candidateDtos.add(candidateDto);
            }
        }
        //接下来需要对dtos给排个序取个top几
        candidateDtos.sort(Comparator.comparing(CandidateDto::getScore).reversed());

        List<CandidateDto> candidateDtosTopk = candidateDtos.subList(0, Math.min(recallerParam.getMaxDssmMatchNum(),
                candidateDtos.size()));
        return candidateDtosTopk;
    }

    public static void main(String[] args) throws Exception {
        String path="C:\\Users\\wangyaole\\Desktop\\nezha\\model\\dssm01";
        DSSMLocalTFModel model=new DSSMLocalTFModel();
        RecallerParam recallerParam =new RecallerParam();
        model.loadModel(path);

        Map<String,String> featureMap=new HashMap<>();
        String colStr="f501001,f504001,f508005,f108001,f201001,f507001,f503001,f406001,f502001,f502002,afee,ad_pk";
        String valStr="Android,HLK-AL00,荣耀,325913,70125,3,3609,1,16,2,7.2,72093_177812";

        String[] colList=colStr.split(",");
        String[] valList=valStr.split(",");
        for (int i = 0; i < colList.length-2; i++) {
            featureMap.put(colList[i],valList[i]);
        }
        String pkStr="73672_199,71121_173168,65521_176061,68810_178437,49829_178811,73041_177565,"+
"67320_177930,67320_178039,68837_178357,62556_177938,67320_179069,72917_177183,73524_0,72764_177427,68810_177405," +
                "68810_178808,67320_178976,72260_178946,73639_0,73526_0,73363_0,72917_177492,68810_173191,68810_175292,72093_177812";
        List<String> strList= Arrays.asList(pkStr.split(","));
        List<AdInputDto> adInputDtos=new ArrayList<>();
        for (String s : strList) {
            AdInputDto ad=new AdInputDto();
            ad.setAdPkOnce(s);
            adInputDtos.add(ad);
        }
        for (int i = 0; i < 100000; i++) {
            Instant start=Instant.now();
            List<CandidateDto> candidateDtos= VectorRecaller.recall(adInputDtos, recallerParam,model,featureMap);
            Instant finish=Instant.now();
            long timeElapsed= Duration.between(start,finish).toMillis();
            System.out.println(timeElapsed);

        }
    }
}
