package cn.com.duiba.nezha.alg.model;

import cn.com.duiba.nezha.alg.common.util.AssertUtil;
import cn.com.duiba.nezha.alg.feature.vo.FeatureMapDo;
import cn.com.duiba.nezha.alg.model.enums.HyperParams;
import cn.com.duiba.nezha.alg.model.enums.MutModelType;
import cn.com.duiba.nezha.alg.model.enums.PredictResultType;
import cn.com.duiba.nezha.alg.model.tf.LocalTFModelV2;
import com.alibaba.fastjson.JSON;
import java.io.Serializable;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:cn/com/duiba/nezha/alg/model/ESMMV3.class */
public class ESMMV3 implements Serializable {
    private static final long serialVersionUID = -316102112618444131L;
    private static final Logger logger = LoggerFactory.getLogger(ESMMV3.class);
    private int PB_MAX_SIZE = 128;
    private Integer algType;
    private IModel ctrModel;
    private IModel cvrModel;
    private LocalTFModelV2 ctrLocalTFModel;
    private LocalTFModelV2 cvrLocalTFModel;
    private MutModelType mutModelType;
    private Boolean isTmp;

    public ESMMV3(IModel iModel, IModel iModel2, LocalTFModelV2 localTFModelV2, LocalTFModelV2 localTFModelV22, MutModelType mutModelType) {
        this.ctrModel = iModel;
        this.cvrModel = iModel2;
        this.ctrLocalTFModel = localTFModelV2;
        this.cvrLocalTFModel = localTFModelV22;
        this.mutModelType = mutModelType;
    }

    public <T> Map<PredictResultType, Map<T, Double>> predictCTRsAndCVRs(Map<T, FeatureMapDo> map, Map<HyperParams, Double> map2) throws Exception {
        Map<PredictResultType, Map<T, Double>> map3 = null;
        if (MutModelType.DEBIAS == this.mutModelType) {
            map3 = predictCTRsAndCVRsLocalTFNew(map);
            Map<T, Double> map4 = map3.get(PredictResultType.CVR);
            Map<T, Double> map5 = map3.get(PredictResultType.CVR);
            Double orDefault = map2.getOrDefault(HyperParams.NCW, Double.valueOf(1.0d));
            for (Map.Entry<T, Double> entry : map4.entrySet()) {
                Double value = entry.getValue();
                map5.put(entry.getKey(), Double.valueOf(value.doubleValue() / (value.doubleValue() + ((1.0d - value.doubleValue()) / orDefault.doubleValue()))));
            }
            map3.put(PredictResultType.CVR, map5);
        } else if (MutModelType.DEEP_E2E_LOCAL == this.mutModelType) {
            map3 = predictCTRsAndCVRsLocalTFNew(map);
        }
        return map3;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v43, types: [java.util.Map] */
    /* JADX WARN: Type inference failed for: r0v46, types: [java.util.Map] */
    public <T> Map<PredictResultType, Map<T, Double>> predictCTRsAndCVRsLocalTFNew(Map<T, FeatureMapDo> map) throws Exception {
        if (Objects.equals(this.isTmp, Boolean.TRUE)) {
            logger.info("模型预热, algType: {}", this.algType);
        }
        HashMap hashMap = new HashMap(this.PB_MAX_SIZE);
        HashMap hashMap2 = new HashMap();
        if (this.ctrLocalTFModel != null) {
            hashMap2 = this.ctrModel.predictWithLocalTFV2(map, this.ctrLocalTFModel);
        }
        HashMap hashMap3 = new HashMap();
        if (this.cvrLocalTFModel != null) {
            hashMap3 = this.cvrModel.predictWithLocalTFV2(map, this.cvrLocalTFModel);
        }
        String jSONString = map.size() < 40 ? JSON.toJSONString(hashMap3) : null;
        for (Map.Entry<T, FeatureMapDo> entry : map.entrySet()) {
            FeatureMapDo value = entry.getValue();
            value.setCodeDoStr(jSONString + ";" + this.cvrLocalTFModel.getVersion() + ";" + ((Double) hashMap3.get(entry.getKey())) + ";" + value.getCodeDoStr());
        }
        if (AssertUtil.isAllNotEmpty(new Object[]{hashMap2, hashMap3})) {
            hashMap.put(PredictResultType.CTR, hashMap2);
            hashMap.put(PredictResultType.CVR, hashMap3);
        } else {
            hashMap = null;
        }
        return hashMap;
    }

    public static void main(String[] strArr) {
        HashMap hashMap = new HashMap(10);
        hashMap.put("1", Double.valueOf(0.01d));
        HashMap hashMap2 = new HashMap(10);
        hashMap2.put("2", Double.valueOf(0.22d));
        System.out.println(AssertUtil.isAllNotEmpty(new Object[]{hashMap, hashMap2}));
    }

    private <T> Map<T, List<Float>> getParams(Map<T, Map<String, String>> map) throws Exception {
        HashMap hashMap = new HashMap(this.PB_MAX_SIZE);
        if (AssertUtil.isNotEmpty(map)) {
            for (Map.Entry<T, Map<String, String>> entry : map.entrySet()) {
                hashMap.put(entry.getKey(), paramsSplicing(this.ctrModel.getParam(entry.getValue()), this.cvrModel.getParam(entry.getValue())));
            }
        }
        return hashMap;
    }

    private <T> Map<T, List<Float>> getParamsNew(Map<T, FeatureMapDo> map) throws Exception {
        HashMap hashMap = new HashMap(this.PB_MAX_SIZE);
        if (AssertUtil.isNotEmpty(map)) {
            for (Map.Entry<T, FeatureMapDo> entry : map.entrySet()) {
                hashMap.put(entry.getKey(), paramsSplicing(this.ctrModel.getParam(entry.getValue()), this.cvrModel.getParam(entry.getValue())));
            }
        }
        return hashMap;
    }

    private List<Float> paramsSplicing(List<Float> list, List<Float> list2) {
        list.addAll(list2);
        return list;
    }

    public int getPB_MAX_SIZE() {
        return this.PB_MAX_SIZE;
    }

    public Integer getAlgType() {
        return this.algType;
    }

    public IModel getCtrModel() {
        return this.ctrModel;
    }

    public IModel getCvrModel() {
        return this.cvrModel;
    }

    public LocalTFModelV2 getCtrLocalTFModel() {
        return this.ctrLocalTFModel;
    }

    public LocalTFModelV2 getCvrLocalTFModel() {
        return this.cvrLocalTFModel;
    }

    public MutModelType getMutModelType() {
        return this.mutModelType;
    }

    public Boolean getIsTmp() {
        return this.isTmp;
    }

    public void setPB_MAX_SIZE(int i) {
        this.PB_MAX_SIZE = i;
    }

    public void setAlgType(Integer num) {
        this.algType = num;
    }

    public void setCtrModel(IModel iModel) {
        this.ctrModel = iModel;
    }

    public void setCvrModel(IModel iModel) {
        this.cvrModel = iModel;
    }

    public void setCtrLocalTFModel(LocalTFModelV2 localTFModelV2) {
        this.ctrLocalTFModel = localTFModelV2;
    }

    public void setCvrLocalTFModel(LocalTFModelV2 localTFModelV2) {
        this.cvrLocalTFModel = localTFModelV2;
    }

    public void setMutModelType(MutModelType mutModelType) {
        this.mutModelType = mutModelType;
    }

    public void setIsTmp(Boolean bool) {
        this.isTmp = bool;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof ESMMV3)) {
            return false;
        }
        ESMMV3 esmmv3 = (ESMMV3) obj;
        if (!esmmv3.canEqual(this) || getPB_MAX_SIZE() != esmmv3.getPB_MAX_SIZE()) {
            return false;
        }
        Integer algType = getAlgType();
        Integer algType2 = esmmv3.getAlgType();
        if (algType == null) {
            if (algType2 != null) {
                return false;
            }
        } else if (!algType.equals(algType2)) {
            return false;
        }
        IModel ctrModel = getCtrModel();
        IModel ctrModel2 = esmmv3.getCtrModel();
        if (ctrModel == null) {
            if (ctrModel2 != null) {
                return false;
            }
        } else if (!ctrModel.equals(ctrModel2)) {
            return false;
        }
        IModel cvrModel = getCvrModel();
        IModel cvrModel2 = esmmv3.getCvrModel();
        if (cvrModel == null) {
            if (cvrModel2 != null) {
                return false;
            }
        } else if (!cvrModel.equals(cvrModel2)) {
            return false;
        }
        LocalTFModelV2 ctrLocalTFModel = getCtrLocalTFModel();
        LocalTFModelV2 ctrLocalTFModel2 = esmmv3.getCtrLocalTFModel();
        if (ctrLocalTFModel == null) {
            if (ctrLocalTFModel2 != null) {
                return false;
            }
        } else if (!ctrLocalTFModel.equals(ctrLocalTFModel2)) {
            return false;
        }
        LocalTFModelV2 cvrLocalTFModel = getCvrLocalTFModel();
        LocalTFModelV2 cvrLocalTFModel2 = esmmv3.getCvrLocalTFModel();
        if (cvrLocalTFModel == null) {
            if (cvrLocalTFModel2 != null) {
                return false;
            }
        } else if (!cvrLocalTFModel.equals(cvrLocalTFModel2)) {
            return false;
        }
        MutModelType mutModelType = getMutModelType();
        MutModelType mutModelType2 = esmmv3.getMutModelType();
        if (mutModelType == null) {
            if (mutModelType2 != null) {
                return false;
            }
        } else if (!mutModelType.equals(mutModelType2)) {
            return false;
        }
        Boolean isTmp = getIsTmp();
        Boolean isTmp2 = esmmv3.getIsTmp();
        return isTmp == null ? isTmp2 == null : isTmp.equals(isTmp2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof ESMMV3;
    }

    public int hashCode() {
        int pb_max_size = (1 * 59) + getPB_MAX_SIZE();
        Integer algType = getAlgType();
        int hashCode = (pb_max_size * 59) + (algType == null ? 43 : algType.hashCode());
        IModel ctrModel = getCtrModel();
        int hashCode2 = (hashCode * 59) + (ctrModel == null ? 43 : ctrModel.hashCode());
        IModel cvrModel = getCvrModel();
        int hashCode3 = (hashCode2 * 59) + (cvrModel == null ? 43 : cvrModel.hashCode());
        LocalTFModelV2 ctrLocalTFModel = getCtrLocalTFModel();
        int hashCode4 = (hashCode3 * 59) + (ctrLocalTFModel == null ? 43 : ctrLocalTFModel.hashCode());
        LocalTFModelV2 cvrLocalTFModel = getCvrLocalTFModel();
        int hashCode5 = (hashCode4 * 59) + (cvrLocalTFModel == null ? 43 : cvrLocalTFModel.hashCode());
        MutModelType mutModelType = getMutModelType();
        int hashCode6 = (hashCode5 * 59) + (mutModelType == null ? 43 : mutModelType.hashCode());
        Boolean isTmp = getIsTmp();
        return (hashCode6 * 59) + (isTmp == null ? 43 : isTmp.hashCode());
    }

    public String toString() {
        return "ESMMV3(PB_MAX_SIZE=" + getPB_MAX_SIZE() + ", algType=" + getAlgType() + ", ctrModel=" + getCtrModel() + ", cvrModel=" + getCvrModel() + ", ctrLocalTFModel=" + getCtrLocalTFModel() + ", cvrLocalTFModel=" + getCvrLocalTFModel() + ", mutModelType=" + getMutModelType() + ", isTmp=" + getIsTmp() + ")";
    }
}
