package svm.instances.dependency;

import base.CachedKernel;
import base.Example;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import svm.Mode;
import svm.SVMStruct;
import svm.SVMStructKernel;
import svm.instances.dependency.kernel.TokenSequenceKernel;

/* loaded from: input_file:svm/instances/dependency/DependencyParserProblem.class */
public class DependencyParserProblem {
    public static void main(String[] strArr) {
        implicit();
    }

    private static void explicit() {
        SVMStruct sVMStruct = new SVMStruct(new DependencyParserInstance(), 200.0d, 0.05d, Mode.MARGIN_RESCALING);
        System.out.print("loading treebank...");
        List<Example<TokenSequence, DependencyTree>> loadTreebank = ConllLoader.loadTreebank();
        System.out.println("done.");
        sVMStruct.train(loadTreebank.subList(0, (int) (0.5d * loadTreebank.size())));
        for (Map.Entry<String, Integer> entry : sVMStruct.wMap.entrySet()) {
            System.out.println(String.valueOf(entry.getKey()) + ": " + sVMStruct.w.x[entry.getValue().intValue()]);
        }
        System.out.println("Test Error: " + sVMStruct.empiricalError(loadTreebank.subList(0 * loadTreebank.size(), loadTreebank.size())));
    }

    private static void implicit() {
        TokenSequenceKernel tokenSequenceKernel = new TokenSequenceKernel();
        for (double d : new double[]{1000.0d}) {
            long currentTimeMillis = System.currentTimeMillis();
            DependencyParserKernelInstance dependencyParserKernelInstance = new DependencyParserKernelInstance();
            dependencyParserKernelInstance.in = tokenSequenceKernel;
            SVMStructKernel sVMStructKernel = new SVMStructKernel(dependencyParserKernelInstance, d, 0.05d, Mode.MARGIN_RESCALING);
            System.out.print("loading treebank...");
            List<Example<TokenSequence, DependencyTree>> loadTreebank = ConllLoader.loadTreebank("data/tuebadz.conll", 100);
            System.out.println("done. (" + loadTreebank.size() + ")");
            List<Example<TokenSequence, DependencyTree>> subList = loadTreebank.subList(0, (int) (0.5d * loadTreebank.size()));
            List<Example<TokenSequence, DependencyTree>> subList2 = loadTreebank.subList(((int) (0.5d * loadTreebank.size())) + 1, loadTreebank.size());
            if (dependencyParserKernelInstance.in instanceof CachedKernel) {
                ((CachedKernel) dependencyParserKernelInstance.in).initKernelCache(loadTreebank.size() + subList2.size());
            }
            System.out.println("Average Input Kernel Value: " + subList.parallelStream().mapToDouble(example -> {
                double d2 = 0.0d;
                Iterator it = loadTreebank.iterator();
                while (it.hasNext()) {
                    Example example = (Example) it.next();
                    double k = sVMStructKernel.inputKernel.k((TokenSequence) example.x, (TokenSequence) example.x);
                    if (Double.isNaN(k)) {
                        sVMStructKernel.inputKernel.k((TokenSequence) example.x, (TokenSequence) example.x);
                        throw new RuntimeException("HOLY MOLY" + example.x + " " + example.x);
                    }
                    d2 += k;
                }
                return d2;
            }).average().getAsDouble() + " [ really just precomputing the input kernel matrix ;-) ]");
            System.out.println(String.valueOf(System.currentTimeMillis() - currentTimeMillis) + "ms for computing/loading the kernel");
            long currentTimeMillis2 = System.currentTimeMillis();
            sVMStructKernel.train(subList);
            Iterator<Double> it = sVMStructKernel.alphas.iterator();
            while (it.hasNext()) {
                System.out.print(String.valueOf(it.next().doubleValue()) + "\t");
            }
            System.out.println("");
            System.out.println(String.valueOf(System.currentTimeMillis() - currentTimeMillis2) + "ms for computing the model");
            long currentTimeMillis3 = System.currentTimeMillis();
            List list = (List) subList2.parallelStream().map(example2 -> {
                return (DependencyTree) sVMStructKernel.predict((TokenSequence) example2.x);
            }).collect(Collectors.toList());
            System.out.println(list.size());
            double d2 = 0.0d;
            double d3 = 0.0d;
            for (int i = 0; i < list.size(); i++) {
                TokenSequence tokenSequence = subList2.get(i).x;
                DependencyTree dependencyTree = (DependencyTree) list.get(i);
                DependencyTree dependencyTree2 = subList2.get(i).y;
                for (int i2 = 1; i2 < tokenSequence.tokens.length; i2++) {
                    System.out.println(String.valueOf(tokenSequence.tokens[i2]) + "\t" + tokenSequence.posTag[i2] + "\t" + tokenSequence.posTag2[i2] + "\t" + dependencyTree2.heads[i2 - 1] + "\t" + dependencyTree.heads[i2 - 1]);
                }
                System.out.println();
                d3 += sVMStructKernel.instance.loss(dependencyTree2, dependencyTree);
                int[] iArr = new int[dependencyTree.heads.length];
                int i3 = 0;
                for (int i4 = 0; i4 < dependencyTree.heads.length; i4++) {
                    if (dependencyTree.heads[i4] == 0) {
                        tiefensuche(i4, dependencyTree.heads, iArr);
                        i3++;
                    }
                }
                int i5 = 1;
                for (int i6 : iArr) {
                    if (i6 > 1 || i6 == 0) {
                        i5 = 0;
                    }
                }
                d2 += i5;
            }
            System.out.println("Test Error:  " + (d3 / subList2.size()));
            System.out.println(String.valueOf(System.currentTimeMillis() - currentTimeMillis3) + "ms for testing the model");
            System.out.println("Number of valid trees: " + (d2 / subList2.size()));
        }
    }

    static void tiefensuche(int i, int[] iArr, int[] iArr2) {
        iArr2[i] = iArr2[i] + 1;
        for (int i2 = 0; i2 < iArr.length; i2++) {
            if (iArr[i2] - 1 == i && iArr2[i2] == 0) {
                tiefensuche(i2, iArr, iArr2);
            } else if (iArr[i2] - 1 == i && iArr2[i2] == 1) {
                iArr2[i2] = 2;
            }
        }
    }
}
