package com.clearnlp.classification.model;

import com.carrotsearch.hppc.cursors.ObjectCursor;
import com.clearnlp.classification.prediction.StringPrediction;
import com.clearnlp.classification.vector.SparseFeatureVector;
import com.clearnlp.collection.map.ObjectIntHashMap;
import com.clearnlp.util.UTArray;
import com.clearnlp.util.UTCollection;
import com.clearnlp.util.pair.Pair;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import opennlp.tools.tokenize.TokenizerME;

/* loaded from: input_file:com/clearnlp/classification/model/AbstractModel.class */
public abstract class AbstractModel implements Serializable {
    private static final long serialVersionUID = 1851285537812008020L;
    protected float[] d_weights;
    protected String[] a_labels;
    protected double[] t_weights;
    public static String LABEL_TRUE = TokenizerME.SPLIT;
    public static String LABEL_FALSE = "F";
    protected int n_labels = 0;
    protected int n_features = 1;
    protected ObjectIntHashMap<String> m_labels = new ObjectIntHashMap<>();

    public void initLabelArray() {
        this.a_labels = new String[this.n_labels];
        Iterator<ObjectCursor<String>> it = this.m_labels.keys().iterator();
        while (it.hasNext()) {
            String str = it.next().value;
            this.a_labels[getLabelIndex(str)] = str;
        }
    }

    public void initWeightVector() {
        this.d_weights = isBinaryLabel() ? new float[this.n_features] : new float[this.n_features * this.n_labels];
    }

    public int getLabelSize() {
        return this.n_labels;
    }

    public int getFeatureSize() {
        return this.n_features;
    }

    public int getLabelIndex(String str) {
        return this.m_labels.get(str) - 1;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public int getWeightIndex(int i, int i2) {
        return (i2 * this.n_labels) + i;
    }

    public String getLabel(int i) {
        return this.a_labels[i];
    }

    public String[] getLabels() {
        return this.a_labels;
    }

    public float[] getWeights() {
        return this.d_weights;
    }

    public float[] getWeights(int i) {
        float[] fArr = new float[this.n_features];
        for (int i2 = 0; i2 < this.n_features; i2++) {
            fArr[i2] = this.d_weights[getWeightIndex(i, i2)];
        }
        return fArr;
    }

    public void addLabel(String str) {
        if (this.m_labels.containsKey(str)) {
            return;
        }
        ObjectIntHashMap<String> objectIntHashMap = this.m_labels;
        int i = this.n_labels + 1;
        this.n_labels = i;
        objectIntHashMap.put(str, i);
    }

    public void setWeights(float[] fArr) {
        this.d_weights = fArr;
    }

    public void copyWeights(float[] fArr) {
        System.arraycopy(fArr, 0, this.d_weights, 0, this.n_features);
    }

    public void copyWeights(float[] fArr, int i) {
        for (int i2 = 0; i2 < this.n_features; i2++) {
            this.d_weights[getWeightIndex(i, i2)] = fArr[i2];
        }
    }

    public boolean isBinaryLabel() {
        return this.n_labels == 2;
    }

    public boolean isRange(int i) {
        return 0 < i && i < this.n_features;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void loadDefault(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        this.a_labels = (String[]) objectInputStream.readObject();
        this.m_labels = (ObjectIntHashMap) objectInputStream.readObject();
        this.d_weights = (float[]) objectInputStream.readObject();
        this.n_labels = this.a_labels.length;
        this.n_features = this.d_weights.length;
        if (isBinaryLabel()) {
            return;
        }
        this.n_features /= this.n_labels;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void saveDefault(ObjectOutputStream objectOutputStream) throws IOException {
        objectOutputStream.writeObject(this.a_labels);
        objectOutputStream.writeObject(this.m_labels);
        objectOutputStream.writeObject(this.d_weights);
    }

    public double[] getScores(SparseFeatureVector sparseFeatureVector) {
        return isBinaryLabel() ? getScoresBinary(sparseFeatureVector) : getScoresMulti(sparseFeatureVector);
    }

    private double[] getScoresBinary(SparseFeatureVector sparseFeatureVector) {
        double d = this.d_weights[0];
        int size = sparseFeatureVector.size();
        for (int i = 0; i < size; i++) {
            if (isRange(sparseFeatureVector.getIndex(i))) {
                d = sparseFeatureVector.hasWeight() ? d + (this.d_weights[r0] * sparseFeatureVector.getWeight(i)) : d + this.d_weights[r0];
            }
        }
        return new double[]{d, -d};
    }

    private double[] getScoresMulti(SparseFeatureVector sparseFeatureVector) {
        double[] copyOf = UTArray.copyOf(this.d_weights, this.n_labels);
        int size = sparseFeatureVector.size();
        double d = 1.0d;
        for (int i = 0; i < size; i++) {
            int index = sparseFeatureVector.getIndex(i);
            if (sparseFeatureVector.hasWeight()) {
                d = sparseFeatureVector.getWeight(i);
            }
            if (isRange(index)) {
                for (int i2 = 0; i2 < this.n_labels; i2++) {
                    int weightIndex = getWeightIndex(i2, index);
                    if (sparseFeatureVector.hasWeight()) {
                        int i3 = i2;
                        copyOf[i3] = copyOf[i3] + (this.d_weights[weightIndex] * d);
                    } else {
                        int i4 = i2;
                        copyOf[i4] = copyOf[i4] + this.d_weights[weightIndex];
                    }
                }
            }
        }
        return copyOf;
    }

    public StringPrediction predictBest(SparseFeatureVector sparseFeatureVector) {
        return (StringPrediction) Collections.max(getPredictions(sparseFeatureVector));
    }

    public Pair<StringPrediction, StringPrediction> predictTwo(SparseFeatureVector sparseFeatureVector) {
        return predictTwo(getPredictions(sparseFeatureVector));
    }

    public Pair<StringPrediction, StringPrediction> predictTwo(List<StringPrediction> list) {
        StringPrediction stringPrediction = list.get(0);
        StringPrediction stringPrediction2 = list.get(1);
        int size = list.size();
        if (stringPrediction.score < stringPrediction2.score) {
            stringPrediction = stringPrediction2;
            stringPrediction2 = list.get(0);
        }
        for (int i = 2; i < size; i++) {
            StringPrediction stringPrediction3 = list.get(i);
            if (stringPrediction.score < stringPrediction3.score) {
                stringPrediction2 = stringPrediction;
                stringPrediction = stringPrediction3;
            } else if (stringPrediction2.score < stringPrediction3.score) {
                stringPrediction2 = stringPrediction3;
            }
        }
        return new Pair<>(stringPrediction, stringPrediction2);
    }

    public List<StringPrediction> predictAll(SparseFeatureVector sparseFeatureVector) {
        List<StringPrediction> predictions = getPredictions(sparseFeatureVector);
        UTCollection.sortReverseOrder(predictions);
        return predictions;
    }

    public List<StringPrediction> getPredictions(SparseFeatureVector sparseFeatureVector) {
        ArrayList arrayList = new ArrayList(this.n_labels);
        double[] scores = getScores(sparseFeatureVector);
        for (int i = 0; i < this.n_labels; i++) {
            arrayList.add(new StringPrediction(this.a_labels[i], scores[i]));
        }
        return arrayList;
    }

    public static String getBooleanLabel(boolean z) {
        return z ? LABEL_TRUE : LABEL_FALSE;
    }

    public static boolean toBoolean(String str) {
        return str.equals(LABEL_TRUE);
    }
}
