package edu.berkeley.nlp.math;

import edu.berkeley.nlp.mapper.AsynchronousMapper;
import edu.berkeley.nlp.mapper.SimpleMapper;
import edu.berkeley.nlp.util.Pair;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:edu/berkeley/nlp/math/CachingObjectiveDifferentiableFunction.class */
public class CachingObjectiveDifferentiableFunction<I> extends CachingDifferentiableFunction {
    private List<? extends ObjectiveItemDifferentiableFunction<I>> itemFns;
    private Regularizer regularizer;
    private Collection<I> items;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:edu/berkeley/nlp/math/CachingObjectiveDifferentiableFunction$Mapper.class */
    public class Mapper implements SimpleMapper<I> {
        ObjectiveItemDifferentiableFunction<I> itemFn;
        double objVal = 0.0d;
        double[] localGrad;

        Mapper(ObjectiveItemDifferentiableFunction<I> objectiveItemDifferentiableFunction) {
            this.itemFn = objectiveItemDifferentiableFunction;
            this.localGrad = new double[objectiveItemDifferentiableFunction.dimension()];
        }

        @Override // edu.berkeley.nlp.mapper.SimpleMapper
        public void map(I i) {
            this.objVal += this.itemFn.update(i, this.localGrad);
        }
    }

    public CachingObjectiveDifferentiableFunction(Collection<I> collection, List<? extends ObjectiveItemDifferentiableFunction<I>> list, Regularizer regularizer) {
        this.itemFns = list;
        this.regularizer = regularizer;
        this.items = collection;
    }

    public CachingObjectiveDifferentiableFunction(Collection<I> collection, ObjectiveItemDifferentiableFunction<I> objectiveItemDifferentiableFunction, Regularizer regularizer) {
        this(collection, Collections.singletonList(objectiveItemDifferentiableFunction), regularizer);
    }

    private List<CachingObjectiveDifferentiableFunction<I>.Mapper> getMappers() {
        ArrayList arrayList = new ArrayList();
        Iterator<? extends ObjectiveItemDifferentiableFunction<I>> it = this.itemFns.iterator();
        while (it.hasNext()) {
            arrayList.add(new Mapper(it.next()));
        }
        return arrayList;
    }

    @Override // edu.berkeley.nlp.math.CachingDifferentiableFunction
    protected Pair<Double, double[]> calculate(double[] dArr) {
        Iterator<? extends ObjectiveItemDifferentiableFunction<I>> it = this.itemFns.iterator();
        while (it.hasNext()) {
            it.next().setWeights(dArr);
        }
        List<CachingObjectiveDifferentiableFunction<I>.Mapper> mappers = getMappers();
        AsynchronousMapper.doMapping(this.items, mappers);
        double d = 0.0d;
        double[] dArr2 = new double[dimension()];
        for (CachingObjectiveDifferentiableFunction<I>.Mapper mapper : mappers) {
            d += mapper.objVal;
            DoubleArrays.addInPlace(dArr2, mapper.localGrad);
        }
        if (this.regularizer != null) {
            d += this.regularizer.update(dArr, dArr2, 1.0d);
        }
        return Pair.newPair(Double.valueOf(d), dArr2);
    }

    @Override // edu.berkeley.nlp.math.CachingDifferentiableFunction, edu.berkeley.nlp.math.Function
    public int dimension() {
        return this.itemFns.get(0).dimension();
    }
}
