package cn.com.duiba.nezha.compute.alg;


import cn.com.duiba.nezha.compute.alg.util.FMModelUtil;
import org.apache.spark.mllib.linalg.SparseVector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.Serializable;
import java.util.List;
import java.util.Map;

/**
 * Created by pc on 2016/12/22.
 */
public class FM2 extends BaseAlgorithm implements Serializable ,IAlgorithm{
    private static final Logger logger = LoggerFactory.getLogger(FM2.class);
    private FMModelUtil modelUtil = null;

    public FMModelUtil getModelUtil() {
        if (this.modelUtil == null) {
            this.modelUtil = new FMModelUtil();
        }
        return this.modelUtil;
    }


    public Double predict(List<String> categoryList) {
        Double ret = null;

        try {
            SparseVector sv = getDictUtil().oneHotSparseVectorEncode(getFeatureIdxList(), categoryList);
            if (sv != null) {
                ret = getModelUtil().predict(sv);
            }
        } catch (Exception e) {
            logger.error("predict happend error", e);
        }
        return ret;
    }

    public Double predict(Map<String, String> categoryMap) {
        Double ret = null;
        try {

            SparseVector sv = getDictUtil().oneHotSparseVectorEncodeWithMap(getFeatureIdxList(), categoryMap);
            if (sv != null) {
                ret = getModelUtil().predict(sv);
            }
        } catch (Exception e) {
            logger.error("predict happend error", e);
        }

        return ret;
    }

}
