package cn.com.duiba.nezha.compute.mllib.lr.ftrl;

import cn.com.duiba.nezha.compute.core.model.local.LocalModel;
import cn.com.duiba.nezha.compute.core.model.local.LocalVector;
import cn.com.duiba.nezha.compute.core.model.local.LocalVector$;
import cn.com.duiba.nezha.compute.core.model.ops.VectorOps$;
import org.apache.spark.mllib.linalg.SparseVector;
import org.apache.spark.mllib.regression.LabeledPoint;
import scala.Array$;
import scala.Predef$;
import scala.Predef$ArrowAssoc$;
import scala.Serializable;
import scala.Tuple2;
import scala.collection.immutable.Map;
import scala.collection.immutable.Nil$;
import scala.collection.mutable.ArrayBuffer;
import scala.reflect.ClassTag$;
import scala.runtime.ObjectRef;

/* compiled from: SparseLRWithFTRL.scala */
/* loaded from: input_file:cn/com/duiba/nezha/compute/mllib/lr/ftrl/SparseLRWithFTRL$.class */
public final class SparseLRWithFTRL$ implements Serializable {
    public static final SparseLRWithFTRL$ MODULE$ = null;

    static {
        new SparseLRWithFTRL$();
    }

    public boolean train(SparseLRWithFTRL sparseLRWithFTRL, LabeledPoint[] labeledPointArr) {
        if (labeledPointArr == null) {
            return false;
        }
        ObjectRef create = ObjectRef.create(sparseLRWithFTRL.localZ());
        ObjectRef create2 = ObjectRef.create(sparseLRWithFTRL.localN());
        double alpha = sparseLRWithFTRL.alpha();
        double beta = sparseLRWithFTRL.beta();
        double lambda1 = sparseLRWithFTRL.lambda1();
        double lambda2 = sparseLRWithFTRL.lambda2();
        ObjectRef create3 = ObjectRef.create(Predef$.MODULE$.Map().apply(Nil$.MODULE$));
        ObjectRef create4 = ObjectRef.create(Predef$.MODULE$.Map().apply(Nil$.MODULE$));
        Predef$.MODULE$.refArrayOps(labeledPointArr).foreach(new SparseLRWithFTRL$$anonfun$train$1(sparseLRWithFTRL, create, create2, alpha, beta, lambda1, lambda2, create3, create4));
        sparseLRWithFTRL.setLocalZ((Map) create.elem);
        sparseLRWithFTRL.setLocalN((Map) create2.elem);
        sparseLRWithFTRL.setLocalIncrModel(getLocalModel(sparseLRWithFTRL.dim(), (Map) create3.elem, (Map) create4.elem));
        return true;
    }

    public LocalModel searchModel(int i) {
        return getLocalModel(i, (Map) Predef$.MODULE$.Map().apply(Nil$.MODULE$), (Map) Predef$.MODULE$.Map().apply(Nil$.MODULE$));
    }

    public LocalModel searchModel(LabeledPoint[] labeledPointArr) {
        if (labeledPointArr == null || labeledPointArr.length <= 0) {
            return null;
        }
        return searchModel(labeledPointArr, labeledPointArr[0].features().size());
    }

    public LocalModel searchModel(LabeledPoint[] labeledPointArr, int i) {
        Map apply = Predef$.MODULE$.Map().apply(Nil$.MODULE$);
        if (labeledPointArr != null) {
            LocalVector localVector = new LocalVector(VectorOps$.MODULE$.toIndexSV((SparseVector[]) Predef$.MODULE$.refArrayOps(labeledPointArr).map(new SparseLRWithFTRL$$anonfun$6(), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(SparseVector.class))), i));
            apply = apply.$plus(Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(FTRL$.MODULE$.w_z()), localVector)).$plus(Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(FTRL$.MODULE$.w_n()), localVector));
        }
        return new LocalModel(Predef$.MODULE$.Map().apply(Nil$.MODULE$), apply, Predef$.MODULE$.Map().apply(Nil$.MODULE$));
    }

    public LocalModel getLocalModel(int i, Map<Object, Object> map, Map<Object, Object> map2) {
        return new LocalModel(Predef$.MODULE$.Map().apply(Nil$.MODULE$), Predef$.MODULE$.Map().apply(Nil$.MODULE$).$plus(Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(FTRL$.MODULE$.w_z()), LocalVector$.MODULE$.toLocalVector(i, map))).$plus(Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(FTRL$.MODULE$.w_n()), LocalVector$.MODULE$.toLocalVector(i, map2))), Predef$.MODULE$.Map().apply(Nil$.MODULE$));
    }

    public Map<Object, Object> incrementVector(Map<Object, Object> map, Tuple2<Object, Object>[] tuple2Arr) {
        ObjectRef create = ObjectRef.create(map);
        Predef$.MODULE$.refArrayOps(tuple2Arr).foreach(new SparseLRWithFTRL$$anonfun$incrementVector$1(create));
        return (Map) create.elem;
    }

    public Map<Object, Object> getGradLoss(LocalVector localVector, double d, LocalVector localVector2) {
        return getGradLoss(localVector2, FTRL$.MODULE$.predict(localVector2, localVector) - d);
    }

    public Map<Object, Object> getGradLoss(LocalVector localVector, double d) {
        return localVector.mutiply(d).toMap();
    }

    public Tuple2<Tuple2<Object, Object>[], Tuple2<Object, Object>[]> getIncrementZAndN(Map<Object, Object> map, Map<Object, Object> map2, Map<Object, Object> map3, Map<Object, Object> map4, double d) {
        ObjectRef create = ObjectRef.create(new ArrayBuffer());
        ObjectRef create2 = ObjectRef.create(new ArrayBuffer());
        map.keySet().foreach(new SparseLRWithFTRL$$anonfun$getIncrementZAndN$1(map, map2, map3, map4, d, create, create2));
        return new Tuple2<>(((ArrayBuffer) create.elem).toArray(ClassTag$.MODULE$.apply(Tuple2.class)), ((ArrayBuffer) create2.elem).toArray(ClassTag$.MODULE$.apply(Tuple2.class)));
    }

    public double updateSigmaOnId(double d, double d2, double d3) {
        return Math.sqrt(d2 + (d * d)) - (Math.sqrt(d2) / (d3 + 1.0E-8d));
    }

    public double incrementZOnId(double d, double d2, double d3) {
        return d - (d2 * d3);
    }

    public double incrementNOnId(double d) {
        return d * d;
    }

    private Object readResolve() {
        return MODULE$;
    }

    private SparseLRWithFTRL$() {
        MODULE$ = this;
    }
}
