Moving around the testing notebooks. Autocropping is about done with exception to any new versions or converting the stuff to C code. Signed-off-by: Ethan Wellenreiter <ewellenreiter@gmail.com>
386 lines
13 KiB
Plaintext
386 lines
13 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 137,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"version=2.0\n",
|
|
"cachepath=\"../.cache/\"\n",
|
|
"savepath=\"./savespot/\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 138,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch\n",
|
|
"from torch.utils.data import DataLoader\n",
|
|
"import torch.nn as nn\n",
|
|
"import torch.nn.functional as fn\n",
|
|
"import torch.optim as optim\n",
|
|
"import torchvision.transforms.functional as tvf\n",
|
|
"import torchvision.transforms.v2 as v2\n",
|
|
"import torchvision.models as models\n",
|
|
"import torchvision.transforms as t\n",
|
|
"\n",
|
|
"\n",
|
|
"from PIL import Image\n",
|
|
"\n",
|
|
"import datasets as ds\n",
|
|
"from tqdm.autonotebook import tqdm\n",
|
|
"\n",
|
|
"import random\n",
|
|
"\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"\n",
|
|
"import numpy as np\n",
|
|
"\n",
|
|
"\n",
|
|
"torch.cuda.empty_cache()\n",
|
|
"\n",
|
|
"\n",
|
|
"import os\n",
|
|
"import cv2"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 139,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# array = np.load(\"./testing_space/outputarray.npy\")\n",
|
|
"# counter = np.load(\"./testing_space/counter.npy\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 140,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# print(array)\n",
|
|
"# print(counter)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 141,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class RotationDeterminer(nn.Module):\n",
|
|
" def __init__(self, new=False):\n",
|
|
" super(RotationDeterminer,self).__init__()\n",
|
|
" \n",
|
|
" torch.cuda.empty_cache()\n",
|
|
" \n",
|
|
" self.device = torch.device(\"cpu\")\n",
|
|
" if torch.cuda.is_available:\n",
|
|
" self.device = torch.device(\"cuda:0\")\n",
|
|
" \n",
|
|
" \n",
|
|
" self.appliers = [v2.RandomApply(transforms=[v2.RandomPosterize(bits=1)], p=0.25),\n",
|
|
" v2.RandomApply(transforms=[v2.ElasticTransform(alpha=25.0)], p=0.25), # maybe add fill=appliedFill\n",
|
|
" v2.RandomApply(transforms=[v2.GaussianBlur(kernel_size=(5,9), sigma=(0.1,2.))],p=0.25),\n",
|
|
" v2.RandomApply(transforms=[v2.RandomEqualize()],p=0.25)]\n",
|
|
" \n",
|
|
" \n",
|
|
" # self.conv = nn.Sequential(nn.Conv2d(3, 9, kernel_size=11,stride=3), # 1100 x 1100 => 201 x 201\n",
|
|
" # nn.ReLU(inplace=True),\n",
|
|
" # nn.Conv2d(9, 18, kernel_size=5,stride=1),\n",
|
|
" # nn.ReLU(inplace=True),\n",
|
|
" # nn.MaxPool2d(kernel_size=4, stride=2),\n",
|
|
" # nn.Conv2d(18, 36, kernel_size=3,stride=2),\n",
|
|
" # nn.BatchNorm2d(36),\n",
|
|
" # nn.ReLU(inplace=True),\n",
|
|
" # nn.Conv2d(36, 72, kernel_size=3,stride=2),\n",
|
|
" # nn.ReLU(inplace=True),\n",
|
|
" # nn.AvgPool2d(kernel_size=5, stride=3),\n",
|
|
" # nn.Conv2d(72, 144, kernel_size=3,stride=1),\n",
|
|
" # nn.ReLU(inplace=True),\n",
|
|
" # nn.Conv2d(144, 288, kernel_size=5,stride=1),\n",
|
|
" # nn.ReLU(inplace=True),\n",
|
|
" # nn.MaxPool2d(kernel_size=4, stride=1),\n",
|
|
" # nn.Conv2d(288, 192, kernel_size=3,stride=1),\n",
|
|
" # nn.ReLU(inplace=True),\n",
|
|
" # nn.Conv2d(192, 192, kernel_size=3,stride=1), # => 1\n",
|
|
" # nn.ReLU(inplace=True))\n",
|
|
" # print(\"hi\")\n",
|
|
" self.conv = models.resnet18(pretrained=new)\n",
|
|
" \n",
|
|
" self.classifier = nn.Sequential(nn.Linear(1000, 4096),\n",
|
|
" nn.ReLU(inplace=True),\n",
|
|
" nn.Linear(4096,1))\n",
|
|
" \n",
|
|
" self.lossfunc = nn.MSELoss()\n",
|
|
" \n",
|
|
" self.imageprep = v2.Compose([self.SquarePad(),v2.Resize(512),v2.Grayscale(num_output_channels=3),v2.CenterCrop(512),v2.ToImageTensor(), v2.ConvertImageDtype()])\n",
|
|
" \n",
|
|
" \n",
|
|
" class SquarePad:\n",
|
|
" def __call__(self, image):\n",
|
|
" # print(\"hi type:\", type(image))\n",
|
|
" temp = image.size()\n",
|
|
" w = temp[-2]\n",
|
|
" h = temp[-1]\n",
|
|
" max_wh = max([w, h])\n",
|
|
" hp = int((max_wh - w) / 2)\n",
|
|
" vp = int((max_wh - h) / 2)\n",
|
|
" padding = (hp, vp, hp, vp)\n",
|
|
" return tvf.pad(image, padding, 0, 'edge')\n",
|
|
"\n",
|
|
"\n",
|
|
" \n",
|
|
"\n",
|
|
" \n",
|
|
" def forward(self, image):\n",
|
|
"\n",
|
|
" transformedimage = self.imageprep(image)\n",
|
|
" transformedimage = transformedimage.to(self.device)\n",
|
|
"\n",
|
|
" if (len(transformedimage.shape) != 4 and len(transformedimage.shape) != 3):\n",
|
|
" raise Exception(\"Sorry, Dimension of image is incorrect (\", len(transformedimage.shape),\"). Expected a 3D (single image) or 4D (batch of images) tensor\")\n",
|
|
"\n",
|
|
" if (len(transformedimage.shape) == 3):\n",
|
|
" x = transformedimage.unsqueeze(0)\n",
|
|
" else:\n",
|
|
" x = transformedimage\n",
|
|
" \n",
|
|
" x = self.conv(x)\n",
|
|
" # print(x.shape)\n",
|
|
" # x = nn.Flatten(start_dim=-1)(x)\n",
|
|
" # print(x.shape)\n",
|
|
" x = self.classifier(x)\n",
|
|
" # print(x.shape)\n",
|
|
" guessRotation = nn.Flatten(start_dim=0)(x)\n",
|
|
" \n",
|
|
" return guessRotation\n",
|
|
" \n",
|
|
" def loss(self, guess, trueAnswer):\n",
|
|
" return self.lossfunc(guess, trueAnswer)\n",
|
|
" \n",
|
|
" "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 142,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n",
|
|
" warnings.warn(\n",
|
|
"/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.\n",
|
|
" warnings.warn(msg)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"model = RotationDeterminer(new=True)\n",
|
|
"device = torch.device(\"cpu\")\n",
|
|
"if torch.cuda.is_available:\n",
|
|
" device = torch.device(\"cuda:0\")\n",
|
|
" model = model.to(device)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 143,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# def ResizeWithAspectRatio(image, width=None, height=None, inter=cv2.INTER_AREA):\n",
|
|
"# dim = None\n",
|
|
"# (h, w) = image.shape[:2]\n",
|
|
"\n",
|
|
"# if width is None and height is None:\n",
|
|
"# return image\n",
|
|
"# if width is None:\n",
|
|
"# r = height / float(h)\n",
|
|
"# dim = (int(w * r), height)\n",
|
|
"# else:\n",
|
|
"# r = width / float(w)\n",
|
|
"# dim = (width, int(h * r))\n",
|
|
"\n",
|
|
"# return cv2.resize(image, dim, interpolation=inter)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 163,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"torch.Size([1, 4032, 3024])\n",
|
|
"torch.Size([3, 4032, 3024])\n",
|
|
"0.7532281875610352\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"working_dataset = ds.load_from_disk(cachepath + \"datasets/customrotation/\")\n",
|
|
"prepimage = v2.Compose([v2.Grayscale(num_output_channels=3),v2.Resize(512), v2.CenterCrop(512),v2.ToImageTensor(), v2.ConvertImageDtype()])\n",
|
|
"tensorize = v2.Compose([v2.ToImageTensor(), v2.ConvertImageDtype()])\n",
|
|
"grayscaler = v2.Grayscale(num_output_channels=3)\n",
|
|
"working_dataset.set_transform(prepimage)\n",
|
|
"counter = np.load(savepath + \"/v\"+str(version)+\"/counter.npy\")\n",
|
|
"model.load_state_dict(torch.load(savepath + \"/v\"+str(version)+\"/modelsave\" + str(counter) +\"epochs\"))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 165,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"torch.Size([1, 800, 723])\n",
|
|
"torch.Size([3, 800, 723])\n",
|
|
"-1.3860492706298828\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"filereadimage = cv2.imread(\"./testing_space/cropped.jpg\", 0)\n",
|
|
"# print(type(filereadimage))\n",
|
|
"tensorizedimage = torch.unsqueeze(torch.from_numpy(filereadimage),0)\n",
|
|
"print(tensorizedimage.shape)\n",
|
|
"adjustedtensorizedimage = tensorize(grayscaler(t.ToPILImage()(tensorizedimage)))\n",
|
|
"print(adjustedtensorizedimage.shape)\n",
|
|
"rotation = model(adjustedtensorizedimage).item()\n",
|
|
"print(rotation)\n",
|
|
"rotatedimage = t.Resize(size=1000)(tvf.rotate(adjustedtensorizedimage, rotation))\n",
|
|
"# imS = mf.ResizeWithAspectRatio(filereadimage, 1000)\n",
|
|
"# imS = cv2.resize(filereadimage, (960, 540)) \n",
|
|
"open_cv_image = np.array(t.ToPILImage()(rotatedimage))\n",
|
|
"cv2.imshow(f'image', open_cv_image)\n",
|
|
"key = cv2.waitKey(0)\n",
|
|
"cv2.destroyAllWindows()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"index = 0\n",
|
|
"active_dataset = working_dataset['test']"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# plt.imshow(t.ToPILImage()(working_dataset['test'][3]['image']), cmap='gray', vmin=0, vmax=255)\n",
|
|
"# plt.show()\n",
|
|
"# rotationapplier = model(working_dataset['test'][3]['image']).item()\n",
|
|
"# print(rotationapplier)\n",
|
|
"# img = tvf.rotate(working_dataset['test'][3]['image'], rotationapplier)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# plt.imshow(t.ToPILImage()(img), cmap='gray', vmin=0, vmax=255)\n",
|
|
"# plt.show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# # To call the model on a bunch of the images and rotate them back\n",
|
|
"\n",
|
|
"# while(True):\n",
|
|
"# activeimage = active_dataset[index]['image']\n",
|
|
"# # img = cv2.imread(active_dataset[index]['image'], 0)\n",
|
|
"# activeimage = tvf.rotate(activeimage, model(activeimage).item())\n",
|
|
"# open_cv_image = np.array(t.ToPILImage()(activeimage))\n",
|
|
"# print(index)\n",
|
|
"# cv2.imshow(f'current image', open_cv_image)\n",
|
|
"# key = cv2.waitKey(0)\n",
|
|
"\n",
|
|
"# if key == ord('c'):\n",
|
|
"# print(\"\\tCopying this one\")\n",
|
|
"# elif key == ord('x'):\n",
|
|
"# index -= 1\n",
|
|
"# elif key == ord('v'):\n",
|
|
"# index +=1\n",
|
|
"# elif key == ord('q'):\n",
|
|
"# break\n",
|
|
"\n",
|
|
"# cv2.destroyAllWindows()\n",
|
|
"# cv2.destroyAllWindows()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# # for trying to call the model on the picture repeatedly to see if it will just get more and more straight if it's called multiple times\n",
|
|
"\n",
|
|
"# currentimage = working_dataset['test'][3]['image']\n",
|
|
"# while(True):\n",
|
|
"# rotationapplier = model(currentimage).item()\n",
|
|
"# print(rotationapplier)\n",
|
|
"# img = tvf.rotate(currentimage, rotationapplier)\n",
|
|
"# open_cv_image = np.array(t.ToPILImage()(img))\n",
|
|
"# cv2.imshow(f'current image', open_cv_image)\n",
|
|
"# key = cv2.waitKey(0)\n",
|
|
" \n",
|
|
"# if key == ord('q'):\n",
|
|
"# break\n",
|
|
"# elif key == ord('v'):\n",
|
|
"# currentimage = img\n",
|
|
"# # cv2.destroyAllWindows()\n",
|
|
"# cv2.destroyAllWindows()"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.12"
|
|
},
|
|
"orig_nbformat": 4
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|