#include "pialign/model-hier.h"

using namespace pialign;
using namespace std;

Prob HierModel::addSentence(const WordString & e, const WordString & f, SpanNode* node, StringWordSet & ePhrases, StringWordSet & fPhrases, PairWordSet & pairs, BaseMeasure * base) {
    if(!node || !node->add) return 0;
    // get the phrase IDs
    const Span & mySpan = node->span;
    int s=mySpan.es,t=mySpan.ee,u=mySpan.fs,v=mySpan.fe;
    WordId eId = ePhrases.getId(e.substr(s,t-s),true),
        fId = fPhrases.getId(f.substr(u,v-u),true);
    node->phraseid = pairs.getId(WordPairHash(eId, fId, GlobalVars::maxPhrase),true);
    int toAdd = node->type;
    // handle either non-terminals or terminals
    Prob rightProb = 0, leftProb = 0;
    if(toAdd == TYPE_REG || toAdd == TYPE_INV) {
        rightProb = addSentence(e,f,node->right,ePhrases,fPhrases,pairs,base);
        leftProb = addSentence(e,f,node->left,ePhrases,fPhrases,pairs,base);
    } else if(toAdd != TYPE_GEN) {
        base->add(node->span,node->phraseid,node->baseProb,node->baseElems);
        toAdd = TYPE_TERM;
    }
    // find the left and right nodes
    WordId lId = (node->left?node->left->phraseid:-1),
            rId = (node->right?node->right->phraseid:-1);
    // add the appropriate values for the derivation
    Prob totProb = 0;
    if(node->type == TYPE_GEN) {
        totProb = log(phrases_.getProb(node->phraseid,0));
        PRINT_DEBUG(" ModelHier::=genProb: "<<totProb<<" @ "<<mySpan<<endl, 2);
        phrases_.addExisting(node->phraseid);
    } else {
        if(node->type == TYPE_BASE) {
            totProb = log(phrases_.getFallbackProb())+addType(toAdd)+node->baseProb;
            PRINT_DEBUG(" ModelHier::=baseProb: "<<totProb<<" @ "<<mySpan<<endl, 2);
        }
        else {
            Prob typeProb = addType(toAdd);
            totProb = log(phrases_.getFallbackProb())+typeProb+leftProb+rightProb;
            PRINT_DEBUG(" ModelHier::=treeProb: "<<log(phrases_.getFallbackProb())<<"+"<<typeProb<<"+"<<leftProb<<"+"<<rightProb<<" == "<<totProb<<" @ "<<mySpan<<endl, 2);
        }
        phrases_.addNew(node->phraseid,lId,rId,toAdd);
    }
    addAverageDerivation(node->phraseid,phrases_.getTotal(node->phraseid),node->prob);
    return totProb;
}

SpanNode* HierModel::removePhrasePair(WordId jId, BaseMeasure * base) {
    if(jId < 0) return 0;
    SpanNode* ret = new SpanNode(Span(0,0,0,0));
    ret->phraseid = jId;
    ret->prob = phrases_.remove(jId);
    PRINT_DEBUG("ret->prob("<<jId<<") = "<<ret->prob<<endl, 2);
    // this was generated from the fallback
    if(phrases_.isRemovedTable()) {
        PyTable<WordId> table = phrases_.getLastTable();
        ret->prob += removeType(table.type);
        ret->type = table.type;
        // generated by breaking down
        if(table.right >= 0) {
            ret->left = removePhrasePair(table.left,base);
            ret->prob += ret->left->prob;
            ret->right = removePhrasePair(table.right,base);
            ret->prob += ret->right->prob;
        }
        // generated directly from the base measure
        else {
            ret->baseProb = base->getBase(jId);
            ret->baseElems = base->getElems(jId);
            ret->prob += ret->baseProb;
            PRINT_DEBUG("ret->baseProb == base["<<jId<<"] == "<<ret->baseProb<<endl, 2);
            ret->type = TYPE_BASE;
            base->remove(ret->span,ret->phraseid,ret->baseProb,ret->baseElems);
        }
    }
    // this was generated from the cache
    else 
        ret->type = TYPE_GEN;
    return ret;
}
