package cn.com.duiba.nezha.alg.alg.recall;

import cn.com.duiba.nezha.alg.alg.recall.data.AdInputDto;
import cn.com.duiba.nezha.alg.alg.recall.data.CandidateDto;
import cn.com.duiba.nezha.alg.alg.recall.data.RecallResultDto;
import cn.com.duiba.nezha.alg.alg.recall.enums.RecallerEnum;
import cn.com.duiba.nezha.alg.alg.recall.model.DSSMLocalTFModel;
import cn.com.duiba.nezha.alg.alg.recall.params.CombineParam;
import cn.com.duiba.nezha.alg.alg.recall.params.FilterParam;
import cn.com.duiba.nezha.alg.alg.recall.params.RecallerParam;
import cn.com.duiba.nezha.alg.alg.recall.recaller.Recaller;
import cn.com.duiba.nezha.alg.alg.recall.recaller.VectorRecaller;
import com.google.common.primitives.Doubles;
import lombok.extern.slf4j.Slf4j;

import java.io.UnsupportedEncodingException;
import java.util.*;

//这里是召回的入口
@Slf4j
public class RecallMain {

    public static List<RecallResultDto> recall(List<AdInputDto> allAdInput, FilterParam filterParam,
                                               RecallerParam recallerParam, CombineParam combineParam,
                                               DSSMLocalTFModel dssmLocalTFModel,
                                               Map<String,String> featureMap) throws UnsupportedEncodingException {
        //先给各个召回器分配他的召回池
        Map<RecallerEnum,List<AdInputDto>> inputListMap= RecallerPool.allocation(allAdInput,filterParam);
        //然后用各个召回器做召回，结果存个Map
        Map<RecallerEnum,List<CandidateDto>> recallMap=new HashMap<>();

        List<CandidateDto> vectorCandidateDtos=new ArrayList<>();
        try {
            vectorCandidateDtos= VectorRecaller.recall(inputListMap.get(RecallerEnum.DSSM), recallerParam,
                    dssmLocalTFModel,featureMap);
        } catch (Exception e) {
            log.error("vector recall",e);
        }
        recallMap.put(RecallerEnum.DSSM,vectorCandidateDtos);

        //然后融合各个召回通道生成最终结果
        List<RecallResultDto> recallResultDtos=RecallCombination.combineRecallResult(recallMap,combineParam);
        return recallResultDtos;
    }

    public static void main(String[] args) throws Exception {

        String path="C:\\Users\\wangyaole\\Desktop\\nezha\\model\\dssm01";
        DSSMLocalTFModel model=new DSSMLocalTFModel();
        model.loadModel(path);

        Map<String,String> featureMap=new HashMap<>();
        String colStr="f501001,f504001,f508005,f108001,f201001,f507001,f503001,f406001,f502001,f502002,afee,ad_pk";
        String valStr="Android,HLK-AL00,荣耀,325913,70125,3,3609,1,16,2,7.2,72093_177812";
        String[] colList=colStr.split(",");
        String[] valList=valStr.split(",");
        for (int i = 0; i < colList.length-1; i++) {
            featureMap.put(colList[i],valList[i]);
        }

        String pkStr="73672_178993,71121_173168,65521_176061,68810_178437,49829_178811,73041_177565,"+
                "67320_177930,67320_178039,68837_178357,62556_177938,67320_179069,72917_177183,73524_0,72764_177427,68810_177405," +
                "68810_178808,67320_178976,72260_178946,73639_0,73526_0,73363_0,72917_177492,68810_173191,68810_175292,72093_177812";
        List<String> strList= Arrays.asList(pkStr.split(","));
        List<AdInputDto> adInputDtos=new ArrayList<>();
        for (String s : strList) {
            AdInputDto ad=new AdInputDto();
            ad.setAdPkOnce(s);
            ad.setReleaseTarget(Math.random()>0.5?1:3);
            ad.setNewTradeTagId(Math.random()>0.5?"15":"22");
            adInputDtos.add(ad);
        }
        FilterParam filterParam=new FilterParam();
        CombineParam combineParam=new CombineParam();
        RecallerParam recallerParam=new RecallerParam();

        List<RecallResultDto> recallResultDtos = recall(adInputDtos, filterParam, recallerParam, combineParam, model, featureMap);

        System.out.println(recallResultDtos);
    }
}
