package edu.berkeley.nlp.crf;

import edu.berkeley.nlp.classify.Encoding;
import edu.berkeley.nlp.classify.FeatureExtractor;
import edu.berkeley.nlp.classify.IndexLinearizer;
import edu.berkeley.nlp.math.DifferentiableFunction;
import edu.berkeley.nlp.util.Counter;
import edu.berkeley.nlp.util.Pair;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:edu/berkeley/nlp/crf/CRFObjectiveFunction.class */
public class CRFObjectiveFunction<V, E, F, L> implements DifferentiableFunction {
    private final List<? extends LabeledInstanceSequence<V, E, L>> trainingData;
    private final Encoding<F, L> encoding;
    private final Counts<V, E, F, L> counts;
    private final IndexLinearizer il;
    private final double sigma;
    double lastValue;
    double[] lastDerivative;
    double[] lastX;

    public CRFObjectiveFunction(List<? extends LabeledInstanceSequence<V, E, L>> list, Encoding<F, L> encoding, FeatureExtractor<V, F> featureExtractor, FeatureExtractor<E, F> featureExtractor2, double d) {
        this.trainingData = list;
        this.encoding = encoding;
        this.counts = new Counts<>(encoding, featureExtractor, featureExtractor2);
        this.il = new IndexLinearizer(encoding.getNumFeatures(), encoding.getNumLabels());
        this.sigma = d;
    }

    @Override // edu.berkeley.nlp.math.Function
    public int dimension() {
        return this.il.getNumLinearIndexes();
    }

    @Override // edu.berkeley.nlp.math.Function
    public double valueAt(double[] dArr) {
        ensureCache(dArr);
        return this.lastValue;
    }

    @Override // edu.berkeley.nlp.math.DifferentiableFunction
    public double[] derivativeAt(double[] dArr) {
        ensureCache(dArr);
        return this.lastDerivative;
    }

    private void ensureCache(double[] dArr) {
        if (requiresUpdate(this.lastX, dArr)) {
            Pair<Double, double[]> calculate = calculate(dArr);
            this.lastValue = calculate.getFirst().doubleValue();
            this.lastDerivative = calculate.getSecond();
            this.lastX = dArr;
        }
    }

    private boolean requiresUpdate(double[] dArr, double[] dArr2) {
        if (dArr == null) {
            return true;
        }
        for (int i = 0; i < dArr2.length; i++) {
            if (dArr[i] != dArr2[i]) {
                return true;
            }
        }
        return false;
    }

    private Pair<Double, double[]> calculate(double[] dArr) {
        double d = 0.0d;
        double[] dArr2 = new double[dimension()];
        List<Counter<F>> empiricalCounts = this.counts.getEmpiricalCounts(this.trainingData);
        for (int i = 0; i < empiricalCounts.size(); i++) {
            for (Map.Entry<F, Double> entry : empiricalCounts.get(i).entrySet()) {
                int linearIndex = this.il.getLinearIndex(this.encoding.getFeatureIndex(entry.getKey()), i);
                d -= entry.getValue().doubleValue() * dArr[linearIndex];
                dArr2[linearIndex] = dArr2[linearIndex] - entry.getValue().doubleValue();
            }
        }
        Pair<Double, List<Counter<F>>> logNormalizationAndExpectedCounts = this.counts.getLogNormalizationAndExpectedCounts(this.trainingData, dArr);
        double doubleValue = d + logNormalizationAndExpectedCounts.getFirst().doubleValue();
        List<Counter<F>> second = logNormalizationAndExpectedCounts.getSecond();
        for (int i2 = 0; i2 < second.size(); i2++) {
            for (Map.Entry<F, Double> entry2 : second.get(i2).entrySet()) {
                int linearIndex2 = this.il.getLinearIndex(this.encoding.getFeatureIndex(entry2.getKey()), i2);
                dArr2[linearIndex2] = dArr2[linearIndex2] + entry2.getValue().doubleValue();
            }
        }
        for (int i3 = 0; i3 < dArr.length; i3++) {
            double d2 = dArr[i3];
            doubleValue += (d2 * d2) / ((2.0d * this.sigma) * this.sigma);
            int i4 = i3;
            dArr2[i4] = dArr2[i4] + (d2 / (this.sigma * this.sigma));
        }
        return Pair.makePair(Double.valueOf(doubleValue), dArr2);
    }
}
