카테고리 없음

CNN 알고리즘- 전이학습

두설날 2024. 6. 25. 15:47

*이 글을 읽기전에 작성자 개인의견이 있으니, 다른 블로그와 교차로 읽는것을 권장합니다.*

pizza-steak 데이터셋 음식이미지 분류

https://www.kaggle.com/datasets/kelixirr/pizza-steak-image-classification-dataset

 

Pizza Steak Image Classification Dataset

CNN Project Pizza Steak Dataset

www.kaggle.com

 

import os
from collections import defaultdict

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader

from torchvision.models import efficientnet_b4, EfficientNet_B4_Weights
from torchvision.models._api import WeightsEnum
from torch.hub import load_state_dict_from_url
# cpu -> gpu 런타임 유형 변경
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

os.environ['KAGGLE_USERNAME'] = 'himdo123'
os.environ['KAGGLE_KEY'] = '8f359ef1099d4ff2a07d5fa2cece36a2'
!kaggle datasets download -d kelixirr/pizza-steak-image-classification-dataset

!unzip /content/pizza-steak-image-classification-dataset.zip

data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((224, 224)),
        # 각도, 찌그러뜨림, 크기
        transforms.RandomAffine(0, shear=10, scale=(0.8, 1.2)),
        # 수평으로 뒤집기
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor()
    ]),
    'test': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])
}
def target_transforms(target):
    return torch.FloatTensor([target])
image_datasets = {
    'train': datasets.ImageFolder('pizza_steak/train', data_transforms['train'], target_transform=target_transforms),
    'test': datasets.ImageFolder('pizza_steak/test', data_transforms['test'], target_transform=target_transforms)
}
dataloaders = {
    'train': DataLoader(
        image_datasets['train'],
        batch_size=32,
        shuffle=True
    ),
    'test':DataLoader(
        image_datasets['test'],
        batch_size=32,
        shuffle=False
    )
}
print(len(image_datasets['train']), len(image_datasets['test']))

imgs, labels = next(iter(dataloaders['train']))

fig, axes = plt.subplots(4, 8, figsize=(16, 8))

for ax, img, label in zip(axes.flatten(), imgs, labels):
    # permute(): index로 차원 바꿔주기
    ax.imshow(img.permute(1, 2, 0)) # (3, 224, 224) -> (224, 224, 3)
    ax.set_title(label.item())
    ax.axis('off')

model = models.resnet50(weights='IMAGENET1K_V1').to(device) # model을 GPU로 보냄
print(model)

for param in model.parameters():
    param.requires_grad = False
model.fc = nn.Sequential(
    nn.Linear(2048, 128),
    nn.ReLU(),
    nn.Linear(128, 1),
    nn.Sigmoid()
).to(device)

print(model)

optimizer = optim.Adam(model.fc.parameters(), lr=0.001)

epochs = 10

for epoch in range(epochs):
    for phase in ['train', 'test']:
        if phase == 'train':
            model.train()
        else: # 'validation'
            model.eval()

        sum_losses = 0
        sum_accs = 0

        for x_batch, y_batch in dataloaders[phase]:
            x_batch = x_batch.to(device)
            y_batch = y_batch.to(device)

            y_pred = model(x_batch)
            loss = nn.BCELoss()(y_pred, y_batch)

            if phase == 'train':
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            sum_losses = sum_losses + loss

            y_bool = (y_pred >= 0.5).float()
            acc = (y_batch == y_bool).float().sum() / len(y_batch) * 100
            sum_accs = sum_accs + acc

        avg_loss = sum_losses / len(dataloaders[phase])
        avg_acc = sum_accs / len(dataloaders[phase])
        print(f'{phase:10s}: Epoch {epoch+1:4d}/{epochs} Loss: {avg_loss:.4f} Accuracy: {avg_acc:.2f}%')

from PIL import Image
img1 = Image.open('/content/pizza_steak/test/pizza/images.jpg')
img2 = Image.open('/content/pizza_steak/test/steak/images.jpg')

fig, axes = plt.subplots(1, 2, figsize=(12,6))
axes[0].imshow(img1)
axes[0].axis('off')
axes[1].imshow(img2)
axes[1].axis('off')
plt.show

img1_input = data_transforms['test'](img1)
img2_input = data_transforms['test'](img2)
print(img1_input.shape)
print(img2_input.shape)
test_batch = torch.stack([img1_input, img2_input])
test_batch = test_batch.to(device)
y_pred = model(test_batch)
y_pred

fig, axes = plt.subplots(1, 2, figsize=(12,6))
axes[0].set_title(f'{(1-y_pred[0,0])*100:.2f}% pizza, {(y_pred[0,0])*100:.2f}% steak')
axes[0].imshow(img1)
axes[0].axis('off')

axes[1].set_title(f'{(1-y_pred[1,0])*100:.2f}% pizza, {(y_pred[1,0])*100:.2f}% steak')
axes[1].imshow(img2)
axes[1].axis('off')