# Copyright 2024, Battelle Energy Alliance, LLC ALL RIGHTS RESERVED
"""
Created on March, 2022
@author: wangc, mandd
"""
import pandas as pd
from spacy.tokenizer import Tokenizer
from spacy.util import compile_infix_regex
from spacy.tokens import Span
# filter_spans is used to resolve the overlap issue in entities
# It gives primacy to longer spans (entities)
from spacy.util import filter_spans
import logging
import networkx as nx
import matplotlib.pyplot as plt
[docs]
logger = logging.getLogger(__name__)
###########################################################################
[docs]
def displayNER(doc, includePunct=False):
"""
Generate data frame for visualization of spaCy doc with custom attributes.
Args:
doc: spacy.tokens.doc.Doc, the processed document using nlp pipelines
includePunct: bool, True if the punctuaction is included
Returns:
df: pandas.DataFrame, data frame contains attributes of tokens
"""
rows = []
for i, t in enumerate(doc):
if not t.is_punct or includePunct:
row = {'token': i,
'text': t.text, 'lemma': t.lemma_,
'pos': t.pos_, 'dep': t.dep_, 'ent_type': t.ent_type_,
'ent_iob_': t.ent_iob_}
if doc.has_extension('coref_chains'):
if t.has_extension('coref_chains') and t._.coref_chains: # neuralcoref attributes
row['coref_chains'] = t._.coref_chains.pretty_representation
else:
row['coref_chains'] = None
if t.has_extension('ref_n'): # referent attribute
row['ref_n'] = t._.ref_n
row['ref_t'] = t._.ref_t
if t.has_extension('ref_ent'): # ref_n/ref_t
row['ref_ent'] = t._.ref_ent
rows.append(row)
df = pd.DataFrame(rows).set_index('token')
df.index.name = None
return df
[docs]
def resetPipeline(nlp, pipes):
"""
remove all custom pipes, and add new pipes
Args:
nlp: spacy.Language object, contains all components and data needed to process text
pipes: list, list of pipes that will be added to nlp pipeline
Returns:
nlp: spacy.Language object, contains updated components and data needed to process text
"""
customPipes = [pipe for (pipe, _) in nlp.pipeline
if pipe not in ['tagger', 'parser',
'tok2vec', 'attribute_ruler', 'lemmatizer', 'ner']]
for pipe in customPipes:
_ = nlp.remove_pipe(pipe)
# re-add specified pipes
for pipe in pipes:
if pipe in ['pysbdSentenceBoundaries']:
# nlp.add_pipe(pipe, before='parser')
nlp.add_pipe(pipe, first=True)
elif pipe not in ['tagger', 'parser', 'tok2vec', 'attribute_ruler', 'lemmatizer', 'ner']:
nlp.add_pipe(pipe)
logger.info(f"Model: {nlp.meta['name']}, Language: {nlp.meta['lang']}")
logger.info('Available pipelines:'+', '.join([pipe for (pipe,_) in nlp.pipeline]))
return nlp
[docs]
def printDepTree(doc, skipPunct=True):
"""
Utility function to pretty print the dependency tree.
Args:
doc: spacy.tokens.doc.Doc, the processed document using nlp pipelines
skipPunct: bool, True if skip punctuactions
Returns:
None
"""
def printRecursive(root, indent, skipPunct):
if not root.dep_ == 'punct' or not skipPunct:
print(" "*indent + f"{root} [{root.pos_}, {root.dep_}]")
for left in root.lefts:
printRecursive(left, indent=indent+4, skipPunct=skipPunct)
for right in root.rights:
printRecursive(right, indent=indent+4, skipPunct=skipPunct)
for sent in doc.sents: # iterate over all sentences in a doc
printRecursive(sent.root, indent=0, skipPunct=skipPunct)
[docs]
def plotDAG(edges, colors='k'):
"""
Args:
edges: list of tuples, [(subj, conj), (..,..)] or [(subj, conj, {"color":"blue"}), (..,..)]
colors: str or list, list of colors
"""
g = nx.MultiDiGraph()
g.add_edges_from(edges)
nx.draw_networkx(g, edge_color=colors)
ax = plot.gca()
plt.axis("off")
plot.show()
[docs]
def generatePattern(form, label, id, attr="LOWER"):
"""
Generate entity pattern
Args:
form: str or list, the given str or list of lemmas that will be used to generate pattern
label: str, the label name for the pattern
id: str, the id name for the pattern
attr: str, attribute used for the pattern, either "LOWER" or "LEMMA"
Returns:
pattern: dict, pattern will be used by entity matcher
"""
# if any of "!", "?", "+", "*" present in the provided string "form", we will treat it as determiter for the form
if attr.lower() == "lower":
ptn = []
attr = "LOWER"
for elem in form.lower().split():
if "-" in form:
subList = list(elem.split("-"))
# slow when using re.split
# subList = re.split("(-)", elem)
# for sub in subList:
# if sub != '-':
# ptn.append({attr:sub} if sub not in ["!", "?", "+", "*"] else {"POS":"DET", "OP":sub})
# else:
# ptn.append({"ORTH":"-"})
for i, sub in enumerate(subList):
ptn.append({attr:sub} if sub not in ["!", "?", "+", "*"] else {"POS":"DET", "OP":sub})
if i < len(subList)-1:
ptn.append({"ORTH":"-"})
else:
ptn.append({attr:elem} if elem not in ["!", "?", "+", "*"] else {"POS":"DET", "OP":elem})
elif attr.lower() == "lemma":
attr = "LEMMA"
ptn = [{attr:elem} if elem not in ["!", "?", "+", "*"] else {"POS":"DET", "OP":elem} for elem in form]
else:
raise IOError(f"Incorrect 'attr={attr}' is provided, valid value for 'attr' is either 'LOWER' or 'LEMMA'")
pattern = {"label":label, "pattern":ptn, "id": id}
return pattern
[docs]
def generatePatternList(entList, label, id, nlp, attr="LOWER"):
"""
Generate a list of entity patterns
Args:
entList: list, list of entities
label: str, the label name for the pattern
id: str, the id name for the pattern
attr: str, attribute used for the pattern, either "LOWER" or "LEMMA"
Returns:
ptnList, list, list of patterns will be used by entity matcher
"""
ptnList = []
for ent in entList:
if len(ent.strip()) == 0:
continue
# if default is LEMMA, we should also add LOWER to the pattern, since the lemma for "Cooling System"
# will be "cool System", which can not capture "cool system".
ptn = generatePattern(ent, label, id, attr="LOWER")
ptnList.append(ptn)
if attr.lower() == "lemma":
try:
entLemma = extractLemma(ent, nlp)
except ValueError:
print(f'Can not add to pattern list {ent}')
continue
ptnLemma = generatePattern(entLemma, label, id, attr)
# print(ptnLemma)
ptnList.append(ptnLemma)
return ptnList
###############
# methods can be used for callback in "add" method
###############
[docs]
def extendEnt(matcher, doc, i, matches):
"""
Extend the doc's entity
Args:
matcher: spacy.Matcher, the spacy matcher instance
doc: spacy.tokens.doc.Doc, the processed document using nlp pipelines
i: int, index of the current match (matches[i])
matches: List[Tuple[int, int, int]], a list of (match_id, start, end) tuples, describing
the matches. A match tuple describes a span doc[start:end]
"""
id, start, end = matches[i]
ent = Span(doc, start, end, label=id)
logger.debug(ent.label_ + ' ' + ent.text)
doc.ents = filter_spans(list(doc.ents) +[ent])
[docs]
def customTokenizer(nlp):
"""custom tokenizer to keep hyphens between letters and digits
When apply tokenizer, the words with hyphens will be splitted into multiple tokens,
this function can be used to avoid the split of the words when hyphens are present.
Args:
nlp (spacy nlp model): spacy nlp model
Returns:
nlp: nlp with custom tokenizer
"""
inf = list(nlp.Defaults.infixes) # Default infixes
inf.remove(r"(?<=[0-9])[+\-\*^](?=[0-9-])") # Remove the generic op between numbers or between a number and a -
inf = tuple(inf) # Convert inf to tuple
infixes = inf + tuple([r"(?<=[0-9])[+\*^](?=[0-9-])", r"(?<=[0-9])-(?=-)"]) # Add the removed rule after subtracting (?<=[0-9])-(?=[0-9]) pattern
infixes = [x for x in infixes if '-|–|—|--|---|——|~' not in x] # Remove - between letters rule
infix_re = compile_infix_regex(infixes)
nlp.tokenizer = Tokenizer(nlp.vocab, prefix_search=nlp.tokenizer.prefix_search,
suffix_search=nlp.tokenizer.suffix_search,
infix_finditer=infix_re.finditer,
token_match=nlp.tokenizer.token_match,
rules=nlp.Defaults.tokenizer_exceptions)
return nlp