/*
 * Decompiled with CFR 0.152.
 */
package cn.com.duiba.nezha.compute.mllib.optimizing.mt;

import cn.com.duiba.nezha.compute.api.point.Point;
import cn.com.duiba.nezha.compute.mllib.optimizing.FMGD$;
import cn.com.duiba.nezha.compute.mllib.optimizing.mt.SparseFMMTUpdater$;
import cn.com.duiba.nezha.compute.mllib.util.SparseUtil$;
import java.util.Random;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.rdd.RDD;
import scala.Array$;
import scala.Function1;
import scala.Function2;
import scala.Predef$;
import scala.Serializable;
import scala.StringContext;
import scala.Tuple2;
import scala.collection.Seq;
import scala.math.package$;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;

public final class SparseFMMTSGDOptimizer$ {
    public static final SparseFMMTSGDOptimizer$ MODULE$;

    static {
        new SparseFMMTSGDOptimizer$();
    }

    public Point.FMParams run(RDD<Point.LabeledSPoint> data, int batchSize, int F, double mtRate, double learningRate, double r1, double r2, int MAX_ITERATIONS, int MIN_ITERATIONS, double DELTA_THRESHOLD) {
        Random rand2 = new Random(11L);
        int D = ((Point.LabeledSPoint)data.first()).x().size();
        long N = data.cache().count();
        Point.FMParams fm_params = new Point.FMParams(SparseUtil$.MODULE$.rand(0), SparseUtil$.MODULE$.rand(D, 0), SparseUtil$.MODULE$.rand(D, F, 0));
        double delta = 999999.0;
        long batchNums = N / (long)batchSize + 1L;
        double[] batchWeightArray = (double[])Array$.MODULE$.tabulate((int)batchNums, (Function1)new Serializable(batchNums){
            public static final long serialVersionUID = 0L;
            private final long batchNums$1;

            public final double apply(int i) {
                return this.apply$mcDI$sp(i);
            }

            public double apply$mcDI$sp(int i) {
                return 1.0 / (double)this.batchNums$1;
            }
            {
                this.batchNums$1 = batchNums$1;
            }
        }, ClassTag$.MODULE$.Double());
        Point.FMParams params_delta = new Point.FMParams(0.0, SparseUtil$.MODULE$.zero(D), SparseUtil$.MODULE$.zero(D, F));
        for (int i = 0; delta > DELTA_THRESHOLD && i < MAX_ITERATIONS; ++i) {
            RDD[] splitsData = data.randomSplit(batchWeightArray, (long)rand2.nextInt());
            int j = 0;
            while (delta > DELTA_THRESHOLD && (long)j < batchNums) {
                Broadcast fm_params_b = data.context().broadcast((Object)fm_params, ClassTag$.MODULE$.apply(Point.FMParams.class));
                Broadcast F_b = data.context().broadcast((Object)BoxesRunTime.boxToInteger((int)F), ClassTag$.MODULE$.Int());
                long splitsDataSize = splitsData[j].cache().count();
                Predef$.MODULE$.println((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"On iteration(i=", ",j=", "), traindate.size=", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)i), BoxesRunTime.boxToInteger((int)j), BoxesRunTime.boxToLong((long)splitsDataSize)})));
                RDD data_g_e = splitsData[j].map((Function1)new Serializable(fm_params_b, F_b){
                    public static final long serialVersionUID = 0L;
                    private final Broadcast fm_params_b$1;
                    private final Broadcast F_b$1;

                    public final Tuple2<Point.FMGradParams, Object> apply(Point.LabeledSPoint p) {
                        return FMGD$.MODULE$.computeWithErr(p, (Broadcast<Point.FMParams>)this.fm_params_b$1, (Broadcast<Object>)this.F_b$1);
                    }
                    {
                        this.fm_params_b$1 = fm_params_b$1;
                        this.F_b$1 = F_b$1;
                    }
                }, ClassTag$.MODULE$.apply(Tuple2.class)).cache();
                Tuple2 grad_ps_e = (Tuple2)data_g_e.reduce((Function2)new Serializable(){
                    public static final long serialVersionUID = 0L;

                    public final Tuple2<Point.FMGradParams, Object> apply(Tuple2<Point.FMGradParams, Object> ge1, Tuple2<Point.FMGradParams, Object> ge2) {
                        return new Tuple2((Object)new Point.FMGradParams(((Point.FMGradParams)ge1._1()).grad_w0() + ((Point.FMGradParams)ge2._1()).grad_w0(), SparseUtil$.MODULE$.add(((Point.FMGradParams)ge1._1()).grad_w(), ((Point.FMGradParams)ge2._1()).grad_w()), SparseUtil$.MODULE$.add(((Point.FMGradParams)ge1._1()).grad_v(), ((Point.FMGradParams)ge2._1()).grad_v())), (Object)BoxesRunTime.boxToDouble((double)(ge1._2$mcD$sp() + ge2._2$mcD$sp())));
                    }
                });
                Point.FMGradParams grad_batch = (Point.FMGradParams)grad_ps_e._1();
                Point.FMGradParams grad = FMGD$.MODULE$.grad(grad_batch, splitsDataSize);
                Point.FMGradParams grad_regularization = FMGD$.MODULE$.gradWithRegularization(fm_params, grad, r1, r2);
                params_delta = SparseFMMTUpdater$.MODULE$.paramsDelta(params_delta, grad_regularization, mtRate, learningRate);
                Point.FMParams params_new = FMGD$.MODULE$.paramsUpdate(fm_params, params_delta);
                double rmse = package$.MODULE$.sqrt(grad_ps_e._2$mcD$sp() / (double)splitsDataSize);
                double grad_rmse = FMGD$.MODULE$.g_rmse(grad);
                double p_delta_rmse = FMGD$.MODULE$.p_rmse(fm_params, params_new);
                double p_delta = FMGD$.MODULE$.p_delta(fm_params, params_new);
                Predef$.MODULE$.println((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"delta=", ", p_delta = ", ", p_delta_rmse=", ", grad_rmse=", ",rmse=", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToDouble((double)delta), BoxesRunTime.boxToDouble((double)p_delta), BoxesRunTime.boxToDouble((double)p_delta_rmse), BoxesRunTime.boxToDouble((double)grad_rmse), BoxesRunTime.boxToDouble((double)rmse)})));
                if (i >= MIN_ITERATIONS && (long)j < batchNums - 1L) {
                    delta = p_delta;
                }
                ++j;
                fm_params = params_new;
            }
        }
        return fm_params;
    }

    private SparseFMMTSGDOptimizer$() {
        MODULE$ = this;
    }
}

