package cn.com.duiba.nezha.compute.alg.util;

import cn.com.duiba.nezha.compute.api.enums.SerializerEnum;
import cn.com.duiba.nezha.compute.common.util.AssertUtil;
import cn.com.duiba.nezha.compute.common.util.OneHotUtil;
import cn.com.duiba.nezha.compute.api.dict.CategoryFeatureDict;
import cn.com.duiba.nezha.compute.common.util.serialize.KryoUtil;
import cn.com.duiba.nezha.compute.common.util.serialize.SerializeTool;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.spark.mllib.linalg.SparseVector;
import org.apache.spark.mllib.linalg.Vectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

/**
 * Created by pc on 2016/12/16.
 */
public class CategoryFeatureDictUtil implements Serializable {

    private Logger logger = LoggerFactory.getLogger(this.getClass().getName());
    private static final long serialVersionUID = -316102112618444930L;
    private CategoryFeatureDict dict = null;


    public void setFeatureDict(CategoryFeatureDict dict) {
        this.dict = dict;
    }

    public void setFeatureDict(String dictStr, SerializerEnum serializerEnum) {
        this.dict = getFeatureDict(dictStr, serializerEnum);
    }


    public CategoryFeatureDict getFeatureDict() {
        return this.dict;
    }

    public CategoryFeatureDict getFeatureDict(String dictStr, SerializerEnum serializerEnum) {
        CategoryFeatureDict dict = null;
        if (serializerEnum == SerializerEnum.JAVA_ORIGINAL) {
            dict = SerializeTool.getObjectFromString(dictStr);
        }
        if (serializerEnum == SerializerEnum.KRYO) {
            dict = KryoUtil.deserializationObject(dictStr, CategoryFeatureDict.class);
        }


        return dict;
    }

    public String featureIdxList2Str(List<String> idxStr, SerializerEnum serializerEnum) {
        String ret = null;
        if (serializerEnum == SerializerEnum.JAVA_ORIGINAL) {
            ret = SerializeTool.object2String(idxStr);
        }
        if (serializerEnum == SerializerEnum.KRYO) {
            ret = KryoUtil.serializationListObject(idxStr);
        }


        return ret;
    }

    public List<String> getFeatureIdxList(String featureIdxStr, SerializerEnum serializerEnum) {
        List<String> ret = null;

        if (serializerEnum == SerializerEnum.JAVA_ORIGINAL) {
            ret = (List<String>) SerializeTool.getObjectFromStr(featureIdxStr);
        }
        if (serializerEnum == SerializerEnum.KRYO) {
            ret = KryoUtil.deserializationListObject(featureIdxStr);
        }

        return ret;
    }


    public String getFeatureDictStr(CategoryFeatureDict dict, SerializerEnum serializerEnum) {
        String ret = null;

        if (serializerEnum == SerializerEnum.JAVA_ORIGINAL) {
            ret = SerializeTool.object2String(dict);
        }
        if (serializerEnum == SerializerEnum.KRYO) {
            ret = KryoUtil.serializationObject(dict);
        }

        return ret;
    }

    public String getFeatureDictStr(SerializerEnum serializerEnum) {
        return getFeatureDictStr(dict, serializerEnum);
    }


    //
    public List<String> getFeature(String featureIdx) {
        if (this.dict == null) {
            logger.warn("dict is null");
            return null;
        }

        if (AssertUtil.isEmpty(featureIdx)) {
            logger.warn("getFeature cn.com.duiba.nezha.compute.biz.param invalid");
            return null;
        }
        return this.dict.getFeatureDict().get(featureIdx);
    }


    //
    public int getFeatureSize(String featureIdx) {
        int ret = 0;
        List<String> featureMap = getFeature(featureIdx);
        if (AssertUtil.isNotEmpty(featureMap)) {
            ret = featureMap.size();
        }
        return ret;
    }


    public int getFeatureCategoryIndex(String featureIdx, String category) {
        int ret = -1;
        List<String> featureMap = getFeature(featureIdx);
        if (AssertUtil.isNotEmpty(featureMap)) {
            ret = featureMap.indexOf(category);
        }
        return ret;
    }


    public double[] featureOneHotDoubleEncode(String featureIdx, String category) {
        double[] ret = null;
        int idx = getFeatureCategoryIndex(featureIdx, category);
        int size = getFeatureSize(featureIdx);
        if (AssertUtil.isNotEmpty(idx)) {
            ret = OneHotUtil.getOneHotDouble(idx, size);
        }
        return ret;
    }

    /**
     * @param featureIdxList
     * @param categoryList
     * @return
     */
    public double[] oneHotDoubleEncode(List<String> featureIdxList, List<String> categoryList) {
        double[] ret = null;
        if (AssertUtil.isAnyEmpty(featureIdxList, categoryList) ||
                categoryList.size() != featureIdxList.size()) {
            return null;
        }
        try {
            for (int i = 0; i < featureIdxList.size(); i++) {
                String featureIdx = featureIdxList.get(i);
                String category = categoryList.get(i);
                double[] partRet = featureOneHotDoubleEncode(featureIdx, category);
                if (partRet == null) {
                    return null;
                } else {
                    ret = ArrayUtils.addAll(ret, partRet);
                }
            }
        } catch (Exception e) {
            logger.error("oneHotDoubleEncode happend error", e);
        }

        return ret;
    }


    /**
     * @param featureIdxList
     * @param categoryList
     * @return
     */
    public SparseVector oneHotSparseVectorEncode(List<String> featureIdxList, List<String> categoryList) {
        SparseVector ret = null;
        if (AssertUtil.isAnyEmpty(featureIdxList, categoryList) ||
                categoryList.size() != featureIdxList.size()) {
            return null;
        }
        try {
            int[] sizeArray = new int[featureIdxList.size()];
            int[] indexArray = new int[featureIdxList.size()];

            int[] r_indices = new int[featureIdxList.size()];
            double[] r_values = new double[featureIdxList.size()];
            int sizeSum = 0;
            for (int i = 0; i < featureIdxList.size(); i++) {
                String featureIdx = featureIdxList.get(i);
                String category = categoryList.get(i);

                sizeArray[i] = getFeatureSize(featureIdx);
                indexArray[i] = getFeatureCategoryIndex(featureIdx, category);

                r_indices[i] = sizeSum + getFeatureCategoryIndex(featureIdx, category) + 1;
                sizeSum += getFeatureSize(featureIdx) + 1;
                r_values[i] = 1.0;
            }


            ret = (SparseVector) Vectors.sparse(sizeSum, r_indices, r_values);

        } catch (Exception e) {
            logger.error("oneHotDoubleEncode happend error", e);
        }

        return ret;
    }


    /**
     * @param featureIdxList
     * @param categoryMap    <featureIdx,category>
     * @return
     */

    public double[] oneHotDoubleEncodeWithMap(List<String> featureIdxList, Map<String, String> categoryMap) {
        double[] ret = null;

        if (AssertUtil.isAnyEmpty(featureIdxList, categoryMap)) {
            return null;
        }
        try {
            ret = oneHotDoubleEncodeWithMap(featureIdxList, categoryMap, true);
        } catch (Exception e) {
            logger.error("oneHotIntEncode happend error", e);

        }

        return ret;
    }

    /**
     * @param featureIdxList
     * @param categoryMap    <featureIdx,category>
     * @return
     */

    public double[] oneHotDoubleEncodeWithMap(List<String> featureIdxList, Map<String, String> categoryMap, boolean ignoreNull) {
        double[] ret = null;

        if (AssertUtil.isAnyEmpty(featureIdxList, categoryMap)) {
            return null;
        }
        try {
            int status = 1;
            List<String> categoryList = new ArrayList<>();
            for (int i = 0; i < featureIdxList.size(); i++) {

                String featureIdx = featureIdxList.get(i);

                if (!categoryMap.containsKey(featureIdx) && !ignoreNull) {
                    status = -1;
                    break;
                }

                String category = categoryMap.get(featureIdx);


                if (category != null) {
                    category = category.toLowerCase();
                }
                categoryList.add(category);
            }
            if (status == 1) {
                ret = oneHotDoubleEncode(featureIdxList, categoryList);
            }
        } catch (Exception e) {
            logger.error("oneHotIntEncode happend error", e);

        }

        return ret;
    }


    /**
     * @param featureIdxList
     * @param categoryMap    <featureIdx,category>
     * @return
     */

    public SparseVector oneHotSparseVectorEncodeWithMap(List<String> featureIdxList, Map<String, String> categoryMap) {
        SparseVector ret = null;

        if (AssertUtil.isAnyEmpty(featureIdxList, categoryMap)) {
            return null;
        }
        try {
            ret = oneHotSparseVectorEncodeWithMap(featureIdxList, categoryMap, true);
        } catch (Exception e) {
            logger.error("oneHotIntEncode happend error", e);

        }

        return ret;
    }


    /**
     * @param featureIdxList
     * @param categoryMap    <featureIdx,category>
     * @return
     */

    public SparseVector oneHotSparseVectorEncodeWithMap(List<String> featureIdxList, Map<String, String> categoryMap, boolean ignoreNull) {
        SparseVector ret = null;

        if (AssertUtil.isAnyEmpty(featureIdxList, categoryMap)) {
            return null;
        }
        try {
            int status = 1;
            List<String> categoryList = new ArrayList<>();
            for (int i = 0; i < featureIdxList.size(); i++) {

                String featureIdx = featureIdxList.get(i);

                if (!categoryMap.containsKey(featureIdx) && !ignoreNull) {
                    status = -1;
                    break;
                }

                String category = categoryMap.get(featureIdx);

                if (category != null) {
                    category = category.toLowerCase();
                }
                categoryList.add(category);
            }
            if (status == 1) {
                ret = oneHotSparseVectorEncode(featureIdxList, categoryList);
            }
        } catch (Exception e) {
            logger.error("oneHotIntEncode happend error", e);

        }

        return ret;
    }


}
