/*
 * Decompiled with CFR 0.152.
 */
package cn.com.duiba.nezha.alg.api.facade.recall.channel;

import cn.com.duiba.nezha.alg.api.dto.recall.RecallChannelResult;
import cn.com.duiba.nezha.alg.api.dto.recall.VectorAdvertDTO;
import cn.com.duiba.nezha.alg.api.dto.recall.VectorRequest;
import cn.com.duiba.nezha.alg.api.enums.RecallEnums;
import cn.com.duiba.nezha.alg.api.facade.recall.AbstractRecallChanelFacade;
import cn.com.duiba.nezha.alg.api.facade.recall.channel.VectorRecallChannelFacade;
import cn.com.duiba.nezha.alg.api.model.LocalTFModel;
import cn.com.duiba.nezha.alg.common.util.AssertUtil;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class VectorRecallChannelFacadeImpl
extends AbstractRecallChanelFacade<VectorAdvertDTO, VectorRequest>
implements VectorRecallChannelFacade {
    private static final Logger log = LoggerFactory.getLogger(VectorRecallChannelFacadeImpl.class);
    static Logger logger = LoggerFactory.getLogger(VectorRecallChannelFacadeImpl.class);

    public RecallEnums.RecallChannelType getType() {
        return RecallEnums.RecallChannelType.Vector;
    }

    @Override
    public Map<String, RecallChannelResult> execute(Map<String, VectorAdvertDTO> vectorAdvertDTOMap, VectorRequest vectorRequest) {
        HashMap<String, RecallChannelResult> res = new HashMap<String, RecallChannelResult>();
        ArrayList<RecallChannelResult> resAll = new ArrayList<RecallChannelResult>();
        if (AssertUtil.isAnyEmpty((Object[])new Object[]{vectorAdvertDTOMap, vectorRequest})) {
            return res;
        }
        if (AssertUtil.isAnyEmpty((Object[])new Object[]{vectorRequest.getModel()})) {
            return res;
        }
        if (AssertUtil.isAnyEmpty((Object[])new Object[]{vectorRequest.getModel().getBundle()})) {
            return res;
        }
        Map predictRes = vectorRequest.getModel().predict(vectorRequest.getFeatureMap());
        HashMap scoreMap = new HashMap();
        List adPkUnionList = new ArrayList();
        List scoreList = new ArrayList();
        for (Map.Entry entry : predictRes.entrySet()) {
            if (((List)predictRes.get(entry.getKey())).get(0).getClass().equals(String.class)) {
                adPkUnionList = (List)predictRes.get(entry.getKey());
                continue;
            }
            scoreList = (List)predictRes.get(entry.getKey());
        }
        for (int i = 0; i < adPkUnionList.size(); ++i) {
            scoreMap.put(adPkUnionList.get(i), scoreList.get(i));
        }
        for (Map.Entry<String, VectorAdvertDTO> entry1 : vectorAdvertDTOMap.entrySet()) {
            RecallChannelResult recallChannelResult = new RecallChannelResult();
            Long ad = Long.valueOf(entry1.getKey().split("_")[0]);
            Long pk = Long.valueOf(entry1.getKey().split("_")[1]);
            recallChannelResult.setAdvertId(ad);
            recallChannelResult.setPackageId(pk);
            recallChannelResult.setIndex("vecSim");
            recallChannelResult.setScore(Double.valueOf(scoreMap.getOrDefault(entry1.getKey(), Float.valueOf(0.0f)).floatValue()));
            resAll.add(recallChannelResult);
        }
        List<RecallChannelResult> innerTmpRes = resAll.stream().filter(curIndexRecallRes -> curIndexRecallRes.getScore() >= vectorRequest.getThreshold().getOrDefault("vecSim", 0.001)).sorted(Comparator.comparing(RecallChannelResult::getScore).reversed()).limit(vectorRequest.getNum().getOrDefault("vecSim", 10).intValue()).collect(Collectors.toList());
        innerTmpRes.forEach(result -> res.put(result.buildKey(), (RecallChannelResult)result));
        return res;
    }

    public static void main(String[] args) throws Exception {
        String path = "/Users/butcher/Desktop/dssmV01";
        LocalTFModel model = new LocalTFModel();
        model.loadModel(path);
        HashMap<String, VectorAdvertDTO> vectorAdvertDTOMap = new HashMap<String, VectorAdvertDTO>();
        VectorRequest vectorRequest = new VectorRequest();
        HashMap<String, Double> threshold = new HashMap<String, Double>();
        HashMap<String, Integer> num = new HashMap<String, Integer>();
        threshold.put("vecSim", 0.001);
        num.put("vecSim", 10);
        HashMap<String, String> featureMap = new HashMap<String, String>();
        String colStr = "f451001\tf451002\tf506001\tf507001\tf507003\tf509001\tf509002\tf505001\tf509004\tf501001\tf8330021\tf8300011\tf8300021\tf8300041\tf8310011\tf8310021\tf8310041\tf205002\tf201001\tf108001\tf503001\tf502001\tf502003\tf502002\tafee\tad_pk";
        String valStr = "0119010101\t01020301\t3\t1\tAndroid6.0.1\tvivoy55a\tvivo\t1100-1699\t2016\tAndroid\t78465,78752,73920,77698,77250,78407,77419,78446,73714,77975,77463,78326,78393,78553,78072,77498,78173,78364,76831,78174\t78465,78752,77698,77250,78407,78446,73714,77975,77463,78326,78393,78553,78072,77498,78173,76831,78174\t41283,41540,40232,40363,39724,37839,37297,41236,40855,40951,39767,41400,41465,41435\t16,21,15\t78393\t37297\t16\t268\t82239\t401503\t4403\t0\t16\t2\t7.13\t78553_191877\t";
        String[] colList = colStr.split("\t");
        String[] valList = valStr.split("\t");
        for (int i = 0; i < colList.length - 1; ++i) {
            featureMap.put(colList[i], valList[i]);
        }
        vectorRequest.setFeatureMap(featureMap);
        vectorRequest.setModel(model);
        vectorRequest.setNum(num);
        vectorRequest.setThreshold(threshold);
        vectorRequest.setAppId(Long.valueOf(6666L));
        vectorRequest.setSlotId(Long.valueOf(9999L));
        String pkStr = "73672_199,71121_173168,65521_176061,68810_178437,49829_178811,73041_177565,67320_177930,67320_178039,68837_178357,62556_177938,67320_179069,72917_177183,73524_0,72764_177427,68810_177405,68810_178808,67320_178976,72260_178946,73639_0,73526_0,73363_0,72917_177492,68810_173191,68810_175292,72093_177812";
        List<String> strList = Arrays.asList(pkStr.split(","));
        for (String s : strList) {
            VectorAdvertDTO vectorAdvertDTO = new VectorAdvertDTO();
            vectorAdvertDTO.setAdvertId(Long.valueOf(s.split("_")[0]));
            vectorAdvertDTO.setPackageId(Long.valueOf(s.split("_")[1]));
            vectorAdvertDTO.setAccountId(Long.valueOf(111L));
            vectorAdvertDTO.setChargeType(Integer.valueOf(2));
            vectorAdvertDTO.setStrongTarget(Boolean.valueOf(false));
            vectorAdvertDTOMap.put(vectorAdvertDTO.buildKey(), vectorAdvertDTO);
        }
        VectorRecallChannelFacadeImpl vectorRecallChannelFacade = new VectorRecallChannelFacadeImpl();
        Map<String, RecallChannelResult> recallChannelResultMap = vectorRecallChannelFacade.execute((Map<String, VectorAdvertDTO>)vectorAdvertDTOMap, vectorRequest);
        System.out.println("recallChannelResultMap: " + recallChannelResultMap.toString());
    }
}

