package com.clearnlp.component.srl;

import com.clearnlp.classification.feature.FtrToken;
import com.clearnlp.classification.feature.JointFtrXml;
import com.clearnlp.classification.instance.StringInstance;
import com.clearnlp.classification.model.StringModel;
import com.clearnlp.classification.prediction.StringPrediction;
import com.clearnlp.classification.train.StringTrainSpace;
import com.clearnlp.classification.vector.StringFeatureVector;
import com.clearnlp.component.AbstractStatisticalComponent;
import com.clearnlp.component.evaluation.SRLEval;
import com.clearnlp.component.label.IDEPLabel;
import com.clearnlp.component.state.SRLState;
import com.clearnlp.dependency.DEPArc;
import com.clearnlp.dependency.DEPLib;
import com.clearnlp.dependency.DEPNode;
import com.clearnlp.dependency.DEPTree;
import com.clearnlp.dependency.srl.SRLLib;
import com.clearnlp.propbank.PBLib;
import com.clearnlp.propbank.frameset.AbstractFrames;
import com.clearnlp.propbank.frameset.PBRoleset;
import com.clearnlp.propbank.frameset.PBType;
import com.clearnlp.util.UTCollection;
import com.clearnlp.util.map.Prob1DMap;
import com.clearnlp.util.pair.ObjectDoublePair;
import com.clearnlp.util.pair.StringIntPair;
import is2.data.PipeGen;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.regex.Matcher;

/* loaded from: input_file:com/clearnlp/component/srl/AbstractSRLabeler.class */
public abstract class AbstractSRLabeler extends AbstractStatisticalComponent<SRLState> {
    protected final int LEXICA_PATH_UP = 0;
    protected final int LEXICA_PATH_DOWN = 1;
    protected final int LEXICA_FRAMES = 2;
    protected final int PATH_ALL = 0;
    protected final int PATH_UP = 1;
    protected final int PATH_DOWN = 2;
    protected final int SUBCAT_ALL = 0;
    protected final int SUBCAT_LEFT = 1;
    protected final int SUBCAT_RIGHT = 2;
    protected final String LB_NO_ARG = "N";
    protected Prob1DMap m_down;
    protected Prob1DMap m_up;
    protected Set<String> s_down;
    protected Set<String> s_up;
    protected AbstractFrames m_frames;

    public AbstractSRLabeler(JointFtrXml[] jointFtrXmlArr, AbstractFrames abstractFrames) {
        super(jointFtrXmlArr);
        this.LEXICA_PATH_UP = 0;
        this.LEXICA_PATH_DOWN = 1;
        this.LEXICA_FRAMES = 2;
        this.PATH_ALL = 0;
        this.PATH_UP = 1;
        this.PATH_DOWN = 2;
        this.SUBCAT_ALL = 0;
        this.SUBCAT_LEFT = 1;
        this.SUBCAT_RIGHT = 2;
        this.LB_NO_ARG = IDEPLabel.LB_NO;
        this.m_down = new Prob1DMap();
        this.m_up = new Prob1DMap();
        this.m_frames = abstractFrames;
    }

    public AbstractSRLabeler(JointFtrXml[] jointFtrXmlArr, StringTrainSpace[] stringTrainSpaceArr, Object[] objArr) {
        super(jointFtrXmlArr, stringTrainSpaceArr, objArr);
        this.LEXICA_PATH_UP = 0;
        this.LEXICA_PATH_DOWN = 1;
        this.LEXICA_FRAMES = 2;
        this.PATH_ALL = 0;
        this.PATH_UP = 1;
        this.PATH_DOWN = 2;
        this.SUBCAT_ALL = 0;
        this.SUBCAT_LEFT = 1;
        this.SUBCAT_RIGHT = 2;
        this.LB_NO_ARG = IDEPLabel.LB_NO;
    }

    public AbstractSRLabeler(JointFtrXml[] jointFtrXmlArr, StringModel[] stringModelArr, Object[] objArr) {
        super(jointFtrXmlArr, stringModelArr, objArr, new SRLEval());
        this.LEXICA_PATH_UP = 0;
        this.LEXICA_PATH_DOWN = 1;
        this.LEXICA_FRAMES = 2;
        this.PATH_ALL = 0;
        this.PATH_UP = 1;
        this.PATH_DOWN = 2;
        this.SUBCAT_ALL = 0;
        this.SUBCAT_LEFT = 1;
        this.SUBCAT_RIGHT = 2;
        this.LB_NO_ARG = IDEPLabel.LB_NO;
    }

    public AbstractSRLabeler(ObjectInputStream objectInputStream) {
        super(objectInputStream);
        this.LEXICA_PATH_UP = 0;
        this.LEXICA_PATH_DOWN = 1;
        this.LEXICA_FRAMES = 2;
        this.PATH_ALL = 0;
        this.PATH_UP = 1;
        this.PATH_DOWN = 2;
        this.SUBCAT_ALL = 0;
        this.SUBCAT_LEFT = 1;
        this.SUBCAT_RIGHT = 2;
        this.LB_NO_ARG = IDEPLabel.LB_NO;
    }

    public AbstractSRLabeler(JointFtrXml[] jointFtrXmlArr, StringTrainSpace[] stringTrainSpaceArr, StringModel[] stringModelArr, Object[] objArr) {
        super(jointFtrXmlArr, stringTrainSpaceArr, stringModelArr, objArr);
        this.LEXICA_PATH_UP = 0;
        this.LEXICA_PATH_DOWN = 1;
        this.LEXICA_FRAMES = 2;
        this.PATH_ALL = 0;
        this.PATH_UP = 1;
        this.PATH_DOWN = 2;
        this.SUBCAT_ALL = 0;
        this.SUBCAT_LEFT = 1;
        this.SUBCAT_RIGHT = 2;
        this.LB_NO_ARG = IDEPLabel.LB_NO;
    }

    @Override // com.clearnlp.component.AbstractStatisticalComponent
    protected void initLexia(Object[] objArr) {
        this.s_down = (Set) objArr[1];
        this.s_up = (Set) objArr[0];
        this.m_frames = (AbstractFrames) objArr[2];
    }

    protected abstract String getHardLabel(SRLState sRLState, String str);

    protected abstract PBType getPBType(DEPNode dEPNode);

    protected abstract void postLabel(SRLState sRLState);

    protected abstract DEPNode getPossibleDescendent(DEPNode dEPNode, DEPNode dEPNode2);

    protected abstract boolean rerankFromArgument(StringPrediction stringPrediction, DEPNode dEPNode);

    @Override // com.clearnlp.component.AbstractStatisticalComponent
    public void load(ObjectInputStream objectInputStream) {
        try {
            loadDefault(objectInputStream);
            loadLexica(objectInputStream);
            objectInputStream.close();
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    @Override // com.clearnlp.component.AbstractStatisticalComponent
    public void save(ObjectOutputStream objectOutputStream) {
        try {
            saveDefault(objectOutputStream);
            saveLexica(objectOutputStream);
            objectOutputStream.close();
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    private void loadLexica(ObjectInputStream objectInputStream) throws Exception {
        this.LOG.info("Loading lexica.\n");
        initLexia(new Object[]{objectInputStream.readObject(), objectInputStream.readObject(), objectInputStream.readObject()});
    }

    private void saveLexica(ObjectOutputStream objectOutputStream) throws Exception {
        this.LOG.info("Saving lexica.\n");
        objectOutputStream.writeObject(this.s_down);
        objectOutputStream.writeObject(this.s_up);
        objectOutputStream.writeObject(this.m_frames);
    }

    @Override // com.clearnlp.component.AbstractStatisticalComponent
    public Object[] getLexica() {
        Object[] objArr = new Object[3];
        objArr[1] = isLexica() ? this.m_down.toSet(this.f_xmls[0].getPathDownCutoff()) : this.s_down;
        objArr[0] = isLexica() ? this.m_up.toSet(this.f_xmls[0].getPathUpCutoff()) : this.s_up;
        objArr[2] = this.m_frames;
        return objArr;
    }

    public Set<String> getDownSet(int i) {
        return this.m_down.toSet(i);
    }

    public Set<String> getUpSet(int i) {
        return this.m_up.toSet(i);
    }

    @Override // com.clearnlp.component.AbstractStatisticalComponent
    public Set<String> getLabels() {
        return getDefaultLabels();
    }

    @Override // com.clearnlp.component.AbstractComponent
    public void process(DEPTree dEPTree) {
        SRLState init = init(dEPTree);
        processAux(init);
        if (isDevelop()) {
            this.e_eval.countAccuracy(init.getTree(), init.getGoldLabels());
        }
    }

    protected SRLState init(DEPTree dEPTree) {
        SRLState sRLState = new SRLState(dEPTree);
        if (isDecode()) {
            dEPTree.initSHeads();
        } else {
            sRLState.setGoldLabels(dEPTree.getSHeads());
            dEPTree.clearSHeads();
        }
        return sRLState;
    }

    protected void processAux(SRLState sRLState) {
        if (isLexica()) {
            addLexica(sRLState);
        } else {
            label(sRLState);
        }
    }

    private void addLexica(SRLState sRLState) {
        while (true) {
            DEPNode moveToNextPredicate = sRLState.moveToNextPredicate();
            if (moveToNextPredicate == null) {
                return;
            }
            Iterator<DEPArc> it = moveToNextPredicate.getGrandDependents().iterator();
            while (it.hasNext()) {
                collectDown(moveToNextPredicate, it.next().getNode());
            }
            DEPNode head = moveToNextPredicate.getHead();
            if (head != null) {
                collectUp(moveToNextPredicate, head.getHead());
            }
        }
    }

    private void collectDown(DEPNode dEPNode, DEPNode dEPNode2) {
        if (dEPNode2.isArgumentOf(dEPNode)) {
            Iterator<String> it = getDUPathList(dEPNode, dEPNode2.getHead()).iterator();
            while (it.hasNext()) {
                this.m_down.add(it.next());
            }
        }
        Iterator<DEPArc> it2 = dEPNode2.getDependents().iterator();
        while (it2.hasNext()) {
            collectDown(dEPNode, it2.next().getNode());
        }
    }

    private void collectUp(DEPNode dEPNode, DEPNode dEPNode2) {
        if (dEPNode2 == null) {
            return;
        }
        Iterator<DEPArc> it = dEPNode2.getDependents().iterator();
        while (true) {
            if (!it.hasNext()) {
                break;
            }
            if (it.next().getNode().isArgumentOf(dEPNode)) {
                Iterator<String> it2 = getDUPathList(dEPNode2, dEPNode).iterator();
                while (it2.hasNext()) {
                    this.m_up.add(it2.next());
                }
            }
        }
        collectUp(dEPNode, dEPNode2.getHead());
    }

    private String getDUPath(DEPNode dEPNode, DEPNode dEPNode2) {
        return getPathAux(dEPNode, dEPNode2, "d", "|", true);
    }

    private List<String> getDUPathList(DEPNode dEPNode, DEPNode dEPNode2) {
        ArrayList arrayList = new ArrayList();
        while (dEPNode2 != dEPNode) {
            arrayList.add(getDUPath(dEPNode, dEPNode2));
            dEPNode2 = dEPNode2.getHead();
        }
        return arrayList;
    }

    private void label(SRLState sRLState) {
        while (true) {
            DEPNode moveToNextPredicate = sRLState.moveToNextPredicate();
            if (moveToNextPredicate == null) {
                postLabel(sRLState);
                return;
            } else {
                setRoleset(moveToNextPredicate, sRLState);
                do {
                    labelAux(sRLState);
                } while (sRLState.moveToNextLowestCommonAncestor());
            }
        }
    }

    private void setRoleset(DEPNode dEPNode, SRLState sRLState) {
        PBType pBType;
        if (this.m_frames == null || (pBType = getPBType(dEPNode)) == null) {
            return;
        }
        sRLState.setRoleset(this.m_frames.getRoleset(pBType, dEPNode.lemma, dEPNode.getFeat(DEPLib.FEAT_PB)));
    }

    private void labelAux(SRLState sRLState) {
        DEPNode lowestCommonAncestor = sRLState.getLowestCommonAncestor();
        if (!sRLState.isSkip(lowestCommonAncestor)) {
            sRLState.setArgument(lowestCommonAncestor);
            addArgument(getLabel(sRLState), sRLState);
        }
        labelDown(lowestCommonAncestor.getDependents(), sRLState);
    }

    private void labelDown(List<DEPArc> list, SRLState sRLState) {
        DEPNode currentPredicate = sRLState.getCurrentPredicate();
        Iterator<DEPArc> it = list.iterator();
        while (it.hasNext()) {
            DEPNode node = it.next().getNode();
            if (!sRLState.isSkip(node)) {
                sRLState.setArgument(node);
                addArgument(getLabel(sRLState), sRLState);
                if (sRLState.isLowestCommonAncestor(currentPredicate)) {
                    if (this.s_down.contains(getDUPath(currentPredicate, node))) {
                        labelDown(node.getDependents(), sRLState);
                    } else {
                        DEPNode possibleDescendent = getPossibleDescendent(currentPredicate, node);
                        if (possibleDescendent != null) {
                            labelDown(possibleDescendent.getDependents(), sRLState);
                        }
                    }
                }
            }
        }
    }

    private StringPrediction getLabel(SRLState sRLState) {
        StringFeatureVector featureVector = getFeatureVector(this.f_xmls[0], sRLState);
        int direction = sRLState.getDirection();
        StringPrediction stringPrediction = null;
        if (isTrain()) {
            stringPrediction = new StringPrediction(getGoldLabel(sRLState), 1.0d);
            this.s_spaces[direction].addInstance(new StringInstance(stringPrediction.label, featureVector));
        } else if (isDevelopOrDecode()) {
            stringPrediction = getAutoLabel(direction, featureVector, sRLState);
        } else if (isBootstrap()) {
            stringPrediction = getAutoLabel(direction, featureVector, sRLState);
            this.s_spaces[direction].addInstance(new StringInstance(getGoldLabel(sRLState), featureVector));
        }
        return stringPrediction;
    }

    private String getGoldLabel(SRLState sRLState) {
        for (StringIntPair stringIntPair : sRLState.getGoldLabel()) {
            if (stringIntPair.i == sRLState.getCurrPredicateID()) {
                return stringIntPair.s;
            }
        }
        return IDEPLabel.LB_NO;
    }

    private StringPrediction getAutoLabel(int i, StringFeatureVector stringFeatureVector, SRLState sRLState) {
        String hardLabel;
        StringPrediction bestPrediction = getBestPrediction(this.s_models[i], stringFeatureVector, sRLState);
        if (isDecode() && !bestPrediction.label.equals(IDEPLabel.LB_NO) && (hardLabel = getHardLabel(sRLState, bestPrediction.label)) != null) {
            bestPrediction.label = hardLabel;
        }
        return bestPrediction;
    }

    private void addArgument(StringPrediction stringPrediction, SRLState sRLState) {
        DEPNode currentArgument = sRLState.getCurrentArgument();
        sRLState.addArgumentToSkipList();
        if (stringPrediction.label.equals(IDEPLabel.LB_NO)) {
            return;
        }
        if (PBLib.isNumberedArgument(stringPrediction.label)) {
            sRLState.addNumberedArgument(stringPrediction.label);
            if (PBLib.isCoreNumberedArgument(stringPrediction.label)) {
                ObjectDoublePair<DEPNode> coreNumberedArgument = sRLState.getCoreNumberedArgument(stringPrediction.label);
                if (coreNumberedArgument != null) {
                    ((DEPNode) coreNumberedArgument.o).removeSHeadsByLabel(stringPrediction.label);
                }
                sRLState.putCoreNumberedArgument(stringPrediction.label, new ObjectDoublePair<>(currentArgument, stringPrediction.score));
            }
        }
        String str = "";
        PBRoleset roleset = sRLState.getRoleset();
        if (!stringPrediction.label.contains("-") && roleset != null) {
            str = roleset.getFunctionTag(PBLib.getNumber(stringPrediction.label));
        }
        currentArgument.addSHead(sRLState.getCurrentPredicate(), stringPrediction.label, str);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.clearnlp.component.AbstractStatisticalComponent
    public String getField(FtrToken ftrToken, SRLState sRLState) {
        DEPNode node = sRLState.getNode(ftrToken);
        if (node == null) {
            return null;
        }
        if (ftrToken.isField("f")) {
            return node.form;
        }
        if (ftrToken.isField(JointFtrXml.F_LEMMA)) {
            return node.lemma;
        }
        if (ftrToken.isField("p")) {
            return node.pos;
        }
        if (ftrToken.isField("d")) {
            return node.getLabel();
        }
        if (ftrToken.isField("n")) {
            return getDistance(node, sRLState);
        }
        Matcher matcher = JointFtrXml.P_ARGN.matcher(ftrToken.field);
        if (matcher.find()) {
            return sRLState.getNumberedArgument(Integer.parseInt(matcher.group(1)));
        }
        Matcher matcher2 = JointFtrXml.P_PATH.matcher(ftrToken.field);
        if (matcher2.find()) {
            return getPath(matcher2.group(1), Integer.parseInt(matcher2.group(2)), sRLState);
        }
        Matcher matcher3 = JointFtrXml.P_SUBCAT.matcher(ftrToken.field);
        if (matcher3.find()) {
            return getSubcat(node, matcher3.group(1), Integer.parseInt(matcher3.group(2)));
        }
        Matcher matcher4 = JointFtrXml.P_FEAT.matcher(ftrToken.field);
        if (matcher4.find()) {
            return node.getFeat(matcher4.group(1));
        }
        Matcher matcher5 = JointFtrXml.P_BOOLEAN.matcher(ftrToken.field);
        if (!matcher5.find()) {
            return null;
        }
        DEPNode currentPredicate = sRLState.getCurrentPredicate();
        switch (Integer.parseInt(matcher5.group(1))) {
            case 0:
                if (node.isDependentOf(currentPredicate)) {
                    return ftrToken.field;
                }
                return null;
            case 1:
                if (currentPredicate.isDependentOf(node)) {
                    return ftrToken.field;
                }
                return null;
            case 2:
                if (currentPredicate.isDependentOf(sRLState.getLowestCommonAncestor())) {
                    return ftrToken.field;
                }
                return null;
            case 3:
                if (sRLState.isLowestCommonAncestor(currentPredicate)) {
                    return ftrToken.field;
                }
                return null;
            case 4:
                if (sRLState.isLowestCommonAncestor(node)) {
                    return ftrToken.field;
                }
                return null;
            default:
                return null;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.clearnlp.component.AbstractStatisticalComponent
    public String[] getFields(FtrToken ftrToken, SRLState sRLState) {
        DEPNode node = sRLState.getNode(ftrToken);
        if (node == null) {
            return null;
        }
        if (ftrToken.isField(JointFtrXml.F_DEPREL_SET)) {
            return getDeprelSet(node.getDependents());
        }
        if (ftrToken.isField(JointFtrXml.F_GRAND_DEPREL_SET)) {
            return getDeprelSet(node.getGrandDependents());
        }
        return null;
    }

    private String getDistance(DEPNode dEPNode, SRLState sRLState) {
        int abs = Math.abs(sRLState.getCurrPredicateID() - dEPNode.id);
        return abs <= 5 ? "0" : abs <= 10 ? "1" : abs <= 15 ? "2" : PipeGen._3;
    }

    private String getPath(String str, int i, SRLState sRLState) {
        DEPNode currentPredicate = sRLState.getCurrentPredicate();
        DEPNode currentArgument = sRLState.getCurrentArgument();
        DEPNode lowestCommonAncestor = sRLState.getLowestCommonAncestor();
        if (i == 1) {
            if (lowestCommonAncestor != currentPredicate) {
                return getPathAux(lowestCommonAncestor, currentPredicate, str, SRLLib.DELIM_PATH_UP, true);
            }
            return null;
        }
        if (i == 2) {
            if (lowestCommonAncestor != currentArgument) {
                return getPathAux(lowestCommonAncestor, currentArgument, str, "|", true);
            }
            return null;
        }
        if (currentPredicate == lowestCommonAncestor) {
            return getPathAux(currentPredicate, currentArgument, str, "|", true);
        }
        if (currentPredicate.isDescendentOf(currentArgument)) {
            return getPathAux(currentArgument, currentPredicate, str, SRLLib.DELIM_PATH_UP, true);
        }
        return getPathAux(lowestCommonAncestor, currentPredicate, str, SRLLib.DELIM_PATH_UP, true) + getPathAux(lowestCommonAncestor, currentArgument, str, "|", false);
    }

    private String getPathAux(DEPNode dEPNode, DEPNode dEPNode2, String str, String str2, boolean z) {
        StringBuilder sb = new StringBuilder();
        DEPNode dEPNode3 = dEPNode2;
        int i = 0;
        do {
            if (str.equals("p")) {
                sb.append(str2);
                sb.append(dEPNode3.pos);
            } else if (str.equals("d")) {
                sb.append(str2);
                sb.append(dEPNode3.getLabel());
            } else if (str.equals("n")) {
                i++;
            }
            dEPNode3 = dEPNode3.getHead();
            if (dEPNode3 == dEPNode) {
                break;
            }
        } while (dEPNode3 != null);
        if (str.equals("p")) {
            if (z) {
                sb.append(str2);
                sb.append(dEPNode.pos);
            }
        } else if (str.equals("n")) {
            sb.append(str2);
            sb.append(i);
        }
        if (sb.length() == 0) {
            return null;
        }
        return sb.toString();
    }

    private String getSubcat(DEPNode dEPNode, String str, int i) {
        List<DEPArc> dependents = dEPNode.getDependents();
        StringBuilder sb = new StringBuilder();
        int size = dependents.size();
        if (i == 1) {
            for (int i2 = 0; i2 < size; i2++) {
                DEPNode node = dependents.get(i2).getNode();
                if (node.id > dEPNode.id) {
                    break;
                }
                getSubcatAux(sb, node, str);
            }
        } else if (i == 2) {
            for (int i3 = size - 1; i3 >= 0; i3--) {
                DEPNode node2 = dependents.get(i3).getNode();
                if (node2.id < dEPNode.id) {
                    break;
                }
                getSubcatAux(sb, node2, str);
            }
        } else {
            for (int i4 = 0; i4 < size; i4++) {
                getSubcatAux(sb, dependents.get(i4).getNode(), str);
            }
        }
        if (sb.length() == 0) {
            return null;
        }
        return sb.substring("_".length());
    }

    private void getSubcatAux(StringBuilder sb, DEPNode dEPNode, String str) {
        sb.append("_");
        if (str.equals("p")) {
            sb.append(dEPNode.pos);
        } else if (str.equals("d")) {
            sb.append(dEPNode.getLabel());
        }
    }

    private StringPrediction getBestPrediction(StringModel stringModel, StringFeatureVector stringFeatureVector, SRLState sRLState) {
        List<StringPrediction> predictAll = stringModel.predictAll(stringFeatureVector);
        rerankPredictions(predictAll, sRLState);
        return predictAll.get(0);
    }

    protected void rerankPredictions(List<StringPrediction> list, SRLState sRLState) {
        DEPNode currentArgument = sRLState.getCurrentArgument();
        boolean z = false;
        for (StringPrediction stringPrediction : list) {
            if (rerankFrameMismatch(stringPrediction, sRLState) || rerankRedundantNumberedArgument(stringPrediction, sRLState) || rerankFromArgument(stringPrediction, currentArgument)) {
                stringPrediction.score = -1.0d;
                z = true;
            }
        }
        if (z) {
            UTCollection.sortReverseOrder(list);
        }
    }

    protected boolean rerankFrameMismatch(StringPrediction stringPrediction, SRLState sRLState) {
        PBRoleset roleset = sRLState.getRoleset();
        return (roleset == null || roleset.isValidArgument(stringPrediction.label)) ? false : true;
    }

    protected boolean rerankRedundantNumberedArgument(StringPrediction stringPrediction, SRLState sRLState) {
        ObjectDoublePair<DEPNode> coreNumberedArgument = sRLState.getCoreNumberedArgument(stringPrediction.label);
        return coreNumberedArgument != null && coreNumberedArgument.d >= stringPrediction.score;
    }
}
