package se.lth.cs.srl.pipeline;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.TreeSet;
import se.lth.cs.srl.corpus.ArgMap;
import se.lth.cs.srl.corpus.Predicate;
import se.lth.cs.srl.corpus.Sentence;
import se.lth.cs.srl.corpus.Word;
import se.lth.cs.srl.features.FeatureSet;
import se.lth.cs.srl.ml.Model;
import se.lth.cs.srl.ml.liblinear.Label;

/* loaded from: input_file:se/lth/cs/srl/pipeline/ArgumentClassifier.class */
public class ArgumentClassifier extends ArgumentStep {
    private static final String FILEPREFIX = "ac_";
    private List<String> argLabels;

    public ArgumentClassifier(FeatureSet featureSet, List<String> list) {
        super(featureSet);
        this.argLabels = list;
    }

    @Override // se.lth.cs.srl.pipeline.AbstractStep, se.lth.cs.srl.pipeline.PipelineStep
    public void extractInstances(Sentence sentence) {
        for (Predicate predicate : sentence.getPredicates()) {
            Iterator<Word> it = predicate.getArgMap().keySet().iterator();
            while (it.hasNext()) {
                super.addInstance(predicate, it.next());
            }
        }
    }

    @Override // se.lth.cs.srl.pipeline.AbstractStep, se.lth.cs.srl.pipeline.PipelineStep
    public void parse(Sentence sentence) {
        for (Predicate predicate : sentence.getPredicates()) {
            Map<Word, String> argMap = predicate.getArgMap();
            for (Word word : argMap.keySet()) {
                argMap.put(word, this.argLabels.get(super.classifyInstance(predicate, word).intValue()));
            }
        }
    }

    @Override // se.lth.cs.srl.pipeline.ArgumentStep
    protected Integer getLabel(Predicate predicate, Word word) {
        return Integer.valueOf(this.argLabels.indexOf(predicate.getArgMap().get(word)));
    }

    @Override // se.lth.cs.srl.pipeline.PipelineStep
    public void prepareLearning() {
        super.prepareLearning(FILEPREFIX);
    }

    @Override // se.lth.cs.srl.pipeline.AbstractStep
    protected String getModelFileName() {
        return "ac_.models";
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public List<ArgMap> beamSearch(Predicate predicate, List<ArgMap> list, int i) {
        ArrayList arrayList = new ArrayList();
        String pOSPrefix = super.getPOSPrefix(predicate.getPOS());
        if (pOSPrefix == null) {
            pOSPrefix = this.featureSet.POSPrefixes[0];
        }
        Model model = this.models.get(pOSPrefix);
        for (ArgMap argMap : list) {
            ArrayList arrayList2 = new ArrayList();
            arrayList2.add(argMap);
            TreeSet treeSet = new TreeSet(ArgMap.REVERSE_PROB_COMPARATOR);
            for (Word word : argMap.keySet()) {
                List<Label> classifyProb = model.classifyProb(super.collectIndices(predicate, word, pOSPrefix, new TreeSet()));
                Iterator it = arrayList2.iterator();
                while (it.hasNext()) {
                    ArgMap argMap2 = (ArgMap) it.next();
                    for (int i2 = 0; i2 < i; i2++) {
                        Label label = classifyProb.get(i2);
                        ArgMap argMap3 = new ArgMap(argMap2);
                        argMap3.put(word, this.argLabels.get(label.getLabel().intValue()), label.getProb());
                        treeSet.add(argMap3);
                    }
                }
                arrayList2.clear();
                Iterator it2 = treeSet.iterator();
                for (int i3 = 0; i3 < i && it2.hasNext(); i3++) {
                    arrayList2.add((ArgMap) it2.next());
                }
                treeSet.clear();
            }
            int size = arrayList2.size();
            for (int i4 = 0; i4 < i && i4 < size; i4++) {
                ArgMap argMap4 = (ArgMap) arrayList2.get(i4);
                argMap4.setLblProb(argMap4.getProb());
                argMap4.resetProb();
                arrayList.add(argMap4);
            }
        }
        return arrayList;
    }
}
