package edu.berkeley.nlp.crf;

import edu.berkeley.nlp.classify.Encoding;
import edu.berkeley.nlp.classify.FeatureExtractor;
import edu.berkeley.nlp.classify.IndexLinearizer;
import edu.berkeley.nlp.util.ArrayUtil;
import edu.berkeley.nlp.util.Counter;
import java.io.Serializable;

/* loaded from: input_file:edu/berkeley/nlp/crf/ScoreCalculator.class */
public class ScoreCalculator<V, E, F, L> implements Serializable {
    private static final long serialVersionUID = 6864706229279071608L;
    private final Encoding<F, L> encoding;
    private final FeatureExtractor<V, F> vertexExtractor;
    private final FeatureExtractor<E, F> edgeExtractor;
    private final IndexLinearizer il;

    public ScoreCalculator(Encoding<F, L> encoding, FeatureExtractor<V, F> featureExtractor, FeatureExtractor<E, F> featureExtractor2) {
        this.encoding = encoding;
        this.vertexExtractor = featureExtractor;
        this.edgeExtractor = featureExtractor2;
        this.il = new IndexLinearizer(encoding.getNumFeatures(), encoding.getNumLabels());
    }

    public double[][] getScoreMatrix(InstanceSequence<V, E, L> instanceSequence, int i, double[] dArr) {
        double[][] linearScoreMatrix = getLinearScoreMatrix(instanceSequence, i, dArr);
        for (int i2 = 0; i2 < linearScoreMatrix.length; i2++) {
            linearScoreMatrix[i2] = ArrayUtil.exp(linearScoreMatrix[i2]);
        }
        return linearScoreMatrix;
    }

    public double[] getVertexScores(InstanceSequence<V, E, L> instanceSequence, int i, double[] dArr) {
        return ArrayUtil.exp(getLinearVertexScores(instanceSequence, i, dArr));
    }

    public double[][] getLinearScoreMatrix(InstanceSequence<V, E, L> instanceSequence, int i, double[] dArr) {
        int numLabels = this.encoding.getNumLabels();
        double[][] dArr2 = new double[numLabels][numLabels];
        Counter<F> extractFeatures = this.vertexExtractor.extractFeatures(instanceSequence.getVertexInstance(i));
        for (int i2 = 0; i2 < numLabels; i2++) {
            double dotProduct = dotProduct(extractFeatures, i2, dArr);
            for (int i3 = 0; i3 < numLabels; i3++) {
                dArr2[i3][i2] = dotProduct + dotProduct(this.edgeExtractor.extractFeatures(instanceSequence.getEdgeInstance(i, this.encoding.getLabel(i3))), i2, dArr);
            }
        }
        return dArr2;
    }

    public double[] getLinearVertexScores(InstanceSequence<V, E, L> instanceSequence, int i, double[] dArr) {
        int numLabels = this.encoding.getNumLabels();
        double[] dArr2 = new double[numLabels];
        Counter<F> extractFeatures = this.vertexExtractor.extractFeatures(instanceSequence.getVertexInstance(i));
        for (int i2 = 0; i2 < numLabels; i2++) {
            dArr2[i2] = dotProduct(extractFeatures, i2, dArr);
        }
        return dArr2;
    }

    private double dotProduct(Counter<F> counter, int i, double[] dArr) {
        double d = 0.0d;
        for (F f : counter.keySet()) {
            if (this.encoding.hasFeature(f)) {
                d += counter.getCount(f) * dArr[this.il.getLinearIndex(this.encoding.getFeatureIndex(f), i)];
            }
        }
        return d;
    }
}
