package edu.berkeley.nlp.classify;

import edu.berkeley.nlp.math.DifferentiableFunction;
import edu.berkeley.nlp.math.DoubleArrays;
import edu.berkeley.nlp.math.LBFGSMinimizer;
import edu.berkeley.nlp.math.SloppyMath;
import edu.berkeley.nlp.util.Counter;
import edu.berkeley.nlp.util.Indexer;
import edu.berkeley.nlp.util.Logger;
import edu.berkeley.nlp.util.Pair;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:edu/berkeley/nlp/classify/MaximumEntropyClassifier.class */
public class MaximumEntropyClassifier<I, F, L> implements ProbabilisticClassifier<I, L>, Serializable {
    private static final long serialVersionUID = 1;
    private double[] weights;
    private Encoding<F, L> encoding;
    private IndexLinearizer indexLinearizer;
    private transient FeatureExtractor<I, F> featureExtractor;

    /* loaded from: input_file:edu/berkeley/nlp/classify/MaximumEntropyClassifier$EncodedDatum.class */
    public static class EncodedDatum {
        int labelIndex;
        int[] featureIndexes;
        double[] featureCounts;

        /* JADX WARN: Multi-variable type inference failed */
        public static <F, L> EncodedDatum encodeDatum(FeatureVector<F> featureVector, Encoding<F, L> encoding) {
            Counter<F> features = featureVector.getFeatures();
            Counter counter = new Counter();
            for (F f : features.keySet()) {
                if (encoding.getFeatureIndex(f) >= 0) {
                    counter.incrementCount(f, features.getCount(f));
                }
            }
            int[] iArr = new int[counter.keySet().size()];
            double[] dArr = new double[counter.keySet().size()];
            int i = 0;
            for (Object obj : counter.keySet()) {
                int featureIndex = encoding.getFeatureIndex(obj);
                double count = counter.getCount(obj);
                iArr[i] = featureIndex;
                dArr[i] = count;
                i++;
            }
            return new EncodedDatum(-1, iArr, dArr);
        }

        public static <F, L> EncodedDatum encodeLabeledDatum(LabeledFeatureVector<F, L> labeledFeatureVector, Encoding<F, L> encoding) {
            EncodedDatum encodeDatum = encodeDatum(labeledFeatureVector, encoding);
            encodeDatum.labelIndex = encoding.getLabelIndex(labeledFeatureVector.getLabel());
            return encodeDatum;
        }

        public int getLabelIndex() {
            return this.labelIndex;
        }

        public int getNumActiveFeatures() {
            return this.featureCounts.length;
        }

        public int getFeatureIndex(int i) {
            return this.featureIndexes[i];
        }

        public double getFeatureCount(int i) {
            return this.featureCounts[i];
        }

        public EncodedDatum(int i, int[] iArr, double[] dArr) {
            this.labelIndex = i;
            this.featureIndexes = iArr;
            this.featureCounts = dArr;
        }
    }

    /* loaded from: input_file:edu/berkeley/nlp/classify/MaximumEntropyClassifier$Factory.class */
    public static class Factory<I, F, L> implements ProbabilisticClassifierFactory<I, L> {
        double sigma;
        int iterations;
        FeatureExtractor<I, F> featureExtractor;

        @Override // edu.berkeley.nlp.classify.ProbabilisticClassifierFactory
        public ProbabilisticClassifier<I, L> trainClassifier(List<LabeledInstance<I, L>> list) {
            return trainClassifier(list, true);
        }

        public ProbabilisticClassifier<I, L> trainClassifier(List<LabeledInstance<I, L>> list, boolean z) {
            if (z) {
                Logger.i().startTrack("Building encoding");
            }
            Encoding<F, L> buildEncoding = buildEncoding(list);
            IndexLinearizer buildIndexLinearizer = buildIndexLinearizer(buildEncoding);
            double[] buildInitialWeights = buildInitialWeights(buildIndexLinearizer);
            EncodedDatum[] encodeData = encodeData(list, buildEncoding);
            if (z) {
                Logger.i().endTrack();
            }
            LBFGSMinimizer lBFGSMinimizer = new LBFGSMinimizer(this.iterations);
            ObjectiveFunction objectiveFunction = new ObjectiveFunction(buildEncoding, encodeData, buildIndexLinearizer, this.sigma);
            if (z) {
                Logger.i().startTrack("Training weights");
            }
            double[] minimize = lBFGSMinimizer.minimize(objectiveFunction, buildInitialWeights, 1.0E-4d, z);
            if (z) {
                Logger.i().endTrack();
            }
            return new MaximumEntropyClassifier(minimize, buildEncoding, buildIndexLinearizer, this.featureExtractor);
        }

        private double[] buildInitialWeights(IndexLinearizer indexLinearizer) {
            return DoubleArrays.constantArray(0.0d, indexLinearizer.getNumLinearIndexes());
        }

        private IndexLinearizer buildIndexLinearizer(Encoding<F, L> encoding) {
            return new IndexLinearizer(encoding.getNumFeatures(), encoding.getNumLabels());
        }

        private Encoding<F, L> buildEncoding(List<LabeledInstance<I, L>> list) {
            Indexer indexer = new Indexer();
            Indexer indexer2 = new Indexer();
            for (LabeledInstance<I, L> labeledInstance : list) {
                BasicLabeledFeatureVector basicLabeledFeatureVector = new BasicLabeledFeatureVector(labeledInstance.getLabel(), this.featureExtractor.extractFeatures(labeledInstance.getInput()));
                indexer2.getIndex(basicLabeledFeatureVector.getLabel());
                Iterator<F> it = basicLabeledFeatureVector.getFeatures().keySet().iterator();
                while (it.hasNext()) {
                    indexer.getIndex(it.next());
                }
            }
            return new Encoding<>(indexer, indexer2);
        }

        private EncodedDatum[] encodeData(List<LabeledInstance<I, L>> list, Encoding<F, L> encoding) {
            EncodedDatum[] encodedDatumArr = new EncodedDatum[list.size()];
            for (int i = 0; i < list.size(); i++) {
                LabeledInstance<I, L> labeledInstance = list.get(i);
                encodedDatumArr[i] = EncodedDatum.encodeLabeledDatum(new BasicLabeledFeatureVector(labeledInstance.getLabel(), this.featureExtractor.extractFeatures(labeledInstance.getInput())), encoding);
            }
            return encodedDatumArr;
        }

        public Factory(double d, int i, FeatureExtractor<I, F> featureExtractor) {
            this.sigma = d;
            this.iterations = i;
            this.featureExtractor = featureExtractor;
        }
    }

    /* loaded from: input_file:edu/berkeley/nlp/classify/MaximumEntropyClassifier$ObjectiveFunction.class */
    public static class ObjectiveFunction<F, L> implements DifferentiableFunction {
        IndexLinearizer indexLinearizer;
        Encoding<F, L> encoding;
        EncodedDatum[] data;
        double sigma;
        double lastValue;
        double[] lastDerivative;
        double[] lastX;

        @Override // edu.berkeley.nlp.math.Function
        public int dimension() {
            return this.indexLinearizer.getNumLinearIndexes();
        }

        @Override // edu.berkeley.nlp.math.Function
        public double valueAt(double[] dArr) {
            ensureCache(dArr);
            return this.lastValue;
        }

        @Override // edu.berkeley.nlp.math.DifferentiableFunction
        public double[] derivativeAt(double[] dArr) {
            ensureCache(dArr);
            return this.lastDerivative;
        }

        private void ensureCache(double[] dArr) {
            if (requiresUpdate(this.lastX, dArr)) {
                Pair<Double, double[]> calculate = calculate(dArr);
                this.lastValue = calculate.getFirst().doubleValue();
                this.lastDerivative = calculate.getSecond();
                this.lastX = dArr;
            }
        }

        private boolean requiresUpdate(double[] dArr, double[] dArr2) {
            if (dArr == null) {
                return true;
            }
            for (int i = 0; i < dArr2.length; i++) {
                if (dArr[i] != dArr2[i]) {
                    return true;
                }
            }
            return false;
        }

        private Pair<Double, double[]> calculate(double[] dArr) {
            double d = 0.0d;
            double[] constantArray = DoubleArrays.constantArray(0.0d, dimension());
            double[] dArr2 = new double[this.encoding.getNumLabels()];
            double[] dArr3 = new double[this.encoding.getNumLabels()];
            for (EncodedDatum encodedDatum : this.data) {
                int numActiveFeatures = encodedDatum.getNumActiveFeatures();
                for (int i = 0; i < this.encoding.getNumLabels(); i++) {
                    double d2 = 0.0d;
                    for (int i2 = 0; i2 < numActiveFeatures; i2++) {
                        d2 += dArr[this.indexLinearizer.getLinearIndex(encodedDatum.getFeatureIndex(i2), i)] * encodedDatum.getFeatureCount(i2);
                    }
                    dArr2[i] = d2;
                }
                double logAdd = SloppyMath.logAdd(dArr2);
                int labelIndex = encodedDatum.getLabelIndex();
                d += dArr2[labelIndex] - logAdd;
                for (int i3 = 0; i3 < this.encoding.getNumLabels(); i3++) {
                    dArr3[i3] = SloppyMath.exp(dArr2[i3] - logAdd);
                }
                for (int i4 = 0; i4 < numActiveFeatures; i4++) {
                    int featureIndex = encodedDatum.getFeatureIndex(i4);
                    int linearIndex = this.indexLinearizer.getLinearIndex(featureIndex, labelIndex);
                    double featureCount = encodedDatum.getFeatureCount(i4);
                    constantArray[linearIndex] = constantArray[linearIndex] + featureCount;
                    for (int i5 = 0; i5 < this.encoding.getNumLabels(); i5++) {
                        int linearIndex2 = this.indexLinearizer.getLinearIndex(featureIndex, i5);
                        constantArray[linearIndex2] = constantArray[linearIndex2] - (dArr3[i5] * featureCount);
                    }
                }
            }
            double d3 = d * (-1.0d);
            DoubleArrays.scale(constantArray, -1.0d);
            for (int i6 = 0; i6 < dArr.length; i6++) {
                double d4 = dArr[i6];
                d3 += (d4 * d4) / ((2.0d * this.sigma) * this.sigma);
                int i7 = i6;
                constantArray[i7] = constantArray[i7] + (d4 / (this.sigma * this.sigma));
            }
            return new Pair<>(Double.valueOf(d3), constantArray);
        }

        public ObjectiveFunction(Encoding<F, L> encoding, EncodedDatum[] encodedDatumArr, IndexLinearizer indexLinearizer, double d) {
            this.indexLinearizer = indexLinearizer;
            this.encoding = encoding;
            this.data = encodedDatumArr;
            this.sigma = d;
        }

        public double[] unregularizedDerivativeAt(double[] dArr) {
            return null;
        }
    }

    public void setFeatureExtractor(FeatureExtractor<I, F> featureExtractor) {
        this.featureExtractor = featureExtractor;
    }

    private static <F, L> double[] getLogProbabilities(EncodedDatum encodedDatum, double[] dArr, Encoding<F, L> encoding, IndexLinearizer indexLinearizer) {
        double[] dArr2 = new double[encoding.getNumLabels()];
        for (int i = 0; i < encoding.getNumLabels(); i++) {
            for (int i2 = 0; i2 < encodedDatum.getNumActiveFeatures(); i2++) {
                int i3 = i;
                dArr2[i3] = dArr2[i3] + (dArr[indexLinearizer.getLinearIndex(encodedDatum.getFeatureIndex(i2), i)] * encodedDatum.getFeatureCount(i2));
            }
        }
        double logAdd = SloppyMath.logAdd(dArr2);
        for (int i4 = 0; i4 < encoding.getNumLabels(); i4++) {
            int i5 = i4;
            dArr2[i5] = dArr2[i5] - logAdd;
        }
        return dArr2;
    }

    @Override // edu.berkeley.nlp.classify.ProbabilisticClassifier
    public Counter<L> getProbabilities(I i) {
        return getProbabilities((FeatureVector) new BasicFeatureVector(this.featureExtractor.extractFeatures(i)));
    }

    private Counter<L> getProbabilities(FeatureVector<F> featureVector) {
        return logProbabiltyArrayToProbabiltyCounter(getLogProbabilities(EncodedDatum.encodeDatum(featureVector, this.encoding), this.weights, this.encoding, this.indexLinearizer));
    }

    private Counter<L> logProbabiltyArrayToProbabiltyCounter(double[] dArr) {
        Counter<L> counter = new Counter<>();
        for (int i = 0; i < dArr.length; i++) {
            counter.setCount(this.encoding.getLabel(i), Math.exp(dArr[i]));
        }
        return counter;
    }

    @Override // edu.berkeley.nlp.classify.Classifier
    public L getLabel(I i) {
        return getProbabilities((MaximumEntropyClassifier<I, F, L>) i).argMax();
    }

    public MaximumEntropyClassifier(double[] dArr, Encoding<F, L> encoding, IndexLinearizer indexLinearizer, FeatureExtractor<I, F> featureExtractor) {
        this.weights = dArr;
        this.encoding = encoding;
        this.indexLinearizer = indexLinearizer;
        this.featureExtractor = featureExtractor;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static void main(String[] strArr) {
        LabeledInstance labeledInstance = new LabeledInstance("cat", new String[]{"fuzzy", "claws", "small"});
        LabeledInstance labeledInstance2 = new LabeledInstance("bear", new String[]{"fuzzy", "claws", "big"});
        LabeledInstance labeledInstance3 = new LabeledInstance("cat", new String[]{"claws", "medium"});
        LabeledInstance labeledInstance4 = new LabeledInstance("cat", new String[]{"claws", "small"});
        ArrayList arrayList = new ArrayList();
        arrayList.add(labeledInstance);
        arrayList.add(labeledInstance2);
        arrayList.add(labeledInstance3);
        new ArrayList().add(labeledInstance4);
        System.out.println("Probabilities on test instance: " + new Factory(1.0d, 20, new FeatureExtractor<String[], String>() { // from class: edu.berkeley.nlp.classify.MaximumEntropyClassifier.1
            private static final long serialVersionUID = 8296036312980792350L;

            @Override // edu.berkeley.nlp.classify.FeatureExtractor
            public Counter<String> extractFeatures(String[] strArr2) {
                return new Counter<>(Arrays.asList(strArr2));
            }
        }).trainClassifier(arrayList).getProbabilities(labeledInstance4.getInput()));
    }
}
