package cn.com.duiba.nezha.compute.mllib.optimizing.mt;

import cn.com.duiba.nezha.compute.api.point.Point;
import cn.com.duiba.nezha.compute.mllib.util.MLUtil$;
import cn.com.duiba.nezha.compute.mllib.util.SparseUtil$;
import org.apache.spark.mllib.linalg.SparseMatrix;
import org.apache.spark.mllib.linalg.SparseVector;
import scala.MatchError;
import scala.Tuple2;

/* compiled from: SparseFMMTUpdater.scala */
/* loaded from: input_file:cn/com/duiba/nezha/compute/mllib/optimizing/mt/SparseFMMTUpdater$.class */
public final class SparseFMMTUpdater$ {
    public static final SparseFMMTUpdater$ MODULE$ = null;

    static {
        new SparseFMMTUpdater$();
    }

    public Tuple2<Point.FMParams, Point.FMGradParams> update(Point.FMParams fMParams, Point.FMGradParams fMGradParams, Point.FMGradParams fMGradParams2, double d, double d2, double d3, double d4) {
        double deltaW0 = deltaW0(fMParams.w0(), fMGradParams.grad_w0(), d3, d4);
        SparseVector deltaW = deltaW(fMParams.w(), fMGradParams.grad_w(), d3, d4);
        SparseMatrix deltaV = deltaV(fMParams.v(), fMGradParams.grad_v(), d3, d4);
        Tuple2<Object, Object> updateW0 = updateW0(fMParams.w0(), deltaW0, fMGradParams2.grad_w0(), d, d2);
        if (updateW0 == null) {
            throw new MatchError(updateW0);
        }
        Tuple2.mcDD.sp spVar = new Tuple2.mcDD.sp(updateW0._1$mcD$sp(), updateW0._2$mcD$sp());
        double _1$mcD$sp = spVar._1$mcD$sp();
        double _2$mcD$sp = spVar._2$mcD$sp();
        Tuple2<SparseVector, SparseVector> updateW = updateW(fMParams.w(), deltaW, fMGradParams2.grad_w(), d, d2);
        if (updateW == null) {
            throw new MatchError(updateW);
        }
        Tuple2 tuple2 = new Tuple2((SparseVector) updateW._1(), (SparseVector) updateW._2());
        SparseVector sparseVector = (SparseVector) tuple2._1();
        SparseVector sparseVector2 = (SparseVector) tuple2._2();
        Tuple2<SparseMatrix, SparseMatrix> updateV = updateV(fMParams.v(), deltaV, fMGradParams2.grad_v(), d, d2);
        if (updateV == null) {
            throw new MatchError(updateV);
        }
        Tuple2 tuple22 = new Tuple2((SparseMatrix) updateV._1(), (SparseMatrix) updateV._2());
        return new Tuple2<>(new Point.FMParams(_1$mcD$sp, sparseVector, (SparseMatrix) tuple22._1()), new Point.FMGradParams(_2$mcD$sp, sparseVector2, (SparseMatrix) tuple22._2()));
    }

    public Tuple2<Object, Object> updateW0(double d, double d2, double d3, double d4, double d5) {
        double d6 = (d4 * d3) + (d5 * d2);
        return new Tuple2.mcDD.sp(d - d6, d6);
    }

    public Tuple2<SparseVector, SparseVector> updateW(SparseVector sparseVector, SparseVector sparseVector2, SparseVector sparseVector3, double d, double d2) {
        SparseVector add = SparseUtil$.MODULE$.add(SparseUtil$.MODULE$.multiply(sparseVector3, d), SparseUtil$.MODULE$.multiply(sparseVector2, d2));
        return new Tuple2<>(SparseUtil$.MODULE$.subtraction(sparseVector, add), add);
    }

    public Tuple2<SparseMatrix, SparseMatrix> updateV(SparseMatrix sparseMatrix, SparseMatrix sparseMatrix2, SparseMatrix sparseMatrix3, double d, double d2) {
        SparseMatrix add = SparseUtil$.MODULE$.add(SparseUtil$.MODULE$.multiply(sparseMatrix3, d), SparseUtil$.MODULE$.multiply(sparseMatrix2, d2));
        return new Tuple2<>(SparseUtil$.MODULE$.subtraction(sparseMatrix, add), add);
    }

    public double deltaW0(double d, double d2, double d3, double d4) {
        return (d4 * d) + (d3 * MLUtil$.MODULE$.sign(d) * d) + d2;
    }

    public SparseVector deltaW(SparseVector sparseVector, SparseVector sparseVector2, double d, double d2) {
        return SparseUtil$.MODULE$.add(sparseVector2, SparseUtil$.MODULE$.add(SparseUtil$.MODULE$.multiply(sparseVector, d2), SparseUtil$.MODULE$.multiply(SparseUtil$.MODULE$.multiply(MLUtil$.MODULE$.sign(sparseVector), sparseVector), d)));
    }

    public SparseMatrix deltaV(SparseMatrix sparseMatrix, SparseMatrix sparseMatrix2, double d, double d2) {
        return SparseUtil$.MODULE$.add(sparseMatrix2, SparseUtil$.MODULE$.add(SparseUtil$.MODULE$.multiply(sparseMatrix, d2), SparseUtil$.MODULE$.multiply(SparseUtil$.MODULE$.multiply(MLUtil$.MODULE$.sign(sparseMatrix), sparseMatrix), d)));
    }

    public Point.FMParams paramsDelta(Point.FMParams fMParams, Point.FMGradParams fMGradParams, double d, double d2) {
        return new Point.FMParams(paramsDeltaW0(fMParams.w0(), fMGradParams.grad_w0(), d, d2), paramsDeltaW(fMParams.w(), fMGradParams.grad_w(), d, d2), paramsDeltaV(fMParams.v(), fMGradParams.grad_v(), d, d2));
    }

    public double paramsDeltaW0(double d, double d2, double d3, double d4) {
        return (d3 * d) + ((-1) * d4 * d2);
    }

    public SparseVector paramsDeltaW(SparseVector sparseVector, SparseVector sparseVector2, double d, double d2) {
        return SparseUtil$.MODULE$.add(SparseUtil$.MODULE$.multiply(sparseVector, d), SparseUtil$.MODULE$.multiply(sparseVector2, (-1.0d) * d2));
    }

    public SparseMatrix paramsDeltaV(SparseMatrix sparseMatrix, SparseMatrix sparseMatrix2, double d, double d2) {
        return SparseUtil$.MODULE$.add(SparseUtil$.MODULE$.multiply(sparseMatrix, d), SparseUtil$.MODULE$.multiply(sparseMatrix2, (-1.0d) * d2));
    }

    private SparseFMMTUpdater$() {
        MODULE$ = this;
    }
}
