package com.clearnlp.bin;

import com.clearnlp.classification.algorithm.AbstractAlgorithm;
import com.clearnlp.classification.feature.JointFtrXml;
import com.clearnlp.classification.model.StringModelAD;
import com.clearnlp.collection.list.FloatArrayList;
import com.clearnlp.component.evaluation.AbstractEval;
import com.clearnlp.component.online.AbstractOnlineStatisticalComponent;
import com.clearnlp.component.online.OnlinePOSTagger;
import com.clearnlp.component.state.AbstractState;
import com.clearnlp.dependency.DEPTree;
import com.clearnlp.morphology.MPLib;
import com.clearnlp.nlp.NLPMode;
import com.clearnlp.reader.JointReader;
import com.clearnlp.util.UTArgs4j;
import com.clearnlp.util.UTFile;
import com.clearnlp.util.UTInput;
import com.clearnlp.util.UTOutput;
import com.clearnlp.util.UTXml;
import com.clearnlp.util.map.Prob1DMap;
import com.clearnlp.util.pair.ObjectDoublePair;
import com.google.common.collect.Sets;
import edu.stanford.nlp.pipeline.CleanXmlAnnotator;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.PrintStream;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;
import org.apache.lucene.util.packed.PackedInts;
import org.kohsuke.args4j.Option;
import org.w3c.dom.Element;
import org.w3c.dom.NodeList;

/* loaded from: input_file:com/clearnlp/bin/NLPDevelop.class */
public class NLPDevelop extends AbstractNLP implements NLPMode {
    protected final String DELIM_FILENAME = ":";
    protected final int MAX_TREES = 5000;

    @Option(name = "-c", usage = "configuration file (required)", required = true, metaVar = "<filename>")
    protected String s_configFile;

    @Option(name = "-f", usage = "feature template files delimited by ':' (required)", required = true, metaVar = "<filename>")
    protected String s_featureFiles;

    @Option(name = "-i", usage = "training file or directory containing training files (required)", required = true, metaVar = "<directory>")
    protected String s_trainPath;

    @Option(name = "-z", usage = "mode (pos|dep|pred|role|srl)", required = true, metaVar = "<string>")
    protected String s_mode;

    @Option(name = "-d", usage = "development file or directory containing development files (required)", required = true, metaVar = "<directory>")
    protected String s_developPath;

    @Option(name = "-m", usage = "model file (required)", required = true, metaVar = "<filename>")
    protected String s_modelFile;

    @Option(name = "-t", usage = "type (required)", required = true, metaVar = "<0|1|2|3>")
    protected int i_type;

    public NLPDevelop(String[] strArr) {
        UTArgs4j.initArgs(this, strArr);
        String[] split = this.s_featureFiles.split(":");
        String[] sortedFileListBySize = UTFile.getSortedFileListBySize(this.s_trainPath, CleanXmlAnnotator.DEFAULT_XML_TAGS, true);
        String[] sortedFileListBySize2 = UTFile.getSortedFileListBySize(this.s_developPath, CleanXmlAnnotator.DEFAULT_XML_TAGS, true);
        try {
            Element documentElement = UTXml.getDocumentElement(new FileInputStream(this.s_configFile));
            JointFtrXml[] featureTemplates = getFeatureTemplates(split);
            switch (this.i_type) {
                case 0:
                    develop(featureTemplates, sortedFileListBySize, sortedFileListBySize2, documentElement, this.s_mode, -1);
                    break;
                case 1:
                    train(featureTemplates, sortedFileListBySize, this.s_modelFile, documentElement, this.s_mode);
                    break;
                case 2:
                    decode(sortedFileListBySize2, this.s_modelFile, documentElement);
                    break;
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void decode(String[] strArr, String str, Element element) throws Exception {
        process(strArr, getJointReader(UTXml.getFirstElementByTagName(element, "reader")), getDecoder(new ObjectInputStream(new BufferedInputStream(new GZIPInputStream(new FileInputStream(str))))), "Decoding:", (byte) 2, -1);
    }

    public void train(JointFtrXml[] jointFtrXmlArr, String[] strArr, String str, Element element, String str2) throws Exception {
        JointReader jointReader = getJointReader(UTXml.getFirstElementByTagName(element, "reader"));
        AbstractOnlineStatisticalComponent<? extends AbstractState> preBootstrap = preBootstrap(jointFtrXmlArr, strArr, jointReader, element, str2, -1);
        Element firstElementByTagName = UTXml.getFirstElementByTagName(element, str2);
        NodeList elementsByTagName = firstElementByTagName.getElementsByTagName("train");
        int i = 0;
        int numberOfBootstraps = getNumberOfBootstraps(firstElementByTagName);
        while (true) {
            train(preBootstrap, elementsByTagName, i);
            if (i >= numberOfBootstraps) {
                break;
            }
            i++;
            this.LOG.info(String.format("===== Bootstrap: %d =====\n", Integer.valueOf(i)));
            process(strArr, jointReader, preBootstrap, "Generating instances:", (byte) 3, -1);
        }
        for (StringModelAD stringModelAD : preBootstrap.getModels()) {
            stringModelAD.trimFeatures(this.LOG, PackedInts.COMPACT);
        }
        preBootstrap.save(new ObjectOutputStream(new BufferedOutputStream(new GZIPOutputStream(new FileOutputStream(str)))));
    }

    public void develop(JointFtrXml[] jointFtrXmlArr, String[] strArr, String[] strArr2, Element element, String str, int i) throws Exception {
        JointReader jointReader = getJointReader(UTXml.getFirstElementByTagName(element, "reader"));
        AbstractOnlineStatisticalComponent<? extends AbstractState> preBootstrap = preBootstrap(jointFtrXmlArr, strArr, jointReader, element, str, i);
        Element firstElementByTagName = UTXml.getFirstElementByTagName(element, str);
        NodeList elementsByTagName = firstElementByTagName.getElementsByTagName("train");
        double d = 0.0d;
        int i2 = 0;
        while (true) {
            if (i2 == 0) {
                develop(strArr, jointReader, preBootstrap, elementsByTagName, getBootstrapScore(firstElementByTagName), i2, (byte) 4);
            } else {
                double develop = develop(strArr2, jointReader, preBootstrap, elementsByTagName, 0.0d, i2, (byte) 4);
                if (develop <= d) {
                    return;
                } else {
                    d = develop;
                }
            }
            i2++;
            this.LOG.info(String.format("===== Bootstrap: %d =====\n", Integer.valueOf(i2)));
            process(strArr, jointReader, preBootstrap, "Generating instances:", (byte) 3, i);
        }
    }

    private AbstractOnlineStatisticalComponent<? extends AbstractState> preBootstrap(JointFtrXml[] jointFtrXmlArr, String[] strArr, JointReader jointReader, Element element, String str, int i) throws Exception {
        Object[] objArr = null;
        AbstractOnlineStatisticalComponent<? extends AbstractState> collector = getCollector(jointFtrXmlArr, strArr, jointReader, element, i);
        if (collector != null) {
            process(strArr, jointReader, collector, "Collecting lexica:", (byte) 0, i);
            objArr = collector.getLexica();
        }
        AbstractOnlineStatisticalComponent<? extends AbstractState> trainer = getTrainer(jointFtrXmlArr, objArr);
        process(strArr, jointReader, trainer, "Generating instances:", (byte) 1, i);
        return trainer;
    }

    /*  JADX ERROR: JadxRuntimeException in pass: RegionMakerVisitor
        jadx.core.utils.exceptions.JadxRuntimeException: Failed to find switch 'out' block (already processed)
        	at jadx.core.dex.visitors.regions.RegionMaker.calcSwitchOut(RegionMaker.java:923)
        	at jadx.core.dex.visitors.regions.RegionMaker.processSwitch(RegionMaker.java:797)
        	at jadx.core.dex.visitors.regions.RegionMaker.traverse(RegionMaker.java:157)
        	at jadx.core.dex.visitors.regions.RegionMaker.makeRegion(RegionMaker.java:91)
        	at jadx.core.dex.visitors.regions.RegionMaker.processIf(RegionMaker.java:735)
        	at jadx.core.dex.visitors.regions.RegionMaker.traverse(RegionMaker.java:152)
        	at jadx.core.dex.visitors.regions.RegionMaker.makeRegion(RegionMaker.java:91)
        	at jadx.core.dex.visitors.regions.RegionMaker.makeEndlessLoop(RegionMaker.java:411)
        	at jadx.core.dex.visitors.regions.RegionMaker.processLoop(RegionMaker.java:201)
        	at jadx.core.dex.visitors.regions.RegionMaker.traverse(RegionMaker.java:135)
        	at jadx.core.dex.visitors.regions.RegionMaker.makeRegion(RegionMaker.java:91)
        	at jadx.core.dex.visitors.regions.RegionMaker.processLoop(RegionMaker.java:263)
        	at jadx.core.dex.visitors.regions.RegionMaker.traverse(RegionMaker.java:135)
        	at jadx.core.dex.visitors.regions.RegionMaker.makeRegion(RegionMaker.java:91)
        	at jadx.core.dex.visitors.regions.RegionMakerVisitor.visit(RegionMakerVisitor.java:52)
        */
    /* JADX WARN: Failed to find 'out' block for switch in B:11:0x0057. Please report as an issue. */
    protected java.util.List<java.lang.String> process(java.lang.String[] r11, com.clearnlp.reader.JointReader r12, com.clearnlp.component.online.AbstractOnlineStatisticalComponent<? extends com.clearnlp.component.state.AbstractState> r13, java.lang.String r14, byte r15, int r16) throws java.lang.Exception {
        /*
            Method dump skipped, instructions count: 364
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: com.clearnlp.bin.NLPDevelop.process(java.lang.String[], com.clearnlp.reader.JointReader, com.clearnlp.component.online.AbstractOnlineStatisticalComponent, java.lang.String, byte, int):java.util.List");
    }

    protected void train(AbstractOnlineStatisticalComponent<? extends AbstractState> abstractOnlineStatisticalComponent, NodeList nodeList, int i) {
        StringModelAD[] models = abstractOnlineStatisticalComponent.getModels();
        int length = models.length;
        for (int i2 = 0; i2 < length; i2++) {
            Element element = (Element) nodeList.item(i2);
            StringModelAD stringModelAD = models[i2];
            stringModelAD.build(getLabelCutoff(element), getFeatureCutoff(element), getRandomSeed(element), true);
            stringModelAD.printInfo(this.LOG);
            trainOnline(stringModelAD, getAlgorithm(element), getNumberOfIterations(element, i));
        }
    }

    private void trainOnline(StringModelAD stringModelAD, AbstractAlgorithm abstractAlgorithm, int i) {
        for (int i2 = 0; i2 < i; i2++) {
            abstractAlgorithm.train(stringModelAD);
            this.LOG.info(".");
        }
        this.LOG.info("\n");
    }

    protected double develop(String[] strArr, JointReader jointReader, AbstractOnlineStatisticalComponent<? extends AbstractState> abstractOnlineStatisticalComponent, NodeList nodeList, double d, int i, byte b) throws Exception {
        StringModelAD[] models = abstractOnlineStatisticalComponent.getModels();
        ObjectDoublePair<List<String>> objectDoublePair = null;
        int length = models.length;
        for (int i2 = 0; i2 < length; i2++) {
            Element element = (Element) nodeList.item(i2);
            StringModelAD stringModelAD = models[i2];
            stringModelAD.build(getLabelCutoff(element), getFeatureCutoff(element), getRandomSeed(element), true);
            stringModelAD.printInfo(this.LOG);
            objectDoublePair = developOnline(strArr, jointReader, abstractOnlineStatisticalComponent, stringModelAD, getAlgorithm(element), d, b);
        }
        if (b == 5) {
            printOutput(strArr, (List) objectDoublePair.o, i);
        }
        return objectDoublePair.d;
    }

    protected ObjectDoublePair<List<String>> developOnline(String[] strArr, JointReader jointReader, AbstractOnlineStatisticalComponent<? extends AbstractState> abstractOnlineStatisticalComponent, StringModelAD stringModelAD, AbstractAlgorithm abstractAlgorithm, double d, byte b) throws Exception {
        boolean z = d > 0.0d;
        List<String> list = null;
        FloatArrayList floatArrayList = null;
        double d2 = 0.0d;
        int i = 1;
        while (true) {
            abstractAlgorithm.train(stringModelAD);
            List<String> process = process(strArr, jointReader, abstractOnlineStatisticalComponent, null, b, -1);
            AbstractEval eval = abstractOnlineStatisticalComponent.getEval();
            double d3 = eval.getAccuracies()[0];
            this.LOG.info(String.format("%2d: %s\n", Integer.valueOf(i), eval.toString()));
            eval.clear();
            if (d2 >= d3) {
                break;
            }
            floatArrayList = stringModelAD.cloneWeights();
            d2 = d3;
            list = process;
            if (z && d <= d3) {
                break;
            }
            i++;
        }
        stringModelAD.setWeights(floatArrayList);
        return new ObjectDoublePair<>(list, d2);
    }

    protected ObjectDoublePair<List<String>> developBatch(String[] strArr, JointReader jointReader, AbstractOnlineStatisticalComponent<? extends AbstractState> abstractOnlineStatisticalComponent, StringModelAD stringModelAD, AbstractAlgorithm abstractAlgorithm, byte b) throws Exception {
        abstractAlgorithm.train(stringModelAD);
        List<String> process = process(strArr, jointReader, abstractOnlineStatisticalComponent, null, b, -1);
        AbstractEval eval = abstractOnlineStatisticalComponent.getEval();
        double d = eval.getAccuracies()[0];
        this.LOG.info(String.format("%s\n", eval.toString()));
        eval.clear();
        return new ObjectDoublePair<>(process, d);
    }

    protected void printOutput(String[] strArr, List<String> list, int i) {
        int length = strArr.length;
        for (int i2 = 0; i2 < length; i2++) {
            PrintStream createPrintBufferedFileStream = UTOutput.createPrintBufferedFileStream(strArr[i2] + "." + i);
            createPrintBufferedFileStream.print(list.get(i2));
            createPrintBufferedFileStream.close();
        }
    }

    protected AbstractOnlineStatisticalComponent<? extends AbstractState> getCollector(JointFtrXml[] jointFtrXmlArr, String[] strArr, JointReader jointReader, Element element, int i) {
        Element firstElementByTagName = UTXml.getFirstElementByTagName(element, "pos");
        int documentFrequencyCutoff = getDocumentFrequencyCutoff(firstElementByTagName);
        int documentMaxTokenCount = getDocumentMaxTokenCount(firstElementByTagName);
        return new OnlinePOSTagger(jointFtrXmlArr, documentMaxTokenCount <= 0 ? getLowerSimplifiedFormsByDocumentFrequencies(jointReader, strArr, i, documentFrequencyCutoff) : getLowerSimplifiedFormsByDocumentFrequencies(jointReader, strArr, i, documentFrequencyCutoff, documentMaxTokenCount));
    }

    protected AbstractOnlineStatisticalComponent<? extends AbstractState> getTrainer(JointFtrXml[] jointFtrXmlArr, Object[] objArr) {
        return new OnlinePOSTagger(jointFtrXmlArr, objArr);
    }

    protected AbstractOnlineStatisticalComponent<? extends AbstractState> getDecoder(ObjectInputStream objectInputStream) {
        return new OnlinePOSTagger(objectInputStream);
    }

    protected String toString(DEPTree dEPTree) {
        return dEPTree.toStringPOS();
    }

    private Set<String> getLowerSimplifiedFormsByDocumentFrequencies(JointReader jointReader, String[] strArr, int i, int i2) {
        int length = strArr.length;
        HashSet newHashSet = Sets.newHashSet();
        Prob1DMap prob1DMap = new Prob1DMap();
        this.LOG.info(String.format("Collecting simplified-forms: cutoff = %d\n", Integer.valueOf(i2)));
        for (int i3 = 0; i3 < length; i3++) {
            if (i3 != i) {
                jointReader.open(UTInput.createBufferedFileReader(strArr[i3]));
                newHashSet.clear();
                while (true) {
                    DEPTree next = jointReader.next();
                    if (next == null) {
                        break;
                    }
                    int size = next.size();
                    for (int i4 = 1; i4 < size; i4++) {
                        newHashSet.add(MPLib.getSimplifiedLowercaseWordForm(next.get(i4).form));
                    }
                }
                prob1DMap.addAll(newHashSet);
                jointReader.close();
                this.LOG.info(".");
            }
        }
        this.LOG.info("\n");
        return prob1DMap.toSet(i2);
    }

    private Set<String> getLowerSimplifiedFormsByDocumentFrequencies(JointReader jointReader, String[] strArr, int i, int i2, int i3) {
        int i4 = 0;
        int length = strArr.length;
        HashSet newHashSet = Sets.newHashSet();
        Prob1DMap prob1DMap = new Prob1DMap();
        this.LOG.info(String.format("Collecting simplified-forms: cutoff = %d, max = %d\n", Integer.valueOf(i2), Integer.valueOf(i3)));
        for (int i5 = 0; i5 < length; i5++) {
            if (i5 != i) {
                jointReader.open(UTInput.createBufferedFileReader(strArr[i5]));
                while (true) {
                    DEPTree next = jointReader.next();
                    if (next == null) {
                        break;
                    }
                    int size = next.size();
                    for (int i6 = 1; i6 < size; i6++) {
                        newHashSet.add(MPLib.getSimplifiedLowercaseWordForm(next.get(i6).form));
                    }
                    int i7 = i4 + size;
                    i4 = i7;
                    if (i7 >= i3) {
                        prob1DMap.addAll(newHashSet);
                        this.LOG.info(".");
                        newHashSet.clear();
                        i4 = 0;
                    }
                }
                jointReader.close();
            }
        }
        this.LOG.info("\n");
        if (!newHashSet.isEmpty()) {
            prob1DMap.addAll(newHashSet);
        }
        return prob1DMap.toSet(i2);
    }

    public static void main(String[] strArr) {
        new NLPDevelop(strArr);
    }
}
