receipt_indexer/code/textdataretriever/textextractor/modelimp.ipynb
Ethan Wellenreiter ae2e6366e1 Updating line isolator to work better with TrOCR model
Signed-off-by: Ethan Wellenreiter <ewellenreiter@gmail.com>
2023-11-14 13:30:54 -05:00

275 lines
6.8 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 190,
"metadata": {},
"outputs": [],
"source": [
"# https://github.com/NielsRogge/Transformers-Tutorials/blob/master/TrOCR/Inference_with_TrOCR_%2B_Gradio_demo.ipynb\n",
"# https://github.com/NielsRogge/Transformers-Tutorials/tree/master/TrOCR\n",
"# https://huggingface.co/docs/transformers/model_doc/trocr"
]
},
{
"cell_type": "code",
"execution_count": 191,
"metadata": {},
"outputs": [],
"source": [
"from transformers import TrOCRProcessor\n",
"from transformers import VisionEncoderDecoderModel\n",
"\n",
"from PIL import Image\n",
"import torch\n",
"\n",
"torch.cuda.empty_cache()"
]
},
{
"cell_type": "code",
"execution_count": 192,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"sys.path.insert(0, '../../autocropper')\n",
"import myfunctions as mf\n",
"\n",
"import extractorfunctions as ef\n",
"import cv2"
]
},
{
"cell_type": "code",
"execution_count": 193,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.\n",
"Some weights of VisionEncoderDecoderModel were not initialized from the model checkpoint at microsoft/trocr-base-stage1 and are newly initialized: ['encoder.pooler.dense.weight', 'encoder.pooler.dense.bias']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
}
],
"source": [
"modelname = \"microsoft/trocr-base-stage1\"\n",
"processor = TrOCRProcessor.from_pretrained(modelname)\n",
"model = VisionEncoderDecoderModel.from_pretrained(modelname)"
]
},
{
"cell_type": "code",
"execution_count": 194,
"metadata": {},
"outputs": [],
"source": [
"device = torch.device(\"cpu\")\n",
"if torch.cuda.is_available:\n",
" device = torch.device(\"cuda:0\")\n",
" \n",
"model = model.to(device)\n"
]
},
{
"cell_type": "code",
"execution_count": 195,
"metadata": {},
"outputs": [],
"source": [
"filename = \"IMG_7605.jpg\"\n",
"pathname = \"../test_images/\""
]
},
{
"cell_type": "code",
"execution_count": 196,
"metadata": {},
"outputs": [],
"source": [
"img = cv2.imread(pathname+filename)"
]
},
{
"cell_type": "code",
"execution_count": 197,
"metadata": {},
"outputs": [],
"source": [
"clarified = mf.houghlineprocessing(img)\n",
"lineimages = ef.lineisolator(clarified)"
]
},
{
"cell_type": "code",
"execution_count": 198,
"metadata": {},
"outputs": [],
"source": [
"# print(len(lineimages))"
]
},
{
"cell_type": "code",
"execution_count": 199,
"metadata": {},
"outputs": [],
"source": [
"PILversions = []\n",
"for line in lineimages:\n",
" rgbline = cv2.cvtColor(line, cv2.COLOR_GRAY2RGB)\n",
" PILversions.append(Image.fromarray(rgbline))"
]
},
{
"cell_type": "code",
"execution_count": 200,
"metadata": {},
"outputs": [],
"source": [
"# PILversions[12]"
]
},
{
"cell_type": "code",
"execution_count": 201,
"metadata": {},
"outputs": [],
"source": [
"# image = Image.open(\"../result_images/6.jpg\").convert(\"RGB\")\n",
"# image"
]
},
{
"cell_type": "code",
"execution_count": 202,
"metadata": {},
"outputs": [],
"source": [
"# pixel_values = processor(image, return_tensors=\"pt\").pixel_values\n",
"# # print(pixel_values.shape)\n",
"# # print(image)\n",
"# # print(pixel_values)"
]
},
{
"cell_type": "code",
"execution_count": 203,
"metadata": {},
"outputs": [],
"source": [
"# pixel_values = processor(image, return_tensors=\"pt\").pixel_values\n",
"# # print(pixel_values.shape)\n",
"# generated_ids = model.generate(pixel_values)\n",
"# generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]\n",
"# print(generated_text)"
]
},
{
"cell_type": "code",
"execution_count": 204,
"metadata": {},
"outputs": [],
"source": [
"finalstring = \"\""
]
},
{
"cell_type": "code",
"execution_count": 205,
"metadata": {},
"outputs": [],
"source": [
"for image in PILversions:\n",
" pixel_values = processor(image, return_tensors=\"pt\").pixel_values\n",
" pixel_values = pixel_values.to(device)\n",
" generated_ids = model.generate(pixel_values)\n",
" generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]\n",
" finalstring = finalstring + generated_text + \"\\n\"\n",
" # print(generated_text)"
]
},
{
"cell_type": "code",
"execution_count": 206,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Canada Computers\n",
"920 Upper Wentworth Street, Unit 1\n",
"Hamilton, ON LSA 5C5\n",
"905-388-5900\n",
"HST# : 882966712RT0081\n",
"==============\n",
"Invoice No : ARHM00352964\n",
"09/09/2023 1:43:12 PM\n",
"Cashier : Francesco\n",
"1 NTTP100137 24.99\n",
"TP- LINK ( Archer T3U PLUS ) ACI\n",
"300 High Gain Wireless\n",
"PART# : Archer T3U PLUS\n",
"[ Warranty : Defective Exchange\n",
":30 Days ; Manufacturing Warrant\n",
"y : 2 Years : ]\n",
"2233038004384\n",
"Subtotal : $24.99\n",
"HST : $3.25\n",
"Total : $28.24\n",
"TRANSACTION REGORD\n",
"TYPE : PURCHASE\n",
"ACCT : VISA $ 28.24\n",
"CARD NUMBER : *********** #845\n",
"DATE/TIME : 23/09/09 13:43:11\n",
"REFERENCE #: [6656723010010012980\n",
"AUTHORIZATION #: 093481\n",
"VISA CREDIT\n",
"A0000000031010\n",
"8080008000 6800\n",
"01/027 Approved - Thank You\n",
"IMPORTANT\n",
"Retain this copy for your records\n",
"*** CARDHOLDER COPY ***\n",
"Returns and exchanges : Please visit\n",
"http://www. canadacomputers. com/re\n",
"turns-exchanges for full return and\n",
"exchange details.\n",
"0' 0\n",
"-\n",
"\n"
]
}
],
"source": [
"print(finalstring)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}