/*
 * Decompiled with CFR 0.152.
 */
package com.mayabot.nlp.fasttext.loss;

import com.mayabot.nlp.fasttext.Model;
import com.mayabot.nlp.fasttext.blas.DenseVector;
import com.mayabot.nlp.fasttext.blas.Matrix;
import com.mayabot.nlp.fasttext.loss.Loss;
import com.mayabot.nlp.fasttext.utils.IntArrayList;
import kotlin.Metadata;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;

@Metadata(mv={1, 1, 16}, bv={1, 0, 3}, k=1, d1={"\u00008\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\u0002\n\u0000\n\u0002\u0018\u0002\n\u0000\n\u0002\u0010\u0007\n\u0000\n\u0002\u0018\u0002\n\u0000\n\u0002\u0010\b\n\u0002\b\u0002\n\u0002\u0010\u000b\n\u0000\u0018\u00002\u00020\u0001B\r\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u00a2\u0006\u0002\u0010\u0004J\u0010\u0010\u0005\u001a\u00020\u00062\u0006\u0010\u0007\u001a\u00020\bH\u0016J0\u0010\t\u001a\u00020\n2\u0006\u0010\u000b\u001a\u00020\f2\u0006\u0010\r\u001a\u00020\u000e2\u0006\u0010\u0007\u001a\u00020\b2\u0006\u0010\u000f\u001a\u00020\n2\u0006\u0010\u0010\u001a\u00020\u0011H\u0016\u00a8\u0006\u0012"}, d2={"Lcom/mayabot/nlp/fasttext/loss/SoftmaxLoss;", "Lcom/mayabot/nlp/fasttext/loss/Loss;", "wo", "Lcom/mayabot/nlp/fasttext/blas/Matrix;", "(Lcom/mayabot/nlp/fasttext/blas/Matrix;)V", "computeOutput", "", "state", "Lcom/mayabot/nlp/fasttext/Model$State;", "forward", "", "targets", "Lcom/mayabot/nlp/fasttext/utils/IntArrayList;", "targetIndex", "", "lr", "backprop", "", "fastText4j"})
public final class SoftmaxLoss
extends Loss {
    @Override
    public void computeOutput(@NotNull Model.State state) {
        int i;
        Intrinsics.checkParameterIsNotNull((Object)state, (String)"state");
        DenseVector output = state.getOutput();
        output.mul(this.getWo(), state.getHidden());
        float max = output.get(0);
        float z = 0.0f;
        int osz = output.length();
        int n = 0;
        int n2 = osz;
        while (n < n2) {
            max = Math.max(output.get(i), max);
            ++i;
        }
        n2 = osz;
        for (i = 0; i < n2; ++i) {
            double d = output.get(i) - max;
            int n3 = i;
            DenseVector denseVector = output;
            boolean bl = false;
            double d2 = Math.exp(d);
            denseVector.set(n3, (float)d2);
            z += output.get(i);
        }
        n2 = osz;
        for (i = 0; i < n2; ++i) {
            output.set(i, output.get(i) / z);
        }
    }

    /*
     * WARNING - void declaration
     */
    @Override
    public float forward(@NotNull IntArrayList targets, int targetIndex, @NotNull Model.State state, float lr, boolean backprop) {
        Intrinsics.checkParameterIsNotNull((Object)targets, (String)"targets");
        Intrinsics.checkParameterIsNotNull((Object)state, (String)"state");
        this.computeOutput(state);
        int target = targets.get(targetIndex);
        if (backprop) {
            int osz = this.getWo().getRow();
            int n = 0;
            int n2 = osz;
            while (n < n2) {
                void i;
                float label = i == target ? 1.0f : 0.0f;
                float alpha = lr * (label - state.getOutput().get((int)i));
                state.getGrad().addRow(this.getWo(), (int)i, alpha);
                this.getWo().addVectorToRow(state.getHidden(), (int)i, alpha);
                ++i;
            }
        }
        float t = -Loss.Companion.log(state.getOutput().get(target));
        return t;
    }

    public SoftmaxLoss(@NotNull Matrix wo) {
        Intrinsics.checkParameterIsNotNull((Object)wo, (String)"wo");
        super(wo);
    }
}

