package se.lth.cs.srl.fs;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.zip.ZipOutputStream;
import se.lth.cs.srl.Learn;
import se.lth.cs.srl.corpus.Sentence;
import se.lth.cs.srl.features.Feature;
import se.lth.cs.srl.features.FeatureFile;
import se.lth.cs.srl.features.FeatureGenerator;
import se.lth.cs.srl.features.FeatureName;
import se.lth.cs.srl.features.FeatureSet;
import se.lth.cs.srl.io.AllCoNLL09Reader;
import se.lth.cs.srl.io.CoNLL09Writer;
import se.lth.cs.srl.io.DepsOnlyCoNLL09Reader;
import se.lth.cs.srl.options.FeatureSelectionOptions;
import se.lth.cs.srl.pipeline.Pipeline;
import se.lth.cs.srl.pipeline.Step;
import se.lth.cs.srl.util.BrownCluster;
import se.lth.cs.srl.util.scorer.AbstractScorer;
import se.lth.cs.srl.util.scorer.PredicateIdentificationScorer;

/* loaded from: input_file:se/lth/cs/srl/fs/SelectFeatures.class */
public class SelectFeatures {
    private static FeatureSelectionOptions options;
    private static FeatureGenerator fg;
    private static final FeatureName[] noArgsSingleFeatures = {FeatureName.PredWord, FeatureName.PredLemma, FeatureName.PredPOS, FeatureName.PredDeprel, FeatureName.PredParentWord, FeatureName.PredParentPOS, FeatureName.DepSubCat, FeatureName.ChildDepSet, FeatureName.ChildPOSSet, FeatureName.ChildWordSet};
    private static final FeatureName[] noArgsSingleFeatures_Feats = {FeatureName.PredFeats, FeatureName.PredParentFeats};
    private static final FeatureName[] argsSingleFeatures = new FeatureName[0];
    private static final FeatureName[] argsSingleFeatures_Feats = new FeatureName[0];

    /* loaded from: input_file:se/lth/cs/srl/fs/SelectFeatures$CorpusStruct.class */
    public static class CorpusStruct {
        public static List<File> parts;
        public static List<List<Sentence>> trainingSets;
        public static List<List<Sentence>> testSets;
    }

    /* loaded from: input_file:se/lth/cs/srl/fs/SelectFeatures$SelectionState.class */
    public static class SelectionState {
        public Map<String, List<Feature>> current;
        public List<Feature> additional;
        public double score = 0.0d;
        public List<String> comments;
    }

    public static void printMemUsage() {
        System.gc();
        long freeMemory = Runtime.getRuntime().freeMemory();
        long j = Runtime.getRuntime().totalMemory();
        System.out.println("Total memory: " + (j / 1024) + "kb");
        System.out.println("Free memory:  " + (freeMemory / 1024) + "kb");
        System.out.println("Used memory:  " + ((j - freeMemory) / 1024) + "kb");
    }

    public static void main(String[] strArr) throws IOException {
        options = new FeatureSelectionOptions(strArr);
        Learn.learnOptions = options.getLearnOptions();
        BrownCluster brownCluster = Learn.learnOptions.brownClusterFile == null ? null : new BrownCluster(Learn.learnOptions.brownClusterFile);
        File file = new File(options.tempDir, "corpora");
        File file2 = new File(options.tempDir, "features");
        file.mkdir();
        file2.mkdir();
        printMemUsage();
        List<Sentence> readSentences = readSentences(options);
        printMemUsage();
        fg = new FeatureGenerator();
        SelectionState startingState = getStartingState(fg, getStartSetFromFile(fg, brownCluster), brownCluster);
        printMemUsage();
        fg.buildFeatureMaps(readSentences);
        printMemUsage();
        List<List<Sentence>> partitionSentences = partitionSentences(readSentences, options);
        printMemUsage();
        CorpusStruct.parts = new ArrayList();
        int size = partitionSentences.size();
        for (int i = 0; i < size; i++) {
            File file3 = new File(file, "gold-" + i);
            CorpusStruct.parts.add(file3);
            CoNLL09Writer coNLL09Writer = new CoNLL09Writer(file3);
            Iterator<Sentence> it = partitionSentences.get(i).iterator();
            while (it.hasNext()) {
                coNLL09Writer.write(it.next());
            }
            coNLL09Writer.close();
        }
        printMemUsage();
        CorpusStruct.testSets = partitionSentences;
        CorpusStruct.trainingSets = new ArrayList();
        int size2 = partitionSentences.size();
        for (int i2 = 0; i2 < size2; i2++) {
            ArrayList arrayList = new ArrayList();
            for (int i3 = 0; i3 < size2; i3++) {
                printMemUsage();
                if (i2 != i3) {
                    arrayList.addAll(partitionSentences.get(i3));
                }
            }
            CorpusStruct.trainingSets.add(arrayList);
        }
        int i4 = 0;
        printMemUsage();
        double d = 999.0d;
        while (d > options.threshold) {
            d = iterate(startingState);
            i4++;
            FeatureFile.writeToFile(startingState.current.get(options.POSPrefix), options.POSPrefix, startingState.comments, new File(file2, options.step + "-fs-" + i4));
        }
        printMemUsage();
        if (options.quadratic) {
            return;
        }
        options.quadratic = true;
        fg = new FeatureGenerator();
        printMemUsage();
        SelectionState startingState2 = getStartingState(fg, startingState.current.get(options.POSPrefix), brownCluster);
        startingState2.score = startingState.score;
        startingState2.comments = startingState.comments;
        printMemUsage();
        fg.buildFeatureMaps(readSentences);
        printMemUsage();
        double d2 = 1.0d;
        while (d2 > options.threshold) {
            d2 = iterate(startingState2);
            i4++;
            FeatureFile.writeToFile(startingState2.current.get(options.POSPrefix), options.POSPrefix, startingState2.comments, new File(file2, options.step + "-fs-" + i4));
        }
    }

    private static double iterate(SelectionState selectionState) throws IOException {
        int size = selectionState.additional.size();
        double[] dArr = new double[size];
        HashMap hashMap = new HashMap();
        FeatureSet featureSet = new FeatureSet(selectionState.current);
        hashMap.put(options.step, featureSet);
        AbstractScorer scorer = getScorer(options.step);
        int i = options.crossValidated ? options.partitions : 1;
        List list = featureSet.get(options.POSPrefix);
        for (int i2 = 0; i2 < i; i2++) {
            for (int i3 = 0; i3 < size; i3++) {
                list.add(selectionState.additional.get(i3));
                scorer.reset();
                Iterator<Sentence> it = new DepsOnlyCoNLL09Reader(CorpusStruct.parts.get(i2)).iterator();
                hashMap.put(options.step, featureSet);
                Pipeline trainNewPipeline = Pipeline.trainNewPipeline(CorpusStruct.trainingSets.get(i2), fg, (ZipOutputStream) null, hashMap);
                for (Sentence sentence : CorpusStruct.testSets.get(i2)) {
                    Sentence next = it.next();
                    trainNewPipeline.parseSentence(next);
                    scorer.accScore(sentence, next);
                }
                int i4 = i3;
                dArr[i4] = dArr[i4] + scorer.getAvgScore();
                list.remove(list.size() - 1);
            }
            System.out.println("Cross: " + i2);
        }
        double d = 0.0d;
        int i5 = -1;
        int length = dArr.length;
        for (int i6 = 0; i6 < length; i6++) {
            double d2 = dArr[i6] / i;
            if (d2 > d) {
                d = d2;
                i5 = i6;
            }
        }
        double d3 = d - selectionState.score;
        if (d3 > 0.0d) {
            selectionState.score = d;
            selectionState.current.get(options.POSPrefix).add(selectionState.additional.remove(i5));
            selectionState.comments.add("F1: " + d + ", increase: " + d3);
        } else {
            System.out.println("negative increase.");
        }
        return d3;
    }

    private static AbstractScorer getScorer(Step step) {
        switch (step) {
            case pi:
                return new PredicateIdentificationScorer();
            case pd:
            case ai:
            case ac:
            default:
                throw new Error("You are wrong here, check your code");
        }
    }

    private static List<Feature> getStartSetFromFile(FeatureGenerator featureGenerator, BrownCluster brownCluster) throws IOException {
        ArrayList arrayList = new ArrayList();
        if (options.startingFeatureFile != null) {
            List<String> list = FeatureFile.readFile(options.startingFeatureFile).get(options.POSPrefix);
            if (list == null) {
                throw new Error("The feature file provided does not contain the POSPrefix we want to explore. Aborting.");
            }
            Iterator<String> it = list.iterator();
            while (it.hasNext()) {
                arrayList.add(featureGenerator.getFeature(it.next(), options.step == Step.pi, options.POSPrefix, brownCluster));
            }
        }
        return arrayList;
    }

    private static SelectionState getStartingState(FeatureGenerator featureGenerator, List<Feature> list, BrownCluster brownCluster) throws IOException {
        SelectionState selectionState = new SelectionState();
        selectionState.current = new HashMap();
        selectionState.additional = new ArrayList();
        selectionState.comments = new ArrayList();
        ArrayList arrayList = new ArrayList();
        Iterator<Feature> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(it.next().getName());
        }
        for (Feature feature : getAllFeatures(featureGenerator, arrayList, brownCluster)) {
            if (list.contains(feature)) {
                System.out.println("We shouldnt end up here... Look into this.");
            } else {
                selectionState.additional.add(feature);
            }
        }
        selectionState.current.put(options.POSPrefix, list);
        int size = list.size();
        for (int i = 0; i < size; i++) {
            selectionState.comments.add(null);
        }
        return selectionState;
    }

    private static Collection<Feature> getAllFeatures(FeatureGenerator featureGenerator, List<String> list, BrownCluster brownCluster) {
        ArrayList<FeatureName> arrayList = new ArrayList();
        Collections.addAll(arrayList, noArgsSingleFeatures);
        if (options.includeFeats) {
            Collections.addAll(arrayList, noArgsSingleFeatures_Feats);
        }
        if (options.step == Step.ai || options.step == Step.ac) {
            Collections.addAll(arrayList, argsSingleFeatures);
            if (options.includeFeats) {
                Collections.addAll(arrayList, argsSingleFeatures_Feats);
            }
        }
        HashSet hashSet = new HashSet();
        for (FeatureName featureName : arrayList) {
            if (!list.contains(featureName.toString())) {
                hashSet.add(featureGenerator.getFeature(featureName, options.step == Step.pi, options.POSPrefix, brownCluster));
            }
        }
        if (options.quadratic) {
            for (FeatureName featureName2 : arrayList) {
                for (FeatureName featureName3 : arrayList) {
                    if (featureName2 != featureName3) {
                        try {
                            if (!list.contains(FeatureGenerator.getCanonicalName(featureName2, featureName3))) {
                                hashSet.add(featureGenerator.getQFeature(featureName2, featureName3, options.step == Step.pi, options.POSPrefix, brownCluster));
                            }
                        } catch (IllegalArgumentException e) {
                        }
                    }
                }
            }
        }
        return hashSet;
    }

    private static List<Sentence> readSentences(FeatureSelectionOptions featureSelectionOptions) {
        ArrayList arrayList = new ArrayList();
        AllCoNLL09Reader allCoNLL09Reader = new AllCoNLL09Reader(featureSelectionOptions.trainingCorpus);
        for (Sentence sentence : allCoNLL09Reader) {
            if (!featureSelectionOptions.dropSentencesWithoutPredicates || sentence.getPredicates().size() > 0) {
                arrayList.add(sentence);
            }
        }
        allCoNLL09Reader.close();
        return arrayList;
    }

    private static List<List<Sentence>> partitionSentences(List<Sentence> list, FeatureSelectionOptions featureSelectionOptions) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < featureSelectionOptions.partitions; i++) {
            arrayList.add(new ArrayList());
        }
        if (featureSelectionOptions.randomizeInput) {
            Collections.shuffle(list);
        }
        int i2 = 0;
        Iterator<Sentence> it = list.iterator();
        while (it.hasNext()) {
            ((List) arrayList.get(i2 % featureSelectionOptions.partitions)).add(it.next());
            i2++;
        }
        return arrayList;
    }
}
