package cn.com.duiba.nezha.compute.alg;

import cn.com.duiba.nezha.compute.alg.util.LogisticRegressionModelUtil;
import cn.com.duiba.nezha.compute.alg.vo.VectorResult;
import cn.com.duiba.nezha.compute.api.PredResultVo;
import cn.com.duiba.nezha.compute.api.dto.AdvertModelEntity;
import cn.com.duiba.nezha.compute.api.enums.SerializerEnum;
import java.io.Serializable;
import java.util.List;
import java.util.Map;
import org.apache.spark.mllib.classification.LogisticRegressionModel;
import org.apache.spark.mllib.linalg.SparseVector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:cn/com/duiba/nezha/compute/alg/LR.class */
public class LR extends BaseAlgorithm implements Serializable, IAlgorithm {
    private static final Logger logger = LoggerFactory.getLogger(LR.class);
    private LogisticRegressionModelUtil modelUtil = null;

    public LogisticRegressionModelUtil getModelUtil() {
        if (this.modelUtil == null) {
            this.modelUtil = new LogisticRegressionModelUtil();
        }
        return this.modelUtil;
    }

    public LR() {
    }

    public LR(String str, String str2, String str3, String str4, SerializerEnum serializerEnum) {
        setFeatureDict(str2, serializerEnum);
        setModel(str3, serializerEnum);
        setFeatureIdxList(str, serializerEnum);
        setFeatureCollectionList(str4, serializerEnum);
    }

    public LR(AdvertModelEntity advertModelEntity) {
        SerializerEnum serializerEnum = advertModelEntity.getSerializerId() == SerializerEnum.KRYO.getIndex() ? SerializerEnum.KRYO : SerializerEnum.JAVA_ORIGINAL;
        setFeatureDict(advertModelEntity.getFeatureDictStr(), serializerEnum);
        setModel(advertModelEntity.getModelStr(), serializerEnum);
        setFeatureIdxList(advertModelEntity.getFeatureIdxListStr(), serializerEnum);
        setFeatureCollectionList(advertModelEntity.getFeatureCollectListStr(), serializerEnum);
    }

    public void setEntity(AdvertModelEntity advertModelEntity) {
        SerializerEnum serializerEnum = advertModelEntity.getSerializerId() == SerializerEnum.KRYO.getIndex() ? SerializerEnum.KRYO : SerializerEnum.JAVA_ORIGINAL;
        setFeatureDict(advertModelEntity.getFeatureDictStr(), serializerEnum);
        setFeatureIdxList(advertModelEntity.getFeatureIdxListStr(), serializerEnum);
        setFeatureCollectionList(advertModelEntity.getFeatureCollectListStr(), serializerEnum);
        setModel(advertModelEntity.getModelStr(), serializerEnum);
    }

    public void setModel(LogisticRegressionModel logisticRegressionModel) {
        getModelUtil().setModel(logisticRegressionModel);
    }

    @Override // cn.com.duiba.nezha.compute.alg.BaseAlgorithm
    public void setModel(String str, SerializerEnum serializerEnum) {
        getModelUtil().setModel(str, serializerEnum);
    }

    @Override // cn.com.duiba.nezha.compute.alg.BaseAlgorithm
    public String getModelStr(SerializerEnum serializerEnum) {
        return getModelUtil().getModelStr(serializerEnum);
    }

    @Override // cn.com.duiba.nezha.compute.alg.IAlgorithm
    public Double predict(List<String> list) {
        Double d = null;
        try {
            VectorResult oneHotSparseVectorEncode = getDictUtil().oneHotSparseVectorEncode(getFeatureIdxList(), list, getFeatureCollectionList());
            if (oneHotSparseVectorEncode != null && oneHotSparseVectorEncode.getVector() != null) {
                d = getModelUtil().predict(oneHotSparseVectorEncode.getVector());
            }
        } catch (Exception e) {
            logger.error("predict happend error", e);
        }
        return d;
    }

    @Override // cn.com.duiba.nezha.compute.alg.IAlgorithm
    public Double predict(Map<String, String> map) {
        Double d = null;
        try {
            VectorResult oneHotSparseVectorEncodeWithMap = getDictUtil().oneHotSparseVectorEncodeWithMap(getFeatureIdxList(), map, getFeatureCollectionList());
            if (oneHotSparseVectorEncodeWithMap != null && oneHotSparseVectorEncodeWithMap.getVector() != null) {
                d = getModelUtil().predict(oneHotSparseVectorEncodeWithMap.getVector());
            }
        } catch (Exception e) {
            logger.error("predict happend error", e);
        }
        return d;
    }

    public PredResultVo predictWithInfo(Map<String, String> map) {
        PredResultVo predResultVo = new PredResultVo();
        try {
            VectorResult oneHotSparseVectorEncodeWithMap = getDictUtil().oneHotSparseVectorEncodeWithMap(getFeatureIdxList(), map, getFeatureCollectionList());
            if (oneHotSparseVectorEncodeWithMap != null && oneHotSparseVectorEncodeWithMap.getVector() != null) {
                predResultVo.setPredValue(predictWithVector(oneHotSparseVectorEncodeWithMap.getVector()));
                predResultVo.setNewFeatureNums(oneHotSparseVectorEncodeWithMap.getNewFeatureNums());
                predResultVo.setTotalFeatureNums(oneHotSparseVectorEncodeWithMap.getTotalFeatureNums());
            }
        } catch (Exception e) {
            logger.error("predict happend error", e);
        }
        return predResultVo;
    }

    public Double predictWithVector(SparseVector sparseVector) {
        Double d = null;
        if (sparseVector != null) {
            try {
                System.out.println("vector " + sparseVector);
                d = getModelUtil().predict(sparseVector);
            } catch (Exception e) {
                logger.error("predict happend error", e);
            }
        }
        return d;
    }
}
