Advertisement
VSZM

OCR Python Cached

Apr 7th, 2020
903
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 11.50 KB | None | 0 0
  1. import random
  2. import copy
  3. import hashlib
  4. import io
  5. import itertools
  6. import os
  7. import pickle
  8. import tempfile
  9. import time
  10. from abc import ABC, abstractmethod
  11. from datetime import datetime, timedelta
  12. from typing import Dict, List, Tuple
  13.  
  14. import cv2
  15. import numpy as np
  16. import pause
  17. import pytesseract
  18. import requests
  19. from PIL import Image
  20.  
  21. import invoiceparser.tesseract_parser as tp
  22. from business_objects import OCRMethod
  23. from invoiceparser.image_util import PIL_img_to_bytes, scale_image_to_filesize
  24. from logger_setup import getLogger
  25. from business_objects import Fragment, InvoiceDocument, YDirection
  26.  
  27. from tempfile import TemporaryFile
  28. from utility import RenamingUnpickler
  29.  
  30. logger = getLogger(__name__)
  31.  
  32.  
  33.  
  34. """
  35. ====================================================================
  36. Base OCR
  37. ====================================================================
  38. """
  39.  
  40. CACHE_FILE_NAME = os.path.join(os.path.dirname(__file__), 'ocr_output_cache.pickle')
  41.  
  42.  
  43. def _init_cache() -> Dict[Tuple[OCRMethod, str], List[Fragment]]:
  44.     global cache
  45.  
  46.     if not os.path.isfile(CACHE_FILE_NAME):
  47.         return {}
  48.  
  49.     with io.open(CACHE_FILE_NAME, 'rb') as f:
  50.         return RenamingUnpickler(f).load()
  51.  
  52. def _save_cache():
  53.     global cache
  54.  
  55.     with io.open(CACHE_FILE_NAME, 'wb') as f:
  56.         pickle.dump(cache, f)
  57.    
  58. cache = _init_cache()
  59.  
  60. class OCRBase(ABC):
  61.  
  62.     def __init__(self):
  63.         pass
  64.  
  65.     def get_fragment_list(self, image_name: str, image_bytes: bytes, use_cache: bool = False) -> List[Fragment]:
  66.         global cache
  67.        
  68.         md5hash = hashlib.md5(image_bytes).hexdigest()
  69.         key = (self._get_method(), image_name + str(md5hash))
  70.        
  71.         if use_cache and key in cache:
  72.             logger.debug('Found the image in the cache by key |%s|!', key)
  73.             return copy.deepcopy(cache[key])
  74.  
  75.         fragments = self._get_fragment_list(image_name, image_bytes)
  76.  
  77.         if use_cache and len(fragments) > 0:
  78.             logger.debug('Adding %s to cache', '\n\t'.join(str(fragment) for fragment in fragments))
  79.             cache[key] = fragments
  80.             _save_cache()
  81.  
  82.         return fragments
  83.  
  84.     @abstractmethod
  85.     def _get_fragment_list(self, image_name: str, image_bytes: bytes) -> List[Fragment]:
  86.         pass
  87.  
  88.  
  89.     @abstractmethod
  90.     def _get_method(self) -> OCRMethod:
  91.         pass
  92.  
  93. """
  94. ====================================================================
  95. Azure OCR
  96. ====================================================================
  97. Pricing: https://azure.microsoft.com/en-us/pricing/details/cognitive-services/
  98. API: https://westus.dev.cognitive.microsoft.com/docs/services/5adf991815e1060e6355ad44/operations/56f91f2e778daf14a499e1fc
  99. """
  100.  
  101.  
  102. class AzureOCR(OCRBase):
  103.  
  104.  
  105.     AZURE_PAUSE_TIME = timedelta(seconds=3)
  106.     AZURE_OCR_URL = "https://westeurope.api.cognitive.microsoft.com/vision/v2.0/ocr"
  107.     AZURE_SUBSCRIPTION_KEY = "44f8dd20ad21451ead1d09c87cf36e9d"
  108.     MAX_IMAGE_SIZE = 4 * 1024 * 1024
  109.     MAX_DIMENSION = 4000
  110.  
  111.  
  112.     def __init__(self):
  113.         self.next_azure_request_time = datetime.now()
  114.  
  115.  
  116.     def _get_fragment_list(self, image_name: str, image_bytes: bytes) -> List[Fragment]:
  117.         image_bytes = AzureOCR.__normalize_image(image_bytes)
  118.         logger.debug(len(image_bytes))
  119.  
  120.         logger.debug('Pausing until |%s|', self.next_azure_request_time)
  121.         pause.until(self.next_azure_request_time)# Need to throttle down as per API contract
  122.  
  123.         headers    = {'Ocp-Apim-Subscription-Key': AzureOCR.AZURE_SUBSCRIPTION_KEY,
  124.                         'Content-Type': 'application/octet-stream'}
  125.         params     = {'language': 'unk', 'detectOrientation': 'true'}
  126.         response = requests.post(AzureOCR.AZURE_OCR_URL, headers=headers, params=params, data=image_bytes)
  127.         if response.status_code != 200:
  128.             logger.warn('Bad response |%d|! Details:\n %s', response.status_code, response.json())
  129.         response.raise_for_status()
  130.  
  131.         self.next_azure_request_time = datetime.now() + AzureOCR.AZURE_PAUSE_TIME
  132.  
  133.         return AzureOCR.__get_fragments_from_json(response.json())
  134.  
  135.     def _get_method(self) -> OCRMethod:
  136.         return OCRMethod.AZURE
  137.  
  138.     @staticmethod
  139.     def __wordbox_to_fragment(wordbox) -> Fragment:
  140.         x = int(wordbox['boundingBox'].split(',')[0])
  141.         y = int(wordbox['boundingBox'].split(',')[1])
  142.         width = int(wordbox['boundingBox'].split(',')[2])
  143.         height = int(wordbox['boundingBox'].split(',')[3])
  144.         text = ' '
  145.         for word in [word['text'] for word in wordbox['words']]:
  146.             # handling the case when a number's digits got separated
  147.             if text[-1].isdigit() and word[0].isdigit():
  148.                 text = text + word
  149.             else:
  150.                 text = text + ' ' + word
  151.         text = text.strip()
  152.  
  153.         return Fragment(text, x, y, width, height)
  154.  
  155.  
  156.     @staticmethod
  157.     def __get_fragments_from_json(json_object) -> List[Fragment]:
  158.         logger.debug(json_object)
  159.         wordboxes = [region['lines'] for region in json_object['regions']]
  160.         wordboxes = list(itertools.chain.from_iterable(wordboxes))
  161.         fragments = map(lambda wordbox: AzureOCR.__wordbox_to_fragment(wordbox), wordboxes)
  162.         return list(fragments)
  163.  
  164.    
  165.  
  166.     @staticmethod
  167.     def __normalize_image(image_bytes: bytes) -> bytes:
  168.         img = Image.open(io.BytesIO(image_bytes))
  169.  
  170.         sizefactor = float(AzureOCR.MAX_IMAGE_SIZE) / len(image_bytes)
  171.  
  172.         if img.height > img.width:
  173.             pixelfactor = AzureOCR.MAX_DIMENSION / img.height
  174.         else:
  175.             pixelfactor = AzureOCR.MAX_DIMENSION / img.width
  176.  
  177.         if pixelfactor < 1.0 or sizefactor < 1.0:
  178.             logger.info('Image is too large! Size: |%s, %d bytes| Dimensions: (%d, %d)', img.size, len(image_bytes), img.width, img.height)
  179.  
  180.             factor = min(sizefactor, pixelfactor)
  181.             factor = factor * 0.9
  182.             # This is hacky but the size of the output image's size from PIL is not determenistic.
  183.             # So if we need to resize let's be aggressive about it.
  184.  
  185.             logger.info('Reducing image to |%f| of the original size!', factor)
  186.             img = img.resize((int(img.width * factor), int(img.height * factor)))
  187.             logger.info('New size of the image is |%s|', img.size)
  188.  
  189.             return PIL_img_to_bytes(img)
  190.         else:
  191.             return image_bytes
  192.  
  193. """
  194. ====================================================================
  195. Tesseract OCR
  196. ====================================================================
  197. """
  198.  
  199. class TesseractOCR(OCRBase):
  200.  
  201.  
  202.     def _get_fragment_list(self, image_name: str, image_bytes: bytes) -> List[Fragment]:
  203.         image = cv2.imdecode(np.frombuffer(image_bytes, np.uint8), cv2.COLOR_BGR2GRAY)
  204.         gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
  205.         gray = cv2.threshold(gray, 0, 255,
  206.                     cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1]
  207.  
  208.         document = None
  209.         with tempfile.TemporaryDirectory() as tmpdir:
  210.             imgfname = os.path.join(tmpdir,  "{}.png".format(os.getpid()))
  211.             cv2.imwrite(imgfname, gray)
  212.             hocr_str = pytesseract.pytesseract.run_and_get_output(imgfname, lang=None, config="hocr", extension='hocr')
  213.             document = tp.HOCRDocument(hocr_str)
  214.  
  215.         page=document.pages[0]
  216.        
  217.         l = []
  218.         for area in page.areas:
  219.             for paragraph in area.paragraphs:
  220.                 for line in paragraph.lines:
  221.                     for word in line.words:
  222.                         l.append(
  223.                                 Fragment(
  224.                                         word.ocr_text,
  225.                                         word.coordinates[0],
  226.                                         4000 - word.coordinates[1],
  227.                                         word.coordinates[2] - word.coordinates[0],
  228.                                         word.coordinates[3] - word.coordinates[1]
  229.                                         )
  230.                                 )
  231.         return l
  232.  
  233.  
  234.     def _get_method(self) -> OCRMethod:
  235.         return OCRMethod.TESSERACT
  236.  
  237. """
  238. ====================================================================
  239. OCR.SPACE OCR
  240. ====================================================================
  241. """
  242.  
  243. class OCRSpaceOCR(OCRBase):
  244.     # wayasam@gmail.com (paid), wayasam@gmail.com (free), vszm5@hotmail.com, vszm@inf.elte.hu Get more keys if needed
  245.     API_KEYS = ['PKMXB9465888A', 'f0a34a178b88957', '62f395656388957', '96b2d9106788957']
  246.     MAX_SIZE_BYTES = 1024 * 1024 * 1
  247.  
  248.     def _get_fragment_list(self, image_name: str, image_bytes: bytes) -> List[Fragment]:
  249.         image_bytes = scale_image_to_filesize(image_bytes, OCRSpaceOCR.MAX_SIZE_BYTES)        
  250.  
  251.         with TemporaryFile(suffix='.' + image_name.split('.')[-1]) as fp:
  252.             fp.write(image_bytes)
  253.             fp.flush()
  254.                
  255.             for api_key in random.sample(OCRSpaceOCR.API_KEYS, len(OCRSpaceOCR.API_KEYS)):
  256.                 fp.seek(0)
  257.                 payload = { 'isOverlayRequired': True,
  258.                             'apikey': api_key,
  259.                             'detectOrientation': True,
  260.                             'language': 'hun'
  261.                         }
  262.  
  263.                 with requests.Session() as session:
  264.                     response = session.post('https://api.ocr.space/parse/image',
  265.                                     files={'filename': fp},
  266.                                     data=payload,
  267.                                     headers={'Connection':'close'})
  268.                    
  269.                 json = response.json()
  270.  
  271.                 if response.status_code == 200 and json["OCRExitCode"] < 3:
  272.                     logger.info('Returning stuff: %s', json)
  273.                     return self.__get_fragments_from_json(json)
  274.                 else:
  275.                     logger.warn('Error occured: |%s|', response.text)
  276.  
  277.                 logger.debug('Failed to get correct response with key |%s|, Code: |%d|, Response: |%s|',\
  278.                         api_key, response.status_code, response.json())
  279.                 #logger.debug('Sleeping for 35 seconds for API throttling')
  280.                 #time.sleep(35)
  281.                
  282.                
  283.         logger.error('Could not get response with any of the api keys. Image Name: |%s|, bytes length: |%d|',\
  284.                         image_name, len(image_bytes))
  285.         return []
  286.    
  287.     def _get_method(self) -> OCRMethod:
  288.         return OCRMethod.OCRSPACE
  289.  
  290.  
  291.    
  292.     @staticmethod
  293.     def __line_to_fragment(line) -> Fragment:
  294.         x = int(line['Words'][0]['Left'])
  295.         y = int(line['Words'][0]['Top'])
  296.         width = int(line['Words'][-1]['Left'] + line['Words'][-1]['Width'] - x)
  297.         height = int(line['Words'][-1]['Top'] + line['Words'][-1]['Height'] - y)
  298.         text = line['LineText'].strip()
  299.  
  300.         return Fragment(text, x, y, width, height)
  301.  
  302.  
  303.     @staticmethod
  304.     def __get_fragments_from_json(json_object) -> List[Fragment]:
  305.         try:
  306.             lines = [line for line in json_object['ParsedResults'][0]['TextOverlay']['Lines']]
  307.             fragments = map(lambda line: OCRSpaceOCR.__line_to_fragment(line), lines)
  308.             return list(fragments)
  309.         except:
  310.             logger.error('Could not parse json! |%s|', json_object)
  311.             return []
  312.  
  313. def get_ocr_engine(method: OCRMethod) -> OCRBase:
  314.     if method == OCRMethod.OCRSPACE:
  315.         return OCRSpaceOCR()
  316.     elif method == OCRMethod.AZURE:
  317.         return AzureOCR()
  318.     else:
  319.         return TesseractOCR()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement