/*
 * Decompiled with CFR 0.152.
 */
package cn.com.duiba.nezha.compute.mllib.optimizing.adam;

import cn.com.duiba.nezha.compute.api.point.Point;
import cn.com.duiba.nezha.compute.common.enums.DateStyle;
import cn.com.duiba.nezha.compute.common.util.DateUtil;
import cn.com.duiba.nezha.compute.mllib.optimizing.FMGD$;
import cn.com.duiba.nezha.compute.mllib.optimizing.adam.SparseFMAdamUpdater$;
import cn.com.duiba.nezha.compute.mllib.util.SparseUtil$;
import java.util.Random;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.rdd.RDD;
import scala.Array$;
import scala.Function1;
import scala.Function2;
import scala.Predef$;
import scala.Serializable;
import scala.StringContext;
import scala.Tuple2;
import scala.collection.Seq;
import scala.collection.immutable.Nil$;
import scala.collection.mutable.StringBuilder;
import scala.math.package$;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;

public final class SparseFMAdamSGDOptimizer$ {
    public static final SparseFMAdamSGDOptimizer$ MODULE$;

    static {
        new SparseFMAdamSGDOptimizer$();
    }

    public Point.FMParams run(RDD<Point.LabeledSPoint> data, int batchSize, int F, double beta1, double beta2, double learningRate, double r1, double r2, int MAX_ITERATIONS, int MIN_ITERATIONS, double DELTA_THRESHOLD) {
        Point.FMParams fm_params;
        Random rand2 = new Random(11L);
        int D = ((Point.LabeledSPoint)data.first()).x().size();
        long N = data.cache().count();
        Broadcast F_b = data.context().broadcast((Object)BoxesRunTime.boxToInteger((int)F), ClassTag$.MODULE$.Int());
        Point.FMParams params_new = fm_params = new Point.FMParams(SparseUtil$.MODULE$.rand(0, 0.5), SparseUtil$.MODULE$.rand(D, 0, 0.1), SparseUtil$.MODULE$.rand(D, F, 0, 0.01));
        double delta = 999999.0;
        int t = 0;
        Point.FMGradParams m_t = new Point.FMGradParams(0.0, SparseUtil$.MODULE$.zero(D), SparseUtil$.MODULE$.zero(D, F));
        Point.FMGradParams v_t = new Point.FMGradParams(0.0, SparseUtil$.MODULE$.zero(D), SparseUtil$.MODULE$.zero(D, F));
        Point.FMGradParams grad_mse = new Point.FMGradParams(0.0, SparseUtil$.MODULE$.zero(D), SparseUtil$.MODULE$.zero(D, F));
        for (int i = 0; delta > DELTA_THRESHOLD && i < MAX_ITERATIONS; ++i) {
            int batch = Math.round(batchSize * 1);
            long batchNums = N / (long)batch + 1L;
            double[] batchWeightArray = (double[])Array$.MODULE$.tabulate((int)batchNums, (Function1)new Serializable(batchNums){
                public static final long serialVersionUID = 0L;
                private final long batchNums$1;

                public final double apply(int ti) {
                    return this.apply$mcDI$sp(ti);
                }

                public double apply$mcDI$sp(int ti) {
                    return 1.0 / (double)this.batchNums$1;
                }
                {
                    this.batchNums$1 = batchNums$1;
                }
            }, ClassTag$.MODULE$.Double());
            RDD[] splitsData = data.randomSplit(batchWeightArray, (long)rand2.nextInt());
            int j = 0;
            while (delta > DELTA_THRESHOLD && (long)j < batchNums) {
                Predef$.MODULE$.println((Object)new StringBuilder().append((Object)DateUtil.getCurrentTime((DateStyle)DateStyle.YYYY_MM_DD_HH_MM_SS_SSS)).append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{", On iteration(i=", ",j_total=", ",j=", ")"})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)i), BoxesRunTime.boxToLong((long)batchNums), BoxesRunTime.boxToInteger((int)j)}))).toString());
                ++t;
                Broadcast fm_m_params_b = data.context().broadcast((Object)new Point.FMModelParams(fm_params.w0(), (Vector)fm_params.w().toDense(), (Matrix)fm_params.v().toDense()), ClassTag$.MODULE$.apply(Point.FMModelParams.class));
                Predef$.MODULE$.println((Object)new StringBuilder().append((Object)DateUtil.getCurrentTime((DateStyle)DateStyle.YYYY_MM_DD_HH_MM_SS_SSS)).append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{",grad computer start ,batchSize=", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)batch)}))).toString());
                Tuple2 grad_and_e = (Tuple2)splitsData[j].map((Function1)new Serializable(F_b, fm_m_params_b){
                    public static final long serialVersionUID = 0L;
                    private final Broadcast F_b$1;
                    private final Broadcast fm_m_params_b$1;

                    public final Tuple2<Point.FMGradParams, Object> apply(Point.LabeledSPoint p) {
                        return FMGD$.MODULE$.computeWithErr2(p, (Broadcast<Point.FMModelParams>)this.fm_m_params_b$1, (Broadcast<Object>)this.F_b$1);
                    }
                    {
                        this.F_b$1 = F_b$1;
                        this.fm_m_params_b$1 = fm_m_params_b$1;
                    }
                }, ClassTag$.MODULE$.apply(Tuple2.class)).reduce((Function2)new Serializable(){
                    public static final long serialVersionUID = 0L;

                    public final Tuple2<Point.FMGradParams, Object> apply(Tuple2<Point.FMGradParams, Object> ge1, Tuple2<Point.FMGradParams, Object> ge2) {
                        return new Tuple2((Object)new Point.FMGradParams(((Point.FMGradParams)ge1._1()).grad_w0() + ((Point.FMGradParams)ge2._1()).grad_w0(), SparseUtil$.MODULE$.add(((Point.FMGradParams)ge1._1()).grad_w(), ((Point.FMGradParams)ge2._1()).grad_w()), SparseUtil$.MODULE$.add(((Point.FMGradParams)ge1._1()).grad_v(), ((Point.FMGradParams)ge2._1()).grad_v())), (Object)BoxesRunTime.boxToDouble((double)(ge1._2$mcD$sp() + ge2._2$mcD$sp())));
                    }
                });
                Predef$.MODULE$.println((Object)new StringBuilder().append((Object)DateUtil.getCurrentTime((DateStyle)DateStyle.YYYY_MM_DD_HH_MM_SS_SSS)).append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{",grad computer end "})).s((Seq)Nil$.MODULE$)).toString());
                Point.FMGradParams grad = FMGD$.MODULE$.grad((Point.FMGradParams)grad_and_e._1(), batch);
                Point.FMGradParams grad_regularization = FMGD$.MODULE$.gradWithRegularization(fm_params, grad, r1, r2);
                m_t = FMGD$.MODULE$.gradMergeAddUpdate(m_t, grad_regularization, beta1, 1.0 - beta1);
                v_t = FMGD$.MODULE$.gradMergeAddUpdate(v_t, FMGD$.MODULE$.gradSquare(grad_regularization), beta2, 1.0 - beta2);
                double learningRate_t = SparseFMAdamUpdater$.MODULE$.learningRateUpdate(learningRate, beta1, beta2, t);
                Point.FMGradParams grad_new = SparseFMAdamUpdater$.MODULE$.gradNew(m_t, v_t);
                params_new = SparseFMAdamUpdater$.MODULE$.paramsUpdate(fm_params, grad_new, learningRate_t);
                double grad_rmse = FMGD$.MODULE$.g_rmse(grad);
                double p_delta_rmse = FMGD$.MODULE$.p_rmse(fm_params, params_new);
                double p_delta = FMGD$.MODULE$.p_delta(fm_params, params_new);
                double rmse = package$.MODULE$.sqrt(grad_and_e._2$mcD$sp() / (double)batch);
                Predef$.MODULE$.println((Object)new StringBuilder().append((Object)DateUtil.getCurrentTime((DateStyle)DateStyle.YYYY_MM_DD_HH_MM_SS_SSS)).append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{", delta=", ", p_delta = ", ", p_delta_rmse=", ", grad_rmse=", ",rmse=", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToDouble((double)delta), BoxesRunTime.boxToDouble((double)p_delta), BoxesRunTime.boxToDouble((double)p_delta_rmse), BoxesRunTime.boxToDouble((double)grad_rmse), BoxesRunTime.boxToDouble((double)rmse)}))).toString());
                if (i >= MIN_ITERATIONS && (long)j < batchNums - 1L) {
                    delta = p_delta;
                }
                ++j;
                fm_params = params_new;
            }
        }
        return fm_params;
    }

    public Tuple2<Point.FMGradParams, Object> comb(Tuple2<Point.FMGradParams, Object> ge1, Tuple2<Point.FMGradParams, Object> ge2) {
        return this.gradAcc(ge1, ge2);
    }

    public Tuple2<Point.FMGradParams, Object> gradAcc(Tuple2<Point.FMGradParams, Object> ge1, Tuple2<Point.FMGradParams, Object> ge2) {
        return new Tuple2((Object)new Point.FMGradParams(((Point.FMGradParams)ge1._1()).grad_w0() + ((Point.FMGradParams)ge2._1()).grad_w0(), SparseUtil$.MODULE$.add(((Point.FMGradParams)ge1._1()).grad_w(), ((Point.FMGradParams)ge2._1()).grad_w()), SparseUtil$.MODULE$.add(((Point.FMGradParams)ge1._1()).grad_v(), ((Point.FMGradParams)ge2._1()).grad_v())), (Object)BoxesRunTime.boxToDouble((double)(ge1._2$mcD$sp() + ge2._2$mcD$sp())));
    }

    private SparseFMAdamSGDOptimizer$() {
        MODULE$ = this;
    }
}

