/*
 * Decompiled with CFR 0.152.
 */
package com.mayabot.nlp.fasttext.train;

import com.mayabot.nlp.fasttext.FastText;
import com.mayabot.nlp.fasttext.Model;
import com.mayabot.nlp.fasttext.args.Args;
import com.mayabot.nlp.fasttext.dictionary.Dictionary;
import com.mayabot.nlp.fasttext.loss.LossName;
import com.mayabot.nlp.fasttext.train.FastTextTrain$TrainThread$WhenMappings;
import com.mayabot.nlp.fasttext.train.LoopReader;
import com.mayabot.nlp.fasttext.train.SampleLine;
import com.mayabot.nlp.fasttext.utils.IntArrayList;
import com.mayabot.nlp.fasttext.utils.LogUtilsKt;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.atomic.AtomicLong;
import kotlin.Metadata;
import kotlin.jvm.internal.Intrinsics;
import kotlin.jvm.internal.StringCompanionObject;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

@Metadata(mv={1, 1, 16}, bv={1, 0, 3}, k=1, d1={"\u0000f\n\u0002\u0018\u0002\n\u0002\u0010\u0000\n\u0000\n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0002\b\u0005\n\u0002\u0018\u0002\n\u0002\b\u0005\n\u0002\u0018\u0002\n\u0000\n\u0002\u0010\t\n\u0002\b\u0004\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0006\n\u0002\u0010\u000b\n\u0000\n\u0002\u0010\u0002\n\u0000\n\u0002\u0010\u0007\n\u0002\b\u0003\n\u0002\u0010 \n\u0002\u0010\u001c\n\u0002\u0018\u0002\n\u0002\b\u0003\u0018\u00002\u00020\u0001:\u0002./B\u0015\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u0012\u0006\u0010\u0004\u001a\u00020\u0005\u00a2\u0006\u0002\u0010\u0006J\b\u0010\"\u001a\u00020#H\u0002J \u0010$\u001a\u00020%2\u0006\u0010&\u001a\u00020'2\u0006\u0010\u0010\u001a\u00020\u00112\u0006\u0010(\u001a\u00020#H\u0002J\b\u0010&\u001a\u00020'H\u0002J\u001a\u0010)\u001a\u00020%2\u0012\u0010*\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020-0,0+R\u0011\u0010\u0007\u001a\u00020\u0003\u00a2\u0006\b\n\u0000\u001a\u0004\b\b\u0010\tR\u0011\u0010\n\u001a\u00020\u000b\u00a2\u0006\b\n\u0000\u001a\u0004\b\f\u0010\rR\u0011\u0010\u0004\u001a\u00020\u0005\u00a2\u0006\b\n\u0000\u001a\u0004\b\u000e\u0010\u000fR\u000e\u0010\u0010\u001a\u00020\u0011X\u0082\u0004\u00a2\u0006\u0002\n\u0000R\u0011\u0010\u0012\u001a\u00020\u0013\u00a2\u0006\b\n\u0000\u001a\u0004\b\u0014\u0010\u0015R\u000e\u0010\u0016\u001a\u00020\u0013X\u0082\u000e\u00a2\u0006\u0002\n\u0000R\u000e\u0010\u0017\u001a\u00020\u0018X\u0082\u0004\u00a2\u0006\u0002\n\u0000R\u0011\u0010\u0002\u001a\u00020\u0003\u00a2\u0006\b\n\u0000\u001a\u0004\b\u0019\u0010\tR\"\u0010\u001a\u001a\n\u0018\u00010\u001bj\u0004\u0018\u0001`\u001cX\u0086\u000e\u00a2\u0006\u000e\n\u0000\u001a\u0004\b\u001d\u0010\u001e\"\u0004\b\u001f\u0010 R\u000e\u0010!\u001a\u00020\u0013X\u0082\u0004\u00a2\u0006\u0002\n\u0000\u00a8\u00060"}, d2={"Lcom/mayabot/nlp/fasttext/train/FastTextTrain;", "", "trainArgs", "Lcom/mayabot/nlp/fasttext/args/Args;", "fastText", "Lcom/mayabot/nlp/fasttext/FastText;", "(Lcom/mayabot/nlp/fasttext/args/Args;Lcom/mayabot/nlp/fasttext/FastText;)V", "args", "getArgs", "()Lcom/mayabot/nlp/fasttext/args/Args;", "dict", "Lcom/mayabot/nlp/fasttext/dictionary/Dictionary;", "getDict", "()Lcom/mayabot/nlp/fasttext/dictionary/Dictionary;", "getFastText", "()Lcom/mayabot/nlp/fasttext/FastText;", "loss", "Lcom/mayabot/nlp/fasttext/train/FastTextTrain$ShareDouble;", "ntokens", "", "getNtokens", "()J", "startTime", "tokenCount", "Ljava/util/concurrent/atomic/AtomicLong;", "getTrainArgs", "trainException", "Ljava/lang/Exception;", "Lkotlin/Exception;", "getTrainException", "()Ljava/lang/Exception;", "setTrainException", "(Ljava/lang/Exception;)V", "wantProcessTotalTokens", "keepTraining", "", "printInfo", "", "progress", "", "stop", "startThreads", "sources", "", "", "Lcom/mayabot/nlp/fasttext/train/SampleLine;", "ShareDouble", "TrainThread", "fastText4j"})
public final class FastTextTrain {
    private final AtomicLong tokenCount;
    private final ShareDouble loss;
    private long startTime;
    @Nullable
    private Exception trainException;
    @NotNull
    private final Dictionary dict;
    @NotNull
    private final Args args;
    private final long ntokens;
    private final long wantProcessTotalTokens;
    @NotNull
    private final Args trainArgs;
    @NotNull
    private final FastText fastText;

    @Nullable
    public final Exception getTrainException() {
        return this.trainException;
    }

    public final void setTrainException(@Nullable Exception exception) {
        this.trainException = exception;
    }

    @NotNull
    public final Dictionary getDict() {
        return this.dict;
    }

    @NotNull
    public final Args getArgs() {
        return this.args;
    }

    public final long getNtokens() {
        return this.ntokens;
    }

    private final boolean keepTraining() {
        return this.tokenCount.longValue() < this.wantProcessTotalTokens && this.trainException == null;
    }

    private final float progress() {
        return this.tokenCount.floatValue() / (float)this.wantProcessTotalTokens;
    }

    /*
     * WARNING - void declaration
     */
    public final void startThreads(@NotNull List<? extends Iterable<SampleLine>> sources) {
        int i;
        Intrinsics.checkParameterIsNotNull(sources, (String)"sources");
        int thread = sources.size();
        ArrayList<Thread> threads = new ArrayList<Thread>();
        int n = 0;
        int n2 = thread;
        while (n < n2) {
            threads.add(new Thread(new TrainThread(i, sources.get(i))));
            ++i;
        }
        n2 = thread;
        for (i = 0; i < n2; ++i) {
            ((Thread)threads.get(i)).start();
        }
        long ntokens = this.dict.getNtokens();
        while (this.keepTraining()) {
            Thread.sleep(100L);
            if (!(this.loss.toFloat() >= 0.0f)) continue;
            float progress = this.progress();
            LogUtilsKt.logger("\r");
            this.printInfo(progress, this.loss, false);
        }
        int progress = 0;
        int n3 = thread;
        while (progress < n3) {
            void i2;
            ((Thread)threads.get((int)i2)).join();
            ++i2;
        }
        Exception exception = this.trainException;
        if (exception != null) {
            Exception exception2 = exception;
            n3 = 0;
            boolean bl = false;
            Exception it = exception2;
            boolean bl2 = false;
            throw (Throwable)it;
        }
        LogUtilsKt.logger("\r");
        this.printInfo(1.0f, this.loss, true);
        LogUtilsKt.loggerln();
        LogUtilsKt.loggerln("Train use time " + (System.currentTimeMillis() - this.startTime) + " ms");
    }

    private final void printInfo(float progress, ShareDouble loss, boolean stop) {
        float progress2 = progress;
        double t = (System.currentTimeMillis() - this.startTime) / (long)1000;
        double lr = this.trainArgs.getLr() * (1.0 - (double)progress2);
        double wst = 0.0;
        long eta = 2592000;
        if (progress2 > 0.0f && t >= 0.0) {
            eta = (long)(t * (double)(100.0f - (progress2 *= (float)100)) / (double)progress2);
            wst = this.tokenCount.doubleValue() / t / (double)this.trainArgs.getThread();
        }
        long etah = eta / (long)3600;
        long etam = eta % (long)3600 / (long)60;
        long etas = eta % (long)3600 % (long)60;
        StringBuilder sb = new StringBuilder();
        StringCompanionObject stringCompanionObject = StringCompanionObject.INSTANCE;
        String string = "%2.2f";
        Object[] objectArray = new Object[]{Float.valueOf(progress2)};
        CharSequence charSequence = new StringBuilder().append("Progress: ");
        StringBuilder stringBuilder = sb;
        boolean bl = false;
        String string2 = String.format(string, Arrays.copyOf(objectArray, objectArray.length));
        Intrinsics.checkExpressionValueIsNotNull((Object)string2, (String)"java.lang.String.format(format, *args)");
        String string3 = string2;
        stringCompanionObject = StringCompanionObject.INSTANCE;
        string = "%8.0f";
        objectArray = new Object[]{wst};
        charSequence = charSequence.append(string3).append("% words/sec/thread: ");
        bl = false;
        String string4 = String.format(string, Arrays.copyOf(objectArray, objectArray.length));
        Intrinsics.checkExpressionValueIsNotNull((Object)string4, (String)"java.lang.String.format(format, *args)");
        string3 = string4;
        stringBuilder.append(charSequence.append(string3).toString());
        if (!stop) {
            stringCompanionObject = StringCompanionObject.INSTANCE;
            string = " lr: %2.5f";
            objectArray = new Object[]{lr};
            stringBuilder = sb;
            bl = false;
            String string5 = String.format(string, Arrays.copyOf(objectArray, objectArray.length));
            Intrinsics.checkExpressionValueIsNotNull((Object)string5, (String)"java.lang.String.format(format, *args)");
            charSequence = string5;
            stringBuilder.append((String)charSequence);
        }
        stringCompanionObject = StringCompanionObject.INSTANCE;
        string = " arg.loss: %2.5f";
        objectArray = new Object[]{Float.valueOf(loss.toFloat())};
        stringBuilder = sb;
        bl = false;
        String string6 = String.format(string, Arrays.copyOf(objectArray, objectArray.length));
        Intrinsics.checkExpressionValueIsNotNull((Object)string6, (String)"java.lang.String.format(format, *args)");
        charSequence = string6;
        stringBuilder.append((String)charSequence);
        if (!stop) {
            sb.append(" ETA: " + etah + "h " + etam + "m " + etas + "s");
        }
        LogUtilsKt.logger(sb);
    }

    @NotNull
    public final Args getTrainArgs() {
        return this.trainArgs;
    }

    @NotNull
    public final FastText getFastText() {
        return this.fastText;
    }

    public FastTextTrain(@NotNull Args trainArgs, @NotNull FastText fastText) {
        Intrinsics.checkParameterIsNotNull((Object)trainArgs, (String)"trainArgs");
        Intrinsics.checkParameterIsNotNull((Object)fastText, (String)"fastText");
        this.trainArgs = trainArgs;
        this.fastText = fastText;
        this.tokenCount = new AtomicLong(0L);
        this.loss = new ShareDouble(-1.0);
        this.startTime = System.currentTimeMillis();
        this.dict = this.fastText.getDict();
        this.args = this.trainArgs;
        this.ntokens = this.dict.getNtokens();
        this.wantProcessTotalTokens = (long)this.args.getEpoch() * this.ntokens;
    }

    @Metadata(mv={1, 1, 16}, bv={1, 0, 3}, k=1, d1={"\u0000F\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0000\n\u0002\u0010\b\n\u0000\n\u0002\u0010\u001c\n\u0002\u0018\u0002\n\u0002\b\u0007\n\u0002\u0010\t\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0010\u0002\n\u0000\n\u0002\u0018\u0002\n\u0000\n\u0002\u0010\u0007\n\u0000\n\u0002\u0018\u0002\n\u0002\b\u0005\b\u0080\u0004\u0018\u00002\u00020\u0001B\u001b\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u0012\f\u0010\u0004\u001a\b\u0012\u0004\u0012\u00020\u00060\u0005\u00a2\u0006\u0002\u0010\u0007J(\u0010\u0015\u001a\u00020\u00162\u0006\u0010\u0011\u001a\u00020\u00122\u0006\u0010\u0017\u001a\u00020\u00182\u0006\u0010\u0019\u001a\u00020\u001a2\u0006\u0010\u001b\u001a\u00020\u001cH\u0002J\b\u0010\u001d\u001a\u00020\u0016H\u0016J(\u0010\u001e\u001a\u00020\u00162\u0006\u0010\u0011\u001a\u00020\u00122\u0006\u0010\u0017\u001a\u00020\u00182\u0006\u0010\u0019\u001a\u00020\u001a2\u0006\u0010\u001b\u001a\u00020\u001cH\u0002J0\u0010\u001f\u001a\u00020\u00162\u0006\u0010\u0011\u001a\u00020\u00122\u0006\u0010\u0017\u001a\u00020\u00182\u0006\u0010\u0019\u001a\u00020\u001a2\u0006\u0010\u001b\u001a\u00020\u001c2\u0006\u0010 \u001a\u00020\u001cH\u0002R\u001a\u0010\b\u001a\u00020\u0003X\u0086\u000e\u00a2\u0006\u000e\n\u0000\u001a\u0004\b\t\u0010\n\"\u0004\b\u000b\u0010\fR\u0011\u0010\r\u001a\u00020\u000e\u00a2\u0006\b\n\u0000\u001a\u0004\b\u000f\u0010\u0010R\u0014\u0010\u0004\u001a\b\u0012\u0004\u0012\u00020\u00060\u0005X\u0082\u0004\u00a2\u0006\u0002\n\u0000R\u0011\u0010\u0011\u001a\u00020\u0012\u00a2\u0006\b\n\u0000\u001a\u0004\b\u0013\u0010\u0014R\u000e\u0010\u0002\u001a\u00020\u0003X\u0082\u0004\u00a2\u0006\u0002\n\u0000\u00a8\u0006!"}, d2={"Lcom/mayabot/nlp/fasttext/train/FastTextTrain$TrainThread;", "Ljava/lang/Runnable;", "threadId", "", "parts", "", "Lcom/mayabot/nlp/fasttext/train/SampleLine;", "(Lcom/mayabot/nlp/fasttext/train/FastTextTrain;ILjava/lang/Iterable;)V", "localTokenCount", "getLocalTokenCount", "()I", "setLocalTokenCount", "(I)V", "ntokens", "", "getNtokens", "()J", "state", "Lcom/mayabot/nlp/fasttext/Model$State;", "getState", "()Lcom/mayabot/nlp/fasttext/Model$State;", "cbow", "", "model", "Lcom/mayabot/nlp/fasttext/Model;", "lr", "", "line", "Lcom/mayabot/nlp/fasttext/utils/IntArrayList;", "run", "skipgram", "supervised", "labels", "fastText4j"})
    public final class TrainThread
    implements Runnable {
        @NotNull
        private final Model.State state;
        private final long ntokens;
        private int localTokenCount;
        private final int threadId;
        private final Iterable<SampleLine> parts;

        @NotNull
        public final Model.State getState() {
            return this.state;
        }

        public final long getNtokens() {
            return this.ntokens;
        }

        public final int getLocalTokenCount() {
            return this.localTokenCount;
        }

        public final void setLocalTokenCount(int n) {
            this.localTokenCount = n;
        }

        @Override
        public void run() {
            int emptyCount = 0;
            LoopReader reader = new LoopReader(this.parts);
            try {
                IntArrayList line = new IntArrayList(0, null, 3, null);
                IntArrayList labels = new IntArrayList(0, null, 3, null);
                while (FastTextTrain.this.keepTraining()) {
                    float progress = FastTextTrain.this.progress();
                    float lr = (float)(FastTextTrain.this.getTrainArgs().getLr() * (1.0 - (double)progress));
                    if (reader.hasNext()) {
                        SampleLine sample = (SampleLine)reader.next();
                        if (sample.getWords().isEmpty()) {
                            if (++emptyCount <= 1000) continue;
                            break;
                        }
                        emptyCount = 0;
                        List<String> tokens = sample.getWords();
                        switch (FastTextTrain$TrainThread$WhenMappings.$EnumSwitchMapping$0[FastTextTrain.this.getArgs().getModel().ordinal()]) {
                            case 1: {
                                this.localTokenCount += FastTextTrain.this.getDict().getLine((Iterable<String>)tokens, line, labels);
                                this.supervised(this.state, FastTextTrain.this.getFastText().getModel(), lr, line, labels);
                                break;
                            }
                            case 2: {
                                this.localTokenCount += FastTextTrain.this.getDict().getLine(tokens, line, this.state.getRng());
                                this.cbow(this.state, FastTextTrain.this.getFastText().getModel(), lr, line);
                                break;
                            }
                            case 3: {
                                this.localTokenCount += FastTextTrain.this.getDict().getLine(tokens, line, this.state.getRng());
                                this.skipgram(this.state, FastTextTrain.this.getFastText().getModel(), lr, line);
                                break;
                            }
                        }
                        if (this.localTokenCount <= FastTextTrain.this.getArgs().getLrUpdateRate()) continue;
                        FastTextTrain.this.tokenCount.addAndGet(this.localTokenCount);
                        this.localTokenCount = 0;
                        if (this.threadId != 0) continue;
                        FastTextTrain.this.loss.set(this.state.getLoss());
                        continue;
                    }
                    String string = "\u4e0d\u53ef\u80fd\u4e3a\u7a7a";
                    boolean bl = false;
                    throw (Throwable)new IllegalStateException(string.toString());
                }
            }
            catch (Exception e) {
                FastTextTrain.this.setTrainException(e);
            }
        }

        private final void supervised(Model.State state, Model model, float lr, IntArrayList line, IntArrayList labels) {
            if (labels.size() == 0 || line.size() == 0) {
                return;
            }
            if (FastTextTrain.this.getArgs().getLoss() == LossName.ova) {
                model.update(line, labels, Model.Companion.getKAllLabelsAsTarget(), lr, state);
            } else {
                int i = state.getRng().nextInt(labels.size());
                model.update(line, labels, i, lr, state);
            }
        }

        /*
         * WARNING - void declaration
         */
        private final void cbow(Model.State state, Model model, float lr, IntArrayList line) {
            IntArrayList bow = new IntArrayList(0, null, 3, null);
            int n = 0;
            int n2 = line.size();
            while (n < n2) {
                void w;
                int boundary = state.getRng().nextInt(FastTextTrain.this.getArgs().getWs()) + 1;
                bow.clear();
                int n3 = -boundary;
                int n4 = boundary;
                if (n3 <= n4) {
                    while (true) {
                        void c;
                        if (c != false && w + c >= 0 && w + c < line.size()) {
                            IntArrayList ngrams = FastTextTrain.this.getDict().getSubwords(line.get((int)(w + c)));
                            bow.addAll(ngrams);
                        }
                        if (c == n4) break;
                        ++c;
                    }
                }
                model.update(bow, line, (int)w, lr, state);
                ++w;
            }
        }

        /*
         * WARNING - void declaration
         */
        private final void skipgram(Model.State state, Model model, float lr, IntArrayList line) {
            int n = 0;
            int n2 = line.size();
            while (n < n2) {
                void w;
                int boundary = state.getRng().nextInt(FastTextTrain.this.getArgs().getWs()) + 1;
                IntArrayList ngrams = FastTextTrain.this.getDict().getSubwords(line.get((int)w));
                int n3 = -boundary;
                int n4 = boundary;
                if (n3 <= n4) {
                    while (true) {
                        void c;
                        if (c != false && w + c >= 0 && w + c < line.size()) {
                            model.update(ngrams, line, (int)(w + c), lr, state);
                        }
                        if (c == n4) break;
                        ++c;
                    }
                }
                ++w;
            }
        }

        public TrainThread(@NotNull int threadId, Iterable<SampleLine> parts) {
            Intrinsics.checkParameterIsNotNull(parts, (String)"parts");
            this.threadId = threadId;
            this.parts = parts;
            this.state = new Model.State(FastTextTrain.this.getArgs().getDim(), FastTextTrain.this.getFastText().getOutput().getRow(), FastTextTrain.this.getTrainArgs().getSeed());
            this.ntokens = FastTextTrain.this.getDict().getNtokens();
        }
    }

    @Metadata(mv={1, 1, 16}, bv={1, 0, 3}, k=1, d1={"\u0000 \n\u0002\u0018\u0002\n\u0002\u0010\u0000\n\u0000\n\u0002\u0010\u0006\n\u0002\b\u0005\n\u0002\u0010\u0002\n\u0002\b\u0002\n\u0002\u0010\u0007\n\u0000\u0018\u00002\u00020\u0001B\r\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u00a2\u0006\u0002\u0010\u0004J\u000e\u0010\b\u001a\u00020\t2\u0006\u0010\n\u001a\u00020\u0003J\u0006\u0010\u000b\u001a\u00020\fR\u001a\u0010\u0002\u001a\u00020\u0003X\u0086\u000e\u00a2\u0006\u000e\n\u0000\u001a\u0004\b\u0005\u0010\u0006\"\u0004\b\u0007\u0010\u0004\u00a8\u0006\r"}, d2={"Lcom/mayabot/nlp/fasttext/train/FastTextTrain$ShareDouble;", "", "value", "", "(D)V", "getValue", "()D", "setValue", "set", "", "v", "toFloat", "", "fastText4j"})
    public static final class ShareDouble {
        private double value;

        public final float toFloat() {
            return (float)this.value;
        }

        public final void set(double v) {
            this.value = v;
        }

        public final double getValue() {
            return this.value;
        }

        public final void setValue(double d) {
            this.value = d;
        }

        public ShareDouble(double value) {
            this.value = value;
        }
    }
}

