/*
 * Decompiled with CFR 0.152.
 */
package com.mayabot.nlp.fasttext.loss;

import com.mayabot.nlp.fasttext.Model;
import com.mayabot.nlp.fasttext.ScoreIdPair;
import com.mayabot.nlp.fasttext.blas.Matrix;
import com.mayabot.nlp.fasttext.blas.Vector;
import com.mayabot.nlp.fasttext.loss.BinaryLogisticLoss;
import com.mayabot.nlp.fasttext.loss.LossKt;
import com.mayabot.nlp.fasttext.utils.IntArrayList;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.List;
import kotlin.Metadata;
import kotlin.collections.CollectionsKt;
import kotlin.comparisons.ComparisonsKt;
import kotlin.jvm.JvmField;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;

@Metadata(mv={1, 1, 16}, bv={1, 0, 3}, k=1, d1={"\u0000x\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0000\n\u0002\u0010\u0016\n\u0002\b\u0002\n\u0002\u0010!\n\u0002\u0010\u0018\n\u0002\b\u0003\n\u0002\u0010\b\n\u0002\b\u0003\n\u0002\u0010\u0015\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\u0002\n\u0002\b\u0002\n\u0002\u0010\u0007\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\u000b\n\u0000\n\u0002\u0018\u0002\n\u0002\b\u0002\u0018\u00002\u00020\u0001:\u0001,B\u0015\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u0012\u0006\u0010\u0004\u001a\u00020\u0005\u00a2\u0006\u0002\u0010\u0006J>\u0010\u0016\u001a\u00020\u00172\u0006\u0010\u0018\u001a\u00020\r2\u0006\u0010\u0019\u001a\u00020\u001a2\u0006\u0010\u001b\u001a\u00020\r2\u0006\u0010\u001c\u001a\u00020\u001a2\f\u0010\u001d\u001a\b\u0012\u0004\u0012\u00020\u001e0\b2\u0006\u0010\u001f\u001a\u00020 H\u0002J0\u0010!\u001a\u00020\u001a2\u0006\u0010\"\u001a\u00020#2\u0006\u0010$\u001a\u00020\r2\u0006\u0010%\u001a\u00020&2\u0006\u0010'\u001a\u00020\u001a2\u0006\u0010(\u001a\u00020)H\u0016J2\u0010*\u001a\u00020\u00172\u0006\u0010\u0018\u001a\u00020\r2\u0006\u0010\u0019\u001a\u00020\u001a2\u0010\u0010\u001d\u001a\f\u0012\u0004\u0012\u00020\u001e0\bj\u0002`+2\u0006\u0010%\u001a\u00020&H\u0016R\u0017\u0010\u0007\u001a\b\u0012\u0004\u0012\u00020\t0\b\u00a2\u0006\b\n\u0000\u001a\u0004\b\n\u0010\u000bR\u0011\u0010\f\u001a\u00020\r\u00a2\u0006\b\n\u0000\u001a\u0004\b\u000e\u0010\u000fR\u0017\u0010\u0010\u001a\b\u0012\u0004\u0012\u00020\u00110\b\u00a2\u0006\b\n\u0000\u001a\u0004\b\u0012\u0010\u000bR\u0017\u0010\u0013\u001a\b\u0012\u0004\u0012\u00020\u00140\b\u00a2\u0006\b\n\u0000\u001a\u0004\b\u0015\u0010\u000b\u00a8\u0006-"}, d2={"Lcom/mayabot/nlp/fasttext/loss/HierarchicalSoftmaxLoss;", "Lcom/mayabot/nlp/fasttext/loss/BinaryLogisticLoss;", "wo", "Lcom/mayabot/nlp/fasttext/blas/Matrix;", "targetCounts", "", "(Lcom/mayabot/nlp/fasttext/blas/Matrix;[J)V", "codes", "", "", "getCodes", "()Ljava/util/List;", "osz", "", "getOsz", "()I", "paths", "", "getPaths", "tree", "Lcom/mayabot/nlp/fasttext/loss/HierarchicalSoftmaxLoss$Node;", "getTree", "dfs", "", "k", "threshold", "", "node", "score", "heap", "Lcom/mayabot/nlp/fasttext/ScoreIdPair;", "hidden", "Lcom/mayabot/nlp/fasttext/blas/Vector;", "forward", "targets", "Lcom/mayabot/nlp/fasttext/utils/IntArrayList;", "targetIndex", "state", "Lcom/mayabot/nlp/fasttext/Model$State;", "lr", "backprop", "", "predict", "Lcom/mayabot/nlp/fasttext/Predictions;", "Node", "fastText4j"})
public final class HierarchicalSoftmaxLoss
extends BinaryLogisticLoss {
    private final int osz;
    @NotNull
    private final List<int[]> paths;
    @NotNull
    private final List<boolean[]> codes;
    @NotNull
    private final List<Node> tree;

    public final int getOsz() {
        return this.osz;
    }

    @NotNull
    public final List<int[]> getPaths() {
        return this.paths;
    }

    @NotNull
    public final List<boolean[]> getCodes() {
        return this.codes;
    }

    @NotNull
    public final List<Node> getTree() {
        return this.tree;
    }

    private final void dfs(int k, float threshold2, int node, float score, List<ScoreIdPair> heap, Vector hidden) {
        if ((double)score < LossKt.stdLog(threshold2)) {
            return;
        }
        if (heap.size() == k && score < heap.get(heap.size() - 1).getScore()) {
            return;
        }
        if (this.tree.get((int)node).left == -1 && this.tree.get((int)node).right == -1) {
            Comparator comparator;
            boolean bl;
            List<ScoreIdPair> list;
            heap.add(new ScoreIdPair(score, node));
            List<ScoreIdPair> $this$sortByDescending$iv = heap;
            boolean $i$f$sortByDescending = false;
            if ($this$sortByDescending$iv.size() > 1) {
                list = $this$sortByDescending$iv;
                bl = false;
                comparator = new Comparator<T>(){

                    public final int compare(T a, T b) {
                        boolean bl = false;
                        ScoreIdPair it = (ScoreIdPair)b;
                        boolean bl2 = false;
                        Comparable comparable = Float.valueOf(it.getScore());
                        it = (ScoreIdPair)a;
                        Comparable comparable2 = comparable;
                        bl2 = false;
                        Float f = Float.valueOf(it.getScore());
                        return ComparisonsKt.compareValues((Comparable)comparable2, (Comparable)f);
                    }
                };
                CollectionsKt.sortWith(list, (Comparator)comparator);
            }
            if (heap.size() > k) {
                $this$sortByDescending$iv = heap;
                $i$f$sortByDescending = false;
                if ($this$sortByDescending$iv.size() > 1) {
                    list = $this$sortByDescending$iv;
                    bl = false;
                    comparator = new Comparator<T>(){

                        public final int compare(T a, T b) {
                            boolean bl = false;
                            ScoreIdPair it = (ScoreIdPair)b;
                            boolean bl2 = false;
                            Comparable comparable = Float.valueOf(it.getScore());
                            it = (ScoreIdPair)a;
                            Comparable comparable2 = comparable;
                            bl2 = false;
                            Float f = Float.valueOf(it.getScore());
                            return ComparisonsKt.compareValues((Comparable)comparable2, (Comparable)f);
                        }
                    };
                    CollectionsKt.sortWith(list, (Comparator)comparator);
                }
                heap.remove(heap.size() - 1);
            }
            return;
        }
        float f = this.getWo().dotRow(hidden, node - this.osz);
        float f2 = -f;
        float f3 = 1.0f;
        float f4 = 1.0f;
        boolean bl = false;
        float f5 = (float)Math.exp(f2);
        f = f4 / (f3 + f5);
        this.dfs(k, threshold2, this.tree.get((int)node).left, score + (float)LossKt.stdLog(1.0f - f), heap, hidden);
        this.dfs(k, threshold2, this.tree.get((int)node).right, score + (float)LossKt.stdLog(f), heap, hidden);
    }

    /*
     * WARNING - void declaration
     */
    @Override
    public float forward(@NotNull IntArrayList targets, int targetIndex, @NotNull Model.State state, float lr, boolean backprop) {
        Intrinsics.checkParameterIsNotNull((Object)targets, (String)"targets");
        Intrinsics.checkParameterIsNotNull((Object)state, (String)"state");
        float loss = 0.0f;
        int target = targets.get(targetIndex);
        boolean[] binaryCode = this.codes.get(target);
        int[] pathToRoot = this.paths.get(target);
        int n = 0;
        int n2 = pathToRoot.length;
        while (n < n2) {
            void i;
            loss += this.binaryLogistic(pathToRoot[i], state, binaryCode[i], lr, backprop);
            ++i;
        }
        return loss;
    }

    @Override
    public void predict(int k, float threshold2, @NotNull List<ScoreIdPair> heap, @NotNull Model.State state) {
        Intrinsics.checkParameterIsNotNull(heap, (String)"heap");
        Intrinsics.checkParameterIsNotNull((Object)state, (String)"state");
        this.dfs(k, threshold2, 2 * this.osz - 2, 0.0f, heap, state.getHidden());
        List<ScoreIdPair> $this$sortByDescending$iv = heap;
        boolean $i$f$sortByDescending = false;
        if ($this$sortByDescending$iv.size() > 1) {
            List<ScoreIdPair> list = $this$sortByDescending$iv;
            boolean bl = false;
            Comparator comparator = new Comparator<T>(){

                public final int compare(T a, T b) {
                    boolean bl = false;
                    ScoreIdPair it = (ScoreIdPair)b;
                    boolean bl2 = false;
                    Comparable comparable = Float.valueOf(it.getScore());
                    it = (ScoreIdPair)a;
                    Comparable comparable2 = comparable;
                    bl2 = false;
                    Float f = Float.valueOf(it.getScore());
                    return ComparisonsKt.compareValues((Comparable)comparable2, (Comparable)f);
                }
            };
            CollectionsKt.sortWith(list, (Comparator)comparator);
        }
    }

    /*
     * WARNING - void declaration
     */
    public HierarchicalSoftmaxLoss(@NotNull Matrix wo, @NotNull long[] targetCounts) {
        int i;
        int i2;
        boolean bl;
        int n;
        Intrinsics.checkParameterIsNotNull((Object)wo, (String)"wo");
        Intrinsics.checkParameterIsNotNull((Object)targetCounts, (String)"targetCounts");
        super(wo);
        this.osz = targetCounts.length;
        long[] counts = targetCounts;
        int osz = wo.getRow();
        ArrayList<int[]> pathsLocal = new ArrayList<int[]>(osz);
        ArrayList<boolean[]> codesLocal = new ArrayList<boolean[]>(osz);
        ArrayList<Node> treeLocal = new ArrayList<Node>(2 * osz - 1);
        int n2 = 0;
        int n3 = 2 * osz - 1;
        while (n2 < n3) {
            Node node = new Node();
            ArrayList<Node> arrayList = treeLocal;
            n = 0;
            boolean bl2 = false;
            Node $this$apply = node;
            bl = false;
            $this$apply.parent = -1;
            $this$apply.left = -1;
            $this$apply.right = -1;
            $this$apply.count = 1000000000000000L;
            $this$apply.binary = false;
            Node node2 = node;
            arrayList.add(node2);
            ++i2;
        }
        n3 = osz;
        for (i2 = 0; i2 < n3; ++i2) {
            ((Node)treeLocal.get((int)i2)).count = counts[i2];
        }
        int leaf = osz - 1;
        int node = osz;
        int n4 = osz;
        n = 2 * osz - 1;
        while (n4 < n) {
            int[] mini = new int[2];
            boolean $this$apply = false;
            bl = true;
            while ($this$apply <= bl) {
                void j;
                mini[j] = leaf >= 0 && ((Node)treeLocal.get((int)leaf)).count < ((Node)treeLocal.get((int)node)).count ? leaf-- : node++;
                ++j;
            }
            Object j = treeLocal.get(i);
            bl = false;
            boolean bl3 = false;
            Node $this$apply2 = (Node)j;
            boolean bl4 = false;
            $this$apply2.left = mini[0];
            $this$apply2.right = mini[1];
            $this$apply2.count = ((Node)treeLocal.get((int)mini[0])).count + ((Node)treeLocal.get((int)mini[1])).count;
            ((Node)treeLocal.get((int)mini[0])).parent = i;
            ((Node)treeLocal.get((int)mini[1])).parent = i++;
            ((Node)treeLocal.get((int)mini[1])).binary = true;
        }
        n = osz;
        for (i = 0; i < n; ++i) {
            ArrayList<Integer> path = new ArrayList<Integer>();
            ArrayList<Boolean> code = new ArrayList<Boolean>();
            int j = i;
            while (((Node)treeLocal.get((int)j)).parent != -1) {
                path.add(((Node)treeLocal.get((int)j)).parent - osz);
                code.add(((Node)treeLocal.get((int)j)).binary);
                j = ((Node)treeLocal.get((int)j)).parent;
            }
            pathsLocal.add(CollectionsKt.toIntArray((Collection)path));
            codesLocal.add(CollectionsKt.toBooleanArray((Collection)code));
        }
        this.paths = pathsLocal;
        this.codes = codesLocal;
        this.tree = treeLocal;
    }

    @Metadata(mv={1, 1, 16}, bv={1, 0, 3}, k=1, d1={"\u0000 \n\u0002\u0018\u0002\n\u0002\u0010\u0000\n\u0002\b\u0002\n\u0002\u0010\u000b\n\u0000\n\u0002\u0010\t\n\u0000\n\u0002\u0010\b\n\u0002\b\u0003\u0018\u00002\u00020\u0001B\u0005\u00a2\u0006\u0002\u0010\u0002R\u0012\u0010\u0003\u001a\u00020\u00048\u0006@\u0006X\u0087\u000e\u00a2\u0006\u0002\n\u0000R\u0012\u0010\u0005\u001a\u00020\u00068\u0006@\u0006X\u0087\u000e\u00a2\u0006\u0002\n\u0000R\u0012\u0010\u0007\u001a\u00020\b8\u0006@\u0006X\u0087\u000e\u00a2\u0006\u0002\n\u0000R\u0012\u0010\t\u001a\u00020\b8\u0006@\u0006X\u0087\u000e\u00a2\u0006\u0002\n\u0000R\u0012\u0010\n\u001a\u00020\b8\u0006@\u0006X\u0087\u000e\u00a2\u0006\u0002\n\u0000\u00a8\u0006\u000b"}, d2={"Lcom/mayabot/nlp/fasttext/loss/HierarchicalSoftmaxLoss$Node;", "", "()V", "binary", "", "count", "", "left", "", "parent", "right", "fastText4j"})
    public static final class Node {
        @JvmField
        public int parent;
        @JvmField
        public int left;
        @JvmField
        public int right;
        @JvmField
        public long count;
        @JvmField
        public boolean binary;
    }
}

