package cn.com.duiba.nezha.compute.common.util;

import org.apache.spark.mllib.classification.LogisticRegressionModel;
import org.apache.spark.mllib.linalg.Vector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.Serializable;

/**
 * Created by pc on 2016/12/21.
 */
public class LogisticRegressionModelUtil implements Serializable {

    private Logger logger = LoggerFactory.getLogger(this.getClass().getName());
    private static final long serialVersionUID = -316102112618444922L;

    private LogisticRegressionModel model = null;


    public void setModel(LogisticRegressionModel model) {
        this.model = model;
    }

    public void setModel(String modelStr) {
        this.model = getModel(modelStr);
    }


    public LogisticRegressionModel getModel() {
        return this.model;
    }

    public LogisticRegressionModel getModel(String src) {

        LogisticRegressionModel model = SerializeTool.getObjectFromString(src);
        return model;
    }

    public String getModelStr(LogisticRegressionModel model) {
        String ret = null;
        ret = SerializeTool.object2String(model);
        return ret;
    }

    public String getModelStr() {
        String ret = null;
        ret = SerializeTool.object2String(this.model);
        return ret;
    }

    public Double predict(Vector data) {
        if (getModel() == null) {
            return null;
        } else {
            return getModel().predict(data);
        }
    }


}
