/*
 * Decompiled with CFR 0.152.
 */
package cn.com.duiba.nezha.compute.biz.app;

import cn.com.duiba.nezha.compute.biz.bo.SampleBo;
import cn.com.duiba.nezha.compute.biz.constant.model.HParamsConstant;
import cn.com.duiba.nezha.compute.biz.constant.model.ModelConstant;
import cn.com.duiba.nezha.compute.biz.dto.PsModelSample;
import cn.com.duiba.nezha.compute.biz.ps.PsAgent;
import cn.com.duiba.nezha.compute.core.LabeledESMMSparsePoint;
import cn.com.duiba.nezha.compute.core.enums.DateStyle;
import cn.com.duiba.nezha.compute.core.model.local.LocalModel;
import cn.com.duiba.nezha.compute.core.util.DateUtil;
import cn.com.duiba.nezha.compute.mllib.fm.ftrl.FMFTRL;
import cn.com.duiba.nezha.compute.mllib.fm.ftrl.FMFTRLHyperParams;
import cn.com.duiba.nezha.compute.mllib.fm.ftrl.SparseESMMFMWithFTRLBatch;
import cn.com.duiba.nezha.compute.mllib.fm.ftrl.SparseESMMFMWithFTRLBatch2;
import cn.com.duiba.nezha.compute.mllib.fm.ftrl.SparseESMMFMWithFTRLBatch3;
import java.util.List;

public class SparseESMMFMWithFTRLApp {
    public static void run(String psCtrModelId, String psCvrModelId, List<PsModelSample> samples, boolean isReplay, int partNums, int type, int batchSize) throws Exception {
        System.out.println(DateUtil.getCurrentTime((DateStyle)DateStyle.YYYY_MM_DD_HH_MM_SS_SSS) + "  batch run start ");
        LabeledESMMSparsePoint[] dataArray = SampleBo.getESMMSample(psCtrModelId, psCvrModelId, samples);
        if (!isReplay) {
            if (type == 1) {
                SparseESMMFMWithFTRLApp.train(psCtrModelId, psCvrModelId, dataArray, partNums, batchSize);
            }
            if (type == 2) {
                SparseESMMFMWithFTRLApp.train2(psCtrModelId, psCvrModelId, dataArray, partNums, batchSize);
            }
            if (type == 3) {
                SparseESMMFMWithFTRLApp.train3(psCtrModelId, psCvrModelId, dataArray, partNums, batchSize);
            }
        }
        System.out.println(DateUtil.getCurrentTime((DateStyle)DateStyle.YYYY_MM_DD_HH_MM_SS_SSS) + "  batch run end");
    }

    public static void run(String psCtrModelId, String psCvrModelId, List<PsModelSample> samples, List<PsModelSample> bCvrSamples, boolean isReplay, int partNums, int type, int batchSize) throws Exception {
        System.out.println(DateUtil.getCurrentTime((DateStyle)DateStyle.YYYY_MM_DD_HH_MM_SS_SSS) + "  batch run start ");
        LabeledESMMSparsePoint[] dataArray = SampleBo.getESMMSample(psCtrModelId, psCvrModelId, samples, bCvrSamples);
        if (!isReplay) {
            if (type == 1) {
                SparseESMMFMWithFTRLApp.train(psCtrModelId, psCvrModelId, dataArray, partNums, batchSize);
            }
            if (type == 2) {
                SparseESMMFMWithFTRLApp.train2(psCtrModelId, psCvrModelId, dataArray, partNums, batchSize);
            }
            if (type == 3) {
                SparseESMMFMWithFTRLApp.train3(psCtrModelId, psCvrModelId, dataArray, partNums, batchSize);
            }
        }
        System.out.println(DateUtil.getCurrentTime((DateStyle)DateStyle.YYYY_MM_DD_HH_MM_SS_SSS) + "  batch run end");
    }

    public static void train(String psCtrModelId, String psCvrModelId, LabeledESMMSparsePoint[] dataArray, int partNums, int batchSize) throws Exception {
        if (dataArray == null || dataArray.length < 1) {
            return;
        }
        FMFTRLHyperParams hyperParams = HParamsConstant.getHParams(psCtrModelId);
        System.out.println("model=" + psCtrModelId + ",hyperParams=" + hyperParams);
        int ctrDim = dataArray[0].ctrLSP().feature().size;
        int cvrDim = dataArray[0].cvrLSP().feature().size;
        PsAgent psCtrAgent = new PsAgent(psCtrModelId, ModelConstant.MODEL_PAR_SIZE, dataArray[0].ctrLSP().feature().size, hyperParams.factorNum());
        PsAgent psCvrAgent = new PsAgent(psCvrModelId, ModelConstant.MODEL_PAR_SIZE, dataArray[0].cvrLSP().feature().size, hyperParams.factorNum());
        LocalModel ctrSearchLocalModel = SparseESMMFMWithFTRLBatch.searchCtrModel((LabeledESMMSparsePoint[])dataArray, (int)ctrDim, (int)hyperParams.factorNum());
        LocalModel cvrSearchLocalModel = SparseESMMFMWithFTRLBatch.searchCvrModel((LabeledESMMSparsePoint[])dataArray, (int)cvrDim, (int)hyperParams.factorNum());
        psCtrAgent.pull(ctrSearchLocalModel, false);
        psCvrAgent.pull(cvrSearchLocalModel, false);
        LocalModel ctrLocalModel = psCtrAgent.getLocalModel();
        LocalModel cvrLocalModel = psCvrAgent.getLocalModel();
        SparseESMMFMWithFTRLBatch model = new SparseESMMFMWithFTRLBatch();
        FMFTRL ctrModel = new FMFTRL();
        FMFTRL cvrModel = new FMFTRL();
        ctrModel.setModelId(psCtrModelId).setAlpha(hyperParams.alpha()).setBeta(hyperParams.beta()).setLambda1(hyperParams.lambda1()).setLambda2(hyperParams.lambda2()).setRho1(hyperParams.rho1()).setRho2(hyperParams.rho2()).setLearnRatio2(hyperParams.learnRatio2()).setNegativeSampleRatio(1.0).setLocalModel(ctrLocalModel).setDim(ctrDim).setFacotrNum(hyperParams.factorNum());
        cvrModel.setModelId(psCvrModelId).setAlpha(hyperParams.alpha()).setBeta(hyperParams.beta()).setLambda1(hyperParams.lambda1()).setLambda2(hyperParams.lambda2()).setRho1(hyperParams.rho1()).setRho2(hyperParams.rho2()).setLearnRatio2(hyperParams.learnRatio2()).setNegativeSampleRatio(1.0).setLocalModel(cvrLocalModel).setDim(cvrDim).setFacotrNum(hyperParams.factorNum());
        model.setCtrFMFTRL(ctrModel).setCvrFMFTRL(cvrModel).setBatchSize(batchSize).setRho1(hyperParams.rho1()).setRho2(hyperParams.rho2());
        System.out.println(DateUtil.getCurrentTime((DateStyle)DateStyle.YYYY_MM_DD_HH_MM_SS_SSS) + "  model.train(dataArray)");
        model.train(dataArray, partNums);
        LocalModel ctrIncrModel = model.ctrFMFTRL().getIncrLocalModel(partNums);
        LocalModel cvrIncrModel = model.cvrFMFTRL().getIncrLocalModel(partNums);
        System.out.println(DateUtil.getCurrentTime((DateStyle)DateStyle.YYYY_MM_DD_HH_MM_SS_SSS) + "  psAgent.push(incrModel)");
        psCtrAgent.push(ctrIncrModel, partNums);
        psCvrAgent.push(cvrIncrModel, partNums);
    }

    public static void train2(String psCtrModelId, String psCvrModelId, LabeledESMMSparsePoint[] dataArray, int partNums, int batchSize) throws Exception {
        if (dataArray == null || dataArray.length < 1) {
            return;
        }
        FMFTRLHyperParams hyperParams = HParamsConstant.getHParams(psCtrModelId);
        System.out.println("model=" + psCtrModelId + ",hyperParams=" + hyperParams);
        int ctrDim = dataArray[0].ctrLSP().feature().size;
        int cvrDim = dataArray[0].cvrLSP().feature().size;
        PsAgent psCtrAgent = new PsAgent(psCtrModelId, ModelConstant.MODEL_PAR_SIZE, dataArray[0].ctrLSP().feature().size, hyperParams.factorNum());
        PsAgent psCvrAgent = new PsAgent(psCvrModelId, ModelConstant.MODEL_PAR_SIZE, dataArray[0].cvrLSP().feature().size, hyperParams.factorNum());
        LocalModel ctrSearchLocalModel = SparseESMMFMWithFTRLBatch.searchCtrModel((LabeledESMMSparsePoint[])dataArray, (int)ctrDim, (int)hyperParams.factorNum());
        LocalModel cvrSearchLocalModel = SparseESMMFMWithFTRLBatch.searchCvrModel((LabeledESMMSparsePoint[])dataArray, (int)cvrDim, (int)hyperParams.factorNum());
        psCtrAgent.pull(ctrSearchLocalModel, false);
        psCvrAgent.pull(cvrSearchLocalModel, false);
        LocalModel ctrLocalModel = psCtrAgent.getLocalModel();
        LocalModel cvrLocalModel = psCvrAgent.getLocalModel();
        SparseESMMFMWithFTRLBatch2 model = new SparseESMMFMWithFTRLBatch2();
        FMFTRL ctrModel = new FMFTRL();
        FMFTRL cvrModel = new FMFTRL();
        ctrModel.setModelId(psCtrModelId).setAlpha(hyperParams.alpha()).setBeta(hyperParams.beta()).setLambda1(hyperParams.lambda1()).setLambda2(hyperParams.lambda2()).setRho1(hyperParams.rho1()).setRho2(hyperParams.rho2()).setLearnRatio2(hyperParams.learnRatio2()).setNegativeSampleRatio(1.0).setLocalModel(ctrLocalModel).setDim(ctrDim).setFacotrNum(hyperParams.factorNum());
        cvrModel.setModelId(psCvrModelId).setAlpha(hyperParams.alpha()).setBeta(hyperParams.beta()).setLambda1(hyperParams.lambda1()).setLambda2(hyperParams.lambda2()).setRho1(hyperParams.rho1()).setRho2(hyperParams.rho2()).setLearnRatio2(hyperParams.learnRatio2()).setNegativeSampleRatio(1.0).setLocalModel(cvrLocalModel).setDim(cvrDim).setFacotrNum(hyperParams.factorNum());
        model.setCtrFMFTRL(ctrModel).setCvrFMFTRL(cvrModel).setBatchSize(batchSize).setRho1(hyperParams.rho1()).setRho2(hyperParams.rho2());
        System.out.println(DateUtil.getCurrentTime((DateStyle)DateStyle.YYYY_MM_DD_HH_MM_SS_SSS) + "  model.train(dataArray)");
        model.train(dataArray, partNums);
        LocalModel ctrIncrModel = model.ctrFMFTRL().getIncrLocalModel(partNums);
        LocalModel cvrIncrModel = model.cvrFMFTRL().getIncrLocalModel(partNums);
        System.out.println(DateUtil.getCurrentTime((DateStyle)DateStyle.YYYY_MM_DD_HH_MM_SS_SSS) + "  psAgent.push(incrModel)");
        psCtrAgent.push(ctrIncrModel, partNums);
        psCvrAgent.push(cvrIncrModel, partNums);
    }

    public static void train3(String psCtrModelId, String psCvrModelId, LabeledESMMSparsePoint[] dataArray, int partNums, int batchSize) throws Exception {
        if (dataArray == null || dataArray.length < 1) {
            return;
        }
        FMFTRLHyperParams hyperParams = HParamsConstant.getHParams(psCtrModelId);
        System.out.println("model=" + psCtrModelId + ",hyperParams=" + hyperParams);
        int ctrDim = dataArray[0].ctrLSP().feature().size;
        int cvrDim = dataArray[0].cvrLSP().feature().size;
        PsAgent psCtrAgent = new PsAgent(psCtrModelId, ModelConstant.MODEL_PAR_SIZE, dataArray[0].ctrLSP().feature().size, hyperParams.factorNum());
        PsAgent psCvrAgent = new PsAgent(psCvrModelId, ModelConstant.MODEL_PAR_SIZE, dataArray[0].cvrLSP().feature().size, hyperParams.factorNum());
        LocalModel ctrSearchLocalModel = SparseESMMFMWithFTRLBatch.searchCtrModel((LabeledESMMSparsePoint[])dataArray, (int)ctrDim, (int)hyperParams.factorNum());
        LocalModel cvrSearchLocalModel = SparseESMMFMWithFTRLBatch.searchCvrModel((LabeledESMMSparsePoint[])dataArray, (int)cvrDim, (int)hyperParams.factorNum());
        psCtrAgent.pull(ctrSearchLocalModel, false);
        psCvrAgent.pull(cvrSearchLocalModel, false);
        LocalModel ctrLocalModel = psCtrAgent.getLocalModel();
        LocalModel cvrLocalModel = psCvrAgent.getLocalModel();
        SparseESMMFMWithFTRLBatch3 model = new SparseESMMFMWithFTRLBatch3();
        FMFTRL ctrModel = new FMFTRL();
        FMFTRL cvrModel = new FMFTRL();
        ctrModel.setModelId(psCtrModelId).setAlpha(hyperParams.alpha()).setBeta(hyperParams.beta()).setLambda1(hyperParams.lambda1()).setLambda2(hyperParams.lambda2()).setRho1(hyperParams.rho1()).setRho2(hyperParams.rho2()).setLearnRatio2(hyperParams.learnRatio2()).setNegativeSampleRatio(1.0).setLocalModel(ctrLocalModel).setDim(ctrDim).setFacotrNum(hyperParams.factorNum());
        cvrModel.setModelId(psCvrModelId).setAlpha(hyperParams.alpha()).setBeta(hyperParams.beta()).setLambda1(hyperParams.lambda1()).setLambda2(hyperParams.lambda2()).setRho1(hyperParams.rho1()).setRho2(hyperParams.rho2()).setLearnRatio2(hyperParams.learnRatio2()).setNegativeSampleRatio(1.0).setLocalModel(cvrLocalModel).setDim(cvrDim).setFacotrNum(hyperParams.factorNum());
        model.setCtrFMFTRL(ctrModel).setCvrFMFTRL(cvrModel).setBatchSize(batchSize).setRho1(hyperParams.rho1()).setRho2(hyperParams.rho2());
        System.out.println(DateUtil.getCurrentTime((DateStyle)DateStyle.YYYY_MM_DD_HH_MM_SS_SSS) + "  model.train(dataArray)");
        model.train(dataArray, partNums);
        LocalModel cvrIncrModel = model.cvrFMFTRL().getIncrLocalModel(partNums);
        System.out.println(DateUtil.getCurrentTime((DateStyle)DateStyle.YYYY_MM_DD_HH_MM_SS_SSS) + "  psAgent.push(incrModel)");
        psCvrAgent.push(cvrIncrModel, partNums);
    }
}

