package cn.com.duiba.nezha.compute.mllib.optimizing;

import cn.com.duiba.nezha.compute.api.constant.GlobalConstant;
import cn.com.duiba.nezha.compute.api.point.Point;
import cn.com.duiba.nezha.compute.mllib.util.MLUtil$;
import cn.com.duiba.nezha.compute.mllib.util.SparseUtil$;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.mllib.linalg.SparseMatrix;
import org.apache.spark.mllib.linalg.SparseVector;
import scala.MatchError;
import scala.Tuple2;
import scala.runtime.BoxesRunTime;

/* compiled from: FMGD.scala */
/* loaded from: input_file:cn/com/duiba/nezha/compute/mllib/optimizing/FMGD$.class */
public final class FMGD$ {
    public static final FMGD$ MODULE$ = null;

    static {
        new FMGD$();
    }

    public double p_delta(Point.FMParams fMParams, Point.FMParams fMParams2) {
        double w0 = (fMParams.w0() - fMParams2.w0()) * (fMParams.w0() - fMParams2.w0());
        SparseVector subtraction = SparseUtil$.MODULE$.subtraction(fMParams.w(), fMParams2.w());
        double dot = SparseUtil$.MODULE$.dot(subtraction, subtraction);
        return w0 + dot + SparseUtil$.MODULE$.sum(SparseUtil$.MODULE$.multiply(SparseUtil$.MODULE$.subtraction(fMParams.v(), fMParams2.v())));
    }

    public double p_delta_mse(Point.FMParams fMParams) {
        double w0 = fMParams.w0() * fMParams.w0();
        double dot = SparseUtil$.MODULE$.dot(fMParams.w(), fMParams.w());
        return Math.sqrt(w0 + dot + SparseUtil$.MODULE$.sum(SparseUtil$.MODULE$.multiply(fMParams.v())));
    }

    public double p_rmse(Point.FMParams fMParams, Point.FMParams fMParams2) {
        double w0 = (fMParams.w0() - fMParams2.w0()) * (fMParams.w0() - fMParams2.w0());
        SparseVector subtraction = SparseUtil$.MODULE$.subtraction(fMParams.w(), fMParams2.w());
        return Math.sqrt(((w0 + SparseUtil$.MODULE$.dot(subtraction, subtraction)) + SparseUtil$.MODULE$.sum(SparseUtil$.MODULE$.multiply(SparseUtil$.MODULE$.subtraction(fMParams.v(), fMParams2.v())))) / ((1 + fMParams.w().numActives()) + fMParams.v().numActives()));
    }

    public double g_rmse(Point.FMGradParams fMGradParams) {
        return Math.sqrt((((fMGradParams.grad_w0() * fMGradParams.grad_w0()) + SparseUtil$.MODULE$.dot(fMGradParams.grad_w(), fMGradParams.grad_w())) + SparseUtil$.MODULE$.sum(SparseUtil$.MODULE$.multiply(fMGradParams.grad_v()))) / ((1 + fMGradParams.grad_w().numActives()) + fMGradParams.grad_v().numActives()));
    }

    public Point.FMGradParams grad(Point.FMGradParams fMGradParams, double d) {
        return new Point.FMGradParams(fMGradParams.grad_w0() / d, SparseUtil$.MODULE$.multiply(fMGradParams.grad_w(), 1.0d / d), SparseUtil$.MODULE$.multiply(fMGradParams.grad_v(), 1.0d / d));
    }

    public Point.FMGradParams gradSqrt(Point.FMGradParams fMGradParams) {
        return new Point.FMGradParams(Math.sqrt(fMGradParams.grad_w0()), SparseUtil$.MODULE$.sqrt(fMGradParams.grad_w()), SparseUtil$.MODULE$.sqrt(fMGradParams.grad_v()));
    }

    public Point.FMGradParams gradInverse(Point.FMGradParams fMGradParams) {
        return new Point.FMGradParams(1 / (fMGradParams.grad_w0() + GlobalConstant.EPSILON), SparseUtil$.MODULE$.inverse(fMGradParams.grad_w()), SparseUtil$.MODULE$.inverse(fMGradParams.grad_v()));
    }

    public Point.FMGradParams gradMergeAddUpdate(Point.FMGradParams fMGradParams, Point.FMGradParams fMGradParams2, double d, double d2) {
        return new Point.FMGradParams((d * fMGradParams.grad_w0()) + (d2 * fMGradParams2.grad_w0()), SparseUtil$.MODULE$.add(SparseUtil$.MODULE$.multiply(fMGradParams.grad_w(), d), SparseUtil$.MODULE$.multiply(fMGradParams2.grad_w(), d2)), SparseUtil$.MODULE$.add(SparseUtil$.MODULE$.multiply(fMGradParams.grad_v(), d), SparseUtil$.MODULE$.multiply(fMGradParams2.grad_v(), d2)));
    }

    public Point.FMGradParams gradSquare(Point.FMGradParams fMGradParams) {
        return new Point.FMGradParams(fMGradParams.grad_w0() * fMGradParams.grad_w0(), SparseUtil$.MODULE$.multiply(fMGradParams.grad_w()), SparseUtil$.MODULE$.multiply(fMGradParams.grad_v()));
    }

    public Point.FMGradParams gradMultiply(Point.FMGradParams fMGradParams, Point.FMGradParams fMGradParams2) {
        return new Point.FMGradParams(fMGradParams.grad_w0() * fMGradParams2.grad_w0(), SparseUtil$.MODULE$.multiply(fMGradParams.grad_w(), fMGradParams2.grad_w()), SparseUtil$.MODULE$.multiply(fMGradParams.grad_v(), fMGradParams2.grad_v()));
    }

    public Point.FMGradParams gradMultiply(Point.FMGradParams fMGradParams, double d) {
        return new Point.FMGradParams(fMGradParams.grad_w0() * d, SparseUtil$.MODULE$.multiply(fMGradParams.grad_w(), d), SparseUtil$.MODULE$.multiply(fMGradParams.grad_v(), d));
    }

    public Point.FMGradParams gradWithRegularization(Point.FMParams fMParams, Point.FMGradParams fMGradParams, double d, double d2) {
        return new Point.FMGradParams(gradWithRegularizationW0(fMParams.w0(), fMGradParams.grad_w0(), d, d2), gradWithRegularizationW(fMParams.w(), fMGradParams.grad_w(), d, d2), gradWithRegularizationV(fMParams.v(), fMGradParams.grad_v(), d, d2));
    }

    public double gradWithRegularizationW0(double d, double d2, double d3, double d4) {
        return (d4 * d) + (d3 * MLUtil$.MODULE$.sign(d)) + d2;
    }

    public SparseVector gradWithRegularizationW(SparseVector sparseVector, SparseVector sparseVector2, double d, double d2) {
        return SparseUtil$.MODULE$.add(sparseVector2, SparseUtil$.MODULE$.add(SparseUtil$.MODULE$.multiply(sparseVector, d2), SparseUtil$.MODULE$.multiply(MLUtil$.MODULE$.sign(sparseVector), d)));
    }

    public SparseMatrix gradWithRegularizationV(SparseMatrix sparseMatrix, SparseMatrix sparseMatrix2, double d, double d2) {
        return SparseUtil$.MODULE$.add(sparseMatrix2, SparseUtil$.MODULE$.add(SparseUtil$.MODULE$.multiply(sparseMatrix, d2), SparseUtil$.MODULE$.multiply(MLUtil$.MODULE$.sign(sparseMatrix), d)));
    }

    public Tuple2<Point.FMGradParams, Object> computeWithErr(Point.LabeledSPoint labeledSPoint, Broadcast<Point.FMParams> broadcast, Broadcast<Object> broadcast2) {
        double h = h(labeledSPoint.x(), (Point.FMParams) broadcast.value()) - labeledSPoint.y();
        return new Tuple2<>(new Point.FMGradParams(h, SparseUtil$.MODULE$.multiply(labeledSPoint.x(), h), SparseUtil$.MODULE$.multiply(SparseUtil$.MODULE$.subtraction(SparseUtil$.MODULE$.dot_m(labeledSPoint.x(), SparseUtil$.MODULE$.dot_row(((Point.FMParams) broadcast.value()).v(), labeledSPoint.x())), SparseUtil$.MODULE$.multiply(SparseUtil$.MODULE$.multiply(labeledSPoint.x()), BoxesRunTime.unboxToInt(broadcast2.value()), ((Point.FMParams) broadcast.value()).v())), h)), BoxesRunTime.boxToDouble(h * h));
    }

    public Tuple2<Point.FMGradParams, Object> computeWithErr2(Point.LabeledSPoint labeledSPoint, Broadcast<Point.FMModelParams> broadcast, Broadcast<Object> broadcast2) {
        double h = h(labeledSPoint.x(), (Point.FMModelParams) broadcast.value()) - labeledSPoint.y();
        return new Tuple2<>(new Point.FMGradParams(h, SparseUtil$.MODULE$.multiply(labeledSPoint.x(), h), SparseUtil$.MODULE$.multiply(SparseUtil$.MODULE$.subtraction(SparseUtil$.MODULE$.dot_m(labeledSPoint.x(), SparseUtil$.MODULE$.dot_row(((Point.FMModelParams) broadcast.value()).v(), labeledSPoint.x())), SparseUtil$.MODULE$.multiply(SparseUtil$.MODULE$.multiply(labeledSPoint.x()), BoxesRunTime.unboxToInt(broadcast2.value()), ((Point.FMModelParams) broadcast.value()).v())), h)), BoxesRunTime.boxToDouble(h * h));
    }

    public Point.FMGradParams compute(Point.LabeledSPoint labeledSPoint, Broadcast<Point.FMParams> broadcast, Broadcast<Object> broadcast2) {
        Tuple2<Point.FMGradParams, Object> computeWithErr = computeWithErr(labeledSPoint, broadcast, broadcast2);
        if (computeWithErr == null) {
            throw new MatchError(computeWithErr);
        }
        Tuple2 tuple2 = new Tuple2((Point.FMGradParams) computeWithErr._1(), BoxesRunTime.boxToDouble(computeWithErr._2$mcD$sp()));
        Point.FMGradParams fMGradParams = (Point.FMGradParams) tuple2._1();
        tuple2._2$mcD$sp();
        return fMGradParams;
    }

    public Point.FMParams paramsUpdate(Point.FMParams fMParams, Point.FMParams fMParams2) {
        return new Point.FMParams(fMParams.w0() + fMParams2.w0(), SparseUtil$.MODULE$.add(fMParams.w(), fMParams2.w()), SparseUtil$.MODULE$.add(fMParams.v(), fMParams2.v()));
    }

    public double h2(SparseVector sparseVector, Point.FMParams fMParams) {
        SparseVector multiply = SparseUtil$.MODULE$.multiply(SparseUtil$.MODULE$.dot_row(fMParams.v(), sparseVector));
        SparseVector multiply2 = SparseUtil$.MODULE$.multiply(sparseVector);
        return MLUtil$.MODULE$.sigmoid(fMParams.w0() + SparseUtil$.MODULE$.dot(sparseVector, fMParams.w()) + (0.5d * (SparseUtil$.MODULE$.sum(multiply) - SparseUtil$.MODULE$.sum(SparseUtil$.MODULE$.dot_row(SparseUtil$.MODULE$.multiply(fMParams.v()), multiply2)))));
    }

    public double h(SparseVector sparseVector, Point.FMModelParams fMModelParams) {
        SparseMatrix multiply = SparseUtil$.MODULE$.multiply(sparseVector, fMModelParams.v().numCols(), fMModelParams.v());
        SparseMatrix multiply2 = SparseUtil$.MODULE$.multiply(multiply);
        return MLUtil$.MODULE$.sigmoid(fMModelParams.w0() + SparseUtil$.MODULE$.dot3(sparseVector, fMModelParams.w()) + (0.5d * (SparseUtil$.MODULE$.sum(SparseUtil$.MODULE$.multiply(SparseUtil$.MODULE$.sum_row(multiply))) - SparseUtil$.MODULE$.sum(SparseUtil$.MODULE$.sum_row(multiply2)))));
    }

    public double h(SparseVector sparseVector, Point.FMParams fMParams) {
        SparseMatrix multiply = SparseUtil$.MODULE$.multiply(SparseUtil$.MODULE$.vector_copy(sparseVector, fMParams.v().numCols()), fMParams.v());
        SparseMatrix multiply2 = SparseUtil$.MODULE$.multiply(multiply);
        return MLUtil$.MODULE$.sigmoid(fMParams.w0() + SparseUtil$.MODULE$.dot(sparseVector, fMParams.w()) + (0.5d * (SparseUtil$.MODULE$.sum(SparseUtil$.MODULE$.multiply(SparseUtil$.MODULE$.sum_row(multiply))) - SparseUtil$.MODULE$.sum(SparseUtil$.MODULE$.sum_row(multiply2)))));
    }

    private FMGD$() {
        MODULE$ = this;
    }
}
