package cn.com.duiba.nezha.alg.model;

import cn.com.duiba.nezha.alg.common.util.AssertUtil;
import cn.com.duiba.nezha.alg.common.util.DataUtil;
import cn.com.duiba.nezha.alg.common.util.MathUtil;
import cn.com.duiba.nezha.alg.feature.coder.FeatureCoder;
import cn.com.duiba.nezha.alg.feature.coder.FeatureNewCoder;
import cn.com.duiba.nezha.alg.feature.type.FeatureBaseType;
import cn.com.duiba.nezha.alg.feature.vo.CodeResult;
import cn.com.duiba.nezha.alg.feature.vo.Feature;
import cn.com.duiba.nezha.alg.feature.vo.FeatureMapDo;
import cn.com.duiba.nezha.alg.model.tf.LocalTFModel;
import cn.com.duiba.nezha.alg.model.tf.TFServingClient;
import cn.com.duiba.nezha.alg.model.vo.ParamsDo;

import java.io.Serializable;
import java.util.*;

/**
 * Created by yudian on 2021/05/26.
 */
public class FwFM implements Serializable, IModel {

    /**
     * 模型ID
     */
    private String modelId;

    /**
     * 更新时间
     */
    private String updateTime;

    /**
     * 模型参数
     */
    private ParamsDo paramsDoffm;

    /**
     * 特征
     */
    private List<FeatureBaseType> featureBaseType = new ArrayList<FeatureBaseType>();


    private int FBT_MAX_SIZE = 64;

    private int PB_MAX_SIZE = 128;


    @Override
    public String getModelId() {
        return modelId;
    }

    @Override
    public List<Float> getParam(Map<String, String> featureMap) throws Exception {
        List<Float> ret = null;

        double[] value = getParam2(featureMap);
        if (value != null) {
            ret = new ArrayList<>();
            for (int i = 0; i < value.length; i++) {
                ret.add((float) value[i]);
            }
        }
        return ret;
    }

    public double[] getParam2(Map<String, String> featureMap) throws Exception {
        double[] ret = null;
        CodeResult codeResult = FeatureCoder.code(getFeatureBaseType(), featureMap);

        if (codeResult != null) {
            ret = getModelParams(codeResult.getFeatureSet());
        }

        return ret;
    }

    public double[] getParam2(FeatureMapDo featureMap) throws Exception {
        double[] ret = null;
        CodeResult codeResult = FeatureCoder.code(getFeatureBaseType(), featureMap);

        if (codeResult != null) {
            ret = getModelParams(codeResult.getFeatureSet());
        }

        return ret;
    }

    public double[] getModelParams(List<Set<Integer>> codeSetList) throws Exception {


        try {
            boolean status = true;

            double[] ret = null;

            return ret;
        } catch (Exception e) {
            System.out.println("FFM.predict error");
            throw new Exception(e);
        }

    }

    @Override
    public List<Float> getParam(FeatureMapDo featureMapDo) throws Exception {
        List<Float> ret = null;

        double[] value = getParam2(featureMapDo);
        if (value != null) {
            ret = new ArrayList<>();
            for (int i = 0; i < value.length; i++) {
                ret.add((float) value[i]);
            }
        }
        return ret;

    }

    @Override
    public Double predict(Map<String, String> featureMap) throws Exception {
        Double ret = null;

        CodeResult codeResult = FeatureCoder.code(getFeatureBaseType(), featureMap);

        if (codeResult != null) {
            ret = predict(codeResult.getFeature());
        }

        return ret;
    }

    /**
     * 预测单条样本
     * @param feature
     * @return
     * @throws Exception
     */
    public Double predict(Feature feature) throws Exception {
        try {
            double ret = 0.0;
            double retW0 = 0.0;
            double retW = 0.0;
            double retV = 0.0;

            int factorNum = paramsDoffm.getFactorNum();

            if (feature != null) {
                Map<String, Double> tmpMap = new HashMap<>(1);
                Map<Integer, Double> tmpMap1 = new HashMap<>(1);

                List<Integer> indices = feature.indices;
                double[] x = feature.values;
                List<String> fields = feature.fields;

                retW0 = paramsDoffm.getWeight0();

                for (int i = 0; i < indices.size(); i ++) {
                    Integer vi = indices.get(i);
                    double xi = x[i];
                    String fieldI = fields.get(i);

                    // 计算一阶项
                    retW += xi * paramsDoffm.getWeight().getOrDefault(vi, 0.0);

                    // 计算二阶项
                    // 当前域与其他域交叉权重
                    Map<String, Double> fieldIMap = paramsDoffm.getVectorC().getOrDefault(fieldI, tmpMap);

                    for (int j = i + 1; j < indices.size(); j ++) {
                        Integer vj = indices.get(j);
                        double xj = x[j];
                        String fieldJ = fields.get(j);

                        Double wij = fieldIMap.getOrDefault(fieldJ, 0.0);

                        // 计算隐向量乘积
                        double tmpProductValue = 0.0;
                        for (int k = 0; k < factorNum; k ++) {
                            Map<Integer, Double> factorKMap = paramsDoffm.getVector().getOrDefault(k, tmpMap1);

                            Double vifk = factorKMap.getOrDefault(vi, 0.0);
                            Double vjfk = factorKMap.getOrDefault(vj, 0.0);

                            if (vifk != null && vjfk != null) {
                                tmpProductValue += vifk * vjfk;
                            }
                        }
                        retV += wij * tmpProductValue * xi * xj;
                    }

                }

                ret = retW0 + retW + retV;

            }

            double pValue = MathUtil.sigmoid(ret);
            double pCValue = correctValue(pValue, paramsDoffm.getCorrect());
            return DataUtil.formatDouble(pCValue, 5);
        } catch (Exception e) {
            System.out.println("FwFM.predict error");
            throw e;
        }
    }

    public Double predict(CodeResult codeResult) throws Exception {
        Double ret = null;
        if (codeResult != null) {
            ret = predict(codeResult.getFeature());
        }
        return ret;
    }

    public double correctValue(double pre_, Double correct) {
        double ret = pre_;
        if (correct != null) {
            ret = correct / (1 / pre_ - 1 + correct);
        }
        return ret;
    }

    @Override
    public <T> Map<T, Double> predicts(Map<T, Map<String, String>> featureMap) throws Exception {
        Map<T, Double> ret = new HashMap<>();
        if (AssertUtil.isNotEmpty(featureMap)) {
            for (Map.Entry<T, Map<String, String>> entry : featureMap.entrySet()) {
                ret.put(entry.getKey(), predict(entry.getValue()));
            }
        }

        return ret;
    }

    public <T> Map<T, List<Float>> getParams(Map<T, Map<String, String>> featureMap) throws Exception {
        Map<T, List<Float>> ret = new HashMap<>(PB_MAX_SIZE);

        if (AssertUtil.isNotEmpty(featureMap)) {
            for (Map.Entry<T, Map<String, String>> entry : featureMap.entrySet()) {
                ret.put(entry.getKey(), getParam(entry.getValue()));
            }
        }
        return ret;

    }

    @Override
    public <T> Map<T, Double> predictWithTF(Map<T, Map<String, String>> featureMap, TFServingClient tfServingClient) throws Exception {
        Map<T, Double> ret = null;
        if (tfServingClient == null) {
            ret = predicts(featureMap);
        } else {
            ret = tfServingClient.predict(getParams(featureMap));

        }
        return ret;
    }

    @Override
    public <T> Map<T, Double> predictWithLocalTF(Map<T, Map<String, String>> featureMap, LocalTFModel localTFModel) throws Exception {
        return null;
    }

    @Override
    public <T> Map<T, Double> predictWithTFNew(Map<T, FeatureMapDo> featureMap, TFServingClient tfServingClient) throws Exception {
        return null;
    }

    @Override
    public <T> Map<T, Double> predictWithLocalTFNew(Map<T, FeatureMapDo> featureMap, LocalTFModel localTFModel) throws Exception {
        return null;
    }

    @Override
    public <T> Map<T, Double> predictsNew(Map<T, FeatureMapDo> featureMap) throws Exception {
        Map<T, Double> ret = new HashMap<>();
        if (AssertUtil.isNotEmpty(featureMap)) {
            Map<T, CodeResult> codeResultMap = FeatureNewCoder.code(getFeatureBaseType(), featureMap);


            for (Map.Entry<T, FeatureMapDo> entry : featureMap.entrySet()) {

                CodeResult cr = codeResultMap.get(entry.getKey());
                ret.put(entry.getKey(), predict(cr));
            }
        }

        return ret;
    }

    public void setModelId(String modelId) {
        this.modelId = modelId;
    }

    public String getUpdateTime() {
        return updateTime;
    }

    public void setUpdateTime(String updateTime) {
        this.updateTime = updateTime;
    }

    public ParamsDo getParamsDoffm() {
        return paramsDoffm;
    }

    public void setParamsDoffm(ParamsDo paramsDoffm) {
        this.paramsDoffm = paramsDoffm;
    }

    public List<FeatureBaseType> getFeatureBaseType() {
        return featureBaseType;
    }

    public void setFeatureBaseType(List<FeatureBaseType> featureBaseType) {
        this.featureBaseType = featureBaseType;
    }
}
