package cn.com.duiba.nezha.compute.biz.app.ml;

import breeze.linalg.DenseMatrix;
import breeze.linalg.DenseMatrix$;
import breeze.linalg.DenseVector;
import breeze.linalg.DenseVector$;
import breeze.linalg.ImmutableNumericOps;
import breeze.linalg.package;
import breeze.linalg.sum$;
import breeze.linalg.support.LiteralRow$;
import breeze.storage.Zero$DoubleZero$;
import cn.com.duiba.nezha.compute.biz.app.ml.FM;
import java.util.Random;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.rdd.RDD;
import scala.Array$;
import scala.MatchError;
import scala.Predef$;
import scala.Tuple3;
import scala.collection.immutable.Nil$;
import scala.collection.mutable.ArrayBuffer;
import scala.collection.mutable.ArrayBuffer$;
import scala.collection.mutable.StringBuilder;
import scala.math.package$;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;
import scala.runtime.DoubleRef;
import scala.runtime.ObjectRef;
import scala.runtime.RichInt$;

/* compiled from: FM.scala */
/* loaded from: input_file:cn/com/duiba/nezha/compute/biz/app/ml/FM$.class */
public final class FM$ {
    public static final FM$ MODULE$ = null;

    static {
        new FM$();
    }

    public FM.DataPoint[] generateData(int i, int i2, double d, int i3) {
        new Random(i3);
        return (FM.DataPoint[]) Array$.MODULE$.tabulate(i2, new FM$$anonfun$generateData$1(i, d), ClassTag$.MODULE$.apply(FM.DataPoint.class));
    }

    public double sigmoid(double d) {
        return 1 / (1 + package$.MODULE$.exp(-d));
    }

    public RDD<FM.Prdlabel> mertix(double d, DenseVector<Object> denseVector, DenseMatrix<Object> denseMatrix, RDD<FM.DataPoint> rdd, double d2) {
        return rdd.map(new FM$$anonfun$2(d, denseVector, denseMatrix, d2), ClassTag$.MODULE$.apply(FM.Prdlabel.class));
    }

    public int labelCompare(FM.Prdlabel prdlabel) {
        return package$.MODULE$.abs(prdlabel.py() - prdlabel.y()) < 0.1d ? 1 : 0;
    }

    public double accuracy(RDD<FM.Prdlabel> rdd) {
        return (BoxesRunTime.unboxToInt(rdd.map(new FM$$anonfun$3(), ClassTag$.MODULE$.Int()).reduce(new FM$$anonfun$1())) + 0.0d) / (rdd.count() + 1);
    }

    public double sign(double d) {
        return d > 0.0d ? 1.0d : -1.0d;
    }

    public DenseVector<Object> sign(DenseVector<Object> denseVector) {
        return (DenseVector) denseVector.map(new FM$$anonfun$sign$1(), DenseVector$.MODULE$.canMapValues(ClassTag$.MODULE$.Double()));
    }

    public DenseMatrix<Object> sign(DenseMatrix<Object> denseMatrix) {
        return (DenseMatrix) denseMatrix.map(new FM$$anonfun$sign$2(), DenseMatrix$.MODULE$.canMapValues(ClassTag$.MODULE$.Double()));
    }

    public double signLabel(double d, double d2) {
        return d >= d2 ? 1.0d : -1.0d;
    }

    public DenseMatrix<Object> multiply(DenseMatrix<Object> denseMatrix, DenseMatrix<Object> denseMatrix2) {
        return (DenseMatrix) denseMatrix.$colon$times(denseMatrix2, DenseMatrix$.MODULE$.op_DM_DM_Double_OpMulScalar());
    }

    public double paramsDelta(double d, DenseVector<Object> denseVector, DenseMatrix<Object> denseMatrix, double d2, DenseVector<Object> denseVector2, DenseMatrix<Object> denseMatrix2) {
        double d3 = (d - d2) * (d - d2);
        double unboxToDouble = BoxesRunTime.unboxToDouble(((ImmutableNumericOps) denseVector.$minus(denseVector2, DenseVector$.MODULE$.canSubD())).dot(denseVector.$minus(denseVector2, DenseVector$.MODULE$.canSubD()), DenseVector$.MODULE$.canDotD()));
        return d3 + unboxToDouble + BoxesRunTime.unboxToDouble(sum$.MODULE$.apply(multiply((DenseMatrix) denseMatrix.$minus(denseMatrix2, DenseMatrix$.MODULE$.op_DM_DM_Double_OpSub()), (DenseMatrix) denseMatrix.$minus(denseMatrix2, DenseMatrix$.MODULE$.op_DM_DM_Double_OpSub())), sum$.MODULE$.reduce_Double(DenseMatrix$.MODULE$.canIterateValues())));
    }

    public Tuple3<Object, DenseVector<Object>, DenseMatrix<Object>> gradAscent(RDD<FM.DataPoint> rdd, int i, int i2, double d, double d2, double d3, int i3, int i4, double d4) {
        ObjectRef objectRef = new ObjectRef(DenseVector$.MODULE$.rand(i, DenseVector$.MODULE$.rand$default$2(), ClassTag$.MODULE$.Double()));
        DoubleRef doubleRef = new DoubleRef(0.0d);
        ObjectRef objectRef2 = new ObjectRef(DenseMatrix$.MODULE$.rand(i, i2, DenseMatrix$.MODULE$.rand$default$3(), ClassTag$.MODULE$.Double(), Zero$DoubleZero$.MODULE$));
        long count = rdd.count();
        double d5 = 1.0d;
        int i5 = 0;
        while (true) {
            int i6 = i5;
            if (d5 <= d4 || i6 >= i3) {
                break;
            }
            Predef$.MODULE$.println(new StringBuilder().append("On iteration ").append(BoxesRunTime.boxToInteger(i6)).toString());
            FM.GradParams gradParams = (FM.GradParams) rdd.map(new FM$$anonfun$4(i2, objectRef, doubleRef, objectRef2), ClassTag$.MODULE$.apply(FM.GradParams.class)).cache().reduce(new FM$$anonfun$5());
            double updateW0 = updateW0(doubleRef.elem, gradParams.w0(), d, d2, d3, count);
            DenseVector<Object> updateW = updateW((DenseVector) objectRef.elem, gradParams.w(), d, d2, d3, count);
            DenseMatrix<Object> updateV = updateV((DenseMatrix) objectRef2.elem, gradParams.v(), d, d2, d3, count);
            if (i6 > i4) {
                d5 = paramsDelta(doubleRef.elem, (DenseVector) objectRef.elem, (DenseMatrix) objectRef2.elem, updateW0, updateW, updateV);
            }
            doubleRef.elem = updateW0;
            objectRef.elem = updateW;
            objectRef2.elem = updateV;
            i5 = i6 + 1;
        }
        return new Tuple3<>(BoxesRunTime.boxToDouble(doubleRef.elem), (DenseVector) objectRef.elem, (DenseMatrix) objectRef2.elem);
    }

    public DenseMatrix<Object> m_m_dot(DenseMatrix<Object> denseMatrix, DenseMatrix<Object> denseMatrix2) {
        return (DenseMatrix) denseMatrix.$times(denseMatrix2, DenseMatrix$.MODULE$.implOpMulMatrix_DMD_DMD_eq_DMD());
    }

    public DenseMatrix<Object> m_m_sub(DenseMatrix<Object> denseMatrix, DenseMatrix<Object> denseMatrix2) {
        return (DenseMatrix) denseMatrix.$minus(denseMatrix2, DenseMatrix$.MODULE$.op_DM_DM_Double_OpSub());
    }

    public double predLabel(FM.DataPoint dataPoint, double d, DenseVector<Object> denseVector, DenseMatrix<Object> denseMatrix, double d2) {
        return signLabel(predDouble(dataPoint, d, denseVector, denseMatrix), d2);
    }

    public double predDouble(FM.DataPoint dataPoint, double d, DenseVector<Object> denseVector, DenseMatrix<Object> denseMatrix) {
        DenseMatrix<Object> m_m_dot = m_m_dot(dataPoint.x().toDenseMatrix$mcD$sp(), denseMatrix);
        return sigmoid(d + BoxesRunTime.unboxToDouble(dataPoint.x().dot(denseVector, DenseVector$.MODULE$.canDotD())) + BoxesRunTime.unboxToDouble(sum$.MODULE$.apply(multiply(m_m_dot, m_m_dot).$minus((DenseMatrix) multiply(dataPoint.x().toDenseMatrix$mcD$sp(), dataPoint.x().toDenseMatrix$mcD$sp()).$times(multiply(denseMatrix, denseMatrix), DenseMatrix$.MODULE$.implOpMulMatrix_DMD_DMD_eq_DMD()), DenseMatrix$.MODULE$.op_DM_DM_Double_OpSub()), sum$.MODULE$.reduce_Double(DenseMatrix$.MODULE$.canIterateValues()))));
    }

    public double updateW0(double d, double d2, double d3, double d4, double d5, long j) {
        return (((1 - (d3 * d5)) * d) - (((d3 * d4) * sign(d)) * d)) - (((d3 * 1.0d) / j) * d2);
    }

    public DenseVector<Object> updateW(DenseVector<Object> denseVector, DenseVector<Object> denseVector2, double d, double d2, double d3, long j) {
        return (DenseVector) ((ImmutableNumericOps) ((DenseVector) new package.InjectNumericOps(breeze.linalg.package$.MODULE$.InjectNumericOps(BoxesRunTime.boxToDouble(1 - (d * d3)))).$times(denseVector, DenseVector$.MODULE$.s_dv_Op_Double_OpMulMatrix())).$minus((DenseVector) ((ImmutableNumericOps) new package.InjectNumericOps(breeze.linalg.package$.MODULE$.InjectNumericOps(BoxesRunTime.boxToDouble(d * d2))).$times(sign(denseVector), DenseVector$.MODULE$.s_dv_Op_Double_OpMulMatrix())).$colon$times(denseVector, DenseVector$.MODULE$.dv_dv_Op_Double_OpMulScalar()), DenseVector$.MODULE$.canSubD())).$minus((DenseVector) new package.InjectNumericOps(breeze.linalg.package$.MODULE$.InjectNumericOps(BoxesRunTime.boxToDouble((d * 1.0d) / j))).$times(denseVector2, DenseVector$.MODULE$.s_dv_Op_Double_OpMulMatrix()), DenseVector$.MODULE$.canSubD());
    }

    public DenseMatrix<Object> updateV(DenseMatrix<Object> denseMatrix, DenseMatrix<Object> denseMatrix2, double d, double d2, double d3, long j) {
        return (DenseMatrix) ((ImmutableNumericOps) ((DenseMatrix) new package.InjectNumericOps(breeze.linalg.package$.MODULE$.InjectNumericOps(BoxesRunTime.boxToDouble(1 - (d * d3)))).$times(denseMatrix, DenseMatrix$.MODULE$.s_dm_op_Double_OpMulMatrix())).$minus((DenseMatrix) ((ImmutableNumericOps) new package.InjectNumericOps(breeze.linalg.package$.MODULE$.InjectNumericOps(BoxesRunTime.boxToDouble(d * d2))).$times(sign(denseMatrix), DenseMatrix$.MODULE$.s_dm_op_Double_OpMulMatrix())).$colon$times(denseMatrix, DenseMatrix$.MODULE$.op_DM_DM_Double_OpMulScalar()), DenseMatrix$.MODULE$.op_DM_DM_Double_OpSub())).$minus((DenseMatrix) new package.InjectNumericOps(breeze.linalg.package$.MODULE$.InjectNumericOps(BoxesRunTime.boxToDouble((d * 1.0d) / j))).$times(denseMatrix2, DenseMatrix$.MODULE$.s_dm_op_Double_OpMulMatrix()), DenseMatrix$.MODULE$.op_DM_DM_Double_OpSub());
    }

    public DenseMatrix<Object> denseVector2Matrix(DenseVector<Object> denseVector, int i) {
        ArrayBuffer apply = ArrayBuffer$.MODULE$.apply(Nil$.MODULE$);
        RichInt$.MODULE$.to$extension0(Predef$.MODULE$.intWrapper(1), i).foreach(new FM$$anonfun$denseVector2Matrix$1(apply, denseVector.toArray$mcD$sp(ClassTag$.MODULE$.Double())));
        DenseMatrix apply2 = DenseMatrix$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new ArrayBuffer[]{apply}), LiteralRow$.MODULE$.seq(Predef$.MODULE$.conforms()), ClassTag$.MODULE$.Double(), Zero$DoubleZero$.MODULE$);
        return (DenseMatrix) apply2.reshape$mcD$sp(denseVector.length(), i, apply2.reshape$default$3()).t(DenseMatrix$.MODULE$.canTranspose());
    }

    public DenseMatrix<Object> denseVector2Matrix(DenseMatrix<Object> denseMatrix, int i) {
        ArrayBuffer apply = ArrayBuffer$.MODULE$.apply(Nil$.MODULE$);
        int rows = denseMatrix.rows();
        int cols = denseMatrix.cols();
        RichInt$.MODULE$.to$extension0(Predef$.MODULE$.intWrapper(1), i).foreach(new FM$$anonfun$denseVector2Matrix$2(apply, ((DenseMatrix) denseMatrix.t(DenseMatrix$.MODULE$.canTranspose())).toDenseVector$mcD$sp().toArray$mcD$sp(ClassTag$.MODULE$.Double())));
        DenseMatrix apply2 = DenseMatrix$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new ArrayBuffer[]{apply}), LiteralRow$.MODULE$.seq(Predef$.MODULE$.conforms()), ClassTag$.MODULE$.Double(), Zero$DoubleZero$.MODULE$);
        return (DenseMatrix) apply2.reshape$mcD$sp(cols, rows * i, apply2.reshape$default$3()).t(DenseMatrix$.MODULE$.canTranspose());
    }

    public double add(double d, double d2) {
        return d + d2;
    }

    public DenseVector<Object> add(DenseVector<Object> denseVector, DenseVector<Object> denseVector2) {
        return (DenseVector) denseVector.$plus(denseVector2, DenseVector$.MODULE$.canAddD());
    }

    public DenseMatrix<Object> add(DenseMatrix<Object> denseMatrix, DenseMatrix<Object> denseMatrix2) {
        return (DenseMatrix) denseMatrix.$plus(denseMatrix2, DenseMatrix$.MODULE$.op_DM_DM_Double_OpAdd());
    }

    public void main(String[] strArr) {
        Predef$.MODULE$.println("init spark context ... ");
        SparkContext sparkContext = new SparkContext(new SparkConf().setAppName("AdvertCTRLR").setMaster("local[3]"));
        Tuple3<Object, DenseVector<Object>, DenseMatrix<Object>> gradAscent = gradAscent(sparkContext.parallelize(Predef$.MODULE$.wrapRefArray(generateData(20, 1000, 0.7d, 42)), sparkContext.parallelize$default$2(), ClassTag$.MODULE$.apply(FM.DataPoint.class)), 20, 5, 0.05d, 0.25d, 2, 100, 50, 1.0E-5d);
        if (gradAscent == null) {
            throw new MatchError(gradAscent);
        }
        Tuple3 tuple3 = new Tuple3(BoxesRunTime.boxToDouble(BoxesRunTime.unboxToDouble(gradAscent._1())), (DenseVector) gradAscent._2(), (DenseMatrix) gradAscent._3());
        double unboxToDouble = BoxesRunTime.unboxToDouble(tuple3._1());
        DenseVector<Object> denseVector = (DenseVector) tuple3._2();
        DenseMatrix<Object> denseMatrix = (DenseMatrix) tuple3._3();
        Predef$.MODULE$.println(new StringBuilder().append("w0= ").append(BoxesRunTime.boxToDouble(unboxToDouble)).toString());
        Predef$.MODULE$.println(new StringBuilder().append("w= ").append(denseVector).toString());
        Predef$.MODULE$.println(new StringBuilder().append("v= ").append(denseMatrix).toString());
        Predef$.MODULE$.println(new StringBuilder().append("accuracy(prdMertix)= ").append(BoxesRunTime.boxToDouble(accuracy(mertix(unboxToDouble, denseVector, denseMatrix, sparkContext.parallelize(Predef$.MODULE$.wrapRefArray(generateData(20, 1000, 0.7d, 42)), sparkContext.parallelize$default$2(), ClassTag$.MODULE$.apply(FM.DataPoint.class)), 0.5d)))).toString());
    }

    public final FM.DataPoint cn$com$duiba$nezha$compute$biz$app$ml$FM$$generatePoint$1(int i, int i2, double d) {
        int i3 = i % 2 == 0 ? -1 : 1;
        return new FM.DataPoint((DenseVector) DenseVector$.MODULE$.rand(i2, DenseVector$.MODULE$.rand$default$2(), ClassTag$.MODULE$.Double()).$plus(BoxesRunTime.boxToDouble(i3 * d), DenseVector$.MODULE$.dv_s_Op_Double_OpAdd()), i3);
    }

    private FM$() {
        MODULE$ = this;
    }
}
