package com.clearnlp.component.dep;

import com.clearnlp.classification.algorithm.old.AbstractAlgorithm;
import com.clearnlp.classification.feature.FtrToken;
import com.clearnlp.classification.feature.JointFtrXml;
import com.clearnlp.classification.instance.StringInstance;
import com.clearnlp.classification.model.StringModel;
import com.clearnlp.classification.prediction.StringPrediction;
import com.clearnlp.classification.train.StringTrainSpace;
import com.clearnlp.classification.vector.StringFeatureVector;
import com.clearnlp.component.AbstractStatisticalComponentSB;
import com.clearnlp.component.evaluation.DEPEval;
import com.clearnlp.component.label.IDEPLabel;
import com.clearnlp.component.state.DEPState;
import com.clearnlp.dependency.DEPHead;
import com.clearnlp.dependency.DEPLabel;
import com.clearnlp.dependency.DEPNode;
import com.clearnlp.dependency.DEPTree;
import com.clearnlp.util.UTCollection;
import com.clearnlp.util.pair.ObjectDoublePair;
import com.clearnlp.util.pair.StringIntPair;
import com.clearnlp.util.triple.ObjectsDoubleTriple;
import com.clearnlp.util.triple.Triple;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.regex.Matcher;

/* loaded from: input_file:com/clearnlp/component/dep/AbstractDEPParser.class */
public abstract class AbstractDEPParser extends AbstractStatisticalComponentSB<DEPState> implements IDEPLabel {
    public AbstractDEPParser(JointFtrXml[] jointFtrXmlArr, StringTrainSpace[] stringTrainSpaceArr, Object[] objArr, double d, int i) {
        super(jointFtrXmlArr, stringTrainSpaceArr, objArr, d, i);
    }

    public AbstractDEPParser(JointFtrXml[] jointFtrXmlArr, StringModel[] stringModelArr, Object[] objArr, double d, int i) {
        super(jointFtrXmlArr, stringModelArr, objArr, new DEPEval(), d, i);
    }

    public AbstractDEPParser(JointFtrXml[] jointFtrXmlArr, StringTrainSpace[] stringTrainSpaceArr, StringModel[] stringModelArr, Object[] objArr, double d, int i) {
        super(jointFtrXmlArr, stringTrainSpaceArr, stringModelArr, objArr, d, i);
    }

    public AbstractDEPParser(ObjectInputStream objectInputStream) {
        super(objectInputStream);
    }

    @Override // com.clearnlp.component.AbstractStatisticalComponent
    protected void initLexia(Object[] objArr) {
    }

    protected abstract void rerankPredictions(List<StringPrediction> list, DEPState dEPState);

    protected abstract boolean resetPre(DEPState dEPState);

    protected abstract void resetPost(DEPNode dEPNode, DEPNode dEPNode2, DEPLabel dEPLabel, DEPState dEPState);

    protected abstract void postProcess(DEPState dEPState);

    protected abstract boolean isNotHead(DEPNode dEPNode);

    @Override // com.clearnlp.component.AbstractStatisticalComponent
    public void load(ObjectInputStream objectInputStream) throws Exception {
        loadSB(objectInputStream);
        loadDefault(objectInputStream);
        objectInputStream.close();
    }

    @Override // com.clearnlp.component.AbstractStatisticalComponent
    public void save(ObjectOutputStream objectOutputStream) {
        try {
            saveSB(objectOutputStream);
            saveDefault(objectOutputStream);
            objectOutputStream.close();
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    @Override // com.clearnlp.component.AbstractStatisticalComponent
    public Object[] getLexica() {
        return null;
    }

    @Override // com.clearnlp.component.AbstractStatisticalComponent
    public Set<String> getLabels() {
        HashSet newHashSet = Sets.newHashSet();
        for (StringModel stringModel : this.s_models) {
            for (String str : stringModel.getLabels()) {
                newHashSet.add(new DEPLabel(str).deprel);
            }
        }
        return newHashSet;
    }

    public List<ObjectDoublePair<DEPTree>> getParsedTrees(DEPTree dEPTree, boolean z) {
        DEPState init = init(dEPTree);
        processAux(init);
        List<ObjectsDoubleTriple<List<StringInstance>, StringIntPair[]>> branches = init.getBranches();
        ArrayList newArrayList = Lists.newArrayList();
        HashSet newHashSet = Sets.newHashSet();
        UTCollection.sortReverseOrder(branches);
        for (ObjectsDoubleTriple<List<StringInstance>, StringIntPair[]> objectsDoubleTriple : branches) {
            dEPTree.resetHeads(objectsDoubleTriple.o2);
            processHeadless(init);
            postProcess(init);
            String arrays = Arrays.toString(dEPTree.getHeads());
            if (!z || !newHashSet.contains(arrays)) {
                newHashSet.add(arrays);
                newArrayList.add(new ObjectDoublePair(dEPTree.clone(), objectsDoubleTriple.d));
            }
        }
        return newArrayList;
    }

    @Override // com.clearnlp.component.AbstractComponent
    public void process(DEPTree dEPTree) {
        DEPState init = init(dEPTree);
        processAux(init);
        if (isDevelopOrDecode()) {
            processHeadless(init);
            postProcess(init);
            if (isDevelop()) {
                this.e_eval.countAccuracy(init.getTree(), init.getGoldLabels());
            }
        }
    }

    protected DEPState init(DEPTree dEPTree) {
        DEPState dEPState = new DEPState(dEPTree);
        if (!isDecode()) {
            dEPState.setGoldLabels(dEPTree.getHeads());
            dEPTree.clearHeads();
        }
        return dEPState;
    }

    protected void processAux(DEPState dEPState) {
        List<StringInstance> parse = parse(dEPState);
        if (isTrainOrBootstrap()) {
            this.s_spaces[0].addInstances(parse);
        }
        if (isDecode() && dEPState.resetPOSTags()) {
            dEPState.reInit();
            processAux(dEPState);
        }
    }

    protected List<StringInstance> parse(DEPState dEPState) {
        List<StringInstance> parseOne = parseOne(dEPState);
        if (dEPState.hasMoreState()) {
            parseOne.addAll(parseBranches(dEPState));
        }
        return parseOne;
    }

    protected List<StringInstance> parseOne(DEPState dEPState) {
        ArrayList newArrayList = Lists.newArrayList();
        while (dEPState.isBetaValid()) {
            if (!dEPState.isLambdaValid()) {
                dEPState.shift();
            } else if (!resetPre(dEPState)) {
                DEPNode lambda = dEPState.getLambda();
                DEPNode beta = dEPState.getBeta();
                DEPLabel label = getLabel(newArrayList, dEPState);
                parseAux(label, dEPState);
                resetPost(lambda, beta, label, dEPState);
            }
        }
        dEPState.trimStates(this.n_beams);
        dEPState.addBranch(newArrayList);
        return newArrayList;
    }

    protected void parseAux(DEPLabel dEPLabel, DEPState dEPState) {
        DEPNode lambda = dEPState.getLambda();
        DEPNode beta = dEPState.getBeta();
        dEPState.increaseTransitionCount();
        dEPState.addScore(dEPLabel.score);
        if (dEPLabel.isArc("L")) {
            if (lambda.id == 0) {
                dEPState.shift();
                return;
            }
            if (beta.isDescendentOf(lambda)) {
                dEPState.pass();
                return;
            }
            leftArc(lambda, beta, dEPLabel.deprel);
            if (dEPLabel.isList("R")) {
                dEPState.reduce();
                return;
            } else {
                dEPState.pass();
                return;
            }
        }
        if (dEPLabel.isArc("R")) {
            if (lambda.isDescendentOf(beta)) {
                dEPState.pass();
                return;
            }
            rightArc(lambda, beta, dEPLabel.deprel);
            if (dEPLabel.isList("S")) {
                dEPState.shift();
                return;
            } else {
                dEPState.pass();
                return;
            }
        }
        if (dEPLabel.isList("S")) {
            dEPState.shift();
        } else if (dEPLabel.isList("R") && lambda.hasHead()) {
            dEPState.reduce();
        } else {
            dEPState.pass();
        }
    }

    protected DEPLabel getLabel(List<StringInstance> list, DEPState dEPState) {
        StringFeatureVector featureVector = getFeatureVector(this.f_xmls[0], dEPState);
        DEPLabel dEPLabel = null;
        if (isTrain()) {
            dEPLabel = dEPState.getGoldLabel();
            list.add(new StringInstance(dEPLabel.toString(), featureVector));
        } else if (isDevelopOrDecode()) {
            dEPLabel = getAutoLabel(featureVector, dEPState);
        } else if (isBootstrap()) {
            dEPLabel = getAutoLabel(featureVector, dEPState);
            list.add(new StringInstance(dEPState.getGoldLabel().toString(), featureVector));
        }
        return dEPLabel;
    }

    private DEPLabel getAutoLabel(StringFeatureVector stringFeatureVector, DEPState dEPState) {
        List<StringPrediction> predictions = getPredictions(stringFeatureVector, dEPState);
        DEPLabel dEPLabel = new DEPLabel(predictions.get(0).label, predictions.get(0).score);
        DEPLabel dEPLabel2 = new DEPLabel(predictions.get(1).label, predictions.get(1).score);
        if (dEPLabel.score - dEPLabel2.score < this.d_margin) {
            if (dEPLabel.isArc(IDEPLabel.LB_NO)) {
                dEPState.add2ndHead(dEPLabel2);
            }
            dEPState.addState(dEPLabel2);
        }
        return dEPLabel;
    }

    private List<StringPrediction> getPredictions(StringFeatureVector stringFeatureVector, DEPState dEPState) {
        List<StringPrediction> predictAll = this.s_models[0].predictAll(stringFeatureVector);
        AbstractAlgorithm.normalize(predictAll);
        rerankPredictions(predictAll, dEPState);
        return predictAll;
    }

    public void leftArc(DEPNode dEPNode, DEPNode dEPNode2, String str) {
        dEPNode.setHead(dEPNode2, str);
    }

    public void rightArc(DEPNode dEPNode, DEPNode dEPNode2, String str) {
        dEPNode2.setHead(dEPNode, str);
    }

    protected void processHeadless(DEPState dEPState) {
        Triple<DEPNode, String, Double> triple = new Triple<>(null, null, Double.valueOf(-1.0d));
        DEPNode node = dEPState.getNode(0);
        int treeSize = dEPState.getTreeSize();
        for (int i = 1; i < treeSize; i++) {
            DEPNode node2 = dEPState.getNode(i);
            if (!node2.hasHead()) {
                List<DEPHead> list = dEPState.get2ndHeads(node2.id);
                if (!list.isEmpty()) {
                    Iterator<DEPHead> it = list.iterator();
                    while (true) {
                        if (!it.hasNext()) {
                            break;
                        }
                        DEPHead next = it.next();
                        DEPNode node3 = dEPState.getNode(next.headId);
                        if (!isNotHead(node3) && !node3.isDescendentOf(node2)) {
                            node2.setHead(node3, next.deprel);
                            break;
                        }
                    }
                }
                if (!node2.hasHead()) {
                    triple.set(node, "root", Double.valueOf(-1.0d));
                    processHeadlessAux(node2, -1, triple, dEPState);
                    processHeadlessAux(node2, 1, triple, dEPState);
                    node2.setHead(triple.o1, triple.o2);
                }
            }
        }
    }

    protected void processHeadlessAux(DEPNode dEPNode, int i, Triple<DEPNode, String, Double> triple, DEPState dEPState) {
        int treeSize = dEPState.getTreeSize();
        if (i < 0) {
            dEPState.setBeta(dEPNode.id);
        } else {
            dEPState.setLambda(dEPNode.id);
        }
        int i2 = dEPNode.id;
        while (true) {
            int i3 = i2 + i;
            if (0 > i3 || i3 >= treeSize) {
                return;
            }
            DEPNode node = dEPState.getNode(i3);
            if (!node.isDescendentOf(dEPNode)) {
                if (i < 0) {
                    dEPState.setLambda(i3);
                } else {
                    dEPState.setBeta(i3);
                }
                for (StringPrediction stringPrediction : getPredictions(getFeatureVector(this.f_xmls[0], dEPState), dEPState)) {
                    if (stringPrediction.score <= triple.o3.doubleValue()) {
                        break;
                    }
                    DEPLabel dEPLabel = new DEPLabel(stringPrediction.label);
                    if ((i < 0 && dEPLabel.isArc("R")) || (i > 0 && dEPLabel.isArc("L"))) {
                        triple.set(node, dEPLabel.deprel, Double.valueOf(stringPrediction.score));
                        break;
                    }
                }
            }
            i2 = i3;
        }
    }

    public List<StringInstance> parseBranches(DEPState dEPState) {
        ObjectsDoubleTriple<List<StringInstance>, StringIntPair[]> bestBranch;
        branch(dEPState);
        if (isDevelopOrDecode()) {
            bestBranch = dEPState.getBestBranch();
            dEPState.resetHeads(bestBranch.o2);
        } else {
            dEPState.setGoldScoresToBranches();
            bestBranch = dEPState.getBestBranch();
        }
        return bestBranch.o1;
    }

    private void branch(DEPState dEPState) {
        dEPState.disableBranching();
        while (true) {
            DEPLabel toNextState = dEPState.setToNextState();
            if (toNextState == null) {
                return;
            }
            parseAux(toNextState, dEPState);
            parseOne(dEPState);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.clearnlp.component.AbstractStatisticalComponent
    public String getField(FtrToken ftrToken, DEPState dEPState) {
        DEPNode node = dEPState.getNode(ftrToken);
        if (node == null) {
            return null;
        }
        if (ftrToken.isField("f")) {
            return node.form;
        }
        if (ftrToken.isField("sf")) {
            return node.simplifiedForm;
        }
        if (ftrToken.isField(JointFtrXml.F_LEMMA)) {
            return node.lemma;
        }
        if (ftrToken.isField("p")) {
            return node.pos;
        }
        if (ftrToken.isField("d")) {
            return node.getLabel();
        }
        if (ftrToken.isField("n")) {
            int distance = dEPState.getDistance();
            return distance > 6 ? "6" : Integer.toString(distance);
        }
        if (ftrToken.isField(JointFtrXml.F_LEFT_VALENCY)) {
            return dEPState.getLeftValency(node.id);
        }
        if (ftrToken.isField(JointFtrXml.F_RIGHT_VALENCY)) {
            return dEPState.getRightValency(node.id);
        }
        Matcher matcher = JointFtrXml.P_BOOLEAN.matcher(ftrToken.field);
        if (!matcher.find()) {
            Matcher matcher2 = JointFtrXml.P_FEAT.matcher(ftrToken.field);
            if (matcher2.find()) {
                return node.getFeat(matcher2.group(1));
            }
            return null;
        }
        int parseInt = Integer.parseInt(matcher.group(1));
        switch (parseInt) {
            case 0:
                if (dEPState.isLambdaFirst()) {
                    return ftrToken.field;
                }
                return null;
            case 1:
                if (dEPState.isBetaLast()) {
                    return ftrToken.field;
                }
                return null;
            case 2:
                if (dEPState.isLambdaBetaAdjacent()) {
                    return ftrToken.field;
                }
                return null;
            default:
                throw new IllegalArgumentException("Unsupported feature: " + parseInt);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.clearnlp.component.AbstractStatisticalComponent
    public String[] getFields(FtrToken ftrToken, DEPState dEPState) {
        return null;
    }
}
