package cn.com.duiba.nezha.compute.biz.app;

import cn.com.duiba.nezha.compute.biz.bo.SampleBo;
import cn.com.duiba.nezha.compute.biz.constant.model.HParamsConstant;
import cn.com.duiba.nezha.compute.biz.constant.model.ModelConstant;
import cn.com.duiba.nezha.compute.biz.dto.PsModelSample;
import cn.com.duiba.nezha.compute.biz.ps.PsAgent;
import cn.com.duiba.nezha.compute.core.LabeledESMMSparsePoint;
import cn.com.duiba.nezha.compute.core.enums.DateStyle;
import cn.com.duiba.nezha.compute.core.model.local.LocalModel;
import cn.com.duiba.nezha.compute.core.util.DateUtil;
import cn.com.duiba.nezha.compute.mllib.fm.ftrl.FMFTRL;
import cn.com.duiba.nezha.compute.mllib.fm.ftrl.FMFTRLHyperParams;
import cn.com.duiba.nezha.compute.mllib.fm.ftrl.SparseESMMFMWithFTRLBatch;
import cn.com.duiba.nezha.compute.mllib.fm.ftrl.SparseESMMFMWithFTRLBatch2;
import cn.com.duiba.nezha.compute.mllib.fm.ftrl.SparseESMMFMWithFTRLBatch3;
import java.util.List;

/* loaded from: input_file:cn/com/duiba/nezha/compute/biz/app/SparseESMMFMWithFTRLApp.class */
public class SparseESMMFMWithFTRLApp {
    public static void run(String str, String str2, List<PsModelSample> list, boolean z, int i, int i2, int i3) throws Exception {
        System.out.println(DateUtil.getCurrentTime(DateStyle.YYYY_MM_DD_HH_MM_SS_SSS) + "  batch run start ");
        LabeledESMMSparsePoint[] eSMMSample = SampleBo.getESMMSample(str, str2, list);
        if (!z) {
            if (i2 == 1) {
                train(str, str2, eSMMSample, i, i3);
            }
            if (i2 == 2) {
                train2(str, str2, eSMMSample, i, i3);
            }
            if (i2 == 3) {
                train3(str, str2, eSMMSample, i, i3);
            }
        }
        System.out.println(DateUtil.getCurrentTime(DateStyle.YYYY_MM_DD_HH_MM_SS_SSS) + "  batch run end");
    }

    public static void run(String str, String str2, List<PsModelSample> list, List<PsModelSample> list2, boolean z, int i, int i2, int i3) throws Exception {
        System.out.println(DateUtil.getCurrentTime(DateStyle.YYYY_MM_DD_HH_MM_SS_SSS) + "  batch run start ");
        LabeledESMMSparsePoint[] eSMMSample = SampleBo.getESMMSample(str, str2, list, list2);
        if (!z) {
            if (i2 == 1) {
                train(str, str2, eSMMSample, i, i3);
            }
            if (i2 == 2) {
                train2(str, str2, eSMMSample, i, i3);
            }
            if (i2 == 3) {
                train3(str, str2, eSMMSample, i, i3);
            }
        }
        System.out.println(DateUtil.getCurrentTime(DateStyle.YYYY_MM_DD_HH_MM_SS_SSS) + "  batch run end");
    }

    public static void train(String str, String str2, LabeledESMMSparsePoint[] labeledESMMSparsePointArr, int i, int i2) throws Exception {
        if (labeledESMMSparsePointArr == null || labeledESMMSparsePointArr.length < 1) {
            return;
        }
        FMFTRLHyperParams hParams = HParamsConstant.getHParams(str);
        System.out.println("model=" + str + ",hyperParams=" + hParams);
        int i3 = labeledESMMSparsePointArr[0].ctrLSP().feature().size;
        int i4 = labeledESMMSparsePointArr[0].cvrLSP().feature().size;
        PsAgent psAgent = new PsAgent(str, ModelConstant.MODEL_PAR_SIZE, labeledESMMSparsePointArr[0].ctrLSP().feature().size, hParams.factorNum());
        PsAgent psAgent2 = new PsAgent(str2, ModelConstant.MODEL_PAR_SIZE, labeledESMMSparsePointArr[0].cvrLSP().feature().size, hParams.factorNum());
        LocalModel searchCtrModel = SparseESMMFMWithFTRLBatch.searchCtrModel(labeledESMMSparsePointArr, i3, hParams.factorNum());
        LocalModel searchCvrModel = SparseESMMFMWithFTRLBatch.searchCvrModel(labeledESMMSparsePointArr, i4, hParams.factorNum());
        psAgent.pull(searchCtrModel, false);
        psAgent2.pull(searchCvrModel, false);
        LocalModel localModel = psAgent.getLocalModel();
        LocalModel localModel2 = psAgent2.getLocalModel();
        SparseESMMFMWithFTRLBatch sparseESMMFMWithFTRLBatch = new SparseESMMFMWithFTRLBatch();
        FMFTRL fmftrl = new FMFTRL();
        FMFTRL fmftrl2 = new FMFTRL();
        fmftrl.setModelId(str).setAlpha(hParams.alpha()).setBeta(hParams.beta()).setLambda1(hParams.lambda1()).setLambda2(hParams.lambda2()).setRho1(hParams.rho1()).setRho2(hParams.rho2()).setLearnRatio2(hParams.learnRatio2()).setNegativeSampleRatio(1.0d).setLocalModel(localModel).setDim(i3).setFacotrNum(hParams.factorNum());
        fmftrl2.setModelId(str2).setAlpha(hParams.alpha()).setBeta(hParams.beta()).setLambda1(hParams.lambda1()).setLambda2(hParams.lambda2()).setRho1(hParams.rho1()).setRho2(hParams.rho2()).setLearnRatio2(hParams.learnRatio2()).setNegativeSampleRatio(1.0d).setLocalModel(localModel2).setDim(i4).setFacotrNum(hParams.factorNum());
        sparseESMMFMWithFTRLBatch.setCtrFMFTRL(fmftrl).setCvrFMFTRL(fmftrl2).setBatchSize(i2).setRho1(hParams.rho1()).setRho2(hParams.rho2());
        System.out.println(DateUtil.getCurrentTime(DateStyle.YYYY_MM_DD_HH_MM_SS_SSS) + "  model.train(dataArray)");
        sparseESMMFMWithFTRLBatch.train(labeledESMMSparsePointArr, i);
        LocalModel incrLocalModel = sparseESMMFMWithFTRLBatch.ctrFMFTRL().getIncrLocalModel(i);
        LocalModel incrLocalModel2 = sparseESMMFMWithFTRLBatch.cvrFMFTRL().getIncrLocalModel(i);
        System.out.println(DateUtil.getCurrentTime(DateStyle.YYYY_MM_DD_HH_MM_SS_SSS) + "  psAgent.push(incrModel)");
        psAgent.push(incrLocalModel, i);
        psAgent2.push(incrLocalModel2, i);
    }

    public static void train2(String str, String str2, LabeledESMMSparsePoint[] labeledESMMSparsePointArr, int i, int i2) throws Exception {
        if (labeledESMMSparsePointArr == null || labeledESMMSparsePointArr.length < 1) {
            return;
        }
        FMFTRLHyperParams hParams = HParamsConstant.getHParams(str);
        System.out.println("model=" + str + ",hyperParams=" + hParams);
        int i3 = labeledESMMSparsePointArr[0].ctrLSP().feature().size;
        int i4 = labeledESMMSparsePointArr[0].cvrLSP().feature().size;
        PsAgent psAgent = new PsAgent(str, ModelConstant.MODEL_PAR_SIZE, labeledESMMSparsePointArr[0].ctrLSP().feature().size, hParams.factorNum());
        PsAgent psAgent2 = new PsAgent(str2, ModelConstant.MODEL_PAR_SIZE, labeledESMMSparsePointArr[0].cvrLSP().feature().size, hParams.factorNum());
        LocalModel searchCtrModel = SparseESMMFMWithFTRLBatch.searchCtrModel(labeledESMMSparsePointArr, i3, hParams.factorNum());
        LocalModel searchCvrModel = SparseESMMFMWithFTRLBatch.searchCvrModel(labeledESMMSparsePointArr, i4, hParams.factorNum());
        psAgent.pull(searchCtrModel, false);
        psAgent2.pull(searchCvrModel, false);
        LocalModel localModel = psAgent.getLocalModel();
        LocalModel localModel2 = psAgent2.getLocalModel();
        SparseESMMFMWithFTRLBatch2 sparseESMMFMWithFTRLBatch2 = new SparseESMMFMWithFTRLBatch2();
        FMFTRL fmftrl = new FMFTRL();
        FMFTRL fmftrl2 = new FMFTRL();
        fmftrl.setModelId(str).setAlpha(hParams.alpha()).setBeta(hParams.beta()).setLambda1(hParams.lambda1()).setLambda2(hParams.lambda2()).setRho1(hParams.rho1()).setRho2(hParams.rho2()).setLearnRatio2(hParams.learnRatio2()).setNegativeSampleRatio(1.0d).setLocalModel(localModel).setDim(i3).setFacotrNum(hParams.factorNum());
        fmftrl2.setModelId(str2).setAlpha(hParams.alpha()).setBeta(hParams.beta()).setLambda1(hParams.lambda1()).setLambda2(hParams.lambda2()).setRho1(hParams.rho1()).setRho2(hParams.rho2()).setLearnRatio2(hParams.learnRatio2()).setNegativeSampleRatio(1.0d).setLocalModel(localModel2).setDim(i4).setFacotrNum(hParams.factorNum());
        sparseESMMFMWithFTRLBatch2.setCtrFMFTRL(fmftrl).setCvrFMFTRL(fmftrl2).setBatchSize(i2).setRho1(hParams.rho1()).setRho2(hParams.rho2());
        System.out.println(DateUtil.getCurrentTime(DateStyle.YYYY_MM_DD_HH_MM_SS_SSS) + "  model.train(dataArray)");
        sparseESMMFMWithFTRLBatch2.train(labeledESMMSparsePointArr, i);
        LocalModel incrLocalModel = sparseESMMFMWithFTRLBatch2.ctrFMFTRL().getIncrLocalModel(i);
        LocalModel incrLocalModel2 = sparseESMMFMWithFTRLBatch2.cvrFMFTRL().getIncrLocalModel(i);
        System.out.println(DateUtil.getCurrentTime(DateStyle.YYYY_MM_DD_HH_MM_SS_SSS) + "  psAgent.push(incrModel)");
        psAgent.push(incrLocalModel, i);
        psAgent2.push(incrLocalModel2, i);
    }

    public static void train3(String str, String str2, LabeledESMMSparsePoint[] labeledESMMSparsePointArr, int i, int i2) throws Exception {
        if (labeledESMMSparsePointArr == null || labeledESMMSparsePointArr.length < 1) {
            return;
        }
        FMFTRLHyperParams hParams = HParamsConstant.getHParams(str);
        System.out.println("model=" + str + ",hyperParams=" + hParams);
        int i3 = labeledESMMSparsePointArr[0].ctrLSP().feature().size;
        int i4 = labeledESMMSparsePointArr[0].cvrLSP().feature().size;
        PsAgent psAgent = new PsAgent(str, ModelConstant.MODEL_PAR_SIZE, labeledESMMSparsePointArr[0].ctrLSP().feature().size, hParams.factorNum());
        PsAgent psAgent2 = new PsAgent(str2, ModelConstant.MODEL_PAR_SIZE, labeledESMMSparsePointArr[0].cvrLSP().feature().size, hParams.factorNum());
        LocalModel searchCtrModel = SparseESMMFMWithFTRLBatch.searchCtrModel(labeledESMMSparsePointArr, i3, hParams.factorNum());
        LocalModel searchCvrModel = SparseESMMFMWithFTRLBatch.searchCvrModel(labeledESMMSparsePointArr, i4, hParams.factorNum());
        psAgent.pull(searchCtrModel, false);
        psAgent2.pull(searchCvrModel, false);
        LocalModel localModel = psAgent.getLocalModel();
        LocalModel localModel2 = psAgent2.getLocalModel();
        SparseESMMFMWithFTRLBatch3 sparseESMMFMWithFTRLBatch3 = new SparseESMMFMWithFTRLBatch3();
        FMFTRL fmftrl = new FMFTRL();
        FMFTRL fmftrl2 = new FMFTRL();
        fmftrl.setModelId(str).setAlpha(hParams.alpha()).setBeta(hParams.beta()).setLambda1(hParams.lambda1()).setLambda2(hParams.lambda2()).setRho1(hParams.rho1()).setRho2(hParams.rho2()).setLearnRatio2(hParams.learnRatio2()).setNegativeSampleRatio(1.0d).setLocalModel(localModel).setDim(i3).setFacotrNum(hParams.factorNum());
        fmftrl2.setModelId(str2).setAlpha(hParams.alpha()).setBeta(hParams.beta()).setLambda1(hParams.lambda1()).setLambda2(hParams.lambda2()).setRho1(hParams.rho1()).setRho2(hParams.rho2()).setLearnRatio2(hParams.learnRatio2()).setNegativeSampleRatio(1.0d).setLocalModel(localModel2).setDim(i4).setFacotrNum(hParams.factorNum());
        sparseESMMFMWithFTRLBatch3.setCtrFMFTRL(fmftrl).setCvrFMFTRL(fmftrl2).setBatchSize(i2).setRho1(hParams.rho1()).setRho2(hParams.rho2());
        System.out.println(DateUtil.getCurrentTime(DateStyle.YYYY_MM_DD_HH_MM_SS_SSS) + "  model.train(dataArray)");
        sparseESMMFMWithFTRLBatch3.train(labeledESMMSparsePointArr, i);
        LocalModel incrLocalModel = sparseESMMFMWithFTRLBatch3.cvrFMFTRL().getIncrLocalModel(i);
        System.out.println(DateUtil.getCurrentTime(DateStyle.YYYY_MM_DD_HH_MM_SS_SSS) + "  psAgent.push(incrModel)");
        psAgent2.push(incrLocalModel, i);
    }
}
