package cn.com.duiba.nezha.compute.mllib.optimizing.adam;

import cn.com.duiba.nezha.compute.api.point.Point;
import cn.com.duiba.nezha.compute.common.enums.DateStyle;
import cn.com.duiba.nezha.compute.common.util.DateUtil;
import cn.com.duiba.nezha.compute.mllib.optimizing.FMGD$;
import cn.com.duiba.nezha.compute.mllib.util.SparseUtil$;
import java.util.Random;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.rdd.RDD;
import scala.Array$;
import scala.Predef$;
import scala.StringContext;
import scala.Tuple2;
import scala.collection.immutable.Nil$;
import scala.collection.mutable.StringBuilder;
import scala.math.package$;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;

/* compiled from: SparseFMAdamSGDOptimizer.scala */
/* loaded from: input_file:cn/com/duiba/nezha/compute/mllib/optimizing/adam/SparseFMAdamSGDOptimizer$.class */
public final class SparseFMAdamSGDOptimizer$ {
    public static final SparseFMAdamSGDOptimizer$ MODULE$ = null;

    static {
        new SparseFMAdamSGDOptimizer$();
    }

    public Point.FMParams run(RDD<Point.LabeledSPoint> rdd, int i, int i2, double d, double d2, double d3, double d4, double d5, int i3, int i4, double d6) {
        Random random = new Random(11L);
        int size = ((Point.LabeledSPoint) rdd.first()).x().size();
        long count = rdd.cache().count();
        Broadcast broadcast = rdd.context().broadcast(BoxesRunTime.boxToInteger(i2), ClassTag$.MODULE$.Int());
        Point.FMParams fMParams = new Point.FMParams(SparseUtil$.MODULE$.rand(0, 0.5d), SparseUtil$.MODULE$.rand(size, 0, 0.1d), SparseUtil$.MODULE$.rand(size, i2, 0, 0.01d));
        double d7 = 999999.0d;
        int i5 = 0;
        long j = (count / i) + 1;
        double[] dArr = (double[]) Array$.MODULE$.tabulate((int) j, new SparseFMAdamSGDOptimizer$$anonfun$1(j), ClassTag$.MODULE$.Double());
        Point.FMGradParams fMGradParams = new Point.FMGradParams(0.0d, SparseUtil$.MODULE$.zero(size), SparseUtil$.MODULE$.zero(size, i2));
        Point.FMGradParams fMGradParams2 = new Point.FMGradParams(0.0d, SparseUtil$.MODULE$.zero(size), SparseUtil$.MODULE$.zero(size, i2));
        new Point.FMGradParams(0.0d, SparseUtil$.MODULE$.zero(size), SparseUtil$.MODULE$.zero(size, i2));
        RDD[] randomSplit = rdd.randomSplit(dArr, random.nextInt());
        for (int i6 = 0; d7 > d6 && i6 < i3; i6++) {
            int i7 = 0;
            while (d7 > d6 && i7 < j) {
                Predef$.MODULE$.println(new StringBuilder().append(DateUtil.getCurrentTime(DateStyle.YYYY_MM_DD_HH_MM_SS_SSS)).append(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{", On iteration(i=", ",j=", ")"})).s(Predef$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToInteger(i6), BoxesRunTime.boxToInteger(i7)}))).toString());
                i5++;
                Broadcast broadcast2 = rdd.context().broadcast(new Point.FMModelParams(fMParams.w0(), fMParams.w().toDense(), fMParams.v().toDense()), ClassTag$.MODULE$.apply(Point.FMModelParams.class));
                Predef$.MODULE$.println(new StringBuilder().append(DateUtil.getCurrentTime(DateStyle.YYYY_MM_DD_HH_MM_SS_SSS)).append(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{",grad computer start ,batchSize=", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToInteger(i)}))).toString());
                Tuple2 tuple2 = (Tuple2) randomSplit[i7].map(new SparseFMAdamSGDOptimizer$$anonfun$2(broadcast, broadcast2), ClassTag$.MODULE$.apply(Tuple2.class)).reduce(new SparseFMAdamSGDOptimizer$$anonfun$3());
                Predef$.MODULE$.println(new StringBuilder().append(DateUtil.getCurrentTime(DateStyle.YYYY_MM_DD_HH_MM_SS_SSS)).append(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{",grad computer end "})).s(Nil$.MODULE$)).toString());
                Point.FMGradParams grad = FMGD$.MODULE$.grad((Point.FMGradParams) tuple2._1(), i);
                Point.FMGradParams gradWithRegularization = FMGD$.MODULE$.gradWithRegularization(fMParams, grad, d4, d5);
                fMGradParams = FMGD$.MODULE$.gradMergeAddUpdate(fMGradParams, gradWithRegularization, d, 1 - d);
                fMGradParams2 = FMGD$.MODULE$.gradMergeAddUpdate(fMGradParams2, FMGD$.MODULE$.gradSquare(gradWithRegularization), d2, 1 - d2);
                Point.FMParams paramsUpdate = SparseFMAdamUpdater$.MODULE$.paramsUpdate(fMParams, SparseFMAdamUpdater$.MODULE$.gradNew(fMGradParams, fMGradParams2), SparseFMAdamUpdater$.MODULE$.learningRateUpdate(d3, d, d2, i5));
                double g_rmse = FMGD$.MODULE$.g_rmse(grad);
                double p_rmse = FMGD$.MODULE$.p_rmse(fMParams, paramsUpdate);
                double p_delta = FMGD$.MODULE$.p_delta(fMParams, paramsUpdate);
                Predef$.MODULE$.println(new StringBuilder().append(DateUtil.getCurrentTime(DateStyle.YYYY_MM_DD_HH_MM_SS_SSS)).append(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{", delta=", ", p_delta = ", ", p_delta_rmse=", ", grad_rmse=", ",rmse=", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToDouble(d7), BoxesRunTime.boxToDouble(p_delta), BoxesRunTime.boxToDouble(p_rmse), BoxesRunTime.boxToDouble(g_rmse), BoxesRunTime.boxToDouble(package$.MODULE$.sqrt(tuple2._2$mcD$sp() / i))}))).toString());
                if (i6 >= i4 && i7 < j - 1) {
                    d7 = p_delta;
                }
                i7++;
                fMParams = paramsUpdate;
            }
        }
        return fMParams;
    }

    public Tuple2<Point.FMGradParams, Object> comb(Tuple2<Point.FMGradParams, Object> tuple2, Tuple2<Point.FMGradParams, Object> tuple22) {
        return gradAcc(tuple2, tuple22);
    }

    public Tuple2<Point.FMGradParams, Object> gradAcc(Tuple2<Point.FMGradParams, Object> tuple2, Tuple2<Point.FMGradParams, Object> tuple22) {
        return new Tuple2<>(new Point.FMGradParams(((Point.FMGradParams) tuple2._1()).grad_w0() + ((Point.FMGradParams) tuple22._1()).grad_w0(), SparseUtil$.MODULE$.add(((Point.FMGradParams) tuple2._1()).grad_w(), ((Point.FMGradParams) tuple22._1()).grad_w()), SparseUtil$.MODULE$.add(((Point.FMGradParams) tuple2._1()).grad_v(), ((Point.FMGradParams) tuple22._1()).grad_v())), BoxesRunTime.boxToDouble(tuple2._2$mcD$sp() + tuple22._2$mcD$sp()));
    }

    private SparseFMAdamSGDOptimizer$() {
        MODULE$ = this;
    }
}
