Source code for datable_ai.ocr
import base64
from io import BytesIO
from langchain_core.messages import HumanMessage
from PIL import Image
from datable_ai.core.llm import LLM_TYPE, create_llm
# Currently only supports LLM_TYPE.ANTHROPIC
[docs]
class OCR:
def __init__(self, llm_type: LLM_TYPE, prompt_template: str) -> None:
"""
Initialize the OCR class.
Args:
llm_type (LLM_TYPE): The type of language model to use.
prompt_template (str): The prompt template for the OCR task.
"""
self.llm_type = llm_type
self.prompt_template = prompt_template
self.llm = create_llm(self.llm_type)
self.max_size = 4 * 1024 * 1024
self.quality = 85
[docs]
def invoke(self, image_path: str):
"""
Invoke the OCR process on the given image.
Args:
image_path (str): The path to the image file.
Returns:
The result of the OCR process.
"""
with open(image_path, "rb") as image_file:
compressed_image = self._compress_image(image_file)
messages = [
HumanMessage(
content=[
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{compressed_image}", # noqa: E501
},
},
{"type": "text", "text": self.prompt_template},
]
)
]
return self.llm.invoke(messages)
def _compress_image(self, image_file):
"""
Compress the image to reduce its size.
Args:
image_file: The image file object.
Returns:
The base64-encoded compressed image.
"""
with Image.open(image_file) as img:
img.thumbnail((1092, 1092))
output_buffer = BytesIO()
img.save(output_buffer, format="JPEG", optimize=True, quality=self.quality)
compressed_image = output_buffer.getvalue()
return self._get_base64_encoded_image(compressed_image)
def _get_base64_encoded_image(self, image_data):
"""
Get the base64-encoded representation of the image data.
Args:
image_data: The image data as bytes.
Returns:
The base64-encoded image string.
Raises:
ValueError: If the image compression ratio reaches its limit.
"""
encoded_image = base64.b64encode(image_data).decode("utf-8")
while len(encoded_image) > self.max_size:
self.quality -= 5
if self.quality < 5:
raise ValueError("Image compression ratio has reached its limit")
with Image.open(BytesIO(image_data)) as img:
output_buffer = BytesIO()
img.save(
output_buffer, format="JPEG", optimize=True, quality=self.quality
)
image_data = output_buffer.getvalue()
encoded_image = base64.b64encode(image_data).decode("utf-8")
return encoded_image