package cn.com.duiba.nezha.alg.model;

import cn.com.duiba.nezha.alg.common.util.AssertUtil;
import cn.com.duiba.nezha.alg.common.util.DataUtil;
import cn.com.duiba.nezha.alg.common.util.MathUtil;
import cn.com.duiba.nezha.alg.feature.coder.FeatureCoder;
import cn.com.duiba.nezha.alg.feature.parse.FeatureParse;
import cn.com.duiba.nezha.alg.feature.type.FeatureBaseType;
import cn.com.duiba.nezha.alg.feature.vo.CodeResult;
import cn.com.duiba.nezha.alg.feature.vo.Feature;
import cn.com.duiba.nezha.alg.model.tf.TFServingClient;
import cn.com.duiba.nezha.alg.model.vo.ParamsDo;
import cn.com.duiba.wolf.perf.timeprofile.DBTimeProfile;
import com.alibaba.fastjson.JSON;
import org.slf4j.LoggerFactory;


import java.io.Serializable;
import java.util.*;

public class FM implements Serializable, IModel {
    private static final long serialVersionUID = -316102112618444130L;

    private static final org.slf4j.Logger logger = LoggerFactory.getLogger(FM.class);

    /**
     * 模型ID
     */
    private String modelId;

    /**
     * 更新时间
     */
    private String updateTime;

    /**
     * 模型参数
     */
    private ParamsDo paramsDo;


    private int FBT_MAX_SIZE = 64;

    private int PB_MAX_SIZE = 128;

    /**
     * 特征
     */
    private List<FeatureBaseType> featureBaseType = new ArrayList<FeatureBaseType>(FBT_MAX_SIZE);
    ;

    public void setFeatureBaseType(List<FeatureBaseType> featureBaseType) {
        this.featureBaseType = featureBaseType;
    }

    public List<FeatureBaseType> getFeatureBaseType() {
        return this.featureBaseType;
    }


    public void setModelId(String modelId) {
        this.modelId = modelId;
    }

    public String getModelId() {
        return this.modelId;
    }


    public void setUpdateTime(String updateTime) {
        this.updateTime = updateTime;
    }

    public String getUpdateTime() {
        return this.updateTime;
    }


    public void setParamsDo(ParamsDo paramsDo) {
        this.paramsDo = paramsDo;
    }

    public ParamsDo getParamsDo() {
        return this.paramsDo;
    }

    /**
     * @param featureMap
     * @return
     * @throws Exception
     */
    public Double predict(Map<String, String> featureMap) throws Exception {
        Double ret = null;
        CodeResult codeResult = FeatureCoder.code(getFeatureBaseType(), featureMap);

        if (codeResult != null) {
            ret = predict(codeResult.getFeature());
        }

        return ret;
    }

    /**
     * @param featureMap
     * @return
     * @throws Exception
     */
    public double[] getParam2(Map<String, String> featureMap) throws Exception {
        double[] ret = null;
        CodeResult codeResult = FeatureCoder.code(getFeatureBaseType(), featureMap);

        if (codeResult != null) {
            ret = getModelParams(codeResult.getFeatureSet());
        }


        if (!modelId.startsWith("fm_ctr_v007") && !modelId.startsWith("fm_cvr_v007")) {

            String subNetIdStr = featureMap.getOrDefault("sub_net_id", "0.0");
            Double subNetId = DataUtil.string2Double(subNetIdStr);
            if (subNetId == null) {
                subNetId = 0.0;
            }
            ret = DataUtil.insertElement(ret, subNetId, 0);

        }

        return ret;
    }


    /**
     * @param featureMap
     * @return
     * @throws Exception
     */
    public List<Float> getParam(Map<String, String> featureMap) throws Exception {
        List<Float> ret = null;

        double[] value = getParam2(featureMap);
        if (value != null) {
            ret = new ArrayList<>();
            for (int i = 0; i < value.length; i++) {
                ret.add((float) value[i]);
            }
        }
        return ret;

    }

    public <T> Map<T, Double> predicts(Map<T, Map<String, String>> featureMap) throws Exception {
        Map<T, Double> ret = new HashMap<>(PB_MAX_SIZE);
        if (AssertUtil.isNotEmpty(featureMap)) {
            for (Map.Entry<T, Map<String, String>> entry : featureMap.entrySet()) {
                ret.put(entry.getKey(), predict(entry.getValue()));
            }
        }

        return ret;
    }

    /**
     * @param featureMap
     * @return
     * @throws Exception
     */
    public <T> Map<T, List<Float>> getParams(Map<T, Map<String, String>> featureMap) throws Exception {
        Map<T, List<Float>> ret = new HashMap<>(PB_MAX_SIZE);

        if (AssertUtil.isNotEmpty(featureMap)) {
            for (Map.Entry<T, Map<String, String>> entry : featureMap.entrySet()) {
                ret.put(entry.getKey(), getParam(entry.getValue()));
            }
        }
        return ret;

    }


    /**
     * @param featureMap
     * @param tfServingClient
     * @param <T>
     * @return
     * @throws Exception
     */
    public <T> Map<T, Double> predictWithTF(Map<T, Map<String, String>> featureMap, TFServingClient tfServingClient) throws Exception {
        Map<T, Double> ret = null;
        if (tfServingClient == null) {
            ret = predicts(featureMap);
        } else {

            try {
                ret = tfServingClient.predict(getParams(featureMap));
            } catch (Exception e) {
                e.printStackTrace();
                logger.error("predictWithTF error " + e);
                ret = predicts(featureMap);
            }


        }



        return ret;
    }

    /**
     * @param feature
     * @return
     * @throws Exception
     */
    public Double predict(Feature feature) throws Exception {


        try {
            double ret = 0.0;
            double retw0 = 0.0;
            double retw = 0.0;
            double retv = 0.0;
            long factorNum = paramsDo.getFactorNum();

            Map tmpMap = new HashMap<>(1);

            if (feature != null) {

                retw0 += paramsDo.getWeight0();
                for (int i = 0; i < feature.indices.length; i++) {
                    Long fId = 0L + feature.indices[i];
                    Double value = feature.values[i];
                    retw += paramsDo.getWeight().getOrDefault(fId, 0.0) * value;
                }

                for (int j = 0; j < factorNum; j++) {
                    Map<Long, Double> parMap = paramsDo.getVector().getOrDefault(0L + j, tmpMap);
                    double retvp1 = 0.0;
                    double retvp2 = 0.0;
                    for (int i = 0; i < feature.indices.length; i++) {
                        Long fId = 0L + feature.indices[i];
                        Double value = feature.values[i];
                        Double vif = parMap.getOrDefault(fId, 0.0);
                        retvp1 += value * vif;
                        retvp2 += Math.pow(value, 2) * Math.pow(vif, 2);
                    }
                    retv += 0.5 * (Math.pow(retvp1, 2) - retvp2);
                }


                ret = retw0 + retw + retv;

            }

            double pValue = MathUtil.sigmoid(ret);
            double pCValue = correctValue(pValue, paramsDo.getCorrect());
            return DataUtil.formatdouble(pCValue, 5);
        } catch (Exception e) {
            System.out.println("FM.predict error");
            throw e;
        }

    }


    public double correctValue(double pre_, Double correct) {
        double ret = pre_;
        if (correct != null) {
            ret = correct / (1 / pre_ - 1 + correct);
        }
        return ret;
    }

    /**
     * @param codeSetList
     * @return
     * @throws Exception
     */
    public double[] getModelParams(List<Set<Long>> codeSetList) throws Exception {


        try {
            boolean status = true;

            double[] ret = null;
            Map tmpMap = new HashMap<>(1);

            if (codeSetList != null) {

                int factorNum = paramsDo.getFactorNum().intValue();
                int fSize = codeSetList.size();

                double retw0 = 0.0;
                double[] retw = new double[fSize];
                double[] retv = new double[fSize * factorNum];
                ret = new double[1 + fSize + fSize * factorNum];

                //1
                retw0 = paramsDo.getWeight0();

                for (int i = 0; i < fSize; i++) {

                    Set<Long> fSet = codeSetList.get(i);
                    if (fSet != null) {
                        double subw = 0.0;
                        int setSize = fSet.size();
                        for (Long f : fSet) {
                            subw += paramsDo.getWeight().getOrDefault(f, 0.0) / setSize;
                        }

                        retw[i] = subw;

                    } else {
                        status = false;
                    }

                }
                for (int j = 0; j < factorNum; j++) {
                    Map<Long, Double> parMap = paramsDo.getVector().getOrDefault(0L + j, tmpMap);

                    for (int i = 0; i < fSize; i++) {

                        Set<Long> fSet = codeSetList.get(i);
                        if (fSet != null) {
                            double subv = 0.0;
                            int setSize = fSet.size();
                            for (Long f : fSet) {
                                subv += parMap.getOrDefault(f, 0.0) / setSize;

                            }
                            retv[i * factorNum + j] = subv;

                        }

                    }

                }

//                System.out.println("JFM retw0="+retw0+",retw="+retw+",retv="+retv);
                ret[0] = retw0;
                System.arraycopy(retw, 0, ret, 1, fSize);
                System.arraycopy(retv, 0, ret, 1 + fSize, fSize * factorNum);

//                System.out.println("retw0="+ JSON.toJSONString(retw0));
//                System.out.println("retw="+JSON.toJSONString(retw));
//                System.out.println("retv="+JSON.toJSONString(retv));
//                System.out.println("ret="+JSON.toJSONString(ret));
            }
            return ret;
        } catch (Exception e) {
            System.out.println("FM.predict error");
            throw new Exception(e);
        }

    }

    public static void main(String[] args) {

//        int factorNum = 3;
//        int fSize = 10;
//        for (int j = 0; j < factorNum; j++) {
//
//            for (int i = 0; i < fSize; i++) {
//
//                System.out.println(i * factorNum + j);
//
//            }
//
//        }


        Map<String, String> map = new HashMap<>(100);
        List<Long> list = new ArrayList<>(100);
        for (int j = 0; j < 200; j++) {
            map.put(j + "-", j + "--");
            list.add(j + 0L);
            System.out.println("map.size=" + map.size() + ",list.size=" + list.size());
        }

    }
}
