/*
 * Decompiled with CFR 0.152.
 */
package cn.com.duiba.nezha.compute.mllib.test;

import cn.com.duiba.nezha.compute.api.point.Point;
import cn.com.duiba.nezha.compute.mllib.algorithm.SparseFM;
import cn.com.duiba.nezha.compute.mllib.evaluater.ClassifierEvaluater$;
import cn.com.duiba.nezha.compute.mllib.model.SparseFMModel;
import cn.com.duiba.nezha.compute.mllib.util.SparseDate$;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics;
import org.apache.spark.mllib.linalg.SparseVector;
import org.apache.spark.rdd.RDD;
import scala.Function1;
import scala.MatchError;
import scala.Predef$;
import scala.Serializable;
import scala.Tuple2;
import scala.collection.Seq;
import scala.collection.mutable.StringBuilder;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;

public final class FMTest$ {
    public static final FMTest$ MODULE$;

    static {
        new FMTest$();
    }

    public void main(String[] args) {
        Logger.getLogger((String)"org").setLevel(Level.ERROR);
        Predef$.MODULE$.println((Object)"init spark context ... ");
        SparkConf conf = new SparkConf().setAppName("AdvertCTRLR").setMaster("local[3]");
        SparkContext sc = new SparkContext(conf);
        int N = 1000;
        int D = 100;
        int F = 5;
        double R = 3.0;
        double a = 0.6;
        double r2 = 0.01;
        double r1 = 0.01;
        double threshold = 0.5;
        Point.LabeledSPoint[] training_data = SparseDate$.MODULE$.generateData2(D, N, R, 42);
        RDD training_dataRdd = sc.parallelize((Seq)Predef$.MODULE$.wrapRefArray((Object[])training_data), sc.parallelize$default$2(), ClassTag$.MODULE$.apply(Point.LabeledSPoint.class)).persist();
        SparseFMModel model = new SparseFM().setLearningRate(0.1).setMtRate(0.8).setAdRate(0.8).setDeltaThreshold(1.0E-5).setBatchSize(100).setFactorNums(3).setMinIterations(0).setReg2(0.01).setMaxIterations(1).runADSGD((RDD<Point.LabeledSPoint>)training_dataRdd);
        Point.FMParams fm_params = model.getFMParams();
        Predef$.MODULE$.println((Object)new StringBuilder().append((Object)"w0= ").append((Object)BoxesRunTime.boxToDouble((double)fm_params.w0())).toString());
        Predef$.MODULE$.println((Object)new StringBuilder().append((Object)"w= ").append((Object)fm_params.w()).toString());
        Predef$.MODULE$.println((Object)new StringBuilder().append((Object)"v= ").append((Object)fm_params.v()).toString());
        model.clearThreshold();
        RDD test_dataRdd = sc.parallelize((Seq)Predef$.MODULE$.wrapRefArray((Object[])SparseDate$.MODULE$.generateData2(D, 10000, R, 42)), sc.parallelize$default$2(), ClassTag$.MODULE$.apply(Point.LabeledSPoint.class));
        RDD pAndL = test_dataRdd.map((Function1)new Serializable(model){
            public static final long serialVersionUID = 0L;
            private final SparseFMModel model$1;

            public final Tuple2<Object, Object> apply(Point.LabeledSPoint p) {
                return this.model$1.predictPoint(p);
            }
            {
                this.model$1 = model$1;
            }
        }, ClassTag$.MODULE$.apply(Tuple2.class));
        ClassifierEvaluater$.MODULE$.calMulticlassMetrics((RDD<Tuple2<Object, Object>>)pAndL, 0.5);
        RDD predictionAndLabels2 = test_dataRdd.map((Function1)new Serializable(model){
            public static final long serialVersionUID = 0L;
            private final SparseFMModel model$1;

            public final Tuple2<Object, Object> apply(Point.LabeledSPoint x0$1) {
                Point.LabeledSPoint labeledSPoint = x0$1;
                if (labeledSPoint != null) {
                    SparseVector features = labeledSPoint.x();
                    double label = labeledSPoint.y();
                    Tuple2.mcDD.sp sp2 = new Tuple2.mcDD.sp(this.model$1.predict(features), label);
                    return sp2;
                }
                throw new MatchError((Object)labeledSPoint);
            }
            {
                this.model$1 = model$1;
            }
        }, ClassTag$.MODULE$.apply(Tuple2.class)).cache();
        predictionAndLabels2.foreach((Function1)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final void apply(Object x) {
                Predef$.MODULE$.println(x);
            }
        });
        BinaryClassificationMetrics metrics2 = new BinaryClassificationMetrics(predictionAndLabels2);
        double auRoc = ClassifierEvaluater$.MODULE$.calAuROC(metrics2);
    }

    private FMTest$() {
        MODULE$ = this;
    }
}

