/*
 * Decompiled with CFR 0.152.
 */
package cn.com.duiba.nezha.compute.mllib.optimizing.mt;

import cn.com.duiba.nezha.compute.api.point.Point;
import cn.com.duiba.nezha.compute.mllib.util.MLUtil$;
import cn.com.duiba.nezha.compute.mllib.util.SparseUtil$;
import org.apache.spark.mllib.linalg.SparseMatrix;
import org.apache.spark.mllib.linalg.SparseVector;
import scala.MatchError;
import scala.Tuple2;

public final class SparseFMMTUpdater$ {
    public static final SparseFMMTUpdater$ MODULE$;

    static {
        new SparseFMMTUpdater$();
    }

    public Tuple2<Point.FMParams, Point.FMGradParams> update(Point.FMParams psOld, Point.FMGradParams grad, Point.FMGradParams deltaOld, double mtRate, double learningRate, double r1, double r2) {
        double grad_w0 = this.deltaW0(psOld.w0(), grad.grad_w0(), r1, r2);
        SparseVector grad_w = this.deltaW(psOld.w(), grad.grad_w(), r1, r2);
        SparseMatrix grad_v = this.deltaV(psOld.v(), grad.grad_v(), r1, r2);
        Tuple2<Object, Object> tuple2 = this.updateW0(psOld.w0(), grad_w0, deltaOld.grad_w0(), mtRate, learningRate);
        if (tuple2 != null) {
            Tuple2.mcDD.sp sp2;
            double w0_new = tuple2._1$mcD$sp();
            double delta_w0 = tuple2._2$mcD$sp();
            Tuple2.mcDD.sp sp3 = sp2 = new Tuple2.mcDD.sp(w0_new, delta_w0);
            double w0_new2 = sp3._1$mcD$sp();
            double delta_w02 = sp3._2$mcD$sp();
            Tuple2<SparseVector, SparseVector> tuple22 = this.updateW(psOld.w(), grad_w, deltaOld.grad_w(), mtRate, learningRate);
            if (tuple22 != null) {
                Tuple2 tuple23;
                SparseVector w_new = (SparseVector)tuple22._1();
                SparseVector delta_w = (SparseVector)tuple22._2();
                Tuple2 tuple24 = tuple23 = new Tuple2((Object)w_new, (Object)delta_w);
                SparseVector w_new2 = (SparseVector)tuple24._1();
                SparseVector delta_w2 = (SparseVector)tuple24._2();
                Tuple2<SparseMatrix, SparseMatrix> tuple25 = this.updateV(psOld.v(), grad_v, deltaOld.grad_v(), mtRate, learningRate);
                if (tuple25 != null) {
                    Tuple2 tuple26;
                    SparseMatrix v_new = (SparseMatrix)tuple25._1();
                    SparseMatrix delta_v = (SparseMatrix)tuple25._2();
                    Tuple2 tuple27 = tuple26 = new Tuple2((Object)v_new, (Object)delta_v);
                    SparseMatrix v_new2 = (SparseMatrix)tuple27._1();
                    SparseMatrix delta_v2 = (SparseMatrix)tuple27._2();
                    return new Tuple2((Object)new Point.FMParams(w0_new2, w_new2, v_new2), (Object)new Point.FMGradParams(delta_w02, delta_w2, delta_v2));
                }
                throw new MatchError(tuple25);
            }
            throw new MatchError(tuple22);
        }
        throw new MatchError(tuple2);
    }

    public Tuple2<Object, Object> updateW0(double w, double grad, double delta_old, double mtRate, double learningRate) {
        double delta = mtRate * delta_old + learningRate * grad;
        return new Tuple2.mcDD.sp(w - delta, delta);
    }

    public Tuple2<SparseVector, SparseVector> updateW(SparseVector w, SparseVector grad, SparseVector delta_old, double mtRate, double learningRate) {
        SparseVector delta_p1 = SparseUtil$.MODULE$.multiply(delta_old, mtRate);
        SparseVector delta_p2 = SparseUtil$.MODULE$.multiply(grad, learningRate);
        SparseVector delta = SparseUtil$.MODULE$.add(delta_p1, delta_p2);
        return new Tuple2((Object)SparseUtil$.MODULE$.subtraction(w, delta), (Object)delta);
    }

    public Tuple2<SparseMatrix, SparseMatrix> updateV(SparseMatrix v, SparseMatrix grad, SparseMatrix delta_old, double mtRate, double learningRate) {
        SparseMatrix delta_p1 = SparseUtil$.MODULE$.multiply(delta_old, mtRate);
        SparseMatrix delta_p2 = SparseUtil$.MODULE$.multiply(grad, learningRate);
        SparseMatrix delta = SparseUtil$.MODULE$.add(delta_p1, delta_p2);
        return new Tuple2((Object)SparseUtil$.MODULE$.subtraction(v, delta), (Object)delta);
    }

    public double deltaW0(double w, double grad, double r1, double r2) {
        double p1 = r2 * w;
        double p2 = r1 * MLUtil$.MODULE$.sign(w) * w;
        return p1 + p2 + grad;
    }

    public SparseVector deltaW(SparseVector w, SparseVector grad, double r1, double r2) {
        SparseVector p1 = SparseUtil$.MODULE$.multiply(w, r2);
        SparseVector p2_1 = SparseUtil$.MODULE$.multiply(MLUtil$.MODULE$.sign(w), w);
        SparseVector p2 = SparseUtil$.MODULE$.multiply(p2_1, r1);
        SparseVector p1_p2 = SparseUtil$.MODULE$.add(p1, p2);
        SparseVector p1_p2_p3 = SparseUtil$.MODULE$.add(grad, p1_p2);
        return p1_p2_p3;
    }

    public SparseMatrix deltaV(SparseMatrix v, SparseMatrix grad, double r1, double r2) {
        SparseMatrix p1 = SparseUtil$.MODULE$.multiply(v, r2);
        SparseMatrix p2_1 = SparseUtil$.MODULE$.multiply(MLUtil$.MODULE$.sign(v), v);
        SparseMatrix p2 = SparseUtil$.MODULE$.multiply(p2_1, r1);
        SparseMatrix p1_p2 = SparseUtil$.MODULE$.add(p1, p2);
        SparseMatrix p1_p2_p3 = SparseUtil$.MODULE$.add(grad, p1_p2);
        return p1_p2_p3;
    }

    public Point.FMParams paramsDelta(Point.FMParams params_delta_old, Point.FMGradParams grad_regularization, double mtRate, double learningRate) {
        double params_delta_w0_new = this.paramsDeltaW0(params_delta_old.w0(), grad_regularization.grad_w0(), mtRate, learningRate);
        SparseVector params_delta_w_new = this.paramsDeltaW(params_delta_old.w(), grad_regularization.grad_w(), mtRate, learningRate);
        SparseMatrix params_delta_v_new = this.paramsDeltaV(params_delta_old.v(), grad_regularization.grad_v(), mtRate, learningRate);
        return new Point.FMParams(params_delta_w0_new, params_delta_w_new, params_delta_v_new);
    }

    public double paramsDeltaW0(double w_delta, double grad, double mtRate, double learningRate) {
        return mtRate * w_delta + (double)-1 * learningRate * grad;
    }

    public SparseVector paramsDeltaW(SparseVector w_delta, SparseVector grad, double mtRate, double learningRate) {
        SparseVector p1 = SparseUtil$.MODULE$.multiply(w_delta, mtRate);
        SparseVector p2 = SparseUtil$.MODULE$.multiply(grad, -1.0 * learningRate);
        return SparseUtil$.MODULE$.add(p1, p2);
    }

    public SparseMatrix paramsDeltaV(SparseMatrix v_delta, SparseMatrix grad, double mtRate, double learningRate) {
        SparseMatrix p1 = SparseUtil$.MODULE$.multiply(v_delta, mtRate);
        SparseMatrix p2 = SparseUtil$.MODULE$.multiply(grad, -1.0 * learningRate);
        return SparseUtil$.MODULE$.add(p1, p2);
    }

    private SparseFMMTUpdater$() {
        MODULE$ = this;
    }
}

