카테고리 없음
CNN 알고리즘- 전이학습
두설날
2024. 6. 25. 15:47

*이 글을 읽기전에 작성자 개인의견이 있으니, 다른 블로그와 교차로 읽는것을 권장합니다.*
pizza-steak 데이터셋 음식이미지 분류
https://www.kaggle.com/datasets/kelixirr/pizza-steak-image-classification-dataset
Pizza Steak Image Classification Dataset
CNN Project Pizza Steak Dataset
www.kaggle.com
import os
from collections import defaultdict
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader
from torchvision.models import efficientnet_b4, EfficientNet_B4_Weights
from torchvision.models._api import WeightsEnum
from torch.hub import load_state_dict_from_url
# cpu -> gpu 런타임 유형 변경
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

os.environ['KAGGLE_USERNAME'] = 'himdo123'
os.environ['KAGGLE_KEY'] = '8f359ef1099d4ff2a07d5fa2cece36a2'
!kaggle datasets download -d kelixirr/pizza-steak-image-classification-dataset

!unzip /content/pizza-steak-image-classification-dataset.zip

data_transforms = {
'train': transforms.Compose([
transforms.Resize((224, 224)),
# 각도, 찌그러뜨림, 크기
transforms.RandomAffine(0, shear=10, scale=(0.8, 1.2)),
# 수평으로 뒤집기
transforms.RandomHorizontalFlip(),
transforms.ToTensor()
]),
'test': transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
}
def target_transforms(target):
return torch.FloatTensor([target])
image_datasets = {
'train': datasets.ImageFolder('pizza_steak/train', data_transforms['train'], target_transform=target_transforms),
'test': datasets.ImageFolder('pizza_steak/test', data_transforms['test'], target_transform=target_transforms)
}
dataloaders = {
'train': DataLoader(
image_datasets['train'],
batch_size=32,
shuffle=True
),
'test':DataLoader(
image_datasets['test'],
batch_size=32,
shuffle=False
)
}
print(len(image_datasets['train']), len(image_datasets['test']))

imgs, labels = next(iter(dataloaders['train']))
fig, axes = plt.subplots(4, 8, figsize=(16, 8))
for ax, img, label in zip(axes.flatten(), imgs, labels):
# permute(): index로 차원 바꿔주기
ax.imshow(img.permute(1, 2, 0)) # (3, 224, 224) -> (224, 224, 3)
ax.set_title(label.item())
ax.axis('off')

model = models.resnet50(weights='IMAGENET1K_V1').to(device) # model을 GPU로 보냄
print(model)

for param in model.parameters():
param.requires_grad = False
model.fc = nn.Sequential(
nn.Linear(2048, 128),
nn.ReLU(),
nn.Linear(128, 1),
nn.Sigmoid()
).to(device)
print(model)

optimizer = optim.Adam(model.fc.parameters(), lr=0.001)
epochs = 10
for epoch in range(epochs):
for phase in ['train', 'test']:
if phase == 'train':
model.train()
else: # 'validation'
model.eval()
sum_losses = 0
sum_accs = 0
for x_batch, y_batch in dataloaders[phase]:
x_batch = x_batch.to(device)
y_batch = y_batch.to(device)
y_pred = model(x_batch)
loss = nn.BCELoss()(y_pred, y_batch)
if phase == 'train':
optimizer.zero_grad()
loss.backward()
optimizer.step()
sum_losses = sum_losses + loss
y_bool = (y_pred >= 0.5).float()
acc = (y_batch == y_bool).float().sum() / len(y_batch) * 100
sum_accs = sum_accs + acc
avg_loss = sum_losses / len(dataloaders[phase])
avg_acc = sum_accs / len(dataloaders[phase])
print(f'{phase:10s}: Epoch {epoch+1:4d}/{epochs} Loss: {avg_loss:.4f} Accuracy: {avg_acc:.2f}%')

from PIL import Image
img1 = Image.open('/content/pizza_steak/test/pizza/images.jpg')
img2 = Image.open('/content/pizza_steak/test/steak/images.jpg')
fig, axes = plt.subplots(1, 2, figsize=(12,6))
axes[0].imshow(img1)
axes[0].axis('off')
axes[1].imshow(img2)
axes[1].axis('off')
plt.show

img1_input = data_transforms['test'](img1)
img2_input = data_transforms['test'](img2)
print(img1_input.shape)
print(img2_input.shape)
test_batch = torch.stack([img1_input, img2_input])
test_batch = test_batch.to(device)
y_pred = model(test_batch)
y_pred

fig, axes = plt.subplots(1, 2, figsize=(12,6))
axes[0].set_title(f'{(1-y_pred[0,0])*100:.2f}% pizza, {(y_pred[0,0])*100:.2f}% steak')
axes[0].imshow(img1)
axes[0].axis('off')
axes[1].set_title(f'{(1-y_pred[1,0])*100:.2f}% pizza, {(y_pred[1,0])*100:.2f}% steak')
axes[1].imshow(img2)
axes[1].axis('off')
