/*
 * Decompiled with CFR 0.152.
 */
package org.bouncycastle160.crypto.test;

import java.io.BufferedReader;
import java.io.EOFException;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import org.bouncycastle160.crypto.digests.SHAKEDigest;
import org.bouncycastle160.util.Arrays;
import org.bouncycastle160.util.encoders.Hex;
import org.bouncycastle160.util.test.SimpleTest;

public class SHAKEDigestTest
extends SimpleTest {
    SHAKEDigestTest() {
    }

    @Override
    public String getName() {
        return "SHAKE";
    }

    @Override
    public void performTest() throws Exception {
        this.testVectors();
    }

    public void testVectors() throws Exception {
        String line;
        BufferedReader r = new BufferedReader(new InputStreamReader(this.getClass().getResourceAsStream("SHAKETestVectors.txt")));
        while (null != (line = this.readLine(r))) {
            if (line.length() == 0) continue;
            TestVector v = this.readTestVector(r, line);
            this.runTestVector(v);
        }
        r.close();
    }

    private MySHAKEDigest createDigest(String algorithm) throws Exception {
        if (algorithm.startsWith("SHAKE-")) {
            int bits = this.parseDecimal(algorithm.substring("SHAKE-".length()));
            return new MySHAKEDigest(bits);
        }
        throw new IllegalArgumentException("Unknown algorithm: " + algorithm);
    }

    private byte[] decodeBinary(String block) {
        int bits = block.length();
        int fullBytes = bits / 8;
        int totalBytes = (bits + 7) / 8;
        byte[] result = new byte[totalBytes];
        for (int i = 0; i < fullBytes; ++i) {
            String byteStr = this.reverse(block.substring(i * 8, (i + 1) * 8));
            result[i] = (byte)this.parseBinary(byteStr);
        }
        if (totalBytes > fullBytes) {
            String byteStr = this.reverse(block.substring(fullBytes * 8));
            result[fullBytes] = (byte)this.parseBinary(byteStr);
        }
        return result;
    }

    private int parseBinary(String s) {
        return Integer.parseInt(s, 2);
    }

    private int parseDecimal(String s) {
        return Integer.parseInt(s);
    }

    private String readBlock(BufferedReader r) throws IOException {
        String line;
        StringBuffer b = new StringBuffer();
        while ((line = this.readBlockLine(r)) != null) {
            b.append(line);
        }
        return b.toString();
    }

    private String readBlockLine(BufferedReader r) throws IOException {
        String line = this.readLine(r);
        if (line == null || line.length() == 0) {
            return null;
        }
        char[] chars = line.toCharArray();
        int pos = 0;
        for (int i = 0; i != chars.length; ++i) {
            if (chars[i] == ' ') continue;
            chars[pos++] = chars[i];
        }
        return new String(chars, 0, pos);
    }

    private TestVector readTestVector(BufferedReader r, String header) throws IOException {
        String[] parts = this.splitAround(header, TestVector.SAMPLE_OF);
        String algorithm = parts[0];
        int bits = this.parseDecimal(this.stripFromChar(parts[1], '-'));
        this.skipUntil(r, TestVector.MSG_HEADER);
        String messageBlock = this.readBlock(r);
        if (messageBlock.length() != bits) {
            throw new IllegalStateException("Test vector length mismatch");
        }
        byte[] message = this.decodeBinary(messageBlock);
        this.skipUntil(r, TestVector.OUTPUT_HEADER);
        byte[] output = Hex.decode(this.readBlock(r));
        return new TestVector(algorithm, bits, message, output);
    }

    private String readLine(BufferedReader r) throws IOException {
        String line = r.readLine();
        return line == null ? null : this.stripFromChar(line, '#').trim();
    }

    private String requireLine(BufferedReader r) throws IOException {
        String line = this.readLine(r);
        if (line == null) {
            throw new EOFException();
        }
        return line;
    }

    private String reverse(String s) {
        return new StringBuffer(s).reverse().toString();
    }

    private void runTestVector(TestVector v) throws Exception {
        int bits = v.getBits();
        int partialBits = bits % 8;
        byte[] expected = v.getOutput();
        int outLen = expected.length;
        MySHAKEDigest d = this.createDigest(v.getAlgorithm());
        byte[] output = new byte[outLen];
        byte[] m = v.getMessage();
        if (partialBits == 0) {
            d.update(m, 0, m.length);
            d.doFinal(output, 0, outLen);
        } else {
            d.update(m, 0, m.length - 1);
            d.myDoFinal(output, 0, outLen, m[m.length - 1], partialBits);
        }
        if (!Arrays.areEqual(expected, output)) {
            this.fail(v.getAlgorithm() + " " + v.getBits() + "-bit test vector hash mismatch");
        }
        if (partialBits == 0) {
            d = this.createDigest(v.getAlgorithm());
            m = v.getMessage();
            d.update(m, 0, m.length);
            d.doOutput(output, 0, outLen / 2);
            d.doOutput(output, outLen / 2, output.length - outLen / 2);
            if (!Arrays.areEqual(expected, output)) {
                this.fail(v.getAlgorithm() + " " + v.getBits() + "-bit test vector extended hash mismatch");
            }
            try {
                d.update((byte)1);
                this.fail("no exception");
            }
            catch (IllegalStateException e) {
                this.isTrue("wrong exception", "attempt to absorb while squeezing".equals(e.getMessage()));
            }
            d = this.createDigest(v.getAlgorithm());
            m = v.getMessage();
            d.update(m, 0, m.length);
            d.doOutput(output, 0, outLen / 2);
            d.doFinal(output, outLen / 2, output.length - outLen / 2);
            if (!Arrays.areEqual(expected, output)) {
                this.fail(v.getAlgorithm() + " " + v.getBits() + "-bit test vector extended doFinal hash mismatch");
            }
            d.update((byte)1);
        }
    }

    private void skipUntil(BufferedReader r, String header) throws IOException {
        String line;
        while ((line = this.requireLine(r)).length() == 0) {
        }
        if (!line.equals(header)) {
            throw new IOException("Expected: " + header);
        }
    }

    private String[] splitAround(String s, String separator) {
        int index;
        ArrayList<String> strings = new ArrayList<String>();
        String remaining = s;
        while ((index = remaining.indexOf(separator)) > 0) {
            strings.add(remaining.substring(0, index));
            remaining = remaining.substring(index + separator.length());
        }
        strings.add(remaining);
        return strings.toArray(new String[strings.size()]);
    }

    private String stripFromChar(String s, char c) {
        int i = s.indexOf(c);
        if (i >= 0) {
            s = s.substring(0, i);
        }
        return s;
    }

    public static void main(String[] args) {
        SHAKEDigestTest.runTest(new SHAKEDigestTest());
    }

    private static class TestVector {
        private static String SAMPLE_OF = " sample of ";
        private static String MSG_HEADER = "Msg as bit string";
        private static String OUTPUT_HEADER = "Output val is";
        private String algorithm;
        private int bits;
        private byte[] message;
        private byte[] output;

        private TestVector(String algorithm, int bits, byte[] message, byte[] output) {
            this.algorithm = algorithm;
            this.bits = bits;
            this.message = message;
            this.output = output;
        }

        public String getAlgorithm() {
            return this.algorithm;
        }

        public int getBits() {
            return this.bits;
        }

        public byte[] getMessage() {
            return this.message;
        }

        public byte[] getOutput() {
            return this.output;
        }
    }

    static class MySHAKEDigest
    extends SHAKEDigest {
        MySHAKEDigest(int bitLength) {
            super(bitLength);
        }

        int myDoFinal(byte[] out, int outOff, int outLen, byte partialByte, int partialBits) {
            return this.doFinal(out, outOff, outLen, partialByte, partialBits);
        }
    }
}

