# Copyright 2024, Battelle Energy Alliance, LLC  ALL RIGHTS RESERVED
"""
Created on April, 2024
@author: wangc, mandd
"""
import logging
import pandas as pd
import re
from operator import itemgetter
from spacy.tokens import Token
from spacy.tokens import Span
from ..text_processing.Preprocessing import Preprocessing
from ..utils.utils import getOnlyWords, getShortAcronym
from ..config import nlpConfig
from .CausalBase import CausalBase
from ..pipelines.CustomPipelineComponents import mergeEntitiesWithSameID
[docs]
logger = logging.getLogger(__name__) 
if not Span.has_extension('conjecture'):
  Span.set_extension('conjecture', default=False)
if not Span.has_extension('status'):
  Span.set_extension("status", default=None)
if not Span.has_extension('neg'):
  Span.set_extension("neg", default=None)
if not Span.has_extension('neg_text'):
  Span.set_extension("neg_text", default=None)
if not Span.has_extension('alias'):
  Span.set_extension("alias", default=None)
if not Span.has_extension('action'):
  Span.set_extension("action", default=None)
if not Span.has_extension('edep'):
  Span.set_extension("edep", default=None)
if not Token.has_extension('conjecture'):
  Token.set_extension('conjecture', default=False)
if not Token.has_extension('status'):
  Token.set_extension("status", default=None)
if not Token.has_extension('neg'):
  Token.set_extension("neg", default=None)
if not Token.has_extension('neg_text'):
  Token.set_extension("neg_text", default=None)
if not Token.has_extension('alias'):
  Token.set_extension("alias", default=None)
if not Token.has_extension('action'):
  Token.set_extension("action", default=None)
if not Token.has_extension('edep'):
  Token.set_extension("edep", default=None)
[docs]
class CausalSimple(CausalBase):
  """
    Class to process OPG Operator Shift Logs dataset
  """
  def __init__(self, nlp, entID='SSC', causalKeywordID='causal', *args, **kwargs):
    """
      Construct
      Args:
        nlp: spacy.Language object, contains all components and data needed to process text
        args: list, positional arguments
        kwargs: dict, keyword arguments
      Returns:
        None
    """
    super().__init__(nlp, entID, causalKeywordID, *args, **kwargs)
    if not nlp.has_pipe('mergeEntitiesWithSameID'):
      self.nlp.add_pipe('mergeEntitiesWithSameID', after='aliasResolver')
[docs]
    self._subjList = ['nsubj', 'nsubjpass', 'nsubj:pass'] 
[docs]
    self._objList = ['pobj', 'dobj', 'iobj', 'obj', 'obl', 'oprd'] 
[docs]
    self._entInfoNames = ['entity', 'label', 'status', 'amod', 'action', 'dep', 'alias', 'negation', 'conjecture', 'sentence'] 
[docs]
  def reset(self):
    """
      Reset rule-based matcher
    """
    super().reset()
    self._entInfoNames = None 
[docs]
  def textProcess(self):
    """
      Function to clean text
      Args:
        None
      Returns:
        procObj, DACKAR.Preprocessing object
    """
    procObj = super().textProcess()
    return procObj 
        # print(self._rawCausalList)
[docs]
  def handleValidSent(self, sent, ents):
    """
      Handle sentence that do not have (subj, predicate, obj)
    """
    root = sent.root
    neg, negText = self.isNegation(root)
    conjecture = self.isConjecture(root)
    sent._.set('neg',neg)
    sent._.set('neg_text',negText)
    sent._.set('conjecture',conjecture)
    root = sent.root
    action = root if root.pos_ in ['VERB', 'AUX'] else None
    sent._.set('action', action)
    for ent in ents:
      neg = None
      negText = None
      status = None        # store health status for identified entities
      entRoot = ent.root
      if entRoot.dep_ in ['nsubj', 'nsubjpass']:
        status, neg, negText = self.getStatusForSubj(ent)
      elif entRoot.dep_ in ['dobj', 'pobj', 'iobj', 'obj', 'obl', 'oprd']:
        status, neg, negText = self.getStatusForObj(ent)
        head = entRoot.head
        if status is None and head.dep_ in ['xcomp', 'advcl', 'relcl']:
          ccomps = [child for child in head.rights if child.dep_ in ['ccomp']]
          status = ccomps[0] if len(ccomps) > 0 else None
      elif entRoot.dep_ in ['compound']:
        status = self.getAmod(ent, ent.start, ent.end, include=False)
        head = entRoot.head
        if status is None and head.dep_ not in ['compound']:
          status = head
      elif entRoot.dep_ in ['conj']:
        # TODO: recursive function to retrieve non-conj
        amod = self.getAmod(ent, ent.start, ent.end, include=False)
        head = entRoot.head
        headStatus = None
        if head.dep_ in ['conj']:
          head = head.head
        headEnt = head.doc[head.i:head.i+1]
        if head.dep_ in ['nsubj', 'nsubjpass']:
          headStatus, neg, negText = self.getStatusForSubj(headEnt)
        elif head.dep_ in ['pobj', 'dobj']:
          headStatus, neg, negText = self.getStatusForObj(headEnt)
          head = entRoot.head
          if headStatus is None and head.dep_ in ['xcomp', 'advcl', 'relcl']:
            ccomps = [child for child in head.rights if child.dep_ in ['ccomp']]
            headStatus = ccomps[0] if len(ccomps) > 0 else None
        if headStatus is None:
          status = amod
        elif isinstance(headStatus, list):
          status = headStatus if amod is None else [amod, headStatus[-1]]
        else:
          status = headStatus if amod is None else [amod, headStatus]
      elif entRoot.dep_ in ['ROOT']:
        status = self.getAmod(ent, ent.start, ent.end, include=False)
        if status is None:
          rights =[tk for tk in list(entRoot.rights) if tk.pos_ in ['VERB', 'NOUN', 'ADJ', 'ADV'] and tk.i >= ent.end]
          if len(rights) > 0:
            status = rights[0]
      else:
        status = self.getAmod(ent, ent.start, ent.end, include=False)
      if isinstance(status, list):
        ent._.set('status', status[1])
        ent._.set('status_amod', status[0])
      else:
        ent._.set('status', status)
      ent._.set('neg', neg)
      ent._.set('neg_text', negText)
      ent._.set('edep', ent.root.dep_)
      if ent.root.head.pos_ in ['VERB', 'AUX']:
        ent._.set('action', ent.root.head)
      elif ent.root.head.dep_ in ['prep'] and ent.root.head.head.pos_ in ['VERB', 'AUX']:
        ent._.set('action', ent.root.head.head) 
[docs]
  def handleInvalidSent(self, sent, ents):
    """
      Handle sentence that do not have (subj, predicate, obj)
    """
    root = sent.root
    neg, negText = self.isNegation(root)
    conjecture = self.isConjecture(root)
    sent._.set('neg',neg)
    sent._.set('neg_text',negText)
    sent._.set('conjecture',conjecture)
    root = sent.root
    action = root if root.pos_ in ['VERB', 'AUX'] else None
    sent._.set('action', action)
    for ent in ents:
      ent._.set('neg', neg)
      ent._.set('neg_text', negText)
      ent._.set('conjecture', conjecture)
      entRoot = ent.root
      ent._.set('edep', entRoot.dep_)
      if entRoot.head.pos_ in ['VERB', 'AUX']:
        ent._.set('action', entRoot.head)
      if ent._.alias is not None:
        # entity at the beginning of sentence
        if ent.start == sent.start:
          status = sent[ent.end:]
          # some clean up for the text
          text = self._textProcess(status.text)
          ent._.set('status', text)
        # entity at the end of sentence
        elif ent.end == sent.end or (ent.end == sent.end - 1 and sent[-1].is_punct):
          text = sent.text
          # substitute entity ID with its alias
          text = re.sub(r"\b%s\b" % str(ent.text) , ent._.alias, text)
          text = self._textProcess(text)
          ent._.set('status', text)
        # entity in the middle of sentence
        else:
          entRoot = ent.root
          # Only include Pred and Obj info
          if entRoot.dep_ in self._subjList:
            status = sent[ent.end:]
            # some clean up for the text
            text = self._textProcess(status.text)
            ent._.set('status', text)
          # Include the whole info with alias substitution
          elif entRoot.dep_ in self._objList:
            text = sent.text
            # substitute entity ID with its alias
            text = re.sub(r"\b%s\b" % str(ent.text) , ent._.alias, text)
            text = getOnlyWords(text)
            text = self._textProcess(text)
            ent._.set('status', text)
      # other type of entities
      else:
        entRoot = ent.root
        if entRoot.dep_ in self._subjList:
          # depend on the application, can use self.getStatusForSubj to get the status
          status = sent[ent.end:]
          # some clean up for the text
          text = self._textProcess(status.text)
          ent._.set('status', text)
        # Include the whole info with alias substitution
        elif entRoot.dep_ in self._objList:
          # depend on the application, can use self.getstatusForObj to get the status
          text = sent.text
          text = getOnlyWords(text)
          text = self._textProcess(text)
          ent._.set('status', text)
        else:
          # If there is single entity, then report it.
          if len(ents) == 1:
            text = sent.text
            text = re.sub(r"\b%s\b" % str(ent.text) , '', text)
            text = getOnlyWords(text)
            text = self._textProcess(text)
            ent._.set('status', text)
          # if the entity not among subj and obj and there are more than one entity, it may not need to report it
          else:
            pass 
[docs]
  def isSubElements(self, elem1, elemList):
    """
    """
    isSub = False
    for elem in elemList:
      isSub = self.isSubElement(elem1, elem)
      if isSub:
        return isSub
    return isSub 
[docs]
  def isSubElement(self, elem1, elem2):
    """
      True if elem1 is a subelement of elem2
    """
    if isinstance(elem1, Token):
      s1, e1 = elem1.i, elem1.i
    elif isinstance(elem1, Span):
      s1, e1 = elem1.start, elem1.end
    else:
      raise IOError("Wrong data type is provided!")
    if isinstance(elem2, Token):
      s2, e2 = elem2.i, elem2.i
    elif isinstance(elem2, Span):
      s2, e2 = elem2.start, elem2.end
    else:
      raise IOError("Wrong data type is provided!")
    if s1 >= s2 and e1 <=e2:
      return True
    else:
      return False