package cn.com.duiba.nezha.alg.model;

import cn.com.duiba.nezha.alg.common.util.AssertUtil;
import cn.com.duiba.nezha.alg.common.util.DataUtil;
import cn.com.duiba.nezha.alg.common.util.MathUtil;
import cn.com.duiba.nezha.alg.feature.coder.FeatureCoder;
import cn.com.duiba.nezha.alg.feature.type.FeatureBaseType;
import cn.com.duiba.nezha.alg.feature.vo.CodeResult;
import cn.com.duiba.nezha.alg.feature.vo.Feature;
import cn.com.duiba.nezha.alg.model.tf.TFServingClient;
import cn.com.duiba.nezha.alg.model.vo.ParamsDo;

import java.io.Serializable;
import java.util.*;

import cn.com.duiba.nezha.alg.common.util.AssertUtil;
import cn.com.duiba.nezha.alg.common.util.DataUtil;
import cn.com.duiba.nezha.alg.common.util.MathUtil;
import cn.com.duiba.nezha.alg.feature.coder.FeatureCoder;
import cn.com.duiba.nezha.alg.feature.type.FeatureBaseType;
import cn.com.duiba.nezha.alg.feature.vo.CodeResult;
import cn.com.duiba.nezha.alg.feature.vo.Feature;
import cn.com.duiba.nezha.alg.model.IModel;
import cn.com.duiba.nezha.alg.model.tf.TFServingClient;
import cn.com.duiba.nezha.alg.model.vo.ParamsDo;
import com.alibaba.fastjson.JSON;

import java.io.Serializable;
import java.util.*;

/**
 * Created by Administrator on 2018/11/8.
 */
public class FFM2 implements Serializable, IModel {
    private static final long serialVersionUID = -316102112618444130L;

    /**
     * 模型ID
     */
    private String modelId;

    /**
     * 更新时间
     */
    private String updateTime;

    /**
     * 模型参数
     */
    private ParamsDo paramsDoffm;

    private int FBT_MAX_SIZE = 64;

    private int PB_MAX_SIZE = 128;


    /**
     * 特征
     */
    private List<FeatureBaseType> featureBaseType = new ArrayList<FeatureBaseType>();
    ;

    public void setFeatureBaseType(List<FeatureBaseType> featureBaseType) {
        this.featureBaseType = featureBaseType;
    }

    public List<FeatureBaseType> getFeatureBaseType() {
        return this.featureBaseType;
    }


    public void setModelId(String modelId) {
        this.modelId = modelId;
    }

    public String getModelId() {
        return this.modelId;
    }


    public void setUpdateTime(String updateTime) {
        this.updateTime = updateTime;
    }

    public String getUpdateTime() {
        return this.updateTime;
    }


    public void setParamsDoffm(ParamsDo paramsDoffm) {
        this.paramsDoffm = paramsDoffm;
    }

    public ParamsDo getParamsDoffm() {
        return this.paramsDoffm;
    }


    /**
     * @param featureMap
     * @return
     * @throws Exception
     */
    public Double predict(Map<String, String> featureMap) throws Exception {
        Double ret = null;
        CodeResult codeResult = FeatureCoder.code(getFeatureBaseType(), featureMap);

        if (codeResult != null) {
            ret = predict(codeResult.getFeature());
        }
        return ret;
    }





    //优化版本
    public Double predict2(Map<String, String> featureMap) throws Exception {
        Double ret = null;
        CodeResult codeResult = FeatureCoder.code(getFeatureBaseType(), featureMap);

        if (codeResult != null) {
            ret = predict2(codeResult.getFeature()); //优化版本
        }
        return ret;
    }

    /**
     * @param featureMap
     * @return
     * @throws Exception
     */
    public double[] getParam2(Map<String, String> featureMap) throws Exception {
        double[] ret = null;
        CodeResult codeResult = FeatureCoder.code(getFeatureBaseType(), featureMap);

        if (codeResult != null) {
            ret = getModelParams(codeResult.getFeatureSet());
        }

        return ret;
    }


    /**
     * @param featureMap
     * @return
     * @throws Exception
     */
    public List<Float> getParam(Map<String, String> featureMap) throws Exception {
        List<Float> ret = null;

        double[] value = getParam2(featureMap);
        if (value != null) {
            ret = new ArrayList<>();
            for (int i = 0; i < value.length; i++) {
                ret.add((float) value[i]);
            }
        }
        return ret;

    }

    // ffm  批量预测 ,同fm 不用改

    public <T> Map<T, Double> predicts(Map<T, Map<String, String>> featureMap) throws Exception {
        Map<T, Double> ret = new HashMap<>();
        if (AssertUtil.isNotEmpty(featureMap)) {
            for (Map.Entry<T, Map<String, String>> entry : featureMap.entrySet()) {
                ret.put(entry.getKey(), predict(entry.getValue()));
            }
        }

        return ret;
    }


    public <T> Map<T, Double> predicts2(Map<T, Map<String, String>> featureMap) throws Exception {
        Map<T, Double> ret = new HashMap<>();
        if (AssertUtil.isNotEmpty(featureMap)) {
            for (Map.Entry<T, Map<String, String>> entry : featureMap.entrySet()) {
                ret.put(entry.getKey(), predict2(entry.getValue()));
            }
        }

        return ret;
    }




    /**
     * @param featureMap
     * @return
     * @throws Exception
     */
    public <T> Map<T, List<Float>> getParams(Map<T, Map<String, String>> featureMap) throws Exception {
        Map<T, List<Float>> ret = new HashMap<>(PB_MAX_SIZE);

        if (AssertUtil.isNotEmpty(featureMap)) {
            for (Map.Entry<T, Map<String, String>> entry : featureMap.entrySet()) {
                ret.put(entry.getKey(), getParam(entry.getValue()));
            }
        }
        return ret;

    }


    public <T> Map<T, Double> predictWithTF(Map<T, Map<String, String>> featureMap, TFServingClient tfServingClient) throws Exception {

        Map<T, Double> ret = null;
        if (tfServingClient == null) {
            ret = predicts(featureMap);
        } else {
            ret = tfServingClient.predict(getParams(featureMap));

        }
        return ret;
    }

    /**
     * @param feature
     * @return
     * @throws Exception
     */
    // ffm 预测 单个预测
    public Double predict(Feature feature) throws Exception {


        try {
            double ret = 0.0;
            double retw0 = 0.0;
            double retw = 0.0;
            double retv = 0.0;
            long factorNum = paramsDoffm.getFactorNum();
            if (feature != null) {

                Map tmpMap = new HashMap<>(1);

                int[] indices = feature.indices;
                double[] x = feature.values;
                String[] fields = feature.fields;

                retw0 += paramsDoffm.getWeight0();

                //一次项求和
                for (int i = 0; i < feature.indices.length; i++) {
                    Long fId = 0L + feature.indices[i];
                    Double value = feature.values[i];
                    retw += paramsDoffm.getWeight().getOrDefault(fId, 0.0) * value;
                }


                //交叉项求和
                for (int i = 0; i < feature.indices.length; i++) {
                    Long vi = 0L + indices[i]; //特征值i对应的索引位置id
                    Double xi = x[i];  //值
                    String fieldI = fields[i];  //特征i对应的域id
                    Map<Long, Map<Long, Double>> mapI = paramsDoffm.getVectorV3().getOrDefault(fieldI, tmpMap);  //域fieldI下面的每个特征的每个隐向量的值,   空值判断后续处理****


                    for (int j = i + 1; j < feature.indices.length; j++) {
                        Long vj = 0L + indices[j];
                        Double xj = x[j];
                        String fieldJ = fields[j]; //第j个特征所属的域
                        Map<Long, Map<Long, Double>> mapJ = paramsDoffm.getVectorV3().getOrDefault(fieldJ, tmpMap);   //域fieldJ下面的每个特征的每个隐向量的值   Map[String, Map[Int, Map[Int,Double]]]


                        Double tmpVecDotSum = 0.0;

                        for (int k = 0; k < factorNum; k++) {

                            Map<Long, Double> factorVMapI = mapJ.getOrDefault(0L+k,tmpMap);  //隐因子向量第k个位置对应的所有索引的值，特征i在特征j对应的域上的向量
//                            Double vifjK = factorVMapI.getOrDefault(vi, 0.0);  //找到特征i对应的索引的因因子值
//
                            Map<Long, Double> factorVMapJ = mapI.getOrDefault(0L+k,tmpMap);
//                            Double vjfiK = factorVMapJ.getOrDefault(vj, 0.0);

                            Double vifjK = factorVMapI.get(vi);  //找到特征i对应的索引的因因子值
                            Double vjfiK = factorVMapJ.get(vj);

                            if(vifjK!=null && vjfiK!=null ){
                                tmpVecDotSum += vifjK * vjfiK;
                            }


                        }

                        retv += tmpVecDotSum * xi * xj;
                    }
                }


                ret = retw0 + retw + retv;

            }
//            System.out.println("retW="+paramsDoffm.getWeight().size());

            double pValue = MathUtil.sigmoid(ret);
            double pCValue = correctValue(pValue, paramsDoffm.getCorrect());
            return DataUtil.formatdouble(pCValue, 5);
        } catch (Exception e) {
            System.out.println("FFM.predict error");
            throw e;
        }

    }

    // ffm 预测 单个预测  优化
    public Double predict2(Feature feature) throws Exception {

        try {
            double ret = 0.0;
            double retw0 = 0.0;
            double retw = 0.0;
            double retv = 0.0;
            long factorNum = paramsDoffm.getFactorNum();
            if (feature != null) {

                Map tmpMap = new HashMap<>();

                int[] indices = feature.indices;
                double[] x = feature.values;
                String[] fields = feature.fields;

                retw0 += paramsDoffm.getWeight0();

                //一次项求和
                for (int i = 0; i < feature.indices.length; i++) {
                    long fId = (long)feature.indices[i];
                    double value = feature.values[i];
                    retw += paramsDoffm.getWeight().getOrDefault(fId, 0.0) * value;
                }


                //交叉项求和
                for (int i = 0; i < feature.indices.length; i++) {
                    long vi = (long) indices[i]; //特征值i对应的索引位置id
                    double xi = x[i];  //值
                    String fieldI = fields[i];  //特征i对应的域id
                    Map<String, Map<Long, Map<Long, Double>>> vectorV3 = paramsDoffm.getVectorV3();
                    Map<Long, Map<Long, Double>> mapI = vectorV3.getOrDefault(fieldI, tmpMap);  //域fieldI下面的每个特征的每个隐向量的值,   空值判断后续处理****

                    for (int j = i + 1; j < feature.indices.length; j++) {
                        long vj = (long) indices[j];
                        double xj = x[j];
                        String fieldJ = fields[j]; //第j个特征所属的域
                        Map<Long, Map<Long, Double>> mapJ = vectorV3.getOrDefault(fieldJ, tmpMap);   //域fieldJ下面的每个特征的每个隐向量的值   Map[String, Map[Int, Map[Int,Double]]]


                        double tmpVecDotSum = 0.0;

                        for (long k = 0; k < factorNum; k++) {

                            Map<Long, Double> factorVMapI = mapJ.getOrDefault(k,tmpMap);  //隐因子向量第k个位置对应的所有索引的值，特征i在特征j对应的域上的向量
                            Map<Long, Double> factorVMapJ = mapI.getOrDefault(k,tmpMap);
                            Double vifjK = factorVMapI.get(vi);  //找到特征i对应的索引的因因子值
                            Double vjfiK = factorVMapJ.get(vj);
                            if(vifjK!=null && vjfiK!=null ){
                                tmpVecDotSum += vifjK * vjfiK;
                            }

                        }

                        retv += tmpVecDotSum * xi * xj;
                    }
                }


                ret = retw0 + retw + retv;

            }
//            System.out.println("retW="+paramsDoffm.getWeight().size());

            double pValue = MathUtil.sigmoid(ret);
            double pCValue = correctValue(pValue, paramsDoffm.getCorrect());
            return DataUtil.formatdouble(pCValue, 5);
        } catch (Exception e) {
            System.out.println("FFM.predict error");
            throw e;
        }

    }


    public double correctValue(double pre_, Double correct) {
        double ret = pre_;
        if (correct != null) {
            ret = correct / (1 / pre_ - 1 + correct);
        }
        return ret;
    }

    /**
     * @param codeSetList
     * @return
     * @throws Exception
     */
    public double[] getModelParams(List<Set<Long>> codeSetList) throws Exception {


        try {
            boolean status = true;

            double[] ret = null;

            return ret;
        } catch (Exception e) {
            System.out.println("FFM.predict error");
            throw new Exception(e);
        }

    }

    public static void main(String[] args) {
        int factorNum = 3;
        int fSize = 10;
        for (int j = 0; j < factorNum; j++) {

            for (int i = 0; i < fSize; i++) {

                System.out.println(i * factorNum + j);

            }

        }

    }

}