package com.clearnlp.classification.model;

import com.carrotsearch.hppc.IntArrayList;
import com.carrotsearch.hppc.IntIntOpenHashMap;
import com.carrotsearch.hppc.cursors.ObjectCursor;
import com.clearnlp.classification.instance.IntInstance;
import com.clearnlp.classification.instance.StringInstance;
import com.clearnlp.classification.prediction.IntPrediction;
import com.clearnlp.classification.prediction.StringPrediction;
import com.clearnlp.classification.train.InstanceCollector;
import com.clearnlp.classification.vector.SparseFeatureVector;
import com.clearnlp.classification.vector.StringFeatureVector;
import com.clearnlp.collection.list.FloatArrayList;
import com.clearnlp.collection.map.ObjectIntHashMap;
import com.clearnlp.util.UTArray;
import com.clearnlp.util.UTCollection;
import com.clearnlp.util.pair.ObjectIntPair;
import com.clearnlp.util.pair.Pair;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import org.apache.log4j.Logger;
import org.apache.lucene.util.packed.PackedInts;

/* loaded from: input_file:com/clearnlp/classification/model/StringModelAD.class */
public class StringModelAD implements Serializable {
    private static final long serialVersionUID = -8388835844936751367L;
    protected ObjectIntHashMap<String> m_labels;
    protected List<String> a_labels;
    protected int n_labels;
    protected Map<String, ObjectIntHashMap<String>> m_features;
    protected int n_features;
    protected FloatArrayList f_weights;
    protected InstanceCollector i_collector = new InstanceCollector();
    protected List<IntInstance> l_instances;
    protected IntArrayList l_indices;
    protected Random r_shuffle;

    public StringModelAD() {
        init();
    }

    public void init() {
        this.m_labels = new ObjectIntHashMap<>();
        this.a_labels = Lists.newArrayList();
        this.n_labels = 0;
        this.m_features = Maps.newHashMap();
        this.n_features = 1;
        this.f_weights = new FloatArrayList();
    }

    public void trimFeatures(Logger logger, float f) {
        FloatArrayList floatArrayList = new FloatArrayList(this.f_weights.size());
        IntIntOpenHashMap intIntOpenHashMap = new IntIntOpenHashMap();
        int i = 1;
        logger.info("Trimming: ");
        for (int i2 = 0; i2 < this.n_labels; i2++) {
            floatArrayList.add(this.f_weights.get(i2));
        }
        for (int i3 = 1; i3 < this.n_features; i3++) {
            boolean z = true;
            int i4 = 0;
            while (true) {
                if (i4 >= this.n_labels) {
                    break;
                }
                if (Math.abs(this.f_weights.get((i3 * this.n_labels) + i4)) > f) {
                    z = false;
                    break;
                }
                i4++;
            }
            if (!z) {
                int i5 = i;
                i++;
                intIntOpenHashMap.put(i3, i5);
                for (int i6 = 0; i6 < this.n_labels; i6++) {
                    floatArrayList.add(this.f_weights.get((i3 * this.n_labels) + i6));
                }
            }
        }
        logger.info(String.format("%d -> %d\n", Integer.valueOf(this.n_features), Integer.valueOf(i)));
        floatArrayList.trimToSize();
        Iterator it = Lists.newArrayList(this.m_features.keySet()).iterator();
        while (it.hasNext()) {
            String str = (String) it.next();
            ObjectIntHashMap<String> objectIntHashMap = this.m_features.get(str);
            for (ObjectIntPair<String> objectIntPair : objectIntHashMap.toList()) {
                int i7 = intIntOpenHashMap.get(objectIntPair.i);
                String str2 = (String) objectIntPair.o;
                if (i7 > 0) {
                    objectIntHashMap.put(str2, i7);
                } else {
                    objectIntHashMap.remove(str2);
                }
            }
            if (objectIntHashMap.isEmpty()) {
                this.m_features.remove(str);
            }
        }
        this.f_weights = floatArrayList;
        this.n_features = i;
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        this.m_labels = (ObjectIntHashMap) objectInputStream.readObject();
        this.a_labels = (List) objectInputStream.readObject();
        this.n_labels = ((Integer) objectInputStream.readObject()).intValue();
        this.m_features = (Map) objectInputStream.readObject();
        this.n_features = ((Integer) objectInputStream.readObject()).intValue();
        this.f_weights = (FloatArrayList) objectInputStream.readObject();
        this.i_collector = new InstanceCollector();
    }

    private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
        objectOutputStream.writeObject(this.m_labels);
        objectOutputStream.writeObject(this.a_labels);
        objectOutputStream.writeObject(Integer.valueOf(this.n_labels));
        objectOutputStream.writeObject(this.m_features);
        objectOutputStream.writeObject(Integer.valueOf(this.n_features));
        objectOutputStream.writeObject(this.f_weights);
    }

    public void addLabel(String str) {
        if (this.m_labels.containsKey(str)) {
            return;
        }
        ObjectIntHashMap<String> objectIntHashMap = this.m_labels;
        int i = this.n_labels + 1;
        this.n_labels = i;
        objectIntHashMap.put(str, i);
        this.a_labels.add(str);
        for (int i2 = 0; i2 < this.n_features; i2++) {
            this.f_weights.insert(((i2 + 1) * this.n_labels) - 1, PackedInts.COMPACT);
        }
    }

    public String getLabel(int i) {
        return this.a_labels.get(i);
    }

    public List<String> getLabels() {
        return this.a_labels;
    }

    public int getLabelIndex(String str) {
        return this.m_labels.get(str) - 1;
    }

    public int getLabelSize() {
        return this.n_labels;
    }

    public boolean isBinaryLabel() {
        return this.n_labels == 2;
    }

    public Map<String, ObjectIntHashMap<String>> getFeatureMap() {
        return this.m_features;
    }

    public int getFeatureSize() {
        return this.n_features;
    }

    public void addFeature(String str, String str2) {
        ObjectIntHashMap<String> objectIntHashMap;
        if (this.m_features.containsKey(str)) {
            objectIntHashMap = this.m_features.get(str);
        } else {
            objectIntHashMap = new ObjectIntHashMap<>();
            this.m_features.put(str, objectIntHashMap);
        }
        if (objectIntHashMap.containsKey(str2)) {
            return;
        }
        int i = this.n_features;
        this.n_features = i + 1;
        objectIntHashMap.put(str2, i);
        for (int i2 = 0; i2 < this.n_labels; i2++) {
            this.f_weights.add(PackedInts.COMPACT);
        }
    }

    public SparseFeatureVector toSparseFeatureVector(StringFeatureVector stringFeatureVector) {
        int i;
        SparseFeatureVector sparseFeatureVector = new SparseFeatureVector(stringFeatureVector.hasWeight());
        int size = stringFeatureVector.size();
        for (int i2 = 0; i2 < size; i2++) {
            String type = stringFeatureVector.getType(i2);
            String value = stringFeatureVector.getValue(i2);
            ObjectIntHashMap<String> objectIntHashMap = this.m_features.get(type);
            if (objectIntHashMap != null && (i = objectIntHashMap.get(value)) > 0) {
                if (sparseFeatureVector.hasWeight()) {
                    sparseFeatureVector.addFeature(i, stringFeatureVector.getWeight(i2));
                } else {
                    sparseFeatureVector.addFeature(i);
                }
            }
        }
        sparseFeatureVector.trimToSize();
        return sparseFeatureVector;
    }

    public boolean isValidFeature(int i) {
        return 0 <= i && i < this.n_features;
    }

    public FloatArrayList cloneWeights() {
        return this.f_weights.mo75clone();
    }

    public FloatArrayList getWeights() {
        return this.f_weights;
    }

    public int getWeightIndex(int i, int i2) {
        return (i2 * this.n_labels) + i;
    }

    public void setWeights(FloatArrayList floatArrayList) {
        this.f_weights = floatArrayList;
    }

    public void setWeights(double[] dArr) {
        int size = this.f_weights.size();
        for (int i = 0; i < size; i++) {
            this.f_weights.set(i, (float) dArr[i]);
        }
    }

    public void setAverageWeights(double[] dArr, int i) {
        int length = dArr.length;
        double d = 1.0d / i;
        for (int i2 = 0; i2 < length; i2++) {
            this.f_weights.set(i2, (float) (this.f_weights.get(i2) - (dArr[i2] * d)));
        }
    }

    public void updateWeight(int i, int i2, float f) {
        int weightIndex = getWeightIndex(i, i2);
        this.f_weights.set(weightIndex, this.f_weights.get(weightIndex) + f);
    }

    public void addInstances(Collection<StringInstance> collection) {
        Iterator<StringInstance> it = collection.iterator();
        while (it.hasNext()) {
            addInstance(it.next());
        }
    }

    public void addInstance(StringInstance stringInstance) {
        this.i_collector.addInstance(stringInstance);
    }

    public IntInstance getInstance(int i) {
        return this.l_instances.get(i);
    }

    public int getInstanceSize() {
        return this.l_instances.size();
    }

    public void shuffleIndices() {
        UTArray.shuffle(this.r_shuffle, this.l_indices);
    }

    public int getShuffledIndex(int i) {
        return this.l_indices.get(i);
    }

    public void build(int i, int i2, int i3, boolean z) {
        if (z) {
            init();
        }
        buildLabels(i);
        buildFeatures(i2);
        this.l_instances = Lists.newArrayList();
        this.r_shuffle = new Random(i3);
        this.l_indices = new IntArrayList();
        while (true) {
            StringInstance pollInstance = this.i_collector.pollInstance();
            if (pollInstance == null) {
                return;
            }
            int labelIndex = getLabelIndex(pollInstance.getLabel());
            if (labelIndex >= 0) {
                SparseFeatureVector sparseFeatureVector = toSparseFeatureVector(pollInstance.getFeatureVector());
                if (!sparseFeatureVector.isEmpty()) {
                    this.l_instances.add(new IntInstance(labelIndex, sparseFeatureVector));
                    this.l_indices.add(this.l_indices.size());
                }
            }
        }
    }

    private void buildLabels(int i) {
        for (String str : this.i_collector.getLabels()) {
            if (this.i_collector.getLabelCount(str) > i) {
                addLabel(str);
            }
        }
        this.i_collector.clearLabels();
    }

    private void buildFeatures(int i) {
        for (String str : this.i_collector.getFeatureTypes()) {
            ObjectIntHashMap<String> featureMap = this.i_collector.getFeatureMap(str);
            Iterator<ObjectCursor<String>> it = featureMap.keys().iterator();
            while (it.hasNext()) {
                String str2 = it.next().value;
                if (featureMap.get(str2) > i) {
                    addFeature(str, str2);
                }
            }
        }
        this.i_collector.clearFeatures();
    }

    public StringPrediction predictBest(SparseFeatureVector sparseFeatureVector) {
        return (StringPrediction) Collections.max(getStringPredictions(sparseFeatureVector));
    }

    public StringPrediction predictBest(StringFeatureVector stringFeatureVector) {
        return predictBest(toSparseFeatureVector(stringFeatureVector));
    }

    public Pair<StringPrediction, StringPrediction> predictTop2(SparseFeatureVector sparseFeatureVector) {
        return predictTop2(getStringPredictions(sparseFeatureVector));
    }

    public Pair<StringPrediction, StringPrediction> predictTop2(StringFeatureVector stringFeatureVector) {
        return predictTop2(toSparseFeatureVector(stringFeatureVector));
    }

    public Pair<StringPrediction, StringPrediction> predictTop2(List<StringPrediction> list) {
        StringPrediction stringPrediction = list.get(0);
        StringPrediction stringPrediction2 = list.get(1);
        int size = list.size();
        if (stringPrediction.score < stringPrediction2.score) {
            stringPrediction = stringPrediction2;
            stringPrediction2 = list.get(0);
        }
        for (int i = 2; i < size; i++) {
            StringPrediction stringPrediction3 = list.get(i);
            if (stringPrediction.score < stringPrediction3.score) {
                stringPrediction2 = stringPrediction;
                stringPrediction = stringPrediction3;
            } else if (stringPrediction2.score < stringPrediction3.score) {
                stringPrediction2 = stringPrediction3;
            }
        }
        return new Pair<>(stringPrediction, stringPrediction2);
    }

    public List<StringPrediction> predictAll(SparseFeatureVector sparseFeatureVector) {
        List<StringPrediction> stringPredictions = getStringPredictions(sparseFeatureVector);
        UTCollection.sortReverseOrder(stringPredictions);
        return stringPredictions;
    }

    public List<StringPrediction> predictAll(StringFeatureVector stringFeatureVector) {
        return predictAll(toSparseFeatureVector(stringFeatureVector));
    }

    public List<StringPrediction> getStringPredictions(SparseFeatureVector sparseFeatureVector) {
        ArrayList newArrayList = Lists.newArrayList();
        double[] scores = getScores(sparseFeatureVector);
        for (int i = 0; i < this.n_labels; i++) {
            newArrayList.add(new StringPrediction(this.a_labels.get(i), scores[i]));
        }
        return newArrayList;
    }

    public List<StringPrediction> getStringPredictions(StringFeatureVector stringFeatureVector) {
        return getStringPredictions(toSparseFeatureVector(stringFeatureVector));
    }

    public List<IntPrediction> getIntPredictions(SparseFeatureVector sparseFeatureVector) {
        ArrayList newArrayList = Lists.newArrayList();
        double[] scores = getScores(sparseFeatureVector);
        for (int i = 0; i < this.n_labels; i++) {
            newArrayList.add(new IntPrediction(i, scores[i]));
        }
        return newArrayList;
    }

    public List<IntPrediction> getIntPredictions(StringFeatureVector stringFeatureVector) {
        return getIntPredictions(toSparseFeatureVector(stringFeatureVector));
    }

    public double[] getScores(SparseFeatureVector sparseFeatureVector) {
        return isBinaryLabel() ? getScoresBinary(sparseFeatureVector) : getScoresMulti(sparseFeatureVector);
    }

    public double[] getScores(SparseFeatureVector sparseFeatureVector, boolean z) {
        double[] scores = getScores(sparseFeatureVector);
        if (z) {
            normalize(scores);
        }
        return scores;
    }

    private double[] getScoresBinary(SparseFeatureVector sparseFeatureVector) {
        int size = sparseFeatureVector.size();
        double d = this.f_weights.get(0);
        for (int i = 0; i < size; i++) {
            if (isValidFeature(sparseFeatureVector.getIndex(i))) {
                d += this.f_weights.get(getWeightIndex(0, r0)) * sparseFeatureVector.getWeight(i);
            }
        }
        return new double[]{d, -d};
    }

    private double[] getScoresMulti(SparseFeatureVector sparseFeatureVector) {
        int size = sparseFeatureVector.size();
        double[] doubleArray = this.f_weights.toDoubleArray(0, this.n_labels);
        for (int i = 0; i < size; i++) {
            int index = sparseFeatureVector.getIndex(i);
            double weight = sparseFeatureVector.getWeight(i);
            if (isValidFeature(index)) {
                for (int i2 = 0; i2 < this.n_labels; i2++) {
                    int i3 = i2;
                    doubleArray[i3] = doubleArray[i3] + (this.f_weights.get(getWeightIndex(i2, index)) * weight);
                }
            }
        }
        return doubleArray;
    }

    private void normalize(double[] dArr) {
        int length = dArr.length;
        double d = 0.0d;
        for (int i = 0; i < length; i++) {
            double exp = Math.exp(dArr[i]);
            dArr[i] = exp;
            d += exp;
        }
        for (int i2 = 0; i2 < length; i2++) {
            int i3 = i2;
            dArr[i3] = dArr[i3] / d;
        }
    }

    public void printInfo(Logger logger) {
        logger.info("- # of labels   : " + getLabelSize() + "\n");
        logger.info("- # of features : " + getFeatureSize() + "\n");
        logger.info("- # of instances: " + getInstanceSize() + "\n");
    }
}
