package se.lth.cs.srl.pipeline;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.TreeSet;
import java.util.zip.ZipFile;
import java.util.zip.ZipOutputStream;
import se.lth.cs.srl.Learn;
import se.lth.cs.srl.corpus.Predicate;
import se.lth.cs.srl.corpus.PredicateReference;
import se.lth.cs.srl.corpus.Sentence;
import se.lth.cs.srl.features.Feature;
import se.lth.cs.srl.features.FeatureSet;
import se.lth.cs.srl.ml.LearningProblem;
import se.lth.cs.srl.ml.Model;
import se.lth.cs.srl.ml.liblinear.LibLinearLearningProblem;

/* loaded from: input_file:se/lth/cs/srl/pipeline/PredicateDisambiguator.class */
public class PredicateDisambiguator implements PipelineStep {
    public static final String FILE_PREFIX = "pd_";
    private FeatureSet featureSet;
    private PredicateReference predicateReference;
    protected Map<String, Model> models;
    private Map<String, List<Predicate>> instances;

    public PredicateDisambiguator(FeatureSet featureSet, PredicateReference predicateReference) {
        this.featureSet = featureSet;
        this.predicateReference = predicateReference;
    }

    @Override // se.lth.cs.srl.pipeline.PipelineStep
    public void parse(Sentence sentence) {
        String sense;
        for (Predicate predicate : sentence.getPredicates()) {
            String pOSPrefix = getPOSPrefix(predicate);
            String lemma = predicate.getLemma();
            if (pOSPrefix == null) {
                sense = this.predicateReference.getSimpleSense(predicate, null);
            } else {
                String fileName = this.predicateReference.getFileName(lemma, pOSPrefix);
                if (fileName == null) {
                    sense = this.predicateReference.getSimpleSense(predicate, pOSPrefix);
                } else {
                    Model model = getModel(fileName);
                    TreeSet treeSet = new TreeSet();
                    Integer num = 0;
                    for (Feature feature : this.featureSet.get(pOSPrefix)) {
                        feature.addFeatures(treeSet, predicate, null, num, false);
                        num = Integer.valueOf(num.intValue() + feature.size(false));
                    }
                    sense = this.predicateReference.getSense(lemma, pOSPrefix, model.classify(treeSet));
                }
            }
            predicate.setSense(sense);
        }
    }

    private Model getModel(String str) {
        return this.models.get(str);
    }

    @Override // se.lth.cs.srl.pipeline.PipelineStep
    public void extractInstances(Sentence sentence) {
        for (Predicate predicate : sentence.getPredicates()) {
            String pOSPrefix = getPOSPrefix(predicate);
            if (pOSPrefix == null) {
                if (!Learn.learnOptions.skipNonMatchingPredicates) {
                    pOSPrefix = this.featureSet.POSPrefixes[0];
                }
            }
            String fileName = this.predicateReference.getFileName(predicate.getLemma(), pOSPrefix);
            if (fileName != null) {
                if (!this.instances.containsKey(fileName)) {
                    this.instances.put(fileName, new ArrayList());
                }
                this.instances.get(fileName).add(predicate);
            }
        }
    }

    private String getPOSPrefix(Predicate predicate) {
        for (String str : this.featureSet.POSPrefixes) {
            if (predicate.getPOS().startsWith(str)) {
                return str;
            }
        }
        return null;
    }

    @Override // se.lth.cs.srl.pipeline.PipelineStep
    public void prepareLearning() {
        this.instances = new HashMap();
    }

    private void addInstance(Predicate predicate, LearningProblem learningProblem) {
        String pOSPrefix = getPOSPrefix(predicate);
        if (pOSPrefix == null) {
            pOSPrefix = this.featureSet.POSPrefixes[0];
        }
        TreeSet treeSet = new TreeSet();
        Integer num = 0;
        for (Feature feature : this.featureSet.get(pOSPrefix)) {
            feature.addFeatures(treeSet, predicate, null, num, false);
            num = Integer.valueOf(num.intValue() + feature.size(false));
        }
        learningProblem.addInstance(Integer.valueOf(this.predicateReference.getLabel(predicate.getLemma(), pOSPrefix, predicate.getSense())).intValue(), treeSet);
    }

    @Override // se.lth.cs.srl.pipeline.PipelineStep
    public void done() {
    }

    @Override // se.lth.cs.srl.pipeline.PipelineStep
    public void train() {
        this.models = new HashMap();
        Iterator<String> it = this.instances.keySet().iterator();
        while (it.hasNext()) {
            String next = it.next();
            LibLinearLearningProblem libLinearLearningProblem = new LibLinearLearningProblem(new File(Learn.learnOptions.tempDir, FILE_PREFIX + next), false);
            Iterator<Predicate> it2 = this.instances.get(next).iterator();
            while (it2.hasNext()) {
                addInstance(it2.next(), libLinearLearningProblem);
            }
            libLinearLearningProblem.done();
            this.models.put(next, libLinearLearningProblem.train(true));
            it.remove();
        }
    }

    @Override // se.lth.cs.srl.pipeline.PipelineStep
    public void writeModels(ZipOutputStream zipOutputStream) throws IOException {
        AbstractStep.writeModels(zipOutputStream, this.models, getModelFileName());
    }

    @Override // se.lth.cs.srl.pipeline.PipelineStep
    public void readModels(ZipFile zipFile) throws IOException, ClassNotFoundException {
        this.models = new HashMap();
        AbstractStep.readModels(zipFile, this.models, getModelFileName());
    }

    private String getModelFileName() {
        return "pd_.models";
    }
}
