{ "cells": [ { "cell_type": "code", "execution_count": 11, "id": "4be470de-020c-4e56-9b8d-1377e2b31e2c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Note: you may need to restart the kernel to use updated packages.\n" ] } ], "source": [ "# %pip install torch pandas torchvision scikit-learn tqdm kaggle torchmetrics -q" ] }, { "cell_type": "code", "execution_count": 1, "id": "e9a6d902-9fe8-47cd-8991-d75f749b6148", "metadata": {}, "outputs": [], "source": [ "# upload kaggle.json first.\n", "!mkdir -p ~/.kaggle\n", "!mv kaggle.json ~/.kaggle/\n", "!chmod 600 ~/.kaggle/kaggle.json" ] }, { "cell_type": "code", "execution_count": 2, "id": "a71ff6d8-54a2-483b-aae2-fdf947305d4a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Dataset URL: https://www.kaggle.com/datasets/nirmalsankalana/sugarcane-leaf-disease-dataset\n", "License(s): CC0-1.0\n" ] } ], "source": [ "# !apt update -qq\n", "# !apt install -qq unzip\n", "!kaggle datasets download nirmalsankalana/sugarcane-leaf-disease-dataset\n", "!unzip -q sugarcane-leaf-disease-dataset.zip -d data" ] }, { "cell_type": "code", "execution_count": 1, "id": "e9a11f4d-c12d-4e5f-aa6e-d37668f4e409", "metadata": {}, "outputs": [], "source": [ "from dataset import get_mean_teacher_dataloaders\n", "\n", "train_loader, test_loader, unlabeled_loader, unlabeled_student_loader = get_mean_teacher_dataloaders('data', 0.2, 16)" ] }, { "cell_type": "code", "execution_count": 2, "id": "1f2be036-3f51-4e18-8b72-e7a7b464f302", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/ubuntu/.local/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from coreplant import Classifier\n", "import torch\n", "\n", "NUM_CLASSES = 5\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "student = Classifier(512, 256, NUM_CLASSES).to(device)\n", "teacher = Classifier(512, 256, NUM_CLASSES).to(device)\n", "\n", "# # Synchronize initial weights\n", "teacher.encoder.load_state_dict(student.encoder.state_dict())\n", "teacher.load_state_dict(student.state_dict())" ] }, { "cell_type": "code", "execution_count": 3, "id": "ef1b3cd2-7082-40cb-94ac-a018f39a8463", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " " ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1 - Train Loss: 0.9669 Acc: 0.5933\n" ] }, { "name": "stderr", "output_type": "stream", "text": [] }, { "name": "stdout", "output_type": "stream", "text": [ "Validation Loss: 1.0751 Acc: 0.7129\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " " ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 2 - Train Loss: 0.6921 Acc: 0.9534\n" ] }, { "name": "stderr", "output_type": "stream", "text": [] }, { "name": "stdout", "output_type": "stream", "text": [ "Validation Loss: 1.0058 Acc: 0.6337\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " " ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 3 - Train Loss: 0.5936 Acc: 0.9871\n" ] }, { "name": "stderr", "output_type": "stream", "text": [] }, { "name": "stdout", "output_type": "stream", "text": [ "Validation Loss: 0.9520 Acc: 0.7327\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " " ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 4 - Train Loss: 0.5735 Acc: 0.9921\n" ] }, { "name": "stderr", "output_type": "stream", "text": [] }, { "name": "stdout", "output_type": "stream", "text": [ "Validation Loss: 1.1093 Acc: 0.7624\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " " ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 5 - Train Loss: 0.5805 Acc: 0.9931\n" ] }, { "name": "stderr", "output_type": "stream", "text": [] }, { "name": "stdout", "output_type": "stream", "text": [ "Validation Loss: 1.2088 Acc: 0.6535\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " " ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 6 - Train Loss: 0.5997 Acc: 0.9906\n" ] }, { "name": "stderr", "output_type": "stream", "text": [] }, { "name": "stdout", "output_type": "stream", "text": [ "Validation Loss: 1.2791 Acc: 0.6832\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " " ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 7 - Train Loss: 0.6033 Acc: 0.9960\n" ] }, { "name": "stderr", "output_type": "stream", "text": [] }, { "name": "stdout", "output_type": "stream", "text": [ "Validation Loss: 1.2851 Acc: 0.7129\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " " ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 8 - Train Loss: 0.6383 Acc: 0.9955\n" ] }, { "name": "stderr", "output_type": "stream", "text": [] }, { "name": "stdout", "output_type": "stream", "text": [ "Validation Loss: 1.2957 Acc: 0.7525\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " " ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 9 - Train Loss: 0.6469 Acc: 0.9965\n" ] }, { "name": "stderr", "output_type": "stream", "text": [] }, { "name": "stdout", "output_type": "stream", "text": [ "Validation Loss: 1.2491 Acc: 0.7327\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " " ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 10 - Train Loss: 0.6597 Acc: 0.9985\n" ] }, { "name": "stderr", "output_type": "stream", "text": [] }, { "name": "stdout", "output_type": "stream", "text": [ "Validation Loss: 1.2907 Acc: 0.7624\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " " ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 11 - Train Loss: 0.6739 Acc: 0.9970\n" ] }, { "name": "stderr", "output_type": "stream", "text": [] }, { "name": "stdout", "output_type": "stream", "text": [ "Validation Loss: 1.2727 Acc: 0.7921\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " " ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 12 - Train Loss: 0.6986 Acc: 0.9965\n" ] }, { "name": "stderr", "output_type": "stream", "text": [] }, { "name": "stdout", "output_type": "stream", "text": [ "Validation Loss: 1.2876 Acc: 0.8020\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " " ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 13 - Train Loss: 0.7012 Acc: 0.9985\n" ] }, { "name": "stderr", "output_type": "stream", "text": [] }, { "name": "stdout", "output_type": "stream", "text": [ "Validation Loss: 1.3070 Acc: 0.6337\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " " ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 14 - Train Loss: 0.7089 Acc: 0.9970\n" ] }, { "name": "stderr", "output_type": "stream", "text": [] }, { "name": "stdout", "output_type": "stream", "text": [ "Validation Loss: 1.2998 Acc: 0.7228\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " " ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 15 - Train Loss: 0.7142 Acc: 0.9975\n" ] }, { "name": "stderr", "output_type": "stream", "text": [] }, { "name": "stdout", "output_type": "stream", "text": [ "Validation Loss: 1.2990 Acc: 0.7624\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " " ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 16 - Train Loss: 0.7181 Acc: 0.9936\n" ] }, { "name": "stderr", "output_type": "stream", "text": [] }, { "name": "stdout", "output_type": "stream", "text": [ "Validation Loss: 1.3175 Acc: 0.7525\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " " ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 17 - Train Loss: 0.7199 Acc: 0.9921\n" ] }, { "name": "stderr", "output_type": "stream", "text": [] }, { "name": "stdout", "output_type": "stream", "text": [ "Validation Loss: 1.3741 Acc: 0.6733\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " " ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 18 - Train Loss: 0.7261 Acc: 0.9871\n" ] }, { "name": "stderr", "output_type": "stream", "text": [] }, { "name": "stdout", "output_type": "stream", "text": [ "Validation Loss: 1.4587 Acc: 0.7525\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " " ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 19 - Train Loss: 0.7460 Acc: 0.9856\n" ] }, { "name": "stderr", "output_type": "stream", "text": [] }, { "name": "stdout", "output_type": "stream", "text": [ "Validation Loss: 1.3727 Acc: 0.6931\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " " ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 20 - Train Loss: 0.7452 Acc: 0.9851\n" ] }, { "name": "stderr", "output_type": "stream", "text": [] }, { "name": "stdout", "output_type": "stream", "text": [ "Validation Loss: 1.4956 Acc: 0.3960\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " " ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 21 - Train Loss: 0.7395 Acc: 0.9851\n" ] }, { "name": "stderr", "output_type": "stream", "text": [] }, { "name": "stdout", "output_type": "stream", "text": [ "Validation Loss: 1.4490 Acc: 0.5842\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " " ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 22 - Train Loss: 0.7491 Acc: 0.9916\n" ] }, { "name": "stderr", "output_type": "stream", "text": [] }, { "name": "stdout", "output_type": "stream", "text": [ "Validation Loss: 1.4563 Acc: 0.5347\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " " ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 23 - Train Loss: 0.7618 Acc: 0.9921\n" ] }, { "name": "stderr", "output_type": "stream", "text": [] }, { "name": "stdout", "output_type": "stream", "text": [ "Validation Loss: 1.3545 Acc: 0.7525\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " " ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 24 - Train Loss: 0.7594 Acc: 0.9896\n" ] }, { "name": "stderr", "output_type": "stream", "text": [] }, { "name": "stdout", "output_type": "stream", "text": [ "Validation Loss: 1.4820 Acc: 0.4554\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " " ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 25 - Train Loss: 0.7588 Acc: 0.9970\n" ] }, { "name": "stderr", "output_type": "stream", "text": [] }, { "name": "stdout", "output_type": "stream", "text": [ "Validation Loss: 1.5111 Acc: 0.3861\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " " ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 26 - Train Loss: 0.7752 Acc: 0.9916\n" ] }, { "name": "stderr", "output_type": "stream", "text": [] }, { "name": "stdout", "output_type": "stream", "text": [ "Validation Loss: 1.4428 Acc: 0.5842\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " " ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 27 - Train Loss: 0.7753 Acc: 0.9985\n" ] }, { "name": "stderr", "output_type": "stream", "text": [] }, { "name": "stdout", "output_type": "stream", "text": [ "Validation Loss: 1.4036 Acc: 0.6931\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " " ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 28 - Train Loss: 0.7685 Acc: 0.9965\n" ] }, { "name": "stderr", "output_type": "stream", "text": [] }, { "name": "stdout", "output_type": "stream", "text": [ "Validation Loss: 1.4063 Acc: 0.6634\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " " ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 29 - Train Loss: 0.7910 Acc: 0.9985\n" ] }, { "name": "stderr", "output_type": "stream", "text": [] }, { "name": "stdout", "output_type": "stream", "text": [ "Validation Loss: 1.4430 Acc: 0.4851\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " " ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 30 - Train Loss: 0.8158 Acc: 0.9945\n" ] }, { "name": "stderr", "output_type": "stream", "text": [] }, { "name": "stdout", "output_type": "stream", "text": [ "Validation Loss: 1.4121 Acc: 0.6634\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " " ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 31 - Train Loss: 0.8174 Acc: 0.9975\n" ] }, { "name": "stderr", "output_type": "stream", "text": [] }, { "name": "stdout", "output_type": "stream", "text": [ "Validation Loss: 1.4690 Acc: 0.4158\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " " ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 32 - Train Loss: 0.8119 Acc: 0.9975\n" ] }, { "name": "stderr", "output_type": "stream", "text": [] }, { "name": "stdout", "output_type": "stream", "text": [ "Validation Loss: 1.4666 Acc: 0.4752\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " " ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 33 - Train Loss: 0.8281 Acc: 0.9985\n" ] }, { "name": "stderr", "output_type": "stream", "text": [] }, { "name": "stdout", "output_type": "stream", "text": [ "Validation Loss: 1.4412 Acc: 0.6436\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " " ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 34 - Train Loss: 0.8392 Acc: 0.9965\n" ] }, { "name": "stderr", "output_type": "stream", "text": [] }, { "name": "stdout", "output_type": "stream", "text": [ "Validation Loss: 1.5037 Acc: 0.3069\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " " ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 35 - Train Loss: 0.8356 Acc: 0.9980\n" ] }, { "name": "stderr", "output_type": "stream", "text": [] }, { "name": "stdout", "output_type": "stream", "text": [ "Validation Loss: 1.5167 Acc: 0.3168\n", "Best Val Acc: 0.8020\n" ] } ], "source": [ "from train import mean_teacher_train, validate, plot\n", "\n", "student_model, results = mean_teacher_train(student, teacher, train_loader, test_loader, unlabeled_loader,unlabeled_student_loader, NUM_CLASSES)" ] }, { "cell_type": "code", "execution_count": null, "id": "166635ba-574d-40a2-a451-b887d8f7b367", "metadata": {}, "outputs": [], "source": [ "torch.save(student_model.state_dict(), 'models/coreplant_nirmal.pth')" ] }, { "cell_type": "code", "execution_count": 1, "id": "4170e973-bee1-4ad7-83c0-e318bffd2cf3", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/ubuntu/.local/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n", " " ] }, { "name": "stdout", "output_type": "stream", "text": [ "Validation Accuracy: 0.7030\n", "Validation Precision: 0.7306\n", "Validation Recall: 0.7133\n", "Validation F1 Score: 0.7164\n" ] }, { "data": { "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "from train import validate\n", "from dataset import get_mean_teacher_dataloaders\n", "from coreplant import Classifier\n", "import torch\n", "\n", "NUM_CLASSES = 5\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "model = Classifier(512, 256, NUM_CLASSES).to(device)\n", "state_dict = torch.load('models/coreplant_nirmal.pth')\n", "\n", "model.load_state_dict(state_dict)\n", "model.eval()\n", "\n", "train_loader, test_loader, unlabeled_loader, unlabeled_student_loader = get_mean_teacher_dataloaders('data', 0.2, 16)\n", "validate(model, test_loader, NUM_CLASSES)" ] }, { "cell_type": "code", "execution_count": null, "id": "ea16f1e8-4f49-4605-94f1-faba6ca79a5c", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }