package svm;

import base.AtomicKernel;
import base.Example;
import base.Kernel;
import base.OutputKernel;
import base.SparseVector;
import base.Structure;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import solver.OJAlgoSolver;

/* loaded from: input_file:svm/SVMStructKernel.class */
public class SVMStructKernel<X extends Structure, Y extends Structure> implements Serializable {
    private static final long serialVersionUID = -3805778215058778725L;
    public Mode m;
    public double C;
    public double eps;
    public SVMStructKernelInstance<X, Y> instance;
    public Kernel<X> inputKernel;
    public OutputKernel<X, Y> outputKernel;
    public double slack;
    public List<Example<X, Y>> data;
    public LinkedList<Integer>[] oracleCaches;
    public AtomicConstraint<X, Y> w;
    public int oracleCacheSize = 0;
    public ArrayList<DualConstraint<X, Y>> constraints = new ArrayList<>();
    public ArrayList<Double> alphas = new ArrayList<>();

    public SVMStructKernel(SVMStructKernelInstance<X, Y> sVMStructKernelInstance, double d, double d2, Mode mode) {
        this.instance = sVMStructKernelInstance;
        this.inputKernel = sVMStructKernelInstance.inputKernel();
        this.outputKernel = sVMStructKernelInstance.outputKernel();
        this.m = mode;
        this.C = d;
        this.eps = d2;
    }

    public double wTimesPhi(X x, Y y) {
        if (this.outputKernel instanceof AtomicKernel) {
            return 0.0d;
        }
        return this.data.parallelStream().mapToDouble(example -> {
            double d = 0.0d;
            double k = this.inputKernel.k(x, example.x);
            for (int i = 0; i < this.alphas.size(); i++) {
                if (Math.abs(this.alphas.get(i).doubleValue()) >= 1.0E-5d && Math.abs(this.alphas.get(i).doubleValue() * k) >= 1.0E-6d) {
                    d = (d + ((this.alphas.get(i).doubleValue() * k) * this.outputKernel.k(y, example.y))) - ((this.alphas.get(i).doubleValue() * k) * this.outputKernel.k(y, this.constraints.get(i).ybar.get(example.i)));
                }
            }
            return d;
        }).average().getAsDouble();
    }

    public double wTimesPhiMinusPhi(X x, Y y, Y y2) {
        if (!(this.outputKernel instanceof AtomicKernel)) {
            return this.data.parallelStream().mapToDouble(example -> {
                double d = 0.0d;
                double k = this.inputKernel.k(x, example.x);
                for (int i = 0; i < this.alphas.size(); i++) {
                    if (Math.abs(this.alphas.get(i).doubleValue()) >= 1.0E-5d && Math.abs(this.alphas.get(i).doubleValue() * k) >= 1.0E-6d) {
                        DualConstraint<X, Y> dualConstraint = this.constraints.get(i);
                        double doubleValue = this.alphas.get(i).doubleValue();
                        d = (((d + ((doubleValue * k) * this.outputKernel.k(y, example.y))) - ((doubleValue * k) * this.outputKernel.k(y2, example.y))) - ((doubleValue * k) * this.outputKernel.k(y, dualConstraint.ybar.get(example.i)))) + (doubleValue * k * this.outputKernel.k(y2, dualConstraint.ybar.get(example.i)));
                    }
                }
                return d;
            }).average().getAsDouble();
        }
        double d = 0.0d;
        SparseVector featureMap = ((AtomicKernel) this.outputKernel).featureMap(x, y);
        SparseVector featureMap2 = ((AtomicKernel) this.outputKernel).featureMap(x, y2);
        for (int i = 0; i < this.alphas.size(); i++) {
            AtomicConstraint atomicConstraint = (AtomicConstraint) this.constraints.get(i);
            for (Map.Entry<String, Double> entry : featureMap.entrySet()) {
                if (atomicConstraint.phi.containsKey(entry.getKey())) {
                    for (Map.Entry<X, Double> entry2 : atomicConstraint.phi.get(entry.getKey()).entrySet()) {
                        d += this.alphas.get(i).doubleValue() * entry.getValue().doubleValue() * entry2.getValue().doubleValue() * this.inputKernel.k(entry2.getKey(), x);
                    }
                }
            }
            for (Map.Entry<String, Double> entry3 : featureMap2.entrySet()) {
                if (atomicConstraint.phi.containsKey(entry3.getKey())) {
                    for (Map.Entry<X, Double> entry4 : atomicConstraint.phi.get(entry3.getKey()).entrySet()) {
                        d -= ((this.alphas.get(i).doubleValue() * entry3.getValue().doubleValue()) * entry4.getValue().doubleValue()) * this.inputKernel.k(entry4.getKey(), x);
                    }
                }
            }
        }
        return d;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public void train(List<Example<X, Y>> list) {
        double d;
        System.out.println("C: " + this.C + "\tEpsilon: " + this.eps);
        if (this.outputKernel instanceof AtomicKernel) {
            System.out.println("Using Atomic Output Kernel and Atomic Constraints");
        }
        this.data = list;
        System.out.println("Number of Training Examples: " + list.size());
        System.out.println("OJALGOSOLVER");
        OJAlgoSolver oJAlgoSolver = new OJAlgoSolver(this.m);
        this.constraints = new ArrayList<>();
        this.alphas = new ArrayList<>();
        this.oracleCaches = new LinkedList[list.size()];
        for (int i = 0; i < list.size(); i++) {
            this.oracleCaches[i] = new LinkedList<>();
        }
        int i2 = 0;
        do {
            i2++;
            System.out.print(String.valueOf(i2) + "\n");
            System.out.println("Number of Constraints: " + this.constraints.size());
            if (this.constraints.size() > 0) {
                System.out.println("SOLVING");
                this.slack = oJAlgoSolver.solve(this.constraints, this.data, this.alphas, this.C, this, i2);
                if (this.outputKernel instanceof AtomicKernel) {
                    System.out.println("COMPRESSING w");
                    int i3 = 0;
                    int i4 = 0;
                    this.w = new AtomicConstraint<>((AtomicKernel) this.outputKernel);
                    for (int i5 = 0; i5 < this.constraints.size(); i5++) {
                        double doubleValue = this.alphas.get(i5).doubleValue();
                        if (doubleValue > 1.0E-6d) {
                            for (Map.Entry<String, HashMap<X, Double>> entry : ((AtomicConstraint) this.constraints.get(i5)).phi.entrySet()) {
                                if (!this.w.phi.containsKey(entry.getKey())) {
                                    this.w.phi.put(entry.getKey(), new HashMap<>());
                                }
                                HashMap<X, Double> hashMap = this.w.phi.get(entry.getKey());
                                for (Map.Entry<X, Double> entry2 : entry.getValue().entrySet()) {
                                    i3++;
                                    i4++;
                                    Double put = hashMap.put(entry2.getKey(), Double.valueOf(entry2.getValue().doubleValue() * doubleValue));
                                    if (put != null) {
                                        i4--;
                                        hashMap.put(entry2.getKey(), Double.valueOf(put.doubleValue() + (entry2.getValue().doubleValue() * doubleValue)));
                                    }
                                }
                            }
                        }
                    }
                    System.out.println("Compression Rate: " + ((1.0d * i4) / i3) + "\t Number of Entries: " + i4 + "\t under " + this.w.phi.size() + " keys");
                }
                System.out.println("SOLVING DONE");
                int i6 = 0;
                Iterator<Double> it = this.alphas.iterator();
                while (it.hasNext()) {
                    if (Math.abs(it.next().doubleValue()) >= 1.0E-5d) {
                        i6++;
                    }
                }
                System.out.println("Number of Support Vectors: " + i6);
            } else {
                this.slack = 0.0d;
            }
            System.out.println("COMPUTING CUTTINGPLANE");
            DualConstraint<X, Y> atomicConstraint = this.outputKernel instanceof AtomicKernel ? new AtomicConstraint((AtomicKernel) this.outputKernel) : new DualConstraint<>();
            double d2 = 0.0d;
            for (int i7 = 0; i7 < this.data.size(); i7++) {
                double d3 = Double.NEGATIVE_INFINITY;
                int i8 = -1;
                for (int i9 = 0; i9 < this.oracleCaches[i7].size(); i9++) {
                    double doubleValue2 = this.constraints.get(this.oracleCaches[i7].get(i9).intValue()).losses.get(i7).doubleValue() - wTimesPhiMinusPhi(this.data.get(i7).x, this.data.get(i7).y, this.constraints.get(this.oracleCaches[i7].get(i9).intValue()).ybar.get(i7));
                    if (doubleValue2 > d3) {
                        d3 = doubleValue2;
                        i8 = i9;
                    }
                }
                if (i8 != -1) {
                    d2 += d3;
                    atomicConstraint.addConstraint(this.data.get(i7).x, this.data.get(i7).y, this.constraints.get(this.oracleCaches[i7].get(i8).intValue()).ybar.get(i7), this.m, this.constraints.get(this.oracleCaches[i7].get(i8).intValue()).losses.get(i7).doubleValue());
                    int intValue = this.oracleCaches[i7].get(i8).intValue();
                    this.oracleCaches[i7].remove(i8);
                    this.oracleCaches[i7].addFirst(Integer.valueOf(intValue));
                }
            }
            if (d2 / this.data.size() < this.slack + this.eps || i2 == 1) {
                System.out.println("Nothing in Oracle Cache");
                atomicConstraint = this.outputKernel instanceof AtomicKernel ? new AtomicConstraint((AtomicKernel) this.outputKernel) : new DualConstraint<>();
                List list2 = (List) this.data.parallelStream().map(example -> {
                    return mostViolatedConstraint(example);
                }).collect(Collectors.toList());
                System.out.println("COMPUTING CUTTINGPLANE MIDWAY");
                for (int i10 = 0; i10 < list2.size(); i10++) {
                    Structure structure = (Structure) list2.get(i10);
                    Example<X, Y> example2 = this.data.get(i10);
                    atomicConstraint.addConstraint(example2.x, example2.y, structure, this.m, this.instance.loss(example2.y, structure));
                    this.oracleCaches[i10].add(0, Integer.valueOf(this.constraints.size()));
                    if (this.oracleCaches[i10].size() > this.oracleCacheSize) {
                        this.oracleCaches[i10].removeLast();
                    }
                }
            }
            atomicConstraint.lastUsed = i2;
            atomicConstraint.finalizeConstraint(this.data.size());
            this.constraints.add(atomicConstraint);
            for (int size = this.constraints.size() - 2; size >= 0; size--) {
                if (i2 - this.constraints.get(size).lastUsed > 50) {
                    this.constraints.remove(size);
                    this.alphas.remove(size);
                    oJAlgoSolver.pruneConstraint(size);
                }
            }
            System.out.println("COMPUTING CUTTINGPLANE DONE");
            System.out.println("UPDATING H MATRIX");
            oJAlgoSolver.updateH(this.constraints, list, this);
            System.out.println("Updating H MATRIX DONE");
            double d4 = 0.0d;
            if (this.m == Mode.MARGIN_RESCALING) {
                for (int i11 = 0; i11 < this.alphas.size(); i11++) {
                    d4 += this.alphas.get(i11).doubleValue() * oJAlgoSolver.H.get(i11).get(this.alphas.size()).doubleValue();
                }
            } else {
                for (int i12 = 0; i12 < atomicConstraint.ybar.size(); i12++) {
                    d4 += atomicConstraint.losses.get(i12).doubleValue() * wTimesPhiMinusPhi(this.data.get(i12).x, this.data.get(i12).y, atomicConstraint.ybar.get(i12));
                }
                d4 /= this.data.size();
            }
            d = atomicConstraint.loss - d4;
            System.out.println("Loss: " + atomicConstraint.loss);
            System.out.println("Margin: " + d4);
            System.out.println("L-M: " + d);
            System.out.println("Slack: " + this.slack);
        } while (d > this.slack + this.eps);
        this.constraints.remove(this.constraints.size() - 1);
        System.out.println();
        int i13 = 0;
        Iterator<Double> it2 = this.alphas.iterator();
        while (it2.hasNext()) {
            if (Math.abs(it2.next().doubleValue()) >= 1.0E-5d) {
                i13++;
            }
        }
        System.out.println("Number of Support Vectors: " + i13);
        System.out.println("Train Error: " + empiricalError(list));
        double d5 = 0.0d;
        for (int i14 = 0; i14 < this.alphas.size(); i14++) {
            for (int i15 = 0; i15 < this.alphas.size(); i15++) {
                d5 += this.alphas.get(i14).doubleValue() * this.alphas.get(i15).doubleValue() * oJAlgoSolver.H.get(i14).get(i15).doubleValue();
            }
        }
        System.out.println("L2: " + d5);
        System.out.println("");
    }

    private void debugH(DualQPSolver<X, Y> dualQPSolver) {
        for (int i = 0; i < dualQPSolver.H.size() - 1; i++) {
            for (int i2 = 0; i2 < dualQPSolver.H.size() - 1; i2++) {
                double d = 0.0d;
                for (int i3 = 0; i3 < this.constraints.get(i).ybar.size(); i3++) {
                    d += this.instance.loss(this.constraints.get(i).ybar.get(i3), this.constraints.get(i2).ybar.get(i3));
                }
                System.out.println(dualQPSolver.H.get(i).get(i2) + "\t" + (d / this.constraints.get(i).ybar.size()));
            }
        }
    }

    public double empiricalError(List<Example<X, Y>> list) {
        return list.parallelStream().mapToDouble(example -> {
            return this.instance.loss(example.y, predict(example.x));
        }).average().getAsDouble();
    }

    public Y predict(X x) {
        if (this.m != Mode.MARGIN_RESCALING) {
            throw new RuntimeException("Not yet implemented");
        }
        if (this.outputKernel instanceof AtomicKernel) {
            return (Y) ((AtomicKernel) this.outputKernel).preImage(this.w, x, this.inputKernel);
        }
        HashMap hashMap = new HashMap();
        for (int i = 0; i < this.data.size(); i++) {
            double k = this.inputKernel.k(x, this.data.get(i).x) / this.data.size();
            for (int i2 = 0; i2 < this.alphas.size(); i2++) {
                if (Math.abs(this.alphas.get(i2).doubleValue()) >= 1.0E-5d && Math.abs(this.alphas.get(i2).doubleValue() * k) >= 1.0E-6d) {
                    DualConstraint<X, Y> dualConstraint = this.constraints.get(i2);
                    Double d = (Double) hashMap.put(this.data.get(i).y, Double.valueOf(this.alphas.get(i2).doubleValue() * k));
                    if (d != null) {
                        hashMap.put(this.data.get(i).y, Double.valueOf((this.alphas.get(i2).doubleValue() * k) + d.doubleValue()));
                    }
                    Double d2 = (Double) hashMap.put(dualConstraint.ybar.get(i), Double.valueOf((-this.alphas.get(i2).doubleValue()) * k));
                    if (d2 != null) {
                        hashMap.put(dualConstraint.ybar.get(i), Double.valueOf(((-this.alphas.get(i2).doubleValue()) * k) + d2.doubleValue()));
                    }
                }
            }
        }
        return this.outputKernel.preImage(hashMap, x);
    }

    private Y mostViolatedConstraint(Example<X, Y> example) {
        X x = example.x;
        Y y = example.y;
        if (this.m != Mode.MARGIN_RESCALING) {
            throw new RuntimeException("Not yet implemented");
        }
        if (this.outputKernel instanceof AtomicKernel) {
            return (Y) ((AtomicKernel) this.outputKernel).preImageL2Penelized(this.w, y, x, this.inputKernel);
        }
        HashMap hashMap = new HashMap();
        for (int i = 0; i < this.data.size(); i++) {
            double k = this.inputKernel.k(x, this.data.get(i).x) / this.data.size();
            double d = 0.0d;
            for (int i2 = 0; i2 < this.alphas.size(); i2++) {
                if (Math.abs(this.alphas.get(i2).doubleValue()) >= 1.0E-5d && Math.abs(this.alphas.get(i2).doubleValue() * k) >= 1.0E-6d) {
                    DualConstraint<X, Y> dualConstraint = this.constraints.get(i2);
                    d += this.alphas.get(i2).doubleValue() * k;
                    Double d2 = (Double) hashMap.put(dualConstraint.ybar.get(i), Double.valueOf((-this.alphas.get(i2).doubleValue()) * k));
                    if (d2 != null) {
                        hashMap.put(dualConstraint.ybar.get(i), Double.valueOf(((-this.alphas.get(i2).doubleValue()) * k) + d2.doubleValue()));
                    }
                }
            }
            Double d3 = (Double) hashMap.put(this.data.get(i).y, Double.valueOf(d));
            if (d3 != null) {
                hashMap.put(this.data.get(i).y, Double.valueOf(d + d3.doubleValue()));
            }
        }
        return this.outputKernel.preImagePenelized(hashMap, y, x);
    }
}
