package svm.instances.dependency;

import base.SparseVector;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import svm.SVMStruct;
import svm.SVMStructInstance;

/* loaded from: input_file:svm/instances/dependency/DependencyParserInstance.class */
public class DependencyParserInstance implements SVMStructInstance<TokenSequence, DependencyTree> {
    private static final int BUFFERSIZE = 1000;
    public ConcurrentHashMap<TokenSequence, SparseVector[][]> precomputed = new ConcurrentHashMap<>();
    static final /* synthetic */ boolean $assertionsDisabled;

    static {
        $assertionsDisabled = !DependencyParserInstance.class.desiredAssertionStatus();
    }

    @Override // svm.SVMStructInstance
    public SparseVector phi(TokenSequence tokenSequence, DependencyTree dependencyTree) {
        SparseVector sparseVector = new SparseVector();
        if (this.precomputed.containsKey(tokenSequence)) {
            for (int i = 0; i < dependencyTree.heads.length; i++) {
                sparseVector.add(this.precomputed.get(tokenSequence)[dependencyTree.heads[i]][i + 1]);
            }
        } else {
            SparseVector[][] sparseVectorArr = new SparseVector[tokenSequence.lemmas.length][tokenSequence.lemmas.length];
            for (int i2 = 0; i2 < tokenSequence.lemmas.length; i2++) {
                for (int i3 = 0; i3 < tokenSequence.lemmas.length; i3++) {
                    sparseVectorArr[i2][i3] = new SparseVector();
                    addEdge(tokenSequence, i2, i3, sparseVectorArr[i2][i3]);
                }
            }
            for (int i4 = 0; i4 < dependencyTree.heads.length; i4++) {
                sparseVector.add(sparseVectorArr[dependencyTree.heads[i4]][i4 + 1]);
            }
            if (this.precomputed.size() < BUFFERSIZE) {
                this.precomputed.put(tokenSequence, sparseVectorArr);
            }
        }
        return sparseVector;
    }

    private static String concat(String... strArr) {
        StringBuilder sb = new StringBuilder();
        for (String str : strArr) {
            sb.append(str);
            sb.append("::");
        }
        return sb.toString();
    }

    private void addEdge(TokenSequence tokenSequence, int i, int i2, SparseVector sparseVector) {
        String str = i < i2 ? "RIGHT" : "LEFT";
        int abs = Math.abs(i - i2);
        for (String str2 : new String[]{"", "::DIR-" + str + "-" + (abs > 10 ? "10" : abs > 5 ? "5" : Integer.toString(abs - 1))}) {
            sparseVector.plusOne(concat("THD", tokenSequence.tokens[i], tokenSequence.tokens[i2], str2));
            sparseVector.plusOne(concat("LHD", tokenSequence.lemmas[i], tokenSequence.lemmas[i2], str2));
            sparseVector.plusOne(concat("PHD", tokenSequence.posTag[i], tokenSequence.posTag[i2], str2));
            sparseVector.plusOne(concat("P2HD", tokenSequence.posTag2[i], tokenSequence.posTag2[i2], str2));
            sparseVector.plusOne(concat("TPHD", tokenSequence.tokens[i], tokenSequence.tokens[i2], tokenSequence.posTag[i], tokenSequence.posTag[i2], str2));
            sparseVector.plusOne(concat("TP2HD", tokenSequence.tokens[i], tokenSequence.tokens[i2], tokenSequence.posTag2[i], tokenSequence.posTag2[i2], str2));
            sparseVector.plusOne(concat("LPHD", tokenSequence.lemmas[i], tokenSequence.lemmas[i2], tokenSequence.posTag[i], tokenSequence.posTag[i2], str2));
            sparseVector.plusOne(concat("LP2HD", tokenSequence.lemmas[i], tokenSequence.lemmas[i2], tokenSequence.posTag2[i], tokenSequence.posTag2[i2], str2));
            sparseVector.plusOne(concat("TH", tokenSequence.tokens[i], str2));
            sparseVector.plusOne(concat("LH", tokenSequence.lemmas[i], str2));
            sparseVector.plusOne(concat("PH", tokenSequence.posTag[i], str2));
            sparseVector.plusOne(concat("P2H", tokenSequence.posTag2[i], str2));
            sparseVector.plusOne(concat("TPH", tokenSequence.tokens[i], tokenSequence.posTag[i], str2));
            sparseVector.plusOne(concat("TP2H", tokenSequence.tokens[i], tokenSequence.posTag2[i], str2));
            sparseVector.plusOne(concat("TD", tokenSequence.tokens[i2], str2));
            sparseVector.plusOne(concat("LD", tokenSequence.lemmas[i2], str2));
            sparseVector.plusOne(concat("PD", tokenSequence.posTag[i2], str2));
            sparseVector.plusOne(concat("P2D", tokenSequence.posTag2[i2], str2));
            sparseVector.plusOne(concat("TPD", tokenSequence.tokens[i2], tokenSequence.posTag[i2], str2));
            sparseVector.plusOne(concat("TP2D", tokenSequence.tokens[i2], tokenSequence.posTag2[i2], str2));
            for (int min = Math.min(i, i2) + 1; min < Math.max(i, i2); min++) {
                sparseVector.plusOne(concat("THBD", tokenSequence.posTag[i], tokenSequence.posTag[min], tokenSequence.posTag[i2], str2));
                sparseVector.plusOne(concat("T2HBD", tokenSequence.posTag2[i], tokenSequence.posTag2[min], tokenSequence.posTag2[i2], str2));
            }
            String[] strArr = new String[6];
            strArr[0] = "TN1";
            strArr[1] = i > 0 ? tokenSequence.posTag[i - 1] : "<null>";
            strArr[2] = tokenSequence.posTag[i];
            strArr[3] = tokenSequence.posTag[i2];
            strArr[4] = i2 + 1 < tokenSequence.posTag.length ? tokenSequence.posTag[i2 + 1] : "<null>";
            strArr[5] = str2;
            sparseVector.plusOne(concat(strArr));
            String[] strArr2 = new String[5];
            strArr2[0] = "T2N1";
            strArr2[1] = i > 0 ? tokenSequence.posTag2[i - 1] : "<null>";
            strArr2[2] = tokenSequence.posTag2[i];
            strArr2[3] = tokenSequence.posTag2[i2];
            strArr2[4] = String.valueOf(i2 + 1 < tokenSequence.posTag2.length ? tokenSequence.posTag2[i2 + 1] : "<null>") + str2;
            sparseVector.plusOne(concat(strArr2));
            String[] strArr3 = new String[6];
            strArr3[0] = "TN2";
            strArr3[1] = tokenSequence.posTag[i];
            strArr3[2] = i + 1 < tokenSequence.posTag.length ? tokenSequence.posTag[i + 1] : "<null>";
            strArr3[3] = i2 > 0 ? tokenSequence.posTag[i2 - 1] : "<null>";
            strArr3[4] = tokenSequence.posTag[i2];
            strArr3[5] = str2;
            sparseVector.plusOne(concat(strArr3));
            String[] strArr4 = new String[6];
            strArr4[0] = "T2N2";
            strArr4[1] = tokenSequence.posTag2[i];
            strArr4[2] = i + 1 < tokenSequence.posTag2.length ? tokenSequence.posTag2[i + 1] : "<null>";
            strArr4[3] = i2 > 0 ? tokenSequence.posTag2[i2 - 1] : "<null>";
            strArr4[4] = tokenSequence.posTag2[i2];
            strArr4[5] = str2;
            sparseVector.plusOne(concat(strArr4));
            String[] strArr5 = new String[6];
            strArr5[0] = "TN3";
            strArr5[1] = tokenSequence.posTag[i];
            strArr5[2] = i + 1 < tokenSequence.posTag.length ? tokenSequence.posTag[i + 1] : "<null>";
            strArr5[3] = tokenSequence.posTag[i2];
            strArr5[4] = i2 + 1 < tokenSequence.posTag.length ? tokenSequence.posTag[i2 + 1] : "<null>";
            strArr5[5] = str2;
            sparseVector.plusOne(concat(strArr5));
            String[] strArr6 = new String[6];
            strArr6[0] = "T2N3";
            strArr6[1] = tokenSequence.posTag2[i];
            strArr6[2] = i + 1 < tokenSequence.posTag2.length ? tokenSequence.posTag2[i + 1] : "<null>";
            strArr6[3] = tokenSequence.posTag2[i2];
            strArr6[4] = i2 + 1 < tokenSequence.posTag2.length ? tokenSequence.posTag2[i2 + 1] : "<null>";
            strArr6[5] = str2;
            sparseVector.plusOne(concat(strArr6));
            String[] strArr7 = new String[6];
            strArr7[0] = "TN4";
            strArr7[1] = i > 0 ? tokenSequence.posTag[i - 1] : "<null>";
            strArr7[2] = tokenSequence.posTag[i];
            strArr7[3] = i2 > 0 ? tokenSequence.posTag[i2 - 1] : "<null>";
            strArr7[4] = tokenSequence.posTag[i2];
            strArr7[5] = str2;
            sparseVector.plusOne(concat(strArr7));
            String[] strArr8 = new String[6];
            strArr8[0] = "T2N4";
            strArr8[1] = i > 0 ? tokenSequence.posTag2[i - 1] : "<null>";
            strArr8[2] = tokenSequence.posTag2[i];
            strArr8[3] = i2 > 0 ? tokenSequence.posTag2[i2 - 1] : "<null>";
            strArr8[4] = tokenSequence.posTag2[i2];
            strArr8[5] = str2;
            sparseVector.plusOne(concat(strArr8));
        }
    }

    @Override // svm.SVMStructInstance
    public DependencyTree predict(TokenSequence tokenSequence, SVMStruct sVMStruct) {
        SparseVector sparseVector;
        boolean containsKey = this.precomputed.containsKey(tokenSequence);
        double[][] dArr = new double[tokenSequence.lemmas.length][tokenSequence.lemmas.length];
        for (int i = 0; i < tokenSequence.lemmas.length; i++) {
            for (int i2 = 0; i2 < tokenSequence.lemmas.length; i2++) {
                if (containsKey) {
                    sparseVector = this.precomputed.get(tokenSequence)[i][i2];
                } else {
                    sparseVector = new SparseVector();
                    addEdge(tokenSequence, i, i2, sparseVector);
                }
                dArr[i][i2] = sparseVector.multiply(sVMStruct.w, sVMStruct.wMap);
            }
        }
        return Edmonds.edmonds(tokenSequence, dArr);
    }

    @Override // svm.SVMStructInstance
    public DependencyTree marginRescaling(TokenSequence tokenSequence, DependencyTree dependencyTree, SVMStruct sVMStruct) {
        SparseVector sparseVector;
        boolean containsKey = this.precomputed.containsKey(tokenSequence);
        double[][] dArr = new double[tokenSequence.lemmas.length][tokenSequence.lemmas.length];
        for (int i = 0; i < tokenSequence.lemmas.length; i++) {
            for (int i2 = 0; i2 < tokenSequence.lemmas.length; i2++) {
                if (containsKey) {
                    sparseVector = this.precomputed.get(tokenSequence)[i][i2];
                } else {
                    sparseVector = new SparseVector();
                    addEdge(tokenSequence, i, i2, sparseVector);
                }
                dArr[i][i2] = sparseVector.multiply(sVMStruct.w, sVMStruct.wMap);
                if (i2 == 0 || i != dependencyTree.heads[i2 - 1]) {
                    double[] dArr2 = dArr[i];
                    int i3 = i2;
                    dArr2[i3] = dArr2[i3] + (1.0d / dependencyTree.heads.length);
                }
            }
        }
        return Edmonds.edmonds(tokenSequence, dArr);
    }

    @Override // svm.SVMStructInstance
    public DependencyTree slackRescaling(TokenSequence tokenSequence, DependencyTree dependencyTree, SVMStruct sVMStruct) {
        throw new RuntimeException("not yet implemented");
    }

    @Override // svm.SVMStructInstance
    public double loss(DependencyTree dependencyTree, DependencyTree dependencyTree2) {
        if (!$assertionsDisabled && dependencyTree2 == null) {
            throw new AssertionError();
        }
        if (dependencyTree.heads.length != dependencyTree2.heads.length) {
            System.out.println("dafuq");
            return 1.0d;
        }
        int i = 0;
        for (int i2 = 0; i2 < dependencyTree.heads.length; i2++) {
            if (dependencyTree.heads[i2] != dependencyTree2.heads[i2]) {
                i++;
            }
        }
        return (1.0d * i) / dependencyTree.heads.length;
    }

    @Override // svm.SVMStructInstance
    public Set<String> features(TokenSequence tokenSequence, DependencyTree dependencyTree) {
        SparseVector sparseVector = new SparseVector();
        for (int i = 0; i < dependencyTree.heads.length; i++) {
            addEdge(tokenSequence, dependencyTree.heads[i], i + 1, sparseVector);
        }
        return sparseVector.keySet();
    }
}
