package com.clearnlp.classification.algorithm;

import com.clearnlp.classification.instance.IntInstance;
import com.clearnlp.classification.model.StringModelAD;
import com.clearnlp.classification.vector.SparseFeatureVector;
import com.clearnlp.util.UTMath;

/* loaded from: input_file:com/clearnlp/classification/algorithm/AdaGradOnlineLogisticRegression.class */
public class AdaGradOnlineLogisticRegression extends AbstractAdaGrad {
    public AdaGradOnlineLogisticRegression(double d, double d2, boolean z) {
        super(d, d2, z);
    }

    @Override // com.clearnlp.classification.algorithm.AbstractAdaGrad
    protected boolean update(StringModelAD stringModelAD, IntInstance intInstance, int i) {
        double[] gradients = getGradients(stringModelAD, intInstance);
        if (gradients[intInstance.getLabel()] <= 0.01d) {
            return false;
        }
        updateCounts(stringModelAD, intInstance, gradients);
        updateWeights(stringModelAD, intInstance, gradients, i);
        return true;
    }

    private double[] getGradients(StringModelAD stringModelAD, IntInstance intInstance) {
        double[] scores = stringModelAD.getScores(intInstance.getFeatureVector(), true);
        int length = scores.length;
        for (int i = 0; i < length; i++) {
            int i2 = i;
            scores[i2] = scores[i2] * (-1.0d);
        }
        int label = intInstance.getLabel();
        scores[label] = scores[label] + 1.0d;
        return scores;
    }

    protected void updateCounts(StringModelAD stringModelAD, IntInstance intInstance, double[] dArr) {
        SparseFeatureVector featureVector = intInstance.getFeatureVector();
        int size = featureVector.size();
        int labelSize = stringModelAD.getLabelSize();
        double[] dArr2 = new double[labelSize];
        for (int i = 0; i < labelSize; i++) {
            dArr2[i] = dArr[i] * dArr[i];
        }
        for (int i2 = 0; i2 < size; i2++) {
            double sq = UTMath.sq(featureVector.getWeight(i2));
            for (int i3 = 0; i3 < labelSize; i3++) {
                double[] dArr3 = this.d_gradients;
                int weightIndex = stringModelAD.getWeightIndex(i3, featureVector.getIndex(i2));
                dArr3[weightIndex] = dArr3[weightIndex] + (sq * dArr2[i3]);
            }
        }
    }

    private void updateWeights(StringModelAD stringModelAD, IntInstance intInstance, double[] dArr, int i) {
        SparseFeatureVector featureVector = intInstance.getFeatureVector();
        int size = featureVector.size();
        int labelSize = stringModelAD.getLabelSize();
        for (int i2 = 0; i2 < size; i2++) {
            int index = featureVector.getIndex(i2);
            double weight = featureVector.getWeight(i2);
            for (int i3 = 0; i3 < labelSize; i3++) {
                updateWeight(stringModelAD, i3, index, dArr[i3] * weight, i);
            }
        }
    }
}
