package cn.com.duiba.nezha.compute.mllib.optimizing;

import cn.com.duiba.nezha.compute.api.constant.GlobalConstant;
import cn.com.duiba.nezha.compute.api.point.Point;
import cn.com.duiba.nezha.compute.mllib.util.SparseUtil$;
import org.apache.spark.mllib.linalg.SparseMatrix;
import org.apache.spark.mllib.linalg.SparseVector;
import scala.math.package$;

/* compiled from: SparseFMADUpdater.scala */
/* loaded from: input_file:cn/com/duiba/nezha/compute/mllib/optimizing/SparseFMADUpdater$.class */
public final class SparseFMADUpdater$ {
    public static final SparseFMADUpdater$ MODULE$ = null;

    static {
        new SparseFMADUpdater$();
    }

    public Point.FMParams paramsDelta(Point.FMParams fMParams, Point.FMGradParams fMGradParams, Point.FMGradParams fMGradParams2, double d, double d2) {
        return new Point.FMParams(paramsDeltaW0(fMParams.w0(), fMGradParams.grad_w0(), fMGradParams2.grad_w0(), d, d2), paramsDeltaW(fMParams.w(), fMGradParams.grad_w(), fMGradParams2.grad_w(), d, d2), paramsDeltaV(fMParams.v(), fMGradParams.grad_v(), fMGradParams2.grad_v(), d, d2));
    }

    public double paramsDeltaW0(double d, double d2, double d3, double d4, double d5) {
        return (((-1) * package$.MODULE$.sqrt(d)) / (package$.MODULE$.sqrt(d3) + GlobalConstant.EPSILON)) * d2;
    }

    public SparseVector paramsDeltaW(SparseVector sparseVector, SparseVector sparseVector2, SparseVector sparseVector3, double d, double d2) {
        SparseVector multiply = SparseUtil$.MODULE$.multiply(sparseVector2, -1.0d);
        SparseVector sqrt = SparseUtil$.MODULE$.sqrt(sparseVector);
        SparseVector inverse = SparseUtil$.MODULE$.inverse(SparseUtil$.MODULE$.sqrt(sparseVector3));
        return SparseUtil$.MODULE$.multiply(SparseUtil$.MODULE$.multiply(multiply, sqrt), inverse);
    }

    public SparseMatrix paramsDeltaV(SparseMatrix sparseMatrix, SparseMatrix sparseMatrix2, SparseMatrix sparseMatrix3, double d, double d2) {
        SparseMatrix multiply = SparseUtil$.MODULE$.multiply(sparseMatrix2, -1.0d);
        SparseMatrix sqrt = SparseUtil$.MODULE$.sqrt(sparseMatrix);
        SparseMatrix inverse = SparseUtil$.MODULE$.inverse(SparseUtil$.MODULE$.sqrt(sparseMatrix3));
        return SparseUtil$.MODULE$.multiply(SparseUtil$.MODULE$.multiply(multiply, sqrt), inverse);
    }

    public Point.FMParams paramsUpdate(Point.FMParams fMParams, Point.FMParams fMParams2) {
        return new Point.FMParams(fMParams.w0() + fMParams2.w0(), SparseUtil$.MODULE$.add(fMParams.w(), fMParams2.w()), SparseUtil$.MODULE$.add(fMParams.v(), fMParams2.v()));
    }

    public Point.FMGradParams gradMseUpdate(Point.FMGradParams fMGradParams, Point.FMGradParams fMGradParams2, double d) {
        return new Point.FMGradParams((d * fMGradParams.grad_w0()) + ((1 - d) * fMGradParams2.grad_w0() * fMGradParams2.grad_w0()), SparseUtil$.MODULE$.add(SparseUtil$.MODULE$.multiply(fMGradParams.grad_w(), d), SparseUtil$.MODULE$.multiply(SparseUtil$.MODULE$.multiply(fMGradParams2.grad_w()), 1 - d)), SparseUtil$.MODULE$.add(SparseUtil$.MODULE$.multiply(fMGradParams.grad_v(), d), SparseUtil$.MODULE$.multiply(SparseUtil$.MODULE$.multiply(fMGradParams2.grad_v()), 1 - d)));
    }

    public Point.FMParams paramsDeltaUpdate(Point.FMParams fMParams, Point.FMParams fMParams2, double d) {
        return new Point.FMParams((d * fMParams.w0()) + ((1 - d) * fMParams2.w0() * fMParams2.w0()), SparseUtil$.MODULE$.add(SparseUtil$.MODULE$.multiply(fMParams.w(), d), SparseUtil$.MODULE$.multiply(SparseUtil$.MODULE$.multiply(fMParams2.w()), 1 - d)), SparseUtil$.MODULE$.add(SparseUtil$.MODULE$.multiply(fMParams.v(), d), SparseUtil$.MODULE$.multiply(SparseUtil$.MODULE$.multiply(fMParams2.v()), 1 - d)));
    }

    public Point.FMParams paramsDelta(Point.FMParams fMParams, Point.FMParams fMParams2) {
        return new Point.FMParams(fMParams.w0() - fMParams2.w0(), SparseUtil$.MODULE$.subtraction(fMParams.w(), fMParams2.w()), SparseUtil$.MODULE$.subtraction(fMParams.v(), fMParams2.v()));
    }

    private SparseFMADUpdater$() {
        MODULE$ = this;
    }
}
