Skip to content

codehooni/OpenSourceAI_Final

Folders and files

NameName
Last commit message
Last commit date

Latest commit

9a85491 · Dec 22, 2023

History

4 Commits
 
 
 
 
 
 

Repository files navigation

오픈소스AI 응용 - 2023년 2학기

Title: 반려동물 피부질환 데이터를 활용한 피부병 진단 모델

Dataset

AI-hub '반려동물 피부질환 데이터'

https://www.aihub.or.kr/aihubdata/data/view.do?currMenu=115&topMenu=100&dataSetSn=561

Tool

  • Google collab

요구사항


  1. AI-hub의 데이터 다운로드

    a) Dataset Link 클릭 후 회원가입 진행

    b) 라벨링 데이터 다운로드


  1. Google collab에 .ipynb 파일 업로드

실행 방법

  1. Import Libraries
 
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torchvision.models import vgg16
from tqdm import tqdm  # Import tqdm for progress bar

  1. Data Pretreatment
 

# 데이터 생성

'''
from PIL import Image
import os

# Set the input and output folders
input_folder = "/content/gdrive/MyDrive/openAI/data/유증상_검증/유증상_라벨"
output_folder = "유증상_검증"

# Desired resolution
new_resolution = (480, 360)

# Iterate through subfolders and resize images
for root, dirs, files in os.walk(input_folder):
    for file in files:
        # Check if the file is an image (you can add more image extensions if needed)
        if file.lower().endswith(('.png', '.jpg', '.jpeg')):
            # Create input and output paths
            input_path = os.path.join(root, file)
            output_path = os.path.join(output_folder, os.path.relpath(input_path, input_folder))

            # Create output folder if it doesn't exist
            os.makedirs(os.path.dirname(output_path), exist_ok=True)

            # Open image, resize, and save to the output folder
            with Image.open(input_path) as img:
                img = img.resize(new_resolution, Image.ANTIALIAS)
                img.save(output_path)

print("Resizing complete.")

'''

  1. Data Load
 

# 이미지 데이터 불러오기
input_folder = "/content/drive/MyDrive/openAI/data/유증상_학습"
validation_folder = "/content/drive/MyDrive/openAI/data/유증상_검증"
  1. Define Model
 
class CustomVGG16(nn.Module):
    def __init__(self, num_classes):
        super(CustomVGG16, self).__init__()
        vgg_model = vgg16(pretrained=True)
        self.features = vgg_model.features
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 앞서 정의한 모델을 장치로 올림, 모델 정의
model = model.to(device)
 

6. Model Train

#학습
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    correct = 0
    total = 0
    with tqdm(total=len(train_loader), desc=f"Epoch {epoch + 1}/{num_epochs}", unit="batch") as pbar:
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            pbar.set_postfix({'loss': train_loss / total, 'accuracy': correct / total})
            pbar.update()

    # Validation
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in validation_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = correct / total
    print(f"Epoch {epoch + 1}/{num_epochs}, Validation Accuracy: {accuracy * 100:.2f}%")

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published