# Copyright 2024, Battelle Energy Alliance, LLC ALL RIGHTS RESERVED
"""
Created on October, 2022
@author: mandd, wangc
"""
import re
import itertools
import numpy as np
import spacy
from spacy.vocab import Vocab
import logging
[docs]
logger = logging.getLogger(__name__)
try:
from contextualSpellCheck.contextualSpellCheck import ContextualSpellCheck
except ModuleNotFoundError as error:
logger.error(f"Unable to import contextualSpellCheck: {error}")
logger.info("Please try to install it via: 'pip install contextualSpellCheck'")
try:
import autocorrect
except ModuleNotFoundError as error:
logger.error(f"Unable to import autocorrect: {error}")
logger.info("Please try to install it via: 'pip install autocorrect'")
try:
from spellchecker import SpellChecker as PySpellChecker
except ModuleNotFoundError as error:
logger.error(f"Unable to import spellchecker: {error}")
logger.info("Please try to install it via: 'pip install spellchecker'")
from ..similarity.simUtils import wordsSimilarity
from ..config import nlpConfig
[docs]
class SpellChecker(object):
"""
Object to find misspelled words and automatically correct spelling
Note: when using autocorrect, one need to conduct a spell test to identify the threshold (the word frequencies)
"""
def __init__(self, checker='autocorrect'):
"""
SpellChecker object constructor
Args:
checker: str, optional, spelling corrector to use ('autocorrect' or 'ContextualSpellCheck')
Returns:
None
"""
[docs]
self.checker = checker.lower()
[docs]
self.includedWords = []
if 'extra_vocab' in nlpConfig['files']:
file2open = nlpConfig['files']['extra_vocab']
with open(file2open, 'r') as file:
tmp = file.readlines()
self.addedWords = list({x.replace('\n', '') for x in tmp})
# get included and additional dictionary words and update speller dictionary
if self.checker == 'autocorrect':
self.speller = autocorrect.Speller()
self.speller.nlp_data.update({x: 1000000 for x in self.addedWords})
elif self.checker == 'pyspellchecker':
self.speller = PySpellChecker()
self.speller.word_frequency.load_words(self.addedWords)
else:
name = 'contextual spellcheck'
languageModel = nlpConfig['params']['spacy_language_pipeline']
self.nlp = spacy.load(languageModel)
self.speller = ContextualSpellCheck(self.nlp, name)
self.includedWords = list(self.speller.BertTokenizer.get_vocab().keys())
self.speller.vocab = Vocab(strings=self.includedWords+self.addedWords)
[docs]
def addWordsToDictionary(self, words):
"""
Adds a list of words to the spell check dictionary
Args:
words: list, list of words to add to the dictionary
Returns:
None
"""
if self.checker == 'autocorrect':
self.speller.nlp_data.update({word: 1000000 for word in words})
elif self.checker == 'pyspellchecker':
self.speller.word_frequency.load_words(self.addedWords+words)
else:
self.speller.vocab = Vocab(strings=self.includedWords+self.addedWords+words)
[docs]
def getMisspelledWords(self, text):
"""
Returns a list of words that are misspelled according to the dictionary used
Args:
None
Returns:
misspelled: list, list of misspelled words
"""
if self.checker == 'autocorrect':
# corrected = self.speller(text.lower())
original = re.findall(r'[^\s!,.?":;-]+', text)
# auto = re.findall(r'[^\s!,.?":;-]+', corrected)
# misspelled = list({w1 if w1.lower() != w2.lower() else None for w1, w2 in zip(original, auto)})
misspelled = {word for word in original if word not in self.speller.nlp_data}
if None in misspelled:
misspelled.remove(None)
elif self.checker == 'pyspellchecker':
original = re.findall(r'[^\s!,.?":;-]+', text)
misspelled = self.speller.unknown(original)
else:
doc = self.nlp(text)
doc = self.speller(doc)
misspelled = {str(x) for x in doc._.suggestions_spellCheck.keys()}
return misspelled
[docs]
def correct(self, text):
"""
Performs automatic spelling correction and returns corrected text
Args:
None
Returns:
corrected: str, spelling corrected text
"""
if self.checker == 'autocorrect':
corrected = self.speller(text)
elif self.checker == 'pyspellchecker':
l = re.split(r"([A-Za-z]+(?=\s|\.))", text)
corrected = []
for elem in l:
if len(elem) == 0:
corrected.append(elem)
elif not re.search(r"[^A-Za-z]+",elem):
if elem in self.speller:
corrected.append(elem)
else:
corrected.append(self.speller.correction(elem))
else:
corrected.append(elem)
corrected = "".join(corrected)
else:
doc = self.nlp(text)
doc = self.speller(doc)
corrected = doc._.outcome_spellCheck
return corrected
[docs]
def handleAbbreviations(self, abbrDatabase, text, type):
"""
Performs automatic correction of abbreviations and returns corrected text
This method relies on a database of abbreviations located at:
`src/nlp/data/abbreviations.xlsx`
This database contains the most common abbreviations collected from literature and
it provides for each abbreviation its corresponding full word(s); an abbreviation might
have multiple words associated. In such case the full word that makes more sense given the
context is chosen (see findOptimalOption method)
Args:
abbrDatabase: pandas dataframe, dataframe containing library of abbreviations
and their corresponding full expression
text: str, string of text that will be analyzed
type: string, type of abbreviation method ('spellcheck','hard','mixed') that are employed
to determine which words are abbreviations that need to be expanded
* spellcheck: in this case spellchecker is used to identify words that
are not recognized
* hard: here we directly search for the abbreviations in the provided
sentence
* mixed: here we perform first a "hard" search followed by a "spellcheck"
search
Returns:
options: list, list of corrected text options
"""
abbreviationSet = set(abbrDatabase['Abbreviation'].values)
if type == 'spellcheck':
unknowns = self.getMisspelledWords(text)
elif type == 'hard' or type=='mixed':
unknowns = []
splitSent = text.split()
for word in splitSent:
if word.lower() in abbreviationSet:
unknowns.append(word)
if type=='mixed':
set1 = set(self.getMisspelledWords(text))
set2 = set(unknowns)
unknowns = list(set1.union(set2))
corrections={}
for word in unknowns:
if word.lower() in abbrDatabase['Abbreviation'].values:
locs = list(abbrDatabase['Abbreviation'][abbrDatabase['Abbreviation']==word.lower()].index.values)
if locs:
corrections[word] = abbrDatabase['Full'][locs].values.tolist()
else:
print(word)
else:
# Here we are addressing the fact that the abbreviation database will never be complete
# Given an abbreviation that is not part of the abbreviation database, we are looking for a
# a subset of abbreviations the abbreviation database that are close enough (and consider
# them as possible candidates
from difflib import SequenceMatcher
corrections[word] = []
abbreviationDS = abbrDatabase['Abbreviation'].values
for index,abbr in enumerate(abbreviationDS):
if SequenceMatcher(None, word, abbr).ratio()>0.8:
corrections[word].append(abbrDatabase['Full'].values.tolist()[index])
if not corrections[word]:
corrections.pop(word)
combinations = list(itertools.product(*list(corrections.values())))
options = []
for comb in combinations:
corrected = text
for index,key in enumerate(corrections.keys()):
corrected = re.sub(r"\b%s\b" % str(key) , comb[index], corrected)
options.append(corrected)
if not options:
return text
else:
bestOpt = self.findOptimalOption(options)
return bestOpt
[docs]
def generateAbbrDict(self, abbrDatabase):
"""
Generates an AbbrDict that can be used by handleAbbreviationsDict
Args:
abbrDatabase: pandas dataframe, dataframe containing library of abbreviations
and their corresponding full expression
Returns:
abbrDict: dictionary, a abbreviations dictionary
"""
abbrDict = {}
#There may be a more efficient way to do the following
for row in abbrDatabase.itertuples():
abbrs = abbrDict.get(row.Abbreviation,[])
abbrs.append(row.Full)
abbrDict[row.Abbreviation] = abbrs
return abbrDict
[docs]
def handleAbbreviationsDict(self, abbrDict, text, type):
"""
Performs automatic correction of abbreviations and returns corrected text
This method relies on a database of abbreviations located at:
src/nlp/data/abbreviations.xlsx
This database contains the most common abbreviations collected from literature and
it provides for each abbreviation its corresponding full word(s); an abbreviation might
have multple words associated. In such case the full word that makes more sense given the
context is chosen (see findOptimalOption method)
Args:
abbrDict: dictionary, dictionary containing library of abbreviations
and their corresponding full expression
text: str, string of text that will be analyzed
type: string, type of abbreviation method ('spellcheck','hard','mixed') that are employed
to determine which words are abbreviations that need to be expanded
* spellcheck: in this case spellchecker is used to identify words that
are not recognized
* hard: here we directly search for the abbreviations in the provided
sentence
* mixed: here we perform first a "hard" search followed by a "spellcheck"
search
Return:
options: list, list of corrected text options
"""
if type == 'spellcheck':
unknowns = self.getMisspelledWords(text)
elif type == 'hard' or type=='mixed':
unknowns = []
splitSent = text.split()
for word in splitSent:
if word.lower() in abbrDict.keys():
unknowns.append(word)
if type=='mixed':
set1 = set(self.getMisspelledWords(text))
set2 = set(unknowns)
unknowns = list(set1.union(set2))
corrections={}
for word in unknowns:
if word.lower() in abbrDict.keys():
if len(abbrDict[word.lower()]) > 0:
corrections[word] = abbrDict[word.lower()]
else:
# Here we are addressing the fact that the abbreviation database will never be complete
# Given an abbreviation that is not part of the abbreviation database, we are looking for a
# a subset of abbreviations the abbreviation database that are close enough (and consider
# them as possible candidates
from difflib import SequenceMatcher
corrections[word] = []
abbreviationDS = list(abbrDict)
for index,abbr in enumerate(abbreviationDS):
val=0
newVal = SequenceMatcher(None, word, abbr).ratio()
if newVal>=0.75 and newVal>val:
corrections[word] = abbrDict[abbr]
val = newVal
if not corrections[word]:
corrections.pop(word)
combinations = list(itertools.product(*list(corrections.values())))
options = []
for comb in combinations:
corrected = text
for index,key in enumerate(corrections.keys()):
corrected = re.sub(r"\b%s\b" % str(key) , comb[index], corrected)
options.append(corrected)
if not options:
return text
else:
bestOpt = self.findOptimalOption(options)
return bestOpt
[docs]
def findOptimalOption(self,options):
"""
Method to handle abbreviation with multiple meanings
Args:
options: list, list of sentence options
Return:
optimalOpt: string, option from the provided options list that fits more the
possible
"""
nOpt = len(options)
combScore = np.zeros(nOpt)
for index,opt in enumerate(options):
listOpt = opt.split()
for i,word in enumerate(listOpt):
for j in range(i+1,len(listOpt)):
combScore[index] = combScore[index] + wordsSimilarity(word,listOpt[j])
optIndex = np.argmax(combScore)
optimalOpt = options[optIndex]
return optimalOpt