/*
 * Decompiled with CFR 0.152.
 */
package com.mayabot.nlp.fasttext;

import com.mayabot.nlp.fasttext.ScoreIdPair;
import com.mayabot.nlp.fasttext.blas.DenseVector;
import com.mayabot.nlp.fasttext.blas.Matrix;
import com.mayabot.nlp.fasttext.loss.Loss;
import com.mayabot.nlp.fasttext.utils.IntArrayList;
import java.util.List;
import kotlin.Metadata;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import kotlin.random.Random;
import kotlin.random.RandomKt;
import org.jetbrains.annotations.NotNull;

@Metadata(mv={1, 1, 16}, bv={1, 0, 3}, k=1, d1={"\u0000P\n\u0002\u0018\u0002\n\u0002\u0010\u0000\n\u0000\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0000\n\u0002\u0010\u000b\n\u0002\b\t\n\u0002\u0010\u0002\n\u0000\n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\b\n\u0000\n\u0002\u0010\u0007\n\u0000\n\u0002\u0010!\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0007\u0018\u0000 $2\u00020\u0001:\u0002$%B%\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u0012\u0006\u0010\u0004\u001a\u00020\u0003\u0012\u0006\u0010\u0005\u001a\u00020\u0006\u0012\u0006\u0010\u0007\u001a\u00020\b\u00a2\u0006\u0002\u0010\tJ\u0018\u0010\u0011\u001a\u00020\u00122\u0006\u0010\u0013\u001a\u00020\u00142\u0006\u0010\u0015\u001a\u00020\u0016H\u0002J8\u0010\u0017\u001a\u00020\u00122\u0006\u0010\u0013\u001a\u00020\u00142\u0006\u0010\u0018\u001a\u00020\u00192\u0006\u0010\u001a\u001a\u00020\u001b2\u0010\u0010\u001c\u001a\f\u0012\u0004\u0012\u00020\u001e0\u001dj\u0002`\u001f2\u0006\u0010\u0015\u001a\u00020\u0016J.\u0010 \u001a\u00020\u00122\u0006\u0010\u0013\u001a\u00020\u00142\u0006\u0010!\u001a\u00020\u00142\u0006\u0010\"\u001a\u00020\u00192\u0006\u0010#\u001a\u00020\u001b2\u0006\u0010\u0015\u001a\u00020\u0016R\u0011\u0010\u0005\u001a\u00020\u0006\u00a2\u0006\b\n\u0000\u001a\u0004\b\n\u0010\u000bR\u0011\u0010\u0007\u001a\u00020\b\u00a2\u0006\b\n\u0000\u001a\u0004\b\f\u0010\rR\u0011\u0010\u0002\u001a\u00020\u0003\u00a2\u0006\b\n\u0000\u001a\u0004\b\u000e\u0010\u000fR\u0011\u0010\u0004\u001a\u00020\u0003\u00a2\u0006\b\n\u0000\u001a\u0004\b\u0010\u0010\u000f\u00a8\u0006&"}, d2={"Lcom/mayabot/nlp/fasttext/Model;", "", "wi", "Lcom/mayabot/nlp/fasttext/blas/Matrix;", "wo", "loss", "Lcom/mayabot/nlp/fasttext/loss/Loss;", "normalizeGradient", "", "(Lcom/mayabot/nlp/fasttext/blas/Matrix;Lcom/mayabot/nlp/fasttext/blas/Matrix;Lcom/mayabot/nlp/fasttext/loss/Loss;Z)V", "getLoss", "()Lcom/mayabot/nlp/fasttext/loss/Loss;", "getNormalizeGradient", "()Z", "getWi", "()Lcom/mayabot/nlp/fasttext/blas/Matrix;", "getWo", "computeHidden", "", "input", "Lcom/mayabot/nlp/fasttext/utils/IntArrayList;", "state", "Lcom/mayabot/nlp/fasttext/Model$State;", "predict", "k", "", "threshold", "", "heap", "", "Lcom/mayabot/nlp/fasttext/ScoreIdPair;", "Lcom/mayabot/nlp/fasttext/Predictions;", "update", "targets", "targetIndex", "lr", "Companion", "State", "fastText4j"})
public final class Model {
    @NotNull
    private final Matrix wi;
    @NotNull
    private final Matrix wo;
    @NotNull
    private final Loss loss;
    private final boolean normalizeGradient;
    private static final int kUnlimitedPredictions = -1;
    private static final int kAllLabelsAsTarget = -1;
    public static final Companion Companion = new Companion(null);

    private final void computeHidden(IntArrayList input, State state) {
        DenseVector hidden = state.getHidden();
        hidden.zero();
        IntArrayList this_$iv = input;
        boolean $i$f$forEach = false;
        int[] buffer$iv = this_$iv.getBuffer();
        int size$iv = this_$iv.size();
        for (int i$iv = 0; i$iv < size$iv; ++i$iv) {
            int row = buffer$iv[i$iv];
            boolean bl = false;
            Matrix.DefaultImpls.addRowToVector$default(this.wi, hidden, row, null, 4, null);
        }
        hidden.timesAssign(Float.valueOf(1.0f / (float)input.size()));
    }

    public final void predict(@NotNull IntArrayList input, int k, float threshold2, @NotNull List<ScoreIdPair> heap, @NotNull State state) {
        int kk;
        Intrinsics.checkParameterIsNotNull((Object)input, (String)"input");
        Intrinsics.checkParameterIsNotNull(heap, (String)"heap");
        Intrinsics.checkParameterIsNotNull((Object)state, (String)"state");
        int n = kk = k == kUnlimitedPredictions ? this.wo.getRow() : k;
        if (kk == 0) {
            throw (Throwable)new RuntimeException("k needs to be 1 or higher");
        }
        this.computeHidden(input, state);
        this.loss.predict(k, threshold2, heap, state);
    }

    public final void update(@NotNull IntArrayList input, @NotNull IntArrayList targets, int targetIndex, float lr, @NotNull State state) {
        Intrinsics.checkParameterIsNotNull((Object)input, (String)"input");
        Intrinsics.checkParameterIsNotNull((Object)targets, (String)"targets");
        Intrinsics.checkParameterIsNotNull((Object)state, (String)"state");
        if (input.size() == 0) {
            return;
        }
        this.computeHidden(input, state);
        DenseVector grad = state.getGrad();
        grad.zero();
        float lossValue = this.loss.forward(targets, targetIndex, state, lr, true);
        state.incrementNExamples(lossValue);
        if (this.normalizeGradient) {
            grad.timesAssign(Float.valueOf(1.0f / (float)input.size()));
        }
        IntArrayList this_$iv = input;
        boolean $i$f$forEach = false;
        int[] buffer$iv = this_$iv.getBuffer();
        int size$iv = this_$iv.size();
        for (int i$iv = 0; i$iv < size$iv; ++i$iv) {
            int i = buffer$iv[i$iv];
            boolean bl = false;
            this.wi.addVectorToRow(grad, i, 1.0f);
        }
    }

    @NotNull
    public final Matrix getWi() {
        return this.wi;
    }

    @NotNull
    public final Matrix getWo() {
        return this.wo;
    }

    @NotNull
    public final Loss getLoss() {
        return this.loss;
    }

    public final boolean getNormalizeGradient() {
        return this.normalizeGradient;
    }

    public Model(@NotNull Matrix wi, @NotNull Matrix wo, @NotNull Loss loss, boolean normalizeGradient) {
        Intrinsics.checkParameterIsNotNull((Object)wi, (String)"wi");
        Intrinsics.checkParameterIsNotNull((Object)wo, (String)"wo");
        Intrinsics.checkParameterIsNotNull((Object)loss, (String)"loss");
        this.wi = wi;
        this.wo = wo;
        this.loss = loss;
        this.normalizeGradient = normalizeGradient;
    }

    static {
        kUnlimitedPredictions = -1;
        kAllLabelsAsTarget = -1;
    }

    @Metadata(mv={1, 1, 16}, bv={1, 0, 3}, k=1, d1={"\u00000\n\u0002\u0018\u0002\n\u0002\u0010\u0000\n\u0000\n\u0002\u0010\b\n\u0002\b\u0004\n\u0002\u0018\u0002\n\u0002\b\u0005\n\u0002\u0010\u0007\n\u0002\b\u0007\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0010\u0002\n\u0000\u0018\u00002\u00020\u0001B\u001d\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u0012\u0006\u0010\u0004\u001a\u00020\u0003\u0012\u0006\u0010\u0005\u001a\u00020\u0003\u00a2\u0006\u0002\u0010\u0006J\u000e\u0010\u0019\u001a\u00020\u001a2\u0006\u0010\r\u001a\u00020\u000eR\u0011\u0010\u0007\u001a\u00020\b\u00a2\u0006\b\n\u0000\u001a\u0004\b\t\u0010\nR\u0011\u0010\u000b\u001a\u00020\b\u00a2\u0006\b\n\u0000\u001a\u0004\b\f\u0010\nR\u0011\u0010\r\u001a\u00020\u000e8F\u00a2\u0006\u0006\u001a\u0004\b\u000f\u0010\u0010R\u000e\u0010\u0011\u001a\u00020\u000eX\u0082\u000e\u00a2\u0006\u0002\n\u0000R\u000e\u0010\u0012\u001a\u00020\u0003X\u0082\u000e\u00a2\u0006\u0002\n\u0000R\u0011\u0010\u0013\u001a\u00020\b\u00a2\u0006\b\n\u0000\u001a\u0004\b\u0014\u0010\nR\u0011\u0010\u0015\u001a\u00020\u0016\u00a2\u0006\b\n\u0000\u001a\u0004\b\u0017\u0010\u0018\u00a8\u0006\u001b"}, d2={"Lcom/mayabot/nlp/fasttext/Model$State;", "", "hiddenSize", "", "outputSize", "seed", "(III)V", "grad", "Lcom/mayabot/nlp/fasttext/blas/DenseVector;", "getGrad", "()Lcom/mayabot/nlp/fasttext/blas/DenseVector;", "hidden", "getHidden", "loss", "", "getLoss", "()F", "lossValue", "nexamples", "output", "getOutput", "rng", "Lkotlin/random/Random;", "getRng", "()Lkotlin/random/Random;", "incrementNExamples", "", "fastText4j"})
    public static final class State {
        private float lossValue;
        private int nexamples;
        @NotNull
        private final DenseVector hidden;
        @NotNull
        private final DenseVector output;
        @NotNull
        private final DenseVector grad;
        @NotNull
        private final Random rng;

        @NotNull
        public final DenseVector getHidden() {
            return this.hidden;
        }

        @NotNull
        public final DenseVector getOutput() {
            return this.output;
        }

        @NotNull
        public final DenseVector getGrad() {
            return this.grad;
        }

        @NotNull
        public final Random getRng() {
            return this.rng;
        }

        public final float getLoss() {
            return this.lossValue / (float)this.nexamples;
        }

        public final void incrementNExamples(float loss) {
            this.lossValue += loss;
            int n = this.nexamples;
            this.nexamples = n + 1;
        }

        public State(int hiddenSize, int outputSize, int seed) {
            this.hidden = new DenseVector(hiddenSize);
            this.output = new DenseVector(outputSize);
            this.grad = new DenseVector(hiddenSize);
            this.rng = RandomKt.Random((int)seed);
        }
    }

    @Metadata(mv={1, 1, 16}, bv={1, 0, 3}, k=1, d1={"\u0000\u0014\n\u0002\u0018\u0002\n\u0002\u0010\u0000\n\u0002\b\u0002\n\u0002\u0010\b\n\u0002\b\u0005\b\u0086\u0003\u0018\u00002\u00020\u0001B\u0007\b\u0002\u00a2\u0006\u0002\u0010\u0002R\u0014\u0010\u0003\u001a\u00020\u0004X\u0086D\u00a2\u0006\b\n\u0000\u001a\u0004\b\u0005\u0010\u0006R\u0014\u0010\u0007\u001a\u00020\u0004X\u0086D\u00a2\u0006\b\n\u0000\u001a\u0004\b\b\u0010\u0006\u00a8\u0006\t"}, d2={"Lcom/mayabot/nlp/fasttext/Model$Companion;", "", "()V", "kAllLabelsAsTarget", "", "getKAllLabelsAsTarget", "()I", "kUnlimitedPredictions", "getKUnlimitedPredictions", "fastText4j"})
    public static final class Companion {
        public final int getKUnlimitedPredictions() {
            return kUnlimitedPredictions;
        }

        public final int getKAllLabelsAsTarget() {
            return kAllLabelsAsTarget;
        }

        private Companion() {
        }

        public /* synthetic */ Companion(DefaultConstructorMarker $constructor_marker) {
            this();
        }
    }
}

