package se.lth.cs.srl.pipeline;

import java.io.File;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.zip.ZipEntry;
import java.util.zip.ZipException;
import java.util.zip.ZipFile;
import java.util.zip.ZipOutputStream;
import se.lth.cs.srl.Learn;
import se.lth.cs.srl.SemanticRoleLabeler;
import se.lth.cs.srl.corpus.ArgMap;
import se.lth.cs.srl.corpus.Predicate;
import se.lth.cs.srl.corpus.Sentence;
import se.lth.cs.srl.corpus.Word;
import se.lth.cs.srl.features.Feature;
import se.lth.cs.srl.features.FeatureGenerator;
import se.lth.cs.srl.features.FeatureSet;
import se.lth.cs.srl.io.AllCoNLL09Reader;
import se.lth.cs.srl.languages.Language;
import se.lth.cs.srl.ml.Model;
import se.lth.cs.srl.ml.liblinear.Label;
import se.lth.cs.srl.ml.liblinear.LibLinearLearningProblem;
import se.lth.cs.srl.options.LearnOptions;
import se.lth.cs.srl.options.ParseOptions;
import se.lth.cs.srl.util.BrownCluster;

/* loaded from: input_file:se/lth/cs/srl/pipeline/Reranker.class */
public class Reranker extends SemanticRoleLabeler {
    public static final String FILENAME = "global";
    private final double alfa;
    private final boolean noPI;
    private final int aiBeam;
    private final int acBeam;
    private Model model;
    private List<String> argLabels;
    private List<Feature> aiFeatures;
    private List<Feature> acFeatures;
    private int sizeAIFeatures;
    private int sizeACFeatures;
    private int sizePipelineFeatures;
    private Map<String, Integer> calsMap;
    private int calsCounter;
    private Pipeline pipeline;
    private ArgumentIdentifier aiModule;
    private ArgumentClassifier acModule;
    private int[] rankCount;
    private int zeroArgMapCount;

    public Reranker(ParseOptions parseOptions) throws ZipException, IOException, ClassNotFoundException {
        this(parseOptions.global_alfa, parseOptions.skipPI, parseOptions.global_aiBeam, parseOptions.global_acBeam);
        ZipFile zipFile = new ZipFile(parseOptions.modelFile);
        this.pipeline = this.noPI ? Pipeline.fromZipFile(zipFile, new Step[]{Step.pd, Step.ai, Step.ac}) : Pipeline.fromZipFile(zipFile);
        System.out.println("Loading reranker from " + zipFile.getName());
        if (this.noPI) {
            System.out.println("Skipping predicate identification. Input is assumed to have predicates identified.");
        }
        this.argLabels = this.pipeline.getArgLabels();
        populateRerankerFeatureSets(this.pipeline.getFeatureSets(), this.pipeline.getFg());
        ObjectInputStream objectInputStream = new ObjectInputStream(zipFile.getInputStream(zipFile.getEntry(FILENAME)));
        this.model = (Model) objectInputStream.readObject();
        this.calsMap = (Map) objectInputStream.readObject();
        objectInputStream.close();
        int i = this.noPI ? 1 : 2;
        this.aiModule = (ArgumentIdentifier) this.pipeline.steps.get(i);
        this.acModule = (ArgumentClassifier) this.pipeline.steps.get(i + 1);
        zipFile.close();
    }

    private Reranker(double d, boolean z, int i, int i2) {
        this.calsCounter = 1;
        this.zeroArgMapCount = 0;
        this.alfa = d;
        this.noPI = z;
        this.aiBeam = i;
        this.acBeam = i2;
        this.rankCount = new int[i * i2];
    }

    public Reranker(LearnOptions learnOptions, ZipOutputStream zipOutputStream) throws IOException {
        this(1.0d, false, learnOptions.global_aiBeam, learnOptions.global_acBeam);
        List<Sentence> readAll = new AllCoNLL09Reader(learnOptions.inputCorpus).readAll();
        Pipeline trainNewPipeline = Pipeline.trainNewPipeline(readAll, learnOptions.getFeatureFiles(), zipOutputStream, Learn.learnOptions.brownClusterFile == null ? null : new BrownCluster(Learn.learnOptions.brownClusterFile));
        FeatureGenerator fg = trainNewPipeline.getFg();
        this.argLabels = trainNewPipeline.getArgLabels();
        HashMap hashMap = new HashMap(trainNewPipeline.getFeatureSets());
        hashMap.remove(Step.pi);
        populateRerankerFeatureSets(hashMap, fg);
        LibLinearLearningProblem libLinearLearningProblem = new LibLinearLearningProblem(new File(learnOptions.tempDir, FILENAME), true);
        List<List<Sentence>> partitionCorpus = partitionCorpus(readAll, learnOptions.global_numberOfCrossTrain);
        this.calsMap = new HashMap();
        for (int i = 0; i < partitionCorpus.size(); i++) {
            List<Sentence> list = partitionCorpus.get(i);
            readAll.clear();
            for (int i2 = 0; i2 < partitionCorpus.size(); i2++) {
                if (i2 != i) {
                    readAll.addAll(partitionCorpus.get(i2));
                }
            }
            Pipeline trainNewPipeline2 = Pipeline.trainNewPipeline(readAll, fg, (ZipOutputStream) null, hashMap);
            ArgumentIdentifier argumentIdentifier = (ArgumentIdentifier) trainNewPipeline2.steps.get(1);
            ArgumentClassifier argumentClassifier = (ArgumentClassifier) trainNewPipeline2.steps.get(2);
            Iterator<Sentence> it = list.iterator();
            while (it.hasNext()) {
                for (Predicate predicate : it.next().getPredicates()) {
                    this.predCount++;
                    List<ArgMap> beamSearch = argumentClassifier.beamSearch(predicate, argumentIdentifier.beamSearch(predicate, learnOptions.global_aiBeam), learnOptions.global_acBeam);
                    HashSet<ArgMap> hashSet = new HashSet();
                    double partitionBestArgMaps = partitionBestArgMaps(beamSearch, predicate.getArgMap(), hashSet);
                    if (learnOptions.global_insertGoldMapForTrain && partitionBestArgMaps != 1.0d) {
                        hashSet.add(new ArgMap(predicate.getArgMap()));
                    }
                    for (ArgMap argMap : hashSet) {
                        Collection<Integer> collectPipelineFeatureIndices = collectPipelineFeatureIndices(predicate, argMap, new ArrayList());
                        addAndCollectGlobalFeatures(predicate, argMap, collectPipelineFeatureIndices);
                        Collections.sort((List) collectPipelineFeatureIndices);
                        libLinearLearningProblem.addInstance(AbstractStep.POSITIVE.intValue(), collectPipelineFeatureIndices);
                    }
                    for (ArgMap argMap2 : beamSearch) {
                        Collection<Integer> collectPipelineFeatureIndices2 = collectPipelineFeatureIndices(predicate, argMap2, new ArrayList());
                        addAndCollectGlobalFeatures(predicate, argMap2, collectPipelineFeatureIndices2);
                        Collections.sort((List) collectPipelineFeatureIndices2);
                        libLinearLearningProblem.addInstance(AbstractStep.NEGATIVE.intValue(), collectPipelineFeatureIndices2);
                    }
                }
            }
        }
        libLinearLearningProblem.done();
        this.model = libLinearLearningProblem.train();
        zipOutputStream.putNextEntry(new ZipEntry(FILENAME));
        ObjectOutputStream objectOutputStream = new ObjectOutputStream(zipOutputStream);
        objectOutputStream.writeObject(this.model);
        objectOutputStream.writeObject(this.calsMap);
        objectOutputStream.flush();
    }

    @Override // se.lth.cs.srl.SemanticRoleLabeler
    protected void parse(Sentence sentence) {
        this.pipeline.steps.get(0).parse(sentence);
        if (!this.noPI) {
            this.pipeline.steps.get(1).parse(sentence);
        }
        for (Predicate predicate : sentence.getPredicates()) {
            List<ArgMap> beamSearch = this.acModule.beamSearch(predicate, this.aiModule.beamSearch(predicate, this.aiBeam), this.acBeam);
            for (ArgMap argMap : beamSearch) {
                Collection<Integer> arrayList = new ArrayList<>();
                collectPipelineFeatureIndices(predicate, argMap, arrayList);
                collectGlobalFeatures(predicate, argMap, arrayList);
                for (Label label : this.model.classifyProb(arrayList)) {
                    if (!label.getLabel().equals(AbstractStep.NEGATIVE)) {
                        argMap.setRerankProb(label.getProb());
                        argMap.resetProb();
                    }
                }
            }
            int softMax = softMax(beamSearch);
            int[] iArr = this.rankCount;
            iArr[softMax] = iArr[softMax] + 1;
            ArgMap argMap2 = beamSearch.get(softMax);
            if (argMap2.size() == 0) {
                this.zeroArgMapCount++;
            }
            predicate.setArgMap(argMap2);
        }
    }

    private int softMax(List<ArgMap> list) {
        for (ArgMap argMap : list) {
            double idProb = argMap.getIdProb();
            if (argMap.size() != 0) {
                idProb *= Math.pow(argMap.getLblProb(), 1.0d / argMap.size());
            }
            argMap.setProb(idProb);
        }
        double d = 0.0d;
        int i = -1;
        int size = list.size();
        for (int i2 = 0; i2 < size; i2++) {
            ArgMap argMap2 = list.get(i2);
            double prob = argMap2.getProb() * Math.pow(argMap2.getRerankProb(), this.alfa);
            if (prob > d) {
                i = i2;
                d = prob;
            } else if (prob == d) {
                System.out.println("!same score..");
            }
        }
        return i;
    }

    private Collection<Integer> collectPipelineFeatureIndices(Predicate predicate, ArgMap argMap, Collection<Integer> collection) {
        for (Word word : argMap.keySet()) {
            Integer num = 0;
            HashSet hashSet = new HashSet();
            for (Feature feature : this.aiFeatures) {
                feature.addFeatures(hashSet, predicate, word, num, false);
                num = Integer.valueOf(num.intValue() + feature.size(false));
            }
            Integer valueOf = Integer.valueOf(this.sizeAIFeatures + (this.sizeACFeatures * this.argLabels.indexOf(argMap.get(word))));
            for (Feature feature2 : this.acFeatures) {
                feature2.addFeatures(hashSet, predicate, word, valueOf, false);
                valueOf = Integer.valueOf(valueOf.intValue() + feature2.size(false));
            }
            collection.addAll(hashSet);
        }
        return collection;
    }

    private void addAndCollectGlobalFeatures(Predicate predicate, ArgMap argMap, Collection<Integer> collection) {
        String coreArgumentLabelSequence = Language.getLanguage().getCoreArgumentLabelSequence(predicate, argMap);
        Integer num = this.calsMap.get(coreArgumentLabelSequence);
        if (num == null) {
            this.calsMap.put(coreArgumentLabelSequence, Integer.valueOf(this.calsCounter));
            int i = this.calsCounter;
            this.calsCounter = i + 1;
            num = Integer.valueOf(i);
        }
        collection.add(Integer.valueOf(this.sizePipelineFeatures + num.intValue()));
    }

    private void collectGlobalFeatures(Predicate predicate, ArgMap argMap, Collection<Integer> collection) {
        Integer num = this.calsMap.get(Language.getLanguage().getCoreArgumentLabelSequence(predicate, argMap));
        if (num != null) {
            collection.add(Integer.valueOf(this.sizePipelineFeatures + num.intValue()));
        }
    }

    private void populateRerankerFeatureSets(Map<Step, FeatureSet> map, FeatureGenerator featureGenerator) {
        this.aiFeatures = new ArrayList();
        this.acFeatures = new ArrayList();
        Iterator<Map.Entry<String, List<Feature>>> it = map.get(Step.ai).entrySet().iterator();
        while (it.hasNext()) {
            this.aiFeatures.addAll(it.next().getValue());
        }
        Iterator<Map.Entry<String, List<Feature>>> it2 = map.get(Step.ac).entrySet().iterator();
        while (it2.hasNext()) {
            this.acFeatures.addAll(it2.next().getValue());
        }
        this.sizeAIFeatures = 0;
        this.sizeACFeatures = 0;
        Iterator<Feature> it3 = this.aiFeatures.iterator();
        while (it3.hasNext()) {
            this.sizeAIFeatures += it3.next().size(false);
        }
        Iterator<Feature> it4 = this.acFeatures.iterator();
        while (it4.hasNext()) {
            this.sizeACFeatures += it4.next().size(false);
        }
        this.sizePipelineFeatures = this.sizeAIFeatures + (this.argLabels.size() * this.sizeACFeatures);
    }

    private static double partitionBestArgMaps(List<ArgMap> list, Map<Word, String> map, Set<ArgMap> set) {
        double d = 0.0d;
        for (ArgMap argMap : list) {
            double computeScore = argMap.computeScore(map);
            if (computeScore > d) {
                d = computeScore;
                set.clear();
                set.add(argMap);
            } else if (computeScore == d) {
                set.add(argMap);
            }
        }
        list.removeAll(set);
        return d;
    }

    private static List<List<Sentence>> partitionCorpus(Iterable<Sentence> iterable, int i) {
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < i; i2++) {
            arrayList.add(new ArrayList());
        }
        if (Learn.learnOptions.deterministicReranker) {
            int i3 = 0;
            Iterator<Sentence> it = iterable.iterator();
            while (it.hasNext()) {
                ((List) arrayList.get(i3 % i)).add(it.next());
                i3++;
            }
        } else {
            Iterator<Sentence> it2 = iterable.iterator();
            while (it2.hasNext()) {
                ((List) arrayList.get((int) Math.floor(Math.random() * i))).add(it2.next());
            }
        }
        return arrayList;
    }

    @Override // se.lth.cs.srl.SemanticRoleLabeler
    protected String getSubStatus() {
        StringBuilder sb = new StringBuilder("Reranker status:\n");
        sb.append("AI beam:\t\t" + this.aiBeam + "\n");
        sb.append("AC beam:\t\t" + this.acBeam + "\n");
        sb.append("Alfa:\t\t\t" + this.alfa + "\n");
        sb.append("\n");
        sb.append("Reranker choices:\n");
        sb.append("Rank\tFrequency\n");
        for (int i = 0; i < this.rankCount.length; i++) {
            sb.append((i + 1) + "\t" + this.rankCount[i] + "\n");
        }
        sb.append("\n");
        sb.append("Number of zero size argmaps:\t" + this.zeroArgMapCount + "\n");
        return sb.toString();
    }
}
