This code aims to upscale our last lession.
Previously, we obtained our data from the DuckDuckGo API and built our model around that.
This time, we will retrieve data from a Kaggle dataset, enhance our model, improve our understanding of different available pre-trained vision architectures in PyTorch using the timm
library, and implement another model according to our requirements.
Level 1 : Repeat last lesson
1.1 : Download dataset from Kaggle
from google.colab import files
# Upload the Kaggle API key JSON file
= files.upload()
uploaded
!pip install kaggle
!mkdir -p ~/.kaggle
!mv kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json
Saving kaggle.json to kaggle.json
Requirement already satisfied: kaggle in /usr/local/lib/python3.10/dist-packages (1.5.16)
Requirement already satisfied: six>=1.10 in /usr/local/lib/python3.10/dist-packages (from kaggle) (1.16.0)
Requirement already satisfied: certifi in /usr/local/lib/python3.10/dist-packages (from kaggle) (2024.2.2)
Requirement already satisfied: python-dateutil in /usr/local/lib/python3.10/dist-packages (from kaggle) (2.8.2)
Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from kaggle) (2.31.0)
Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from kaggle) (4.66.2)
Requirement already satisfied: python-slugify in /usr/local/lib/python3.10/dist-packages (from kaggle) (8.0.4)
Requirement already satisfied: urllib3 in /usr/local/lib/python3.10/dist-packages (from kaggle) (2.0.7)
Requirement already satisfied: bleach in /usr/local/lib/python3.10/dist-packages (from kaggle) (6.1.0)
Requirement already satisfied: webencodings in /usr/local/lib/python3.10/dist-packages (from bleach->kaggle) (0.5.1)
Requirement already satisfied: text-unidecode>=1.3 in /usr/local/lib/python3.10/dist-packages (from python-slugify->kaggle) (1.3)
Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->kaggle) (3.3.2)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->kaggle) (3.6)
Download dataset :
Open dataset in Kaggle, click on 3 vertical dots (ellipsis) then click on Copy API comamnd, and we are good to go.
!kaggle datasets download -d gpiosenka/cats-in-the-wild-image-classification
# Create a folder called "Big Cat"
!mkdir -p Big_Cat
# Unzip the dataset into the "Big Cat" folder
!unzip cats-in-the-wild-image-classification.zip -d Big_Cat
# Remove the zip file
!rm cats-in-the-wild-image-classification.zip
Downloading cats-in-the-wild-image-classification.zip to /content
96% 118M/123M [00:01<00:00, 57.4MB/s]
100% 123M/123M [00:01<00:00, 65.0MB/s]
Archive: cats-in-the-wild-image-classification.zip
inflating: Big_Cat/EfficientNetB0-10-(224 X 224)-100.00.h5
inflating: Big_Cat/MobileNetV3 small-10-(224 X 224)-95.96.h5
inflating: Big_Cat/WILDCATS.CSV
inflating: Big_Cat/test/AFRICAN LEOPARD/1.jpg
inflating: Big_Cat/test/AFRICAN LEOPARD/5.jpg
inflating: Big_Cat/test/CARACAL/1.jpg
inflating: Big_Cat/test/CARACAL/5.jpg
inflating: Big_Cat/test/CHEETAH/1.jpg
inflating: Big_Cat/test/CHEETAH/5.jpg
inflating: Big_Cat/test/CLOUDED LEOPARD/1.jpg
inflating: Big_Cat/test/CLOUDED LEOPARD/5.jpg
inflating: Big_Cat/test/JAGUAR/1.jpg
inflating: Big_Cat/test/JAGUAR/5.jpg
inflating: Big_Cat/test/LIONS/1.jpg
inflating: Big_Cat/test/LIONS/5.jpg
inflating: Big_Cat/test/OCELOT/1.jpg
inflating: Big_Cat/test/OCELOT/5.jpg
inflating: Big_Cat/test/PUMA/1.jpg
inflating: Big_Cat/test/PUMA/5.jpg
inflating: Big_Cat/test/SNOW LEOPARD/1.jpg
inflating: Big_Cat/test/SNOW LEOPARD/5.jpg
inflating: Big_Cat/test/TIGER/1.jpg
inflating: Big_Cat/test/TIGER/5.jpg
inflating: Big_Cat/train/AFRICAN LEOPARD/001.jpg
inflating: Big_Cat/train/AFRICAN LEOPARD/236.jpg
inflating: Big_Cat/train/CARACAL/001.jpg
inflating: Big_Cat/train/CARACAL/236.jpg
inflating: Big_Cat/train/CHEETAH/001.jpg
inflating: Big_Cat/train/CHEETAH/235.jpg
inflating: Big_Cat/train/CLOUDED LEOPARD/001.jpg
inflating: Big_Cat/train/CLOUDED LEOPARD/229.jpg
inflating: Big_Cat/train/JAGUAR/001.jpg
inflating: Big_Cat/train/JAGUAR/238.jpg
inflating: Big_Cat/train/LIONS/001.jpg
inflating: Big_Cat/train/LIONS/228.jpg
inflating: Big_Cat/train/OCELOT/001.jpg
inflating: Big_Cat/train/OCELOT/233.jpg
inflating: Big_Cat/train/PUMA/001.jpg
inflating: Big_Cat/train/PUMA/236.jpg
inflating: Big_Cat/train/SNOW LEOPARD/001.jpg
inflating: Big_Cat/train/SNOW LEOPARD/231.jpg
inflating: Big_Cat/train/TIGER/001.jpg
inflating: Big_Cat/train/TIGER/237.jpg
inflating: Big_Cat/valid/AFRICAN LEOPARD/1.jpg
inflating: Big_Cat/valid/AFRICAN LEOPARD/5.jpg
inflating: Big_Cat/valid/CARACAL/1.jpg
inflating: Big_Cat/valid/CARACAL/5.jpg
inflating: Big_Cat/valid/CHEETAH/1.jpg
inflating: Big_Cat/valid/CHEETAH/5.jpg
inflating: Big_Cat/valid/CLOUDED LEOPARD/1.jpg
inflating: Big_Cat/valid/CLOUDED LEOPARD/5.jpg
inflating: Big_Cat/valid/JAGUAR/1.jpg
inflating: Big_Cat/valid/JAGUAR/5.jpg
inflating: Big_Cat/valid/LIONS/1.jpg
inflating: Big_Cat/valid/LIONS/5.jpg
inflating: Big_Cat/valid/OCELOT/1.jpg
inflating: Big_Cat/valid/OCELOT/5.jpg
inflating: Big_Cat/valid/PUMA/1.jpg
inflating: Big_Cat/valid/PUMA/5.jpg
inflating: Big_Cat/valid/SNOW LEOPARD/1.jpg
inflating: Big_Cat/valid/SNOW LEOPARD/5.jpg
inflating: Big_Cat/valid/TIGER/1.jpg
inflating: Big_Cat/valid/TIGER/5.jpg
Remove all files with the “.h5” extension and move images from the “valid” folder to their corresponding subfolders in the “train” folder
import os
import shutil
# Define the absolute paths of main, train and valid folders
= '/content/Big_Cat/'
big_cat_folder = '/content/Big_Cat/train'
train_folder = '/content/Big_Cat/valid'
valid_folder
# Remove all files with ".h5" extension
!find {big_cat_folder} -type f -name '*.h5' -delete
# Create a dictionary to track image counts
= {}
image_counts
# Get a list of subfolders in the train folder
= [f.path for f in os.scandir(train_folder) if f.is_dir()]
train_subfolders
# Get a list of subfolders in the valid folder
= [f.path for f in os.scandir(valid_folder) if f.is_dir()]
valid_subfolders
# Move images from valid to their corresponding subfolders in train
for subfolder in valid_subfolders:
= os.path.basename(subfolder)
class_name = os.path.join(train_folder, class_name)
train_subfolder
# Create the train subfolder if it doesn't exist
if not os.path.exists(train_subfolder):
os.makedirs(train_subfolder)
# Initialize counts in the dictionary
f"{class_name}_VALID"] = len(os.listdir(subfolder))
image_counts[
# Move images from valid to train subfolder and update counts
for file in os.listdir(subfolder):
= os.path.join(subfolder, file)
file_path = os.path.join(train_subfolder, file)
dest_path
shutil.move(file_path, dest_path)
# Remove the empty valid subfolders
for subfolder in valid_subfolders:
os.rmdir(subfolder)
# Get count of images in train folder so that we can understand on how much we are training
for subfolder in train_subfolders:
= os.path.basename(subfolder)
class_name f"{class_name}_TRAIN"] = len(os.listdir(subfolder))
image_counts[
= dict(sorted(image_counts.items()))
sorted_image_counts
# Print the sorted image counts dictionary
print("Sorted Image Counts:")
for key, value in sorted_image_counts.items():
print(f"{key}: {value}")
Sorted Image Counts:
AFRICAN LEOPARD_TRAIN: 241
AFRICAN LEOPARD_VALID: 5
CARACAL_TRAIN: 241
CARACAL_VALID: 5
CHEETAH_TRAIN: 240
CHEETAH_VALID: 5
CLOUDED LEOPARD_TRAIN: 234
CLOUDED LEOPARD_VALID: 5
JAGUAR_TRAIN: 243
JAGUAR_VALID: 5
LIONS_TRAIN: 233
LIONS_VALID: 5
OCELOT_TRAIN: 238
OCELOT_VALID: 5
PUMA_TRAIN: 241
PUMA_VALID: 5
SNOW LEOPARD_TRAIN: 236
SNOW LEOPARD_VALID: 5
TIGER_TRAIN: 242
TIGER_VALID: 5
Install latest FastAI Version
#hide
! [ -e /content ] && pip install -Uqq fastbook
! pip install timm
import fastbook
fastbook.setup_book()import timm
#hide
from fastbook import *
from fastai.vision.widgets import *
from fastai.vision.all import *
[2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m719.8/719.8 kB[0m [31m7.1 MB/s[0m eta [36m0:00:00[0m
[2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m510.5/510.5 kB[0m [31m11.6 MB/s[0m eta [36m0:00:00[0m
[2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m13.7 MB/s[0m eta [36m0:00:00[0m
[2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
[2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m12.0 MB/s[0m eta [36m0:00:00[0m
[2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m19.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting timm
Downloading timm-0.9.16-py3-none-any.whl (2.2 MB)
[2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m15.2 MB/s[0m eta [36m0:00:00[0m
[?25hRequirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (from timm) (2.2.1+cu121)
Requirement already satisfied: torchvision in /usr/local/lib/python3.10/dist-packages (from timm) (0.17.1+cu121)
Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from timm) (6.0.1)
Requirement already satisfied: huggingface_hub in /usr/local/lib/python3.10/dist-packages (from timm) (0.20.3)
Requirement already satisfied: safetensors in /usr/local/lib/python3.10/dist-packages (from timm) (0.4.2)
Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface_hub->timm) (3.13.3)
Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface_hub->timm) (2023.6.0)
Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from huggingface_hub->timm) (2.31.0)
Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.10/dist-packages (from huggingface_hub->timm) (4.66.2)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface_hub->timm) (4.10.0)
Requirement already satisfied: packaging>=20.9 in /usr/local/lib/python3.10/dist-packages (from huggingface_hub->timm) (24.0)
Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch->timm) (1.12)
Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch->timm) (3.2.1)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch->timm) (3.1.3)
Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch->timm) (12.1.105)
Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch->timm) (12.1.105)
Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch->timm) (12.1.105)
Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /usr/local/lib/python3.10/dist-packages (from torch->timm) (8.9.2.26)
Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /usr/local/lib/python3.10/dist-packages (from torch->timm) (12.1.3.1)
Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /usr/local/lib/python3.10/dist-packages (from torch->timm) (11.0.2.54)
Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /usr/local/lib/python3.10/dist-packages (from torch->timm) (10.3.2.106)
Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /usr/local/lib/python3.10/dist-packages (from torch->timm) (11.4.5.107)
Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /usr/local/lib/python3.10/dist-packages (from torch->timm) (12.1.0.106)
Requirement already satisfied: nvidia-nccl-cu12==2.19.3 in /usr/local/lib/python3.10/dist-packages (from torch->timm) (2.19.3)
Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch->timm) (12.1.105)
Requirement already satisfied: triton==2.2.0 in /usr/local/lib/python3.10/dist-packages (from torch->timm) (2.2.0)
Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.10/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->torch->timm) (12.4.127)
Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from torchvision->timm) (1.25.2)
Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.10/dist-packages (from torchvision->timm) (9.4.0)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch->timm) (2.1.5)
Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface_hub->timm) (3.3.2)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface_hub->timm) (3.6)
Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface_hub->timm) (2.0.7)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface_hub->timm) (2024.2.2)
Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch->timm) (1.3.0)
Installing collected packages: timm
Successfully installed timm-0.9.16
Mounted at /content/gdrive
verify_images()
will return path of images which are corrupt and using unlink
we can remove these files.
= Path('Big_Cat')
path
= get_image_files(path)
fns = len(fns)
total_imagelength = verify_images(fns)
failed = len(failed)
failed_imagelength
map(Path.unlink)
failed.= {"Total_Image_Count": total_imagelength, "Failed_Image_Count": failed_imagelength}
Image_Count_Dict Image_Count_Dict
/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
self.pid = os.fork()
{'Total_Image_Count': 2439, 'Failed_Image_Count': 0}
We have good chunk of images to be trained on
1.2 : Prepare data for model training (Data Loaders, Data Augmentaion, etc.).
Data Loaders
= DataBlock(
big_cat =(ImageBlock, CategoryBlock),
blocks=get_image_files,
get_items=RandomSplitter(valid_pct=0.2, seed=42),
splitter=parent_label,
get_y=Resize(128))
item_tfms= big_cat.dataloaders(path)
dls
=8, nrows=2) dls.valid.show_batch(max_n
Data Augmentation
= big_cat.new(
big_cat =RandomResizedCrop(224, min_scale=0.5),
item_tfms=aug_transforms())
batch_tfms= big_cat.dataloaders(path)
big_cat_dls =8, nrows=2) big_cat_dls.train.show_batch(max_n
1.3 : Train Model
= vision_learner(big_cat_dls, resnet34, metrics=error_rate)
learn 5) learn.fine_tune(
Downloading: "https://download.pytorch.org/models/resnet34-b627a593.pth" to /root/.cache/torch/hub/checkpoints/resnet34-b627a593.pth
100%|██████████| 83.3M/83.3M [00:00<00:00, 164MB/s]
epoch | train_loss | valid_loss | error_rate | time |
---|---|---|---|---|
0 | 1.837715 | 0.237122 | 0.088296 | 00:13 |
epoch | train_loss | valid_loss | error_rate | time |
---|---|---|---|---|
0 | 0.565063 | 0.141556 | 0.041068 | 00:18 |
1 | 0.431256 | 0.147855 | 0.051335 | 00:20 |
2 | 0.337753 | 0.119496 | 0.036961 | 00:17 |
3 | 0.268090 | 0.115800 | 0.032854 | 00:14 |
understand structure of model
learn.summary()
Sequential (Input shape: 64 x 3 x 224 x 224)
============================================================================
Layer (type) Output Shape Param # Trainable
============================================================================
64 x 64 x 112 x 112
Conv2d 9408 True
BatchNorm2d 128 True
ReLU
____________________________________________________________________________
64 x 64 x 56 x 56
MaxPool2d
Conv2d 36864 True
BatchNorm2d 128 True
ReLU
Conv2d 36864 True
BatchNorm2d 128 True
Conv2d 36864 True
BatchNorm2d 128 True
ReLU
Conv2d 36864 True
BatchNorm2d 128 True
Conv2d 36864 True
BatchNorm2d 128 True
ReLU
Conv2d 36864 True
BatchNorm2d 128 True
____________________________________________________________________________
64 x 128 x 28 x 28
Conv2d 73728 True
BatchNorm2d 256 True
ReLU
Conv2d 147456 True
BatchNorm2d 256 True
Conv2d 8192 True
BatchNorm2d 256 True
Conv2d 147456 True
BatchNorm2d 256 True
ReLU
Conv2d 147456 True
BatchNorm2d 256 True
Conv2d 147456 True
BatchNorm2d 256 True
ReLU
Conv2d 147456 True
BatchNorm2d 256 True
Conv2d 147456 True
BatchNorm2d 256 True
ReLU
Conv2d 147456 True
BatchNorm2d 256 True
____________________________________________________________________________
64 x 256 x 14 x 14
Conv2d 294912 True
BatchNorm2d 512 True
ReLU
Conv2d 589824 True
BatchNorm2d 512 True
Conv2d 32768 True
BatchNorm2d 512 True
Conv2d 589824 True
BatchNorm2d 512 True
ReLU
Conv2d 589824 True
BatchNorm2d 512 True
Conv2d 589824 True
BatchNorm2d 512 True
ReLU
Conv2d 589824 True
BatchNorm2d 512 True
Conv2d 589824 True
BatchNorm2d 512 True
ReLU
Conv2d 589824 True
BatchNorm2d 512 True
Conv2d 589824 True
BatchNorm2d 512 True
ReLU
Conv2d 589824 True
BatchNorm2d 512 True
Conv2d 589824 True
BatchNorm2d 512 True
ReLU
Conv2d 589824 True
BatchNorm2d 512 True
____________________________________________________________________________
64 x 512 x 7 x 7
Conv2d 1179648 True
BatchNorm2d 1024 True
ReLU
Conv2d 2359296 True
BatchNorm2d 1024 True
Conv2d 131072 True
BatchNorm2d 1024 True
Conv2d 2359296 True
BatchNorm2d 1024 True
ReLU
Conv2d 2359296 True
BatchNorm2d 1024 True
Conv2d 2359296 True
BatchNorm2d 1024 True
ReLU
Conv2d 2359296 True
BatchNorm2d 1024 True
____________________________________________________________________________
64 x 512 x 1 x 1
AdaptiveAvgPool2d
AdaptiveMaxPool2d
____________________________________________________________________________
64 x 1024
Flatten
BatchNorm1d 2048 True
Dropout
____________________________________________________________________________
64 x 512
Linear 524288 True
ReLU
BatchNorm1d 1024 True
Dropout
____________________________________________________________________________
64 x 10
Linear 5120 True
____________________________________________________________________________
Total params: 21,817,152
Total trainable params: 21,817,152
Total non-trainable params: 0
Optimizer used: <function Adam at 0x7b5a37dbbeb0>
Loss function: FlattenedLoss of CrossEntropyLoss()
Model unfrozen
Callbacks:
- TrainEvalCallback
- CastToTensor
- Recorder
- ProgressCallback
Confusion Metric
= ClassificationInterpretation.from_learner(learn)
interp interp.plot_confusion_matrix()
Display Images with highest loss, to Get the picture 😊
6, nrows=2, figsize=(18,4)) interp.plot_top_losses(
We can observe from both the confusion matrix and visual representation that the model is having difficulty differentiating between the Jaguar and the African Leopard. Even I find it challenging to distinguish between the two. 😵 So, we can let it be.
1.4 : Clear the data
#hide_output
= ImageClassifierCleaner(learn)
cleaner cleaner
VBox(children=(Dropdown(options=('AFRICAN LEOPARD', 'CARACAL', 'CHEETAH', 'CLOUDED LEOPARD', 'JAGUAR', 'LIONS'…
Apply those changes
for idx in cleaner.delete(): cleaner.fns[idx].unlink()
for idx,cat in cleaner.change(): shutil.move(str(cleaner.fns[idx]), path/cat)
Level 2 : Understand Computer Vision Architectures
timm is a wonderful library by Ross Wightman which provides state-of-the-art pre-trained computer vision models. It’s like Huggingface Transformers, but for computer vision instead of NLP.
2.1 : Download Data
Let’s download Ross’s GitHub repository, which is regularly updated with benchmark data for computer vision architectures. These benchmark are created on Imagenet.
! git clone --depth 1 https://github.com/rwightman/pytorch-image-models.git
%cd pytorch-image-models/results
Cloning into 'pytorch-image-models'...
remote: Enumerating objects: 572, done.[K
remote: Counting objects: 100% (572/572), done.[K
remote: Compressing objects: 100% (403/403), done.[K
remote: Total 572 (delta 222), reused 341 (delta 163), pack-reused 0[K
Receiving objects: 100% (572/572), 2.59 MiB | 4.87 MiB/s, done.
Resolving deltas: 100% (222/222), done.
/content/pytorch-image-models/results/pytorch-image-models/results
import pandas as pd
= pd.read_csv('results-imagenet.csv')
Benchmark_Result 'model_org'] = Benchmark_Result['model']
Benchmark_Result['model'] = Benchmark_Result['model'].str.split('.').str[0]
Benchmark_Result[5) Benchmark_Result.head(
model | top1 | top1_err | top5 | top5_err | param_count | img_size | crop_pct | interpolation | model_org | |
---|---|---|---|---|---|---|---|---|---|---|
0 | eva02_large_patch14_448 | 90.052 | 9.948 | 99.048 | 0.952 | 305.08 | 448 | 1.0 | bicubic | eva02_large_patch14_448.mim_m38m_ft_in22k_in1k |
1 | eva02_large_patch14_448 | 89.970 | 10.030 | 99.012 | 0.988 | 305.08 | 448 | 1.0 | bicubic | eva02_large_patch14_448.mim_in22k_ft_in22k_in1k |
2 | eva_giant_patch14_560 | 89.786 | 10.214 | 98.992 | 1.008 | 1,014.45 | 560 | 1.0 | bicubic | eva_giant_patch14_560.m30m_ft_in22k_in1k |
3 | eva02_large_patch14_448 | 89.622 | 10.378 | 98.950 | 1.050 | 305.08 | 448 | 1.0 | bicubic | eva02_large_patch14_448.mim_in22k_ft_in1k |
4 | eva02_large_patch14_448 | 89.574 | 10.426 | 98.924 | 1.076 | 305.08 | 448 | 1.0 | bicubic | eva02_large_patch14_448.mim_m38m_ft_in1k |
Let’s add a “family” column that will allow us to group architectures into categories with similar characteristics:
def get_data(part, col):
= pd.read_csv(f'benchmark-{part}-amp-nhwc-pt111-cu113-rtx3090.csv').merge(Benchmark_Result, on='model')
df 'secs'] = 1. / df[col]
df['family'] = df.model.str.extract('^([a-z]+?(?:v2)?)(?:\d|_|$)')
df[= df[~df.model.str.endswith('gn')]
df str.contains('in22'),'family'] = df.loc[df.model.str.contains('in22'),'family'] + '_in22'
df.loc[df.model.str.contains('resnet.*d'),'family'] = df.loc[df.model.str.contains('resnet.*d'),'family'] + 'd'
df.loc[df.model.return df[df.family.str.contains('^re[sg]netd?|beit|convnext|levit|efficient|vit|vgg|swin')]
= get_data('infer', 'infer_samples_per_sec')
Inference_Data 5) Inference_Data.head(
model | infer_samples_per_sec | infer_step_time | infer_batch_size | infer_img_size | param_count_x | top1 | top1_err | top5 | top5_err | param_count_y | img_size | crop_pct | interpolation | model_org | secs | family | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
12 | levit_128s | 21485.80 | 47.648 | 1024 | 224 | 7.78 | 76.526 | 23.474 | 92.872 | 7.128 | 7.78 | 224 | 0.900 | bicubic | levit_128s.fb_dist_in1k | 0.000047 | levit |
13 | regnetx_002 | 17821.98 | 57.446 | 1024 | 224 | 2.68 | 68.752 | 31.248 | 88.542 | 11.458 | 2.68 | 224 | 0.875 | bicubic | regnetx_002.pycls_in1k | 0.000056 | regnetx |
15 | regnety_002 | 16673.08 | 61.405 | 1024 | 224 | 3.16 | 70.280 | 29.720 | 89.530 | 10.470 | 3.16 | 224 | 0.875 | bicubic | regnety_002.pycls_in1k | 0.000060 | regnety |
17 | levit_128 | 14657.83 | 69.849 | 1024 | 224 | 9.21 | 78.490 | 21.510 | 94.012 | 5.988 | 9.21 | 224 | 0.900 | bicubic | levit_128.fb_dist_in1k | 0.000068 | levit |
18 | regnetx_004 | 14440.03 | 70.903 | 1024 | 224 | 5.16 | 72.402 | 27.598 | 90.826 | 9.174 | 5.16 | 224 | 0.875 | bicubic | regnetx_004.pycls_in1k | 0.000069 | regnetx |
2.2 : Plot of all the architectures.
Here’s the results for inference performance (see the last section for training performance). In this chart:
the x axis shows how many seconds it takes to process one image (note: it’s a log scale)
the y axis is the accuracy on Imagenet
the size of each bubble is proportional to the size of images used in testing
the color shows what “family” the architecture is from.
Hover your mouse over a marker to see details about the model. Double-click in the legend to display just one family. Single-click in the legend to show or hide a family.
import plotly.express as px
= 1000,800
w,h
def show_all(Inference_Data, title, size):
return px.scatter(Inference_Data, width=w, height=h, size=Inference_Data[size]**2, title=title,
='secs', y='top1', log_x=True, color='family', hover_name='model_org', hover_data=[size]) x
'Inference', 'infer_img_size') show_all(Inference_Data,
2.3 : Specific Architectures Plot
Let’s create a plot for selected architectures which we would like to use normally
# Filter data only for convnext, resnet
= ['convnext', 'resnet','levit','beit']
keywords
# Filter rows based on the exact keywords
= Inference_Data[Inference_Data['family'].isin(keywords)]
Best_Model_Df
'Inference', 'infer_img_size') show_all(Best_Model_Df,
2.4 : Family Connection Plot
Let’s add lines through the points of each family, to help see how they compare – but note that we can see that a linear fit isn’t actually ideal here! It’s just there to help visually see the groups.
= 'levit|resnetd?|regnetx|vgg|convnext.*|efficientnetv2|beit|swin'
subs
def show_subs(Inference_Data, title, size):
= Inference_Data[Inference_Data.family.str.fullmatch(subs)]
df_subs return px.scatter(df_subs, width=w, height=h, size=df_subs[size]**2, title=title,
="ols", trendline_options={'log_x':True},
trendline='secs', y='top1', log_x=True, color='family', hover_name='model_org', hover_data=[size]) x
'Inference', 'infer_img_size') show_subs(Inference_Data,
We can conclude that Convnext
can be go to model with decent GPU at our disposal, because it has more accuracy then resenet and it take less time than beit
Level 3 : Build a model using Convnext(basic or tiny)
List of all the basic & tiny version models in Convnext
and choose the best.
for model in timm.list_models('convnext*') if 'base' in model or 'tiny' in model] [model
['convnext_base',
'convnext_tiny',
'convnext_tiny_hnf',
'convnextv2_base',
'convnextv2_tiny']
= vision_learner(dls, convnext_base, metrics=error_rate).to_fp16()
learn_conv 5) learn_conv.fine_tune(
/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
warnings.warn(
/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ConvNeXt_Base_Weights.IMAGENET1K_V1`. You can also use `weights=ConvNeXt_Base_Weights.DEFAULT` to get the most up-to-date weights.
warnings.warn(msg)
epoch | train_loss | valid_loss | error_rate | time |
---|---|---|---|---|
0 | 1.130588 | 0.183650 | 0.045175 | 00:17 |
/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
self.pid = os.fork()
/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
self.pid = os.fork()
epoch | train_loss | valid_loss | error_rate | time |
---|---|---|---|---|
0 | 0.306945 | 0.164550 | 0.047228 | 00:20 |
1 | 0.253462 | 0.118687 | 0.026694 | 00:11 |
2 | 0.207020 | 0.103058 | 0.024641 | 00:11 |
3 | 0.175618 | 0.094374 | 0.020534 | 00:11 |
4 | 0.151438 | 0.092010 | 0.020534 | 00:13 |
Compared to the resnet34
model, which had an error rate of 32%, the convnext_base
model demonstrates a significant improvement with an error rate of just 21%
Structure of the architecture
learn_conv.summary()
Sequential (Input shape: 64 x 3 x 128 x 128)
============================================================================
Layer (type) Output Shape Param # Trainable
============================================================================
64 x 128 x 32 x 32
Conv2d 6272 True
LayerNorm2d 256 True
Conv2d 6400 True
____________________________________________________________________________
64 x 32 x 32 x 128
Permute
LayerNorm 256 True
____________________________________________________________________________
64 x 32 x 32 x 512
Linear 66048 True
GELU
____________________________________________________________________________
64 x 32 x 32 x 128
Linear 65664 True
____________________________________________________________________________
64 x 128 x 32 x 32
Permute
StochasticDepth
Conv2d 6400 True
____________________________________________________________________________
64 x 32 x 32 x 128
Permute
LayerNorm 256 True
____________________________________________________________________________
64 x 32 x 32 x 512
Linear 66048 True
GELU
____________________________________________________________________________
64 x 32 x 32 x 128
Linear 65664 True
____________________________________________________________________________
64 x 128 x 32 x 32
Permute
StochasticDepth
Conv2d 6400 True
____________________________________________________________________________
64 x 32 x 32 x 128
Permute
LayerNorm 256 True
____________________________________________________________________________
64 x 32 x 32 x 512
Linear 66048 True
GELU
____________________________________________________________________________
64 x 32 x 32 x 128
Linear 65664 True
____________________________________________________________________________
64 x 128 x 32 x 32
Permute
StochasticDepth
LayerNorm2d 256 True
____________________________________________________________________________
64 x 256 x 16 x 16
Conv2d 131328 True
Conv2d 12800 True
____________________________________________________________________________
64 x 16 x 16 x 256
Permute
LayerNorm 512 True
____________________________________________________________________________
64 x 16 x 16 x 1024
Linear 263168 True
GELU
____________________________________________________________________________
64 x 16 x 16 x 256
Linear 262400 True
____________________________________________________________________________
64 x 256 x 16 x 16
Permute
StochasticDepth
Conv2d 12800 True
____________________________________________________________________________
64 x 16 x 16 x 256
Permute
LayerNorm 512 True
____________________________________________________________________________
64 x 16 x 16 x 1024
Linear 263168 True
GELU
____________________________________________________________________________
64 x 16 x 16 x 256
Linear 262400 True
____________________________________________________________________________
64 x 256 x 16 x 16
Permute
StochasticDepth
Conv2d 12800 True
____________________________________________________________________________
64 x 16 x 16 x 256
Permute
LayerNorm 512 True
____________________________________________________________________________
64 x 16 x 16 x 1024
Linear 263168 True
GELU
____________________________________________________________________________
64 x 16 x 16 x 256
Linear 262400 True
____________________________________________________________________________
64 x 256 x 16 x 16
Permute
StochasticDepth
LayerNorm2d 512 True
____________________________________________________________________________
64 x 512 x 8 x 8
Conv2d 524800 True
Conv2d 25600 True
____________________________________________________________________________
64 x 8 x 8 x 512
Permute
LayerNorm 1024 True
____________________________________________________________________________
64 x 8 x 8 x 2048
Linear 1050624 True
GELU
____________________________________________________________________________
64 x 8 x 8 x 512
Linear 1049088 True
____________________________________________________________________________
64 x 512 x 8 x 8
Permute
StochasticDepth
Conv2d 25600 True
____________________________________________________________________________
64 x 8 x 8 x 512
Permute
LayerNorm 1024 True
____________________________________________________________________________
64 x 8 x 8 x 2048
Linear 1050624 True
GELU
____________________________________________________________________________
64 x 8 x 8 x 512
Linear 1049088 True
____________________________________________________________________________
64 x 512 x 8 x 8
Permute
StochasticDepth
Conv2d 25600 True
____________________________________________________________________________
64 x 8 x 8 x 512
Permute
LayerNorm 1024 True
____________________________________________________________________________
64 x 8 x 8 x 2048
Linear 1050624 True
GELU
____________________________________________________________________________
64 x 8 x 8 x 512
Linear 1049088 True
____________________________________________________________________________
64 x 512 x 8 x 8
Permute
StochasticDepth
Conv2d 25600 True
____________________________________________________________________________
64 x 8 x 8 x 512
Permute
LayerNorm 1024 True
____________________________________________________________________________
64 x 8 x 8 x 2048
Linear 1050624 True
GELU
____________________________________________________________________________
64 x 8 x 8 x 512
Linear 1049088 True
____________________________________________________________________________
64 x 512 x 8 x 8
Permute
StochasticDepth
Conv2d 25600 True
____________________________________________________________________________
64 x 8 x 8 x 512
Permute
LayerNorm 1024 True
____________________________________________________________________________
64 x 8 x 8 x 2048
Linear 1050624 True
GELU
____________________________________________________________________________
64 x 8 x 8 x 512
Linear 1049088 True
____________________________________________________________________________
64 x 512 x 8 x 8
Permute
StochasticDepth
Conv2d 25600 True
____________________________________________________________________________
64 x 8 x 8 x 512
Permute
LayerNorm 1024 True
____________________________________________________________________________
64 x 8 x 8 x 2048
Linear 1050624 True
GELU
____________________________________________________________________________
64 x 8 x 8 x 512
Linear 1049088 True
____________________________________________________________________________
64 x 512 x 8 x 8
Permute
StochasticDepth
Conv2d 25600 True
____________________________________________________________________________
64 x 8 x 8 x 512
Permute
LayerNorm 1024 True
____________________________________________________________________________
64 x 8 x 8 x 2048
Linear 1050624 True
GELU
____________________________________________________________________________
64 x 8 x 8 x 512
Linear 1049088 True
____________________________________________________________________________
64 x 512 x 8 x 8
Permute
StochasticDepth
Conv2d 25600 True
____________________________________________________________________________
64 x 8 x 8 x 512
Permute
LayerNorm 1024 True
____________________________________________________________________________
64 x 8 x 8 x 2048
Linear 1050624 True
GELU
____________________________________________________________________________
64 x 8 x 8 x 512
Linear 1049088 True
____________________________________________________________________________
64 x 512 x 8 x 8
Permute
StochasticDepth
Conv2d 25600 True
____________________________________________________________________________
64 x 8 x 8 x 512
Permute
LayerNorm 1024 True
____________________________________________________________________________
64 x 8 x 8 x 2048
Linear 1050624 True
GELU
____________________________________________________________________________
64 x 8 x 8 x 512
Linear 1049088 True
____________________________________________________________________________
64 x 512 x 8 x 8
Permute
StochasticDepth
Conv2d 25600 True
____________________________________________________________________________
64 x 8 x 8 x 512
Permute
LayerNorm 1024 True
____________________________________________________________________________
64 x 8 x 8 x 2048
Linear 1050624 True
GELU
____________________________________________________________________________
64 x 8 x 8 x 512
Linear 1049088 True
____________________________________________________________________________
64 x 512 x 8 x 8
Permute
StochasticDepth
Conv2d 25600 True
____________________________________________________________________________
64 x 8 x 8 x 512
Permute
LayerNorm 1024 True
____________________________________________________________________________
64 x 8 x 8 x 2048
Linear 1050624 True
GELU
____________________________________________________________________________
64 x 8 x 8 x 512
Linear 1049088 True
____________________________________________________________________________
64 x 512 x 8 x 8
Permute
StochasticDepth
Conv2d 25600 True
____________________________________________________________________________
64 x 8 x 8 x 512
Permute
LayerNorm 1024 True
____________________________________________________________________________
64 x 8 x 8 x 2048
Linear 1050624 True
GELU
____________________________________________________________________________
64 x 8 x 8 x 512
Linear 1049088 True
____________________________________________________________________________
64 x 512 x 8 x 8
Permute
StochasticDepth
Conv2d 25600 True
____________________________________________________________________________
64 x 8 x 8 x 512
Permute
LayerNorm 1024 True
____________________________________________________________________________
64 x 8 x 8 x 2048
Linear 1050624 True
GELU
____________________________________________________________________________
64 x 8 x 8 x 512
Linear 1049088 True
____________________________________________________________________________
64 x 512 x 8 x 8
Permute
StochasticDepth
Conv2d 25600 True
____________________________________________________________________________
64 x 8 x 8 x 512
Permute
LayerNorm 1024 True
____________________________________________________________________________
64 x 8 x 8 x 2048
Linear 1050624 True
GELU
____________________________________________________________________________
64 x 8 x 8 x 512
Linear 1049088 True
____________________________________________________________________________
64 x 512 x 8 x 8
Permute
StochasticDepth
Conv2d 25600 True
____________________________________________________________________________
64 x 8 x 8 x 512
Permute
LayerNorm 1024 True
____________________________________________________________________________
64 x 8 x 8 x 2048
Linear 1050624 True
GELU
____________________________________________________________________________
64 x 8 x 8 x 512
Linear 1049088 True
____________________________________________________________________________
64 x 512 x 8 x 8
Permute
StochasticDepth
Conv2d 25600 True
____________________________________________________________________________
64 x 8 x 8 x 512
Permute
LayerNorm 1024 True
____________________________________________________________________________
64 x 8 x 8 x 2048
Linear 1050624 True
GELU
____________________________________________________________________________
64 x 8 x 8 x 512
Linear 1049088 True
____________________________________________________________________________
64 x 512 x 8 x 8
Permute
StochasticDepth
Conv2d 25600 True
____________________________________________________________________________
64 x 8 x 8 x 512
Permute
LayerNorm 1024 True
____________________________________________________________________________
64 x 8 x 8 x 2048
Linear 1050624 True
GELU
____________________________________________________________________________
64 x 8 x 8 x 512
Linear 1049088 True
____________________________________________________________________________
64 x 512 x 8 x 8
Permute
StochasticDepth
Conv2d 25600 True
____________________________________________________________________________
64 x 8 x 8 x 512
Permute
LayerNorm 1024 True
____________________________________________________________________________
64 x 8 x 8 x 2048
Linear 1050624 True
GELU
____________________________________________________________________________
64 x 8 x 8 x 512
Linear 1049088 True
____________________________________________________________________________
64 x 512 x 8 x 8
Permute
StochasticDepth
Conv2d 25600 True
____________________________________________________________________________
64 x 8 x 8 x 512
Permute
LayerNorm 1024 True
____________________________________________________________________________
64 x 8 x 8 x 2048
Linear 1050624 True
GELU
____________________________________________________________________________
64 x 8 x 8 x 512
Linear 1049088 True
____________________________________________________________________________
64 x 512 x 8 x 8
Permute
StochasticDepth
Conv2d 25600 True
____________________________________________________________________________
64 x 8 x 8 x 512
Permute
LayerNorm 1024 True
____________________________________________________________________________
64 x 8 x 8 x 2048
Linear 1050624 True
GELU
____________________________________________________________________________
64 x 8 x 8 x 512
Linear 1049088 True
____________________________________________________________________________
64 x 512 x 8 x 8
Permute
StochasticDepth
Conv2d 25600 True
____________________________________________________________________________
64 x 8 x 8 x 512
Permute
LayerNorm 1024 True
____________________________________________________________________________
64 x 8 x 8 x 2048
Linear 1050624 True
GELU
____________________________________________________________________________
64 x 8 x 8 x 512
Linear 1049088 True
____________________________________________________________________________
64 x 512 x 8 x 8
Permute
StochasticDepth
Conv2d 25600 True
____________________________________________________________________________
64 x 8 x 8 x 512
Permute
LayerNorm 1024 True
____________________________________________________________________________
64 x 8 x 8 x 2048
Linear 1050624 True
GELU
____________________________________________________________________________
64 x 8 x 8 x 512
Linear 1049088 True
____________________________________________________________________________
64 x 512 x 8 x 8
Permute
StochasticDepth
Conv2d 25600 True
____________________________________________________________________________
64 x 8 x 8 x 512
Permute
LayerNorm 1024 True
____________________________________________________________________________
64 x 8 x 8 x 2048
Linear 1050624 True
GELU
____________________________________________________________________________
64 x 8 x 8 x 512
Linear 1049088 True
____________________________________________________________________________
64 x 512 x 8 x 8
Permute
StochasticDepth
Conv2d 25600 True
____________________________________________________________________________
64 x 8 x 8 x 512
Permute
LayerNorm 1024 True
____________________________________________________________________________
64 x 8 x 8 x 2048
Linear 1050624 True
GELU
____________________________________________________________________________
64 x 8 x 8 x 512
Linear 1049088 True
____________________________________________________________________________
64 x 512 x 8 x 8
Permute
StochasticDepth
Conv2d 25600 True
____________________________________________________________________________
64 x 8 x 8 x 512
Permute
LayerNorm 1024 True
____________________________________________________________________________
64 x 8 x 8 x 2048
Linear 1050624 True
GELU
____________________________________________________________________________
64 x 8 x 8 x 512
Linear 1049088 True
____________________________________________________________________________
64 x 512 x 8 x 8
Permute
StochasticDepth
Conv2d 25600 True
____________________________________________________________________________
64 x 8 x 8 x 512
Permute
LayerNorm 1024 True
____________________________________________________________________________
64 x 8 x 8 x 2048
Linear 1050624 True
GELU
____________________________________________________________________________
64 x 8 x 8 x 512
Linear 1049088 True
____________________________________________________________________________
64 x 512 x 8 x 8
Permute
StochasticDepth
Conv2d 25600 True
____________________________________________________________________________
64 x 8 x 8 x 512
Permute
LayerNorm 1024 True
____________________________________________________________________________
64 x 8 x 8 x 2048
Linear 1050624 True
GELU
____________________________________________________________________________
64 x 8 x 8 x 512
Linear 1049088 True
____________________________________________________________________________
64 x 512 x 8 x 8
Permute
StochasticDepth
LayerNorm2d 1024 True
____________________________________________________________________________
64 x 1024 x 4 x 4
Conv2d 2098176 True
Conv2d 51200 True
____________________________________________________________________________
64 x 4 x 4 x 1024
Permute
LayerNorm 2048 True
____________________________________________________________________________
64 x 4 x 4 x 4096
Linear 4198400 True
GELU
____________________________________________________________________________
64 x 4 x 4 x 1024
Linear 4195328 True
____________________________________________________________________________
64 x 1024 x 4 x 4
Permute
StochasticDepth
Conv2d 51200 True
____________________________________________________________________________
64 x 4 x 4 x 1024
Permute
LayerNorm 2048 True
____________________________________________________________________________
64 x 4 x 4 x 4096
Linear 4198400 True
GELU
____________________________________________________________________________
64 x 4 x 4 x 1024
Linear 4195328 True
____________________________________________________________________________
64 x 1024 x 4 x 4
Permute
StochasticDepth
Conv2d 51200 True
____________________________________________________________________________
64 x 4 x 4 x 1024
Permute
LayerNorm 2048 True
____________________________________________________________________________
64 x 4 x 4 x 4096
Linear 4198400 True
GELU
____________________________________________________________________________
64 x 4 x 4 x 1024
Linear 4195328 True
____________________________________________________________________________
64 x 1024 x 4 x 4
Permute
StochasticDepth
____________________________________________________________________________
64 x 1024 x 1 x 1
AdaptiveAvgPool2d
AdaptiveMaxPool2d
____________________________________________________________________________
64 x 2048
Flatten
BatchNorm1d 4096 True
Dropout
____________________________________________________________________________
64 x 512
Linear 1048576 True
ReLU
BatchNorm1d 1024 True
Dropout
____________________________________________________________________________
64 x 10
Linear 5120 True
____________________________________________________________________________
Total params: 88,605,184
Total trainable params: 88,605,184
Total non-trainable params: 0
Optimizer used: <function Adam at 0x7b5a37dbbeb0>
Loss function: FlattenedLoss of CrossEntropyLoss()
Model unfrozen
Callbacks:
- TrainEvalCallback
- CastToTensor
- MixedPrecision
- Recorder
- ProgressCallback
Let’s downlod the model for future reference
'Big_Cat_Convnext_Model.pkl')
learn_conv.export(#learn_conv.export('/content/drive/MyDrive/Colab Notebooks/FastAI Course/Big_Cat_Convnext_Model.pkl')
Level 4. Test the Model
Let’s test the model with an image
from fastai.vision.all import *
import gradio as gr
= PILImage.create('/content/drive/MyDrive/Colab Notebooks/FastAI Course/SnowLeopard.jpg')
im 224,224))
im.thumbnail(( im
= load_learner('/content/drive/MyDrive/Colab Notebooks/FastAI Course/Big_Cat_Convnext_Model.pkl') learn_conv
learn_conv.predict(im)
('CHEETAH',
tensor(2),
tensor([1.0988e-04, 8.7617e-05, 9.2564e-01, 2.9294e-06, 1.2592e-06, 1.6162e-06,
2.6748e-02, 6.5111e-07, 4.7398e-02, 7.3348e-06]))
learn_conv.dls.vocab
['AFRICAN LEOPARD', 'CARACAL', 'CHEETAH', 'CLOUDED LEOPARD', 'JAGUAR', 'LIONS', 'OCELOT', 'PUMA', 'SNOW LEOPARD', 'TIGER']
= learn_conv.dls.vocab
categories
= learn_conv.predict(im)
pred, idx, probs = dict(zip(categories, map(float,probs)))
result result
{'AFRICAN LEOPARD': 0.00010988322173943743,
'CARACAL': 8.761714707361534e-05,
'CHEETAH': 0.9256432056427002,
'CLOUDED LEOPARD': 2.929433776444057e-06,
'JAGUAR': 1.2592141729328432e-06,
'LIONS': 1.6162448446266353e-06,
'OCELOT': 0.026747871190309525,
'PUMA': 6.511066317216319e-07,
'SNOW LEOPARD': 0.04739758372306824,
'TIGER': 7.334848760365276e-06}
Top 3 Predicted Cat Names with Highest Probability”
= dict(sorted(result.items(), key=lambda item: item[1], reverse=True))
sorted_result = list(sorted_result.keys())[:3]
top_classes top_classes
['CHEETAH', 'SNOW LEOPARD', 'OCELOT']
So our model predicted CHEETAH
with probablity of 93%