package cn.com.duiba.nezha.compute.mllib.optimizing.mt;

import cn.com.duiba.nezha.compute.api.point.Point;
import cn.com.duiba.nezha.compute.mllib.optimizing.FMGD$;
import cn.com.duiba.nezha.compute.mllib.util.SparseUtil$;
import java.util.Random;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.rdd.RDD;
import scala.Array$;
import scala.Predef$;
import scala.StringContext;
import scala.Tuple2;
import scala.math.package$;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;

/* compiled from: SparseFMMTSGDOptimizer.scala */
/* loaded from: input_file:cn/com/duiba/nezha/compute/mllib/optimizing/mt/SparseFMMTSGDOptimizer$.class */
public final class SparseFMMTSGDOptimizer$ {
    public static final SparseFMMTSGDOptimizer$ MODULE$ = null;

    static {
        new SparseFMMTSGDOptimizer$();
    }

    public Point.FMParams run(RDD<Point.LabeledSPoint> rdd, int i, int i2, double d, double d2, double d3, double d4, int i3, int i4, double d5) {
        Random random = new Random(11L);
        int size = ((Point.LabeledSPoint) rdd.first()).x().size();
        long count = rdd.cache().count();
        Point.FMParams fMParams = new Point.FMParams(SparseUtil$.MODULE$.rand(0), SparseUtil$.MODULE$.rand(size, 0), SparseUtil$.MODULE$.rand(size, i2, 0));
        double d6 = 999999.0d;
        long j = (count / i) + 1;
        double[] dArr = (double[]) Array$.MODULE$.tabulate((int) j, new SparseFMMTSGDOptimizer$$anonfun$1(j), ClassTag$.MODULE$.Double());
        Point.FMParams fMParams2 = new Point.FMParams(0.0d, SparseUtil$.MODULE$.zero(size), SparseUtil$.MODULE$.zero(size, i2));
        for (int i5 = 0; d6 > d5 && i5 < i3; i5++) {
            RDD[] randomSplit = rdd.randomSplit(dArr, random.nextInt());
            int i6 = 0;
            while (d6 > d5 && i6 < j) {
                Broadcast broadcast = rdd.context().broadcast(fMParams, ClassTag$.MODULE$.apply(Point.FMParams.class));
                Broadcast broadcast2 = rdd.context().broadcast(BoxesRunTime.boxToInteger(i2), ClassTag$.MODULE$.Int());
                long count2 = randomSplit[i6].cache().count();
                Predef$.MODULE$.println(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"On iteration(i=", ",j=", "), traindate.size=", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToInteger(i5), BoxesRunTime.boxToInteger(i6), BoxesRunTime.boxToLong(count2)})));
                Tuple2 tuple2 = (Tuple2) randomSplit[i6].map(new SparseFMMTSGDOptimizer$$anonfun$2(broadcast, broadcast2), ClassTag$.MODULE$.apply(Tuple2.class)).cache().reduce(new SparseFMMTSGDOptimizer$$anonfun$3());
                Point.FMGradParams grad = FMGD$.MODULE$.grad((Point.FMGradParams) tuple2._1(), count2);
                fMParams2 = SparseFMMTUpdater$.MODULE$.paramsDelta(fMParams2, FMGD$.MODULE$.gradWithRegularization(fMParams, grad, d3, d4), d, d2);
                Point.FMParams paramsUpdate = FMGD$.MODULE$.paramsUpdate(fMParams, fMParams2);
                double sqrt = package$.MODULE$.sqrt(tuple2._2$mcD$sp() / count2);
                double g_rmse = FMGD$.MODULE$.g_rmse(grad);
                double p_rmse = FMGD$.MODULE$.p_rmse(fMParams, paramsUpdate);
                double p_delta = FMGD$.MODULE$.p_delta(fMParams, paramsUpdate);
                Predef$.MODULE$.println(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"delta=", ", p_delta = ", ", p_delta_rmse=", ", grad_rmse=", ",rmse=", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToDouble(d6), BoxesRunTime.boxToDouble(p_delta), BoxesRunTime.boxToDouble(p_rmse), BoxesRunTime.boxToDouble(g_rmse), BoxesRunTime.boxToDouble(sqrt)})));
                if (i5 >= i4 && i6 < j - 1) {
                    d6 = p_delta;
                }
                i6++;
                fMParams = paramsUpdate;
            }
        }
        return fMParams;
    }

    private SparseFMMTSGDOptimizer$() {
        MODULE$ = this;
    }
}
