package com.clearnlp.classification.algorithm;

import com.clearnlp.classification.instance.IntInstance;
import com.clearnlp.classification.model.StringModelAD;
import java.util.Arrays;

/* loaded from: input_file:com/clearnlp/classification/algorithm/AbstractAdaGrad.class */
public abstract class AbstractAdaGrad extends AbstractAlgorithm {
    protected double[] d_gradients;
    protected double[] d_average;
    protected boolean b_average;
    protected double d_alpha;
    protected double d_rho;

    protected abstract boolean update(StringModelAD stringModelAD, IntInstance intInstance, int i);

    public AbstractAdaGrad(double d, double d2, boolean z) {
        super((byte) 0);
        init(d, d2, z);
    }

    public void init(double d, double d2, boolean z) {
        this.d_alpha = d;
        this.d_rho = d2;
        this.b_average = z;
    }

    @Override // com.clearnlp.classification.algorithm.AbstractAlgorithm
    public void train(StringModelAD stringModelAD) {
        int labelSize = stringModelAD.getLabelSize() * stringModelAD.getFeatureSize();
        int instanceSize = stringModelAD.getInstanceSize();
        stringModelAD.shuffleIndices();
        if (this.d_gradients == null || this.d_gradients.length != labelSize) {
            this.d_gradients = new double[labelSize];
            if (this.b_average) {
                this.d_average = new double[labelSize];
            }
        } else {
            Arrays.fill(this.d_gradients, 0.0d);
            if (this.b_average) {
                Arrays.fill(this.d_average, 0.0d);
            }
        }
        for (int i = 0; i < instanceSize; i++) {
            update(stringModelAD, stringModelAD.getInstance(stringModelAD.getShuffledIndex(i)), i + 1);
        }
        if (this.b_average) {
            stringModelAD.setAverageWeights(this.d_average, instanceSize + 1);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void updateWeight(StringModelAD stringModelAD, int i, int i2, double d, int i3) {
        double cost = getCost(stringModelAD, i, i2) * d;
        stringModelAD.updateWeight(i, i2, (float) cost);
        if (this.b_average) {
            double[] dArr = this.d_average;
            int weightIndex = stringModelAD.getWeightIndex(i, i2);
            dArr[weightIndex] = dArr[weightIndex] + (cost * i3);
        }
    }

    protected double getCost(StringModelAD stringModelAD, int i, int i2) {
        return this.d_alpha / (this.d_rho + Math.sqrt(this.d_gradients[stringModelAD.getWeightIndex(i, i2)]));
    }
}
