/*
 * Decompiled with CFR 0.152.
 */
package cn.com.duiba.nezha.compute.mllib.optimizing.ad;

import cn.com.duiba.nezha.compute.api.constant.GlobalConstant;
import cn.com.duiba.nezha.compute.api.point.Point;
import cn.com.duiba.nezha.compute.mllib.util.SparseUtil$;
import org.apache.spark.mllib.linalg.SparseMatrix;
import org.apache.spark.mllib.linalg.SparseVector;
import scala.math.package$;

public final class FMADUpdater$ {
    public static final FMADUpdater$ MODULE$;

    static {
        new FMADUpdater$();
    }

    public Point.FMParams paramsDelta(Point.FMParams ps_delta_mse, Point.FMGradParams grad, Point.FMGradParams grad_mse, double r1, double r2) {
        double w0_new = this.paramsDeltaW0(ps_delta_mse.w0(), grad.grad_w0(), grad_mse.grad_w0(), r1, r2);
        SparseVector w_new = this.paramsDeltaW(ps_delta_mse.w(), grad.grad_w(), grad_mse.grad_w(), r1, r2);
        SparseMatrix v_new = this.paramsDeltaV(ps_delta_mse.v(), grad.grad_v(), grad_mse.grad_v(), r1, r2);
        return new Point.FMParams(w0_new, w_new, v_new);
    }

    public double paramsDeltaW0(double w_delta_mse, double grad, double grad_mse, double r1, double r2) {
        return (double)-1 * package$.MODULE$.sqrt(w_delta_mse) / (package$.MODULE$.sqrt(grad_mse) + GlobalConstant.EPSILON) * grad;
    }

    public SparseVector paramsDeltaW(SparseVector w_delta_mse, SparseVector grad, SparseVector grad_mse, double r1, double r2) {
        SparseVector p1 = SparseUtil$.MODULE$.multiply(grad, -1.0);
        SparseVector p2 = SparseUtil$.MODULE$.sqrt(w_delta_mse);
        SparseVector grad_rmse = SparseUtil$.MODULE$.sqrt(grad_mse);
        SparseVector p3 = SparseUtil$.MODULE$.inverse(grad_rmse);
        SparseVector p1_p2 = SparseUtil$.MODULE$.multiply(p1, p2);
        SparseVector p1_p2_p3 = SparseUtil$.MODULE$.multiply(p1_p2, p3);
        return p1_p2_p3;
    }

    public SparseMatrix paramsDeltaV(SparseMatrix w_delta_mse, SparseMatrix grad, SparseMatrix grad_mse, double r1, double r2) {
        SparseMatrix p1 = SparseUtil$.MODULE$.multiply(grad, -1.0);
        SparseMatrix p2 = SparseUtil$.MODULE$.sqrt(w_delta_mse);
        SparseMatrix grad_rmse = SparseUtil$.MODULE$.sqrt(grad_mse);
        SparseMatrix p3 = SparseUtil$.MODULE$.inverse(grad_rmse);
        SparseMatrix p1_p2 = SparseUtil$.MODULE$.multiply(p1, p2);
        SparseMatrix p1_p2_p3 = SparseUtil$.MODULE$.multiply(p1_p2, p3);
        return p1_p2_p3;
    }

    public Point.FMParams paramsUpdate(Point.FMParams psOld, Point.FMParams psDelta) {
        double params_w0 = psOld.w0() + psDelta.w0();
        SparseVector params_w = SparseUtil$.MODULE$.add(psOld.w(), psDelta.w());
        SparseMatrix params_v = SparseUtil$.MODULE$.add(psOld.v(), psDelta.v());
        return new Point.FMParams(params_w0, params_w, params_v);
    }

    public Point.FMGradParams gradMseUpdate(Point.FMGradParams grad_mse, Point.FMGradParams grad, double eta) {
        double grad_w0_mse = eta * grad_mse.grad_w0() + (1.0 - eta) * grad.grad_w0() * grad.grad_w0();
        SparseVector grad_w_mse_part1 = SparseUtil$.MODULE$.multiply(grad_mse.grad_w(), eta);
        SparseVector grad_w_mse_part2 = SparseUtil$.MODULE$.multiply(grad.grad_w());
        SparseVector grad_w_mse_part3 = SparseUtil$.MODULE$.multiply(grad_w_mse_part2, 1.0 - eta);
        SparseVector grad_w_mse = SparseUtil$.MODULE$.add(grad_w_mse_part1, grad_w_mse_part3);
        SparseMatrix grad_v_mse_part1 = SparseUtil$.MODULE$.multiply(grad_mse.grad_v(), eta);
        SparseMatrix grad_v_mse_part2 = SparseUtil$.MODULE$.multiply(grad.grad_v());
        SparseMatrix grad_v_mse_part3 = SparseUtil$.MODULE$.multiply(grad_v_mse_part2, 1.0 - eta);
        SparseMatrix grad_v_mse = SparseUtil$.MODULE$.add(grad_v_mse_part1, grad_v_mse_part3);
        return new Point.FMGradParams(grad_w0_mse, grad_w_mse, grad_v_mse);
    }

    public Point.FMParams paramsDeltaUpdate(Point.FMParams params_delta_mse, Point.FMParams params_delta, double eta) {
        double params_w0_mse = eta * params_delta_mse.w0() + (1.0 - eta) * params_delta.w0() * params_delta.w0();
        SparseVector params_w_mse_part1 = SparseUtil$.MODULE$.multiply(params_delta_mse.w(), eta);
        SparseVector params_w_mse_part2 = SparseUtil$.MODULE$.multiply(params_delta.w());
        SparseVector params_w_mse_part3 = SparseUtil$.MODULE$.multiply(params_w_mse_part2, 1.0 - eta);
        SparseVector params_w_mse = SparseUtil$.MODULE$.add(params_w_mse_part1, params_w_mse_part3);
        SparseMatrix params_v_mse_part1 = SparseUtil$.MODULE$.multiply(params_delta_mse.v(), eta);
        SparseMatrix params_v_mse_part2 = SparseUtil$.MODULE$.multiply(params_delta.v());
        SparseMatrix params_v_mse_part3 = SparseUtil$.MODULE$.multiply(params_v_mse_part2, 1.0 - eta);
        SparseMatrix params_v_mse = SparseUtil$.MODULE$.add(params_v_mse_part1, params_v_mse_part3);
        return new Point.FMParams(params_w0_mse, params_w_mse, params_v_mse);
    }

    public Point.FMParams paramsDelta(Point.FMParams params_new, Point.FMParams params_old) {
        double params_w0_delta = params_new.w0() - params_old.w0();
        SparseVector params_w_delta = SparseUtil$.MODULE$.subtraction(params_new.w(), params_old.w());
        SparseMatrix params_v_delta = SparseUtil$.MODULE$.subtraction(params_new.v(), params_old.v());
        return new Point.FMParams(params_w0_delta, params_w_delta, params_v_delta);
    }

    private FMADUpdater$() {
        MODULE$ = this;
    }
}

