/*
 * Decompiled with CFR 0.152.
 */
package cn.com.duiba.nezha.compute.mllib.optimizing.gd;

import cn.com.duiba.nezha.compute.api.point.Point;
import cn.com.duiba.nezha.compute.mllib.optimizing.FMGD$;
import cn.com.duiba.nezha.compute.mllib.optimizing.SparseFMUpdater$;
import cn.com.duiba.nezha.compute.mllib.util.SparseUtil$;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.rdd.RDD;
import scala.Function1;
import scala.Function2;
import scala.Predef$;
import scala.Serializable;
import scala.collection.mutable.StringBuilder;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;

public final class SparseFMGDOptimizer$ {
    public static final SparseFMGDOptimizer$ MODULE$;

    static {
        new SparseFMGDOptimizer$();
    }

    public Point.FMParams run(RDD<Point.LabeledSPoint> data, int F, double learningRate, double r1, double r2, int MAX_ITERATIONS, int MIN_ITERATIONS, double DELTA_THRESHOLD) {
        int D = ((Point.LabeledSPoint)data.first()).x().size();
        long N = data.count();
        Point.FMParams fm_params = new Point.FMParams(SparseUtil$.MODULE$.rand(0), SparseUtil$.MODULE$.rand(D, 0), SparseUtil$.MODULE$.rand(D, F, 0));
        double delta = 1.0;
        for (int i = 0; delta > DELTA_THRESHOLD && i < MAX_ITERATIONS; ++i) {
            Predef$.MODULE$.println((Object)new StringBuilder().append((Object)"On iteration ").append((Object)BoxesRunTime.boxToInteger((int)i)).toString());
            Broadcast fm_params_b = data.context().broadcast((Object)fm_params, ClassTag$.MODULE$.apply(Point.FMParams.class));
            Broadcast F_b = data.context().broadcast((Object)BoxesRunTime.boxToInteger((int)F), ClassTag$.MODULE$.Int());
            RDD data_g = data.map((Function1)new Serializable(fm_params_b, F_b){
                public static final long serialVersionUID = 0L;
                private final Broadcast fm_params_b$1;
                private final Broadcast F_b$1;

                public final Point.FMGradParams apply(Point.LabeledSPoint p) {
                    return FMGD$.MODULE$.compute(p, (Broadcast<Point.FMParams>)this.fm_params_b$1, (Broadcast<Object>)this.F_b$1);
                }
                {
                    this.fm_params_b$1 = fm_params_b$1;
                    this.F_b$1 = F_b$1;
                }
            }, ClassTag$.MODULE$.apply(Point.FMGradParams.class)).cache();
            Point.FMGradParams grad_batch = (Point.FMGradParams)data_g.reduce((Function2)new Serializable(){
                public static final long serialVersionUID = 0L;

                public final Point.FMGradParams apply(Point.FMGradParams g1, Point.FMGradParams g2) {
                    return new Point.FMGradParams(g1.grad_w0() + g2.grad_w0(), SparseUtil$.MODULE$.add(g1.grad_w(), g2.grad_w()), SparseUtil$.MODULE$.add(g1.grad_v(), g2.grad_v()));
                }
            });
            Point.FMGradParams grad = FMGD$.MODULE$.grad(grad_batch, N);
            Point.FMParams params_new = SparseFMUpdater$.MODULE$.update(fm_params, grad, learningRate, r1, r2);
            if (i >= MIN_ITERATIONS) {
                delta = FMGD$.MODULE$.p_delta(fm_params, params_new);
            }
            fm_params = params_new;
        }
        return fm_params;
    }

    private SparseFMGDOptimizer$() {
        MODULE$ = this;
    }
}

