package se.lth.cs.srl.ml.liblinear;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.Serializable;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import org.apache.lucene.util.packed.PackedInts;

/* loaded from: input_file:se/lth/cs/srl/ml/liblinear/WeightVector.class */
public abstract class WeightVector implements Serializable {
    private static final long serialVersionUID = 1;
    protected double bias;
    protected int features;
    protected int classes;

    /* loaded from: input_file:se/lth/cs/srl/ml/liblinear/WeightVector$BinaryLibLinearVector.class */
    public static class BinaryLibLinearVector extends BinaryVector {
        private static final long serialVersionUID = 1;
        private float[] weights;

        public BinaryLibLinearVector(BufferedReader bufferedReader, int i, double d) throws IOException {
            super(d, i, 2);
            this.weights = new float[i + 1];
            int i2 = 0;
            while (true) {
                String readLine = bufferedReader.readLine();
                if (readLine == null) {
                    return;
                }
                this.weights[i2] = Float.parseFloat(readLine);
                i2++;
            }
        }

        @Override // se.lth.cs.srl.ml.liblinear.WeightVector.BinaryVector
        protected double computeScore(Collection<Integer> collection) {
            double d = this.bias > 0.0d ? this.bias * this.weights[this.features] : 0.0d;
            Iterator<Integer> it = collection.iterator();
            while (it.hasNext()) {
                if (it.next().intValue() - 1 < this.features) {
                    d += this.weights[r0.intValue() - 1];
                }
            }
            return d;
        }
    }

    /* loaded from: input_file:se/lth/cs/srl/ml/liblinear/WeightVector$BinarySparseVector.class */
    public static class BinarySparseVector extends BinaryVector {
        private static final long serialVersionUID = 1;
        private HashMap<Integer, Float> weightMap;

        public BinarySparseVector(BinaryLibLinearVector binaryLibLinearVector) {
            super(binaryLibLinearVector.bias, binaryLibLinearVector.features, 2);
            this.weightMap = new HashMap<>();
            for (int i = 0; i < binaryLibLinearVector.features; i++) {
                if (binaryLibLinearVector.weights[i] != PackedInts.COMPACT) {
                    this.weightMap.put(Integer.valueOf(i), Float.valueOf(binaryLibLinearVector.weights[i]));
                }
            }
            if (this.bias > 0.0d) {
                this.weightMap.put(Integer.valueOf(this.features), Float.valueOf(binaryLibLinearVector.weights[this.features]));
            }
        }

        public BinarySparseVector(BufferedReader bufferedReader, int i, double d) throws IOException {
            super(d, i, 2);
            this.weightMap = new HashMap<>();
            int i2 = 0;
            while (true) {
                String readLine = bufferedReader.readLine();
                if (readLine == null) {
                    return;
                }
                Float valueOf = Float.valueOf(Float.parseFloat(readLine));
                if (valueOf.floatValue() != PackedInts.COMPACT) {
                    this.weightMap.put(Integer.valueOf(i2), valueOf);
                }
                i2++;
            }
        }

        @Override // se.lth.cs.srl.ml.liblinear.WeightVector.BinaryVector
        protected double computeScore(Collection<Integer> collection) {
            double floatValue = this.bias > 0.0d ? this.weightMap.containsKey(Integer.valueOf(this.features)) ? this.bias * this.weightMap.get(Integer.valueOf(this.features)).floatValue() : 0.0d : 0.0d;
            for (Integer num : collection) {
                if (num.intValue() - 1 < this.features && this.weightMap.containsKey(Integer.valueOf(num.intValue() - 1))) {
                    floatValue += this.weightMap.get(Integer.valueOf(num.intValue() - 1)).floatValue();
                }
            }
            return floatValue;
        }
    }

    /* loaded from: input_file:se/lth/cs/srl/ml/liblinear/WeightVector$BinaryVector.class */
    public static abstract class BinaryVector extends WeightVector {
        private static final long serialVersionUID = 1;

        public BinaryVector(double d, int i, int i2) {
            super(d, i, i2);
        }

        protected abstract double computeScore(Collection<Integer> collection);

        @Override // se.lth.cs.srl.ml.liblinear.WeightVector
        public double[] computeAllProbs(Collection<Integer> collection) {
            double exp = 1.0d / (1.0d + Math.exp(-computeScore(collection)));
            return new double[]{exp, 1.0d - exp};
        }

        @Override // se.lth.cs.srl.ml.liblinear.WeightVector
        public short computeBestClass(Collection<Integer> collection) {
            return computeScore(collection) > 0.0d ? (short) 0 : (short) 1;
        }
    }

    /* loaded from: input_file:se/lth/cs/srl/ml/liblinear/WeightVector$MultipleLibLinearVector.class */
    public static class MultipleLibLinearVector extends MultipleVector {
        private static final long serialVersionUID = 1;
        private float[][] weights;

        public MultipleLibLinearVector(BufferedReader bufferedReader, int i, int i2, double d) throws IOException {
            super(d, i, i2);
            this.weights = new float[i2][i + 1];
            int i3 = 0;
            while (true) {
                String readLine = bufferedReader.readLine();
                if (readLine == null) {
                    return;
                }
                String[] split = readLine.split(" ");
                for (int i4 = 0; i4 < i2; i4++) {
                    this.weights[i4][i3] = Float.parseFloat(split[i4]);
                }
                i3++;
            }
        }

        @Override // se.lth.cs.srl.ml.liblinear.WeightVector.MultipleVector
        protected double[] computeScores(Collection<Integer> collection) {
            double[] dArr = new double[this.classes];
            for (int i = 0; i < this.classes; i++) {
                double d = this.bias > 0.0d ? this.bias * this.weights[i][this.features] : 0.0d;
                Iterator<Integer> it = collection.iterator();
                while (it.hasNext()) {
                    if (it.next().intValue() - 1 < this.features) {
                        d += this.weights[i][r0.intValue() - 1];
                    }
                }
                dArr[i] = d;
            }
            return dArr;
        }
    }

    /* loaded from: input_file:se/lth/cs/srl/ml/liblinear/WeightVector$MultipleSparseVector.class */
    public static class MultipleSparseVector extends MultipleVector {
        private static final long serialVersionUID = 1;
        private HashMap<Integer, WeightArray> weightMap;

        /* loaded from: input_file:se/lth/cs/srl/ml/liblinear/WeightVector$MultipleSparseVector$WeightArray.class */
        private static class WeightArray implements Serializable {
            private static final long serialVersionUID = 1;
            float[] weights;

            public WeightArray(int i) {
                this.weights = new float[i];
            }
        }

        public MultipleSparseVector(MultipleLibLinearVector multipleLibLinearVector) {
            super(multipleLibLinearVector.bias, multipleLibLinearVector.features, multipleLibLinearVector.classes);
            this.weightMap = new HashMap<>();
            for (int i = 0; i < multipleLibLinearVector.features; i++) {
                WeightArray weightArray = new WeightArray(this.classes);
                boolean z = false;
                for (int i2 = 0; i2 < multipleLibLinearVector.classes; i2++) {
                    weightArray.weights[i2] = multipleLibLinearVector.weights[i2][i];
                    z = z || weightArray.weights[i2] != PackedInts.COMPACT;
                }
                if (z) {
                    this.weightMap.put(Integer.valueOf(i), weightArray);
                }
            }
            if (this.bias > 0.0d) {
                WeightArray weightArray2 = new WeightArray(this.classes);
                for (int i3 = 0; i3 < this.classes; i3++) {
                    weightArray2.weights[i3] = multipleLibLinearVector.weights[i3][this.features];
                }
                this.weightMap.put(Integer.valueOf(this.features), weightArray2);
            }
        }

        public MultipleSparseVector(BufferedReader bufferedReader, int i, int i2, double d) throws IOException {
            super(d, i, i2);
            this.weightMap = new HashMap<>();
            int i3 = 0;
            while (true) {
                String readLine = bufferedReader.readLine();
                if (readLine == null) {
                    return;
                }
                WeightArray weightArray = new WeightArray(i2);
                int i4 = 0;
                boolean z = false;
                for (String str : readLine.split(" ")) {
                    float parseFloat = Float.parseFloat(str);
                    int i5 = i4;
                    i4++;
                    weightArray.weights[i5] = parseFloat;
                    z = z || parseFloat != PackedInts.COMPACT;
                }
                if (z) {
                    this.weightMap.put(Integer.valueOf(i3), weightArray);
                }
                i3++;
            }
        }

        @Override // se.lth.cs.srl.ml.liblinear.WeightVector.MultipleVector
        protected double[] computeScores(Collection<Integer> collection) {
            double[] dArr = new double[this.classes];
            for (int i = 0; i < this.classes; i++) {
                double d = this.bias > 0.0d ? this.weightMap.containsKey(Integer.valueOf(this.features)) ? this.bias * this.weightMap.get(Integer.valueOf(this.features)).weights[i] : 0.0d : 0.0d;
                for (Integer num : collection) {
                    if (this.weightMap.containsKey(Integer.valueOf(num.intValue() - 1)) && num.intValue() - 1 < this.features) {
                        d += this.weightMap.get(Integer.valueOf(num.intValue() - 1)).weights[i];
                    }
                }
                dArr[i] = d;
            }
            return dArr;
        }
    }

    /* loaded from: input_file:se/lth/cs/srl/ml/liblinear/WeightVector$MultipleVector.class */
    public static abstract class MultipleVector extends WeightVector {
        private static final long serialVersionUID = 1;

        public MultipleVector(double d, int i, int i2) {
            super(d, i, i2);
        }

        protected abstract double[] computeScores(Collection<Integer> collection);

        @Override // se.lth.cs.srl.ml.liblinear.WeightVector
        public double[] computeAllProbs(Collection<Integer> collection) {
            double[] dArr = new double[this.classes];
            double[] computeScores = computeScores(collection);
            double d = 0.0d;
            short s = 0;
            while (true) {
                short s2 = s;
                if (s2 >= this.classes) {
                    break;
                }
                dArr[s2] = 1.0d / (1.0d + Math.exp(-computeScores[s2]));
                d += dArr[s2];
                s = (short) (s2 + 1);
            }
            short s3 = 0;
            while (true) {
                short s4 = s3;
                if (s4 >= this.classes) {
                    return dArr;
                }
                dArr[s4] = dArr[s4] / d;
                s3 = (short) (s4 + 1);
            }
        }

        @Override // se.lth.cs.srl.ml.liblinear.WeightVector
        public short computeBestClass(Collection<Integer> collection) {
            short s = 0;
            double[] computeScores = computeScores(collection);
            short s2 = 0;
            while (true) {
                short s3 = s2;
                if (s3 >= this.classes) {
                    return s;
                }
                if (computeScores[s3] > computeScores[s]) {
                    s = s3;
                }
                s2 = (short) (s3 + 1);
            }
        }
    }

    public WeightVector(double d, int i, int i2) {
        this.bias = d;
        this.classes = i2;
        this.features = i;
    }

    public static WeightVector parseWeights(BufferedReader bufferedReader, int i, int i2, double d, boolean z) throws IOException {
        return z ? i2 == 2 ? new BinarySparseVector(bufferedReader, i, d) : new MultipleSparseVector(bufferedReader, i, i2, d) : i2 == 2 ? new BinaryLibLinearVector(bufferedReader, i, d) : new MultipleLibLinearVector(bufferedReader, i, i2, d);
    }

    public abstract double[] computeAllProbs(Collection<Integer> collection);

    public abstract short computeBestClass(Collection<Integer> collection);
}
