package cn.com.duiba.nezha.alg.api.facade.recall.channel;

import cn.com.duiba.nezha.alg.api.dto.recall.RecallChannelResult;
import cn.com.duiba.nezha.alg.api.dto.recall.VectorAdvertDTO;
import cn.com.duiba.nezha.alg.api.dto.recall.VectorRequest;
import cn.com.duiba.nezha.alg.api.enums.RecallEnums.RecallChannelType;
import cn.com.duiba.nezha.alg.api.facade.recall.AbstractRecallChanelFacade;
import cn.com.duiba.nezha.alg.api.model.LocalTFModel;
import cn.com.duiba.nezha.alg.common.util.AssertUtil;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.*;
import java.util.stream.Collectors;

/**
 * 向量召回
 * @author lijicong
 * @since 2021-08-17
 */
@Slf4j
public class VectorRecallChannelFacadeImpl extends AbstractRecallChanelFacade<VectorAdvertDTO, VectorRequest>
        implements VectorRecallChannelFacade {

    static Logger logger = LoggerFactory.getLogger(VectorRecallChannelFacadeImpl.class);

    @Override
    public RecallChannelType getType() { return RecallChannelType.Vector; }

    /**
     * 生成向量召回的结果
     * @param vectorAdvertDTOMap 哪吒配置列表
     * @param vectorRequest 大禹中的配置项(分数阈值、数目阈值等)及初始化好的本地模型及对应的特征列表
     * @return 向量召回的结果
     */
    @Override
    public Map<String, RecallChannelResult> execute(Map<String, VectorAdvertDTO> vectorAdvertDTOMap, VectorRequest vectorRequest) {

        Map<String, RecallChannelResult> res = new HashMap<>();     //此处哈希表的键为'广告ID_配置ID'，eg：379245_138832
        List<RecallChannelResult> resAll = new ArrayList<>();
        List<RecallChannelResult> innerTmpRes;

        if (AssertUtil.isAnyEmpty(vectorAdvertDTOMap, vectorRequest)) {
            return res;
        }
        if (AssertUtil.isAnyEmpty(vectorRequest.getModel())) {
            return res;
        }
        if (AssertUtil.isAnyEmpty(vectorRequest.getModel().getBundle())) {
            return res;
        }

        Map<String, List<?>> predictRes = vectorRequest.getModel().predict(vectorRequest.getFeatureMap());
        Map<String, Float> scoreMap = new HashMap<>();
        List<String> adPkUnionList = new ArrayList<>();
        List<Float> scoreList = new ArrayList<>();

        // 获取到配置列表及得分列表
        for (Map.Entry<String, List<?>> entry: predictRes.entrySet()) {
            if (predictRes.get(entry.getKey()).get(0).getClass().equals(String.class)) {
                adPkUnionList = (List<String>) predictRes.get(entry.getKey());
            }else {
                scoreList = (List<Float>) predictRes.get(entry.getKey());
            }
        }

        // 转化成scoreMap
        for (int i=0; i < adPkUnionList.size(); i++) {
            scoreMap.put(adPkUnionList.get(i), scoreList.get(i));
        }

        // 对哪吒回传的配置获取到对应的相似度得分
        for (Map.Entry<String, VectorAdvertDTO> entry1: vectorAdvertDTOMap.entrySet()) {
            RecallChannelResult recallChannelResult = new RecallChannelResult();
            Long ad = Long.valueOf(entry1.getKey().split("_")[0]);
            Long pk = Long.valueOf(entry1.getKey().split("_")[1]);
            recallChannelResult.setAdvertId(ad);
            recallChannelResult.setPackageId(pk);
            recallChannelResult.setIndex("vecSim");
            recallChannelResult.setScore(Double.valueOf(scoreMap.getOrDefault(entry1.getKey(), 0.0f)));
            resAll.add(recallChannelResult);
//            // 大于指标阈值添加到候选集中
//            if (recallChannelResult.getScore() >= vectorRequest.getThreshold().get("vecSim")) {
//                resAll.add(recallChannelResult);
//            }
        }

        // 截断（长度、指标阈值）
        innerTmpRes = resAll.stream()
                .filter(curIndexRecallRes -> curIndexRecallRes.getScore() >= vectorRequest.getThreshold().getOrDefault("vecSim", 0.001D))    //按照得分阈值进行过滤
                .sorted(Comparator.comparing(RecallChannelResult::getScore).reversed())                                     //按照通道得分进行降序排
                .limit(vectorRequest.getNum().getOrDefault("vecSim", 10))                                                                //按照通道最大召回长度进行截断
                .collect(Collectors.toList());
        innerTmpRes.forEach(result -> res.put(result.buildKey(), result));
        return res;
    }

    public static void main(String[] args) throws Exception {
        String path = "/Users/butcher/Desktop/dssmV01";
        LocalTFModel model = new LocalTFModel();
        model.loadModel(path);
        Map<String, VectorAdvertDTO> vectorAdvertDTOMap = new HashMap<>();
        VectorRequest vectorRequest = new VectorRequest();
        Map<String, Double> threshold = new HashMap<>();
        Map<String, Integer> num = new HashMap<>();
        threshold.put("vecSim", 0.001);
        num.put("vecSim", 10);

        Map<String, String> featureMap = new HashMap<>();
        String colStr = "f451001\tf451002\tf506001\tf507001\tf507003\tf509001\tf509002\tf505001\tf509004\tf501001\t" +
                "f8330021\tf8300011\tf8300021\tf8300041\tf8310011\tf8310021\tf8310041\tf205002\tf201001\tf108001\t" +
                "f503001\tf502001\tf502003\tf502002\tafee\tad_pk";
        String valStr = "0119010101\t01020301\t3\t1\tAndroid6.0.1\tvivoy55a\tvivo\t1100-1699\t2016\tAndroid\t" +
                "78465,78752,73920,77698,77250,78407,77419,78446,73714,77975,77463,78326,78393,78553,78072,77498,78173,78364,76831,78174\t" +
                "78465,78752,77698,77250,78407,78446,73714,77975,77463,78326,78393,78553,78072,77498,78173,76831,78174\t" +
                "41283,41540,40232,40363,39724,37839,37297,41236,40855,40951,39767,41400,41465,41435\t16,21,15\t78393\t" +
                "37297\t16\t268\t82239\t401503\t4403\t0\t16\t2\t7.13\t78553_191877\t";
        String[] colList = colStr.split("\t");
        String[] valList = valStr.split("\t");
        for (int i = 0; i < colList.length - 1; i++) {
            featureMap.put(colList[i], valList[i]);
        }

        vectorRequest.setFeatureMap(featureMap);
        vectorRequest.setModel(model);
        vectorRequest.setNum(num);
        vectorRequest.setThreshold(threshold);
        vectorRequest.setAppId(6666L);
        vectorRequest.setSlotId(9999L);

        String pkStr="73672_199,71121_173168,65521_176061,68810_178437,49829_178811,73041_177565,"+
                "67320_177930,67320_178039,68837_178357,62556_177938,67320_179069,72917_177183,73524_0,72764_177427,68810_177405," +
                "68810_178808,67320_178976,72260_178946,73639_0,73526_0,73363_0,72917_177492,68810_173191,68810_175292,72093_177812";
        List<String> strList= Arrays.asList(pkStr.split(","));

        for (String s: strList) {
            VectorAdvertDTO vectorAdvertDTO = new VectorAdvertDTO();
            vectorAdvertDTO.setAdvertId(Long.valueOf(s.split("_")[0]));
            vectorAdvertDTO.setPackageId(Long.valueOf(s.split("_")[1]));
            vectorAdvertDTO.setAccountId(111L);
            vectorAdvertDTO.setChargeType(2);
            vectorAdvertDTO.setStrongTarget(false);
            vectorAdvertDTOMap.put(vectorAdvertDTO.buildKey(), vectorAdvertDTO);
        }

        VectorRecallChannelFacadeImpl vectorRecallChannelFacade = new VectorRecallChannelFacadeImpl();
        Map<String, RecallChannelResult> recallChannelResultMap = vectorRecallChannelFacade.execute(vectorAdvertDTOMap, vectorRequest);
        System.out.println("recallChannelResultMap: "  +  recallChannelResultMap.toString());
    }

}
