/*
 * Decompiled with CFR 0.152.
 */
package cn.com.duiba.nezha.compute.mllib.optimizing;

import cn.com.duiba.nezha.compute.api.constant.GlobalConstant;
import cn.com.duiba.nezha.compute.api.point.Point;
import cn.com.duiba.nezha.compute.mllib.util.MLUtil$;
import cn.com.duiba.nezha.compute.mllib.util.SparseUtil$;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.linalg.SparseMatrix;
import org.apache.spark.mllib.linalg.SparseVector;
import scala.MatchError;
import scala.Tuple2;
import scala.runtime.BoxesRunTime;

public final class FMGD$ {
    public static final FMGD$ MODULE$;

    static {
        new FMGD$();
    }

    public double p_delta(Point.FMParams fp_old, Point.FMParams fp_new) {
        double p_w0 = (fp_old.w0() - fp_new.w0()) * (fp_old.w0() - fp_new.w0());
        SparseVector o_n_w = SparseUtil$.MODULE$.subtraction(fp_old.w(), fp_new.w());
        double p_w = SparseUtil$.MODULE$.dot(o_n_w, o_n_w);
        SparseMatrix o_n_v = SparseUtil$.MODULE$.subtraction(fp_old.v(), fp_new.v());
        SparseMatrix o_n_v_2 = SparseUtil$.MODULE$.multiply(o_n_v);
        double p_v = SparseUtil$.MODULE$.sum(o_n_v_2);
        return p_w0 + p_w + p_v;
    }

    public double p_delta_mse(Point.FMParams fp_delta) {
        double p_w0 = fp_delta.w0() * fp_delta.w0();
        double p_w = SparseUtil$.MODULE$.dot(fp_delta.w(), fp_delta.w());
        SparseMatrix o_n_v_2 = SparseUtil$.MODULE$.multiply(fp_delta.v());
        double p_v = SparseUtil$.MODULE$.sum(o_n_v_2);
        return Math.sqrt(p_w0 + p_w + p_v);
    }

    public double p_rmse(Point.FMParams fp_old, Point.FMParams fp_new) {
        double p_w0 = (fp_old.w0() - fp_new.w0()) * (fp_old.w0() - fp_new.w0());
        SparseVector o_n_w = SparseUtil$.MODULE$.subtraction(fp_old.w(), fp_new.w());
        double p_w = SparseUtil$.MODULE$.dot(o_n_w, o_n_w);
        SparseMatrix o_n_v = SparseUtil$.MODULE$.subtraction(fp_old.v(), fp_new.v());
        SparseMatrix o_n_v_2 = SparseUtil$.MODULE$.multiply(o_n_v);
        double p_v = SparseUtil$.MODULE$.sum(o_n_v_2);
        int p_size = 1 + fp_old.w().numActives() + fp_old.v().numActives();
        return Math.sqrt((p_w0 + p_w + p_v) / (double)p_size);
    }

    public double g_rmse(Point.FMGradParams fmps) {
        double p_w0 = fmps.grad_w0() * fmps.grad_w0();
        double p_w = SparseUtil$.MODULE$.dot(fmps.grad_w(), fmps.grad_w());
        SparseMatrix o_n_v_2 = SparseUtil$.MODULE$.multiply(fmps.grad_v());
        double p_v = SparseUtil$.MODULE$.sum(o_n_v_2);
        int p_size = 1 + fmps.grad_w().numActives() + fmps.grad_v().numActives();
        return Math.sqrt((p_w0 + p_w + p_v) / (double)p_size);
    }

    public Point.FMGradParams grad(Point.FMGradParams fmps, double batchSize) {
        double grad_w0 = fmps.grad_w0() / batchSize;
        SparseVector grad_w = SparseUtil$.MODULE$.multiply(fmps.grad_w(), 1.0 / batchSize);
        SparseMatrix grad_v = SparseUtil$.MODULE$.multiply(fmps.grad_v(), 1.0 / batchSize);
        return new Point.FMGradParams(grad_w0, grad_w, grad_v);
    }

    public Point.FMGradParams gradSqrt(Point.FMGradParams grad_1) {
        double grad_w0 = Math.sqrt(grad_1.grad_w0());
        SparseVector grad_w = SparseUtil$.MODULE$.sqrt(grad_1.grad_w());
        SparseMatrix grad_v = SparseUtil$.MODULE$.sqrt(grad_1.grad_v());
        return new Point.FMGradParams(grad_w0, grad_w, grad_v);
    }

    public Point.FMGradParams gradInverse(Point.FMGradParams grad_1) {
        double grad_w0 = 1.0 / (grad_1.grad_w0() + GlobalConstant.EPSILON);
        SparseVector grad_w = SparseUtil$.MODULE$.inverse(grad_1.grad_w());
        SparseMatrix grad_v = SparseUtil$.MODULE$.inverse(grad_1.grad_v());
        return new Point.FMGradParams(grad_w0, grad_w, grad_v);
    }

    public Point.FMGradParams gradMergeAddUpdate(Point.FMGradParams grad_1, Point.FMGradParams grad_2, double a, double b) {
        double grad_w0 = a * grad_1.grad_w0() + b * grad_2.grad_w0();
        SparseVector grad_w_part1 = SparseUtil$.MODULE$.multiply(grad_1.grad_w(), a);
        SparseVector grad_w_part2 = SparseUtil$.MODULE$.multiply(grad_2.grad_w(), b);
        SparseVector grad_w = SparseUtil$.MODULE$.add(grad_w_part1, grad_w_part2);
        SparseMatrix grad_v_part1 = SparseUtil$.MODULE$.multiply(grad_1.grad_v(), a);
        SparseMatrix grad_v_part2 = SparseUtil$.MODULE$.multiply(grad_2.grad_v(), b);
        SparseMatrix grad_v = SparseUtil$.MODULE$.add(grad_v_part1, grad_v_part2);
        return new Point.FMGradParams(grad_w0, grad_w, grad_v);
    }

    public Point.FMGradParams gradSquare(Point.FMGradParams grad_1) {
        double grad_w0 = grad_1.grad_w0() * grad_1.grad_w0();
        SparseVector grad_w = SparseUtil$.MODULE$.multiply(grad_1.grad_w());
        SparseMatrix grad_v = SparseUtil$.MODULE$.multiply(grad_1.grad_v());
        return new Point.FMGradParams(grad_w0, grad_w, grad_v);
    }

    public Point.FMGradParams gradMultiply(Point.FMGradParams grad_1, Point.FMGradParams grad_2) {
        double grad_w0 = grad_1.grad_w0() * grad_2.grad_w0();
        SparseVector grad_w = SparseUtil$.MODULE$.multiply(grad_1.grad_w(), grad_2.grad_w());
        SparseMatrix grad_v = SparseUtil$.MODULE$.multiply(grad_1.grad_v(), grad_2.grad_v());
        return new Point.FMGradParams(grad_w0, grad_w, grad_v);
    }

    public Point.FMGradParams gradMultiply(Point.FMGradParams grad_1, double factor) {
        double grad_w0 = grad_1.grad_w0() * factor;
        SparseVector grad_w = SparseUtil$.MODULE$.multiply(grad_1.grad_w(), factor);
        SparseMatrix grad_v = SparseUtil$.MODULE$.multiply(grad_1.grad_v(), factor);
        return new Point.FMGradParams(grad_w0, grad_w, grad_v);
    }

    public Point.FMGradParams gradWithRegularization(Point.FMParams psOld, Point.FMGradParams grad, double r1, double r2) {
        double w0_new = this.gradWithRegularizationW0(psOld.w0(), grad.grad_w0(), r1, r2);
        SparseVector w_new = this.gradWithRegularizationW(psOld.w(), grad.grad_w(), r1, r2);
        SparseMatrix v_new = this.gradWithRegularizationV(psOld.v(), grad.grad_v(), r1, r2);
        return new Point.FMGradParams(w0_new, w_new, v_new);
    }

    public double gradWithRegularizationW0(double w, double grad, double r1, double r2) {
        double p1 = r2 * w;
        double p2 = r1 * MLUtil$.MODULE$.sign(w);
        return p1 + p2 + grad;
    }

    public SparseVector gradWithRegularizationW(SparseVector w, SparseVector grad, double r1, double r2) {
        SparseVector p1 = SparseUtil$.MODULE$.multiply(w, r2);
        SparseVector p2 = SparseUtil$.MODULE$.multiply(MLUtil$.MODULE$.sign(w), r1);
        SparseVector p1_p2 = SparseUtil$.MODULE$.add(p1, p2);
        SparseVector p1_p2_p3 = SparseUtil$.MODULE$.add(grad, p1_p2);
        return p1_p2_p3;
    }

    public SparseMatrix gradWithRegularizationV(SparseMatrix v, SparseMatrix grad, double r1, double r2) {
        SparseMatrix p1 = SparseUtil$.MODULE$.multiply(v, r2);
        SparseMatrix p2 = SparseUtil$.MODULE$.multiply(MLUtil$.MODULE$.sign(v), r1);
        SparseMatrix p1_p2 = SparseUtil$.MODULE$.add(p1, p2);
        SparseMatrix p1_p2_p3 = SparseUtil$.MODULE$.add(grad, p1_p2);
        return p1_p2_p3;
    }

    public Tuple2<Point.FMGradParams, Object> computeWithErr(Point.LabeledSPoint p, Broadcast<Point.FMParams> fm_ps, Broadcast<Object> F) {
        double err;
        double grad_w0 = err = this.h(p.x(), (Point.FMParams)fm_ps.value()) - p.y();
        SparseVector grad_w = SparseUtil$.MODULE$.multiply(p.x(), err);
        SparseVector inter_1 = SparseUtil$.MODULE$.dot_row(((Point.FMParams)fm_ps.value()).v(), p.x());
        SparseMatrix grad_v_p1 = SparseUtil$.MODULE$.dot_m(p.x(), inter_1);
        SparseVector x_2 = SparseUtil$.MODULE$.multiply(p.x());
        SparseMatrix grad_v_p3 = SparseUtil$.MODULE$.multiply(x_2, BoxesRunTime.unboxToInt((Object)F.value()), (Matrix)((Point.FMParams)fm_ps.value()).v());
        SparseMatrix grad_v_p4 = SparseUtil$.MODULE$.subtraction(grad_v_p1, grad_v_p3);
        SparseMatrix grad_v = SparseUtil$.MODULE$.multiply(grad_v_p4, err);
        return new Tuple2((Object)new Point.FMGradParams(grad_w0, grad_w, grad_v), (Object)BoxesRunTime.boxToDouble((double)(err * err)));
    }

    public Tuple2<Point.FMGradParams, Object> computeWithErr2(Point.LabeledSPoint p, Broadcast<Point.FMModelParams> fm_ps, Broadcast<Object> F) {
        double err;
        double grad_w0 = err = this.h(p.x(), (Point.FMModelParams)fm_ps.value()) - p.y();
        SparseVector grad_w = SparseUtil$.MODULE$.multiply(p.x(), err);
        SparseVector inter_1 = SparseUtil$.MODULE$.dot_row(((Point.FMModelParams)fm_ps.value()).v(), p.x());
        SparseMatrix grad_v_p1 = SparseUtil$.MODULE$.dot_m(p.x(), inter_1);
        SparseVector x_2 = SparseUtil$.MODULE$.multiply(p.x());
        SparseMatrix grad_v_p3 = SparseUtil$.MODULE$.multiply(x_2, BoxesRunTime.unboxToInt((Object)F.value()), ((Point.FMModelParams)fm_ps.value()).v());
        SparseMatrix grad_v_p4 = SparseUtil$.MODULE$.subtraction(grad_v_p1, grad_v_p3);
        SparseMatrix grad_v = SparseUtil$.MODULE$.multiply(grad_v_p4, err);
        return new Tuple2((Object)new Point.FMGradParams(grad_w0, grad_w, grad_v), (Object)BoxesRunTime.boxToDouble((double)(err * err)));
    }

    public Point.FMGradParams compute(Point.LabeledSPoint p, Broadcast<Point.FMParams> fm_ps, Broadcast<Object> F) {
        Tuple2<Point.FMGradParams, Object> tuple2 = this.computeWithErr(p, fm_ps, F);
        if (tuple2 != null) {
            Tuple2 tuple22;
            Point.FMGradParams fm_gps = (Point.FMGradParams)tuple2._1();
            double err2 = tuple2._2$mcD$sp();
            Tuple2 tuple23 = tuple22 = new Tuple2((Object)fm_gps, (Object)BoxesRunTime.boxToDouble((double)err2));
            Point.FMGradParams fm_gps2 = (Point.FMGradParams)tuple23._1();
            double err22 = tuple23._2$mcD$sp();
            return fm_gps2;
        }
        throw new MatchError(tuple2);
    }

    public Point.FMParams paramsUpdate(Point.FMParams psOld, Point.FMParams psDelta) {
        double params_w0 = psOld.w0() + psDelta.w0();
        SparseVector params_w = SparseUtil$.MODULE$.add(psOld.w(), psDelta.w());
        SparseMatrix params_v = SparseUtil$.MODULE$.add(psOld.v(), psDelta.v());
        return new Point.FMParams(params_w0, params_w, params_v);
    }

    public double h2(SparseVector x, Point.FMParams fm_ps) {
        SparseVector inter_1 = SparseUtil$.MODULE$.dot_row(fm_ps.v(), x);
        SparseVector inter_1_2 = SparseUtil$.MODULE$.multiply(inter_1);
        SparseVector x_2 = SparseUtil$.MODULE$.multiply(x);
        SparseMatrix v_2 = SparseUtil$.MODULE$.multiply(fm_ps.v());
        SparseVector inter_2 = SparseUtil$.MODULE$.dot_row(v_2, x_2);
        double interaction = SparseUtil$.MODULE$.sum(inter_1_2) - SparseUtil$.MODULE$.sum(inter_2);
        double inx = fm_ps.w0() + SparseUtil$.MODULE$.dot(x, fm_ps.w()) + 0.5 * interaction;
        return MLUtil$.MODULE$.sigmoid(inx);
    }

    public double h(SparseVector x, Point.FMModelParams fm_ps) {
        SparseMatrix v_x = SparseUtil$.MODULE$.multiply(x, fm_ps.v().numCols(), fm_ps.v());
        SparseMatrix v2_x2 = SparseUtil$.MODULE$.multiply(v_x);
        SparseVector v_x_rcount = SparseUtil$.MODULE$.sum_row(v_x);
        SparseVector v_x_rcount2 = SparseUtil$.MODULE$.multiply(v_x_rcount);
        SparseVector v2_x2_rcount = SparseUtil$.MODULE$.sum_row(v2_x2);
        double interaction = SparseUtil$.MODULE$.sum(v_x_rcount2) - SparseUtil$.MODULE$.sum(v2_x2_rcount);
        double inx = fm_ps.w0() + SparseUtil$.MODULE$.dot3(x, fm_ps.w()) + 0.5 * interaction;
        return MLUtil$.MODULE$.sigmoid(inx);
    }

    public double h(SparseVector x, Point.FMParams fm_ps) {
        SparseMatrix x_copy_F = SparseUtil$.MODULE$.vector_copy(x, fm_ps.v().numCols());
        SparseMatrix v_x = SparseUtil$.MODULE$.multiply(x_copy_F, fm_ps.v());
        SparseMatrix v2_x2 = SparseUtil$.MODULE$.multiply(v_x);
        SparseVector v_x_rcount = SparseUtil$.MODULE$.sum_row(v_x);
        SparseVector v_x_rcount2 = SparseUtil$.MODULE$.multiply(v_x_rcount);
        SparseVector v2_x2_rcount = SparseUtil$.MODULE$.sum_row(v2_x2);
        double interaction = SparseUtil$.MODULE$.sum(v_x_rcount2) - SparseUtil$.MODULE$.sum(v2_x2_rcount);
        double inx = fm_ps.w0() + SparseUtil$.MODULE$.dot(x, fm_ps.w()) + 0.5 * interaction;
        return MLUtil$.MODULE$.sigmoid(inx);
    }

    private FMGD$() {
        MODULE$ = this;
    }
}

