package svm;

import base.DenseVector;
import base.Example;
import base.Structure;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;

/* loaded from: input_file:svm/SVMStruct.class */
public class SVMStruct<X extends Structure, Y extends Structure> {
    public DenseVector w;
    public double slack;
    public double loss;
    public HashMap<String, Integer> wMap;
    public Mode m;
    public double C;
    public double eps;
    public SVMStructInstance<X, Y> instance;

    public SVMStruct(SVMStructInstance<X, Y> sVMStructInstance, double d, double d2, Mode mode) {
        this.C = d;
        this.eps = d2;
        this.m = mode;
        this.instance = sVMStructInstance;
    }

    private Y mostViolatedConstraint(X x, Y y) {
        if (this.m == Mode.MARGIN_RESCALING) {
            return this.instance.marginRescaling(x, y, this);
        }
        if (this.m == Mode.SLACK_RESCALING) {
            return this.instance.slackRescaling(x, y, this);
        }
        return null;
    }

    public void train(List<Example<X, Y>> list) {
        double d;
        System.out.println("C: " + this.C + "\tEpsilon: " + this.eps);
        long currentTimeMillis = System.currentTimeMillis();
        System.out.println("Number of Training Examples: " + list.size());
        QPSolver qPSolver = new QPSolver();
        System.out.println("Initializing Feature Space (May Involve Precomputing Feature Maps)");
        this.w = new DenseVector(initializeFeatureSpace(list));
        ArrayList<Constraint<X, Y>> arrayList = new ArrayList<>();
        int i = 0;
        long j = 0;
        long j2 = 0;
        do {
            i++;
            System.out.print(String.valueOf(i) + "\n");
            System.out.println("Number of Constraints: " + arrayList.size());
            long currentTimeMillis2 = System.currentTimeMillis();
            if (arrayList.size() > 0) {
                this.slack = qPSolver.solve(arrayList, this.w, this.C, i);
            } else {
                this.slack = 0.0d;
            }
            j += System.currentTimeMillis() - currentTimeMillis2;
            System.out.println("Computing new Cuttingplane");
            Constraint<X, Y> constraint = new Constraint<>(this.w.length());
            long currentTimeMillis3 = System.currentTimeMillis();
            AtomicInteger atomicInteger = new AtomicInteger();
            list.parallelStream().forEach(example -> {
                Y mostViolatedConstraint = mostViolatedConstraint(example.x, example.y);
                int incrementAndGet = atomicInteger.incrementAndGet();
                if (incrementAndGet % 1000 == 0) {
                    System.out.print(String.valueOf(incrementAndGet) + "\t");
                }
                double loss = this.instance.loss(example.y, mostViolatedConstraint);
                constraint.addConstraint(example.x, example.y, mostViolatedConstraint, this.m == Mode.MARGIN_RESCALING ? 1.0d : loss, loss, this, this.instance);
            });
            j2 += System.currentTimeMillis() - currentTimeMillis3;
            constraint.finalizeConstraint(list.size());
            constraint.lastUsed = i;
            arrayList.add(constraint);
            for (int size = arrayList.size() - 2; size >= 0; size--) {
                if (i - arrayList.get(size).lastUsed > 50) {
                    arrayList.remove(size);
                    qPSolver.pruneConstraint(size);
                }
            }
            double multiply = constraint.phi.multiply(this.w);
            d = constraint.loss - multiply;
            System.out.println("Loss: " + constraint.loss);
            System.out.println("Margin: " + multiply);
            System.out.println("L-M: " + d);
            System.out.println("Slack: " + this.slack);
            System.out.println("Solver: " + j + "\tOracle: " + j2);
            if (i % 20 == 0) {
                System.out.println("Training Error: " + empiricalError(list));
            }
        } while (d > this.slack + this.eps);
        System.out.println();
        System.out.println("Error: " + empiricalError(list));
        System.out.println("L2: " + this.w.multiply(this.w));
        System.out.println("Number of selected features: " + this.w.countNonZero());
        System.out.println("");
        System.out.println("Time:" + (System.currentTimeMillis() - currentTimeMillis));
    }

    private int initializeFeatureSpace(List<Example<X, Y>> list) {
        this.wMap = new HashMap<>();
        HashMap hashMap = new HashMap();
        for (Example<X, Y> example : list) {
            for (String str : this.instance.features(example.x, example.y)) {
                if (this.wMap.containsKey(str)) {
                    hashMap.put(str, Integer.valueOf(1 + ((Integer) hashMap.get(str)).intValue()));
                } else {
                    this.wMap.put(str, Integer.valueOf(this.wMap.size()));
                    hashMap.put(str, 1);
                }
            }
        }
        System.out.println("Number of Features before Pruning: " + hashMap.size());
        int i = 0;
        for (Map.Entry entry : hashMap.entrySet()) {
            if (((Integer) entry.getValue()).intValue() < 5) {
                this.wMap.remove(entry.getKey());
            } else {
                this.wMap.put((String) entry.getKey(), Integer.valueOf(i));
                i++;
            }
        }
        System.out.println("Number of Features after Pruning: " + this.wMap.keySet().size());
        Iterator<String> it = this.wMap.keySet().iterator();
        for (int i2 = 0; i2 < Math.min(200, this.wMap.size()); i2++) {
            System.out.print(String.valueOf(it.next()) + ", ");
        }
        System.out.println("...");
        System.out.println("Example Feature Vector:");
        System.out.println(this.instance.phi(list.get(0).x, list.get(0).y));
        return this.wMap.size();
    }

    public double empiricalError(List<Example<X, Y>> list) {
        return list.parallelStream().mapToDouble(example -> {
            return this.instance.loss(example.y, predict(example.x));
        }).average().getAsDouble();
    }

    public Y predict(X x) {
        return this.instance.predict(x, this);
    }

    public double testError(List<Example<X, Y>> list) {
        double d = 0.0d;
        for (Example<X, Y> example : list) {
            d += this.instance.loss(example.y, predict(example.x));
        }
        return d / list.size();
    }
}
