Toggle Menu

The human in the machine

Using NLP (BERT) to improve OCR accuracy

author profile photo

By Ravi Ilango

clock icon

4m

Optical Character Recognition (OCR) is a popular technique used to extract data from scanned documents. As you would expect, the accuracy of an OCR solution is contingent on the quality of images being used as input. One challenge facing practical applications of OCR solutions is the significant drop in word-level accuracy as a function of character-level accuracy. An OCR solution that achieves 98% character-level accuracy will find itself incorrectly extracting words 10–20% of the time, as depicted in the chart below.

Source: A Fast Alignment Scheme for Automatic OCR Evaluation of Books by Yalniz, Ismet and Manmatha, R.

One way to improve the word accuracies is to use NLP (Natural Language Processing) techniques to replace incorrect words with correct ones. In this blog, we will use a spell checker and BERT¹ (pre-trained NLP model) to improve OCR accuracy.

OCR-BERT Pipeline

BERT (Bidirectional Encoder Representations from Transformers) is a Natural Language Processing technique developed by Google. The BERT model has been trained using Wikipedia (2.5B words) + BookCorpus (800M words). BERT models can be used for a variety of NLP tasks, including sentence prediction, sentence classification, and missing word prediction. In this blog, we will use a PyTorch pre-trained BERT model³ to correct words incorrectly read by OCR.

Google BERT currently supports over 90 languages

Using BERT to increase accuracy of OCR processing

Let’s walk through an example with code. I’ll be using python to process a scanned image and create a text document using OCR, and BERT.

A. Process scanned image using OCR

Input scanned image

from PIL import Image
from pytesseract import image_to_string
import torch
from pytorch_pretrained_bert import BertTokenizer, BertForMaskedLM
import re
import nltk
from enchant.checker import SpellChecker
from difflib import SequenceMatcher
filename = './sample6.png'
text = image_to_string(Image.open(filename))
text_original = str(text)
print (text_original)

Output of OCR with incorrectly parsed words

B. Process document and identify unreadable words

Incorrect words are identified by enchant’s SpellChecker function. One thing to be mindful of when using SpellChecker is that it flags uncommon names as misspelled words. I’ll work around this problem by using nltk’s “parts of speech” tagging to exclude person names. To obtain a prediction from BERT, each incorrect word needs to be replaced with a [MASK] token. Finally, we will store replacement word suggestions from SpellChecker in our suggestedwords list.

# cleanup text
rep = { '\n': ' ', '\\': ' ', '\"': '"', '-': ' ', '"': ' " ', 
        '"': ' " ', '"': ' " ', ',':' , ', '.':' . ', '!':' ! ', 
        '?':' ? ', "n't": " not" , "'ll": " will", '*':' * ', 
        '(': ' ( ', ')': ' ) ', "s'": "s '"}
rep = dict((re.escape(k), v) for k, v in rep.items()) 
pattern = re.compile("|".join(rep.keys()))
text = pattern.sub(lambda m: rep[re.escape(m.group(0))], text)
def get_personslist(text):
    personslist=[]
    for sent in nltk.sent_tokenize(text):
        for chunk in nltk.ne_chunk(nltk.pos_tag(nltk.word_tokenize(sent))):
            if isinstance(chunk, nltk.tree.Tree) and chunk.label() == 'PERSON':
                personslist.insert(0, (chunk.leaves()[0][0]))
    return list(set(personslist))
personslist = get_personslist(text)
ignorewords = personslist + ["!", ",", ".", "\"", "?", '(', ')', '*', ''']
# using enchant.checker.SpellChecker, identify incorrect words
d = SpellChecker("en_US")
words = text.split()
incorrectwords = [w for w in words if not d.check(w) and w not in ignorewords]
# using enchant.checker.SpellChecker, get suggested replacements
suggestedwords = [d.suggest(w) for w in incorrectwords]
# replace incorrect words with [MASK]
for w in incorrectwords:
    text = text.replace(w, '[MASK]')
    text_original = text_original.replace(w, '[MASK]')
    
print(text)

Document with incorrect words replaced with [MASK]

C. Load BERT model and predict replacement words

BERT model looks for the [MASK] tokens and then attempts to predict the original value of the masked words, based on the context provided by the other, non-masked, words in the sequence. BERT also accepts segment embeddings, a vector used to distinguish multiple sentences and assist with word prediction. For example the segment vector for “Tom went to store. He bought two gallons of milk.” would be [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1].

# Load, train and predict using pre-trained model
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
tokenized_text = tokenizer.tokenize(text)
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
MASKIDS = [i for i, e in enumerate(tokenized_text) if e == '[MASK]']
# Create the segments tensors
segs = [i for i, e in enumerate(tokenized_text) if e == "."]
segments_ids=[]
prev=-1
for k, s in enumerate(segs):
    segments_ids = segments_ids + [k] * (s-prev)
    prev=s
segments_ids = segments_ids + [len(segs)] * (len(tokenized_text) - len(segments_ids))
segments_tensors = torch.tensor([segments_ids])
# prepare Torch inputs 
tokens_tensor = torch.tensor([indexed_tokens])
# Load pre-trained model
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
# Predict all tokens
with torch.no_grad():
    predictions = model(tokens_tensor, segments_tensors)

D. Refine BERT predictions by using suggestions from SpellChecker

The BERT pre-trained language model is useful for predicting multiple viable replacements for the masked words. With that said, the model is not aware of any characters uncovered by OCR. We can augment this deficiency with our suggested word list from SpellChecker, which incorporates characters from the garbled OCR output. Combining BERT’s context-based suggestions with SpellChecker’s word-based suggestions yielded better predictions than relying solely on BERT.

#Predict words for mask using BERT; 
#refine prediction by matching with proposals from SpellChecker
def predict_word(text_original, predictions, maskids):
    pred_words=[]
    for i in range(len(MASKIDS)):
        preds = torch.topk(predictions[0, MASKIDS[i]], k=50) 
        indices = preds.indices.tolist()
        list1 = tokenizer.convert_ids_to_tokens(indices)
        list2 = suggestedwords[i]
        simmax=0
        predicted_token=''
        for word1 in list1:
            for word2 in list2:
                s = SequenceMatcher(None, word1, word2).ratio()
                if s is not None and s > simmax:
                    simmax = s
                    predicted_token = word1
        text_original = text_original.replace('[MASK]', predicted_token, 1)
    return text_original
text_original = predict_word(text_original, predictions, MASKIDS)
print (text_original)

Final output with corrected words from BERT and SpellChecker

The output looks a lot better now! The incorrect words have been accurately replaced thanks to the pre-trained BERT model.

Wrapping up

The BERT language model does a good job of predicting viable replacements for the masked word(s). In the above example, when asked to predict the masked value for “conmer”, the model suggested “tax”, “government”, “business”, and “consumer” as some of the choices. While all these suggestions make sense, “consumer” did not have the highest output probability. Without SpellChecker to augment BERT’s results, we would have misclassified the masked words. This could lead to problems when replacing words where the majority of characters are misidentified by OCR.

Please note that the methods used in this blog are applicable for words and not numbers. A different approach (something like checksum) needs to be used for checking numbers read by OCR. I also recommend application specific error/suggestion identification as opposed to using SpellChecker alone.

Get notified when new blogs post