*이 글을 읽기전에 작성자 개인의견이 있으니, 다른 블로그와 교차로 읽는것을 권장합니다.*
1. 포켓몬 분류
- Train: https://www.kaggle.com/datasets/thedagger/pokemon-generation-one
- Validation: https://www.kaggle.com/hlrhegemony/pokemon-image-dataset
Pokemon Generation One
Gotta train 'em all!
www.kaggle.com
Complete Pokemon Image Dataset
2,500+ clean labeled images, all official art, for Generations 1 through 8.
www.kaggle.com
import os
os.environ['KAGGLE_USERNAME'] = 'himdo123'
os.environ['KAGGLE_KEY'] = '8f359ef1099d4ff2a07d5fa2cece36a2'
!kaggle datasets download -d thedagger/pokemon-generation-one
!unzip -q /content/pokemon-generation-one.zip
!kaggle datasets download -d hlrhegemony/pokemon-image-dataset
!unzip -q /content/pokemon-image-dataset.zip
!mv dataset train
!rm -rf train/dataset
!mv images validation
val_labels = os.listdir('validation')
print(val_labels)
print(len(val_labels))
import shutil
for val_label in val_labels:
if val_label not in train_labels:
shutil.rmtree(os.path.join('validation', val_label))
val_labels = os.listdir('validation')
len(val_labels)
# 2개 오차가 어디서 나온거?
for train_label in train_labels:
if train_label not in val_labels:
print(train_label)
for train_label in train_labels:
if train_label not in val_labels:
print(train_label)
os.makedirs(os.path.join('validation', train_label), exist_ok=True)
val_labels = os.listdir('validation')
len(val_labels)
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader
# 런타임 변경
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
data_transforms = {
'train': transforms.Compose([
transforms.Resize((224, 224)),
# 각도, 찌그러뜨림, 크기
transforms.RandomAffine(0, shear=10, scale=(0.8, 1.2)),
# 수평으로 뒤집기
transforms.RandomHorizontalFlip(),
transforms.ToTensor()
]),
'validation': transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
}
# 해당 폴더안에 사진 삽입
image_datasets = {
'train': datasets.ImageFolder('train', data_transforms['train']),
'validation': datasets.ImageFolder('validation', data_transforms['validation'])
}
dataloaders = {
'train': DataLoader(
image_datasets['train'],
batch_size=32,
shuffle=True
),
'validation': DataLoader(
image_datasets['validation'],
batch_size=32,
shuffle=False
)
}
print(len(image_datasets['train']), len(image_datasets['validation']))
imgs, labels = next(iter(dataloaders['train']))
fig, axes = plt.subplots(4, 8, figsize=(16, 8))
for ax, img, label in zip(axes.flatten(), imgs, labels):
# permute(): index로 차원 바꿔주기
ax.imshow(img.permute(1, 2, 0)) # (3, 224, 224) -> (224, 224, 3)
ax.set_title(label.item())
ax.axis('off')
image_datasets['train'].classes[101]
2. EfficientNet
- 구글의 연구팀이 개발한 모델로, 이미지 분류, 객체 검출 등 컴퓨터 비전 작업에서 높은 성능을 보여주는 신경망 모델
- 신경망의 깊이, 너비, 해상도를 동시에 확장하는 방법을 통해 효율성과 성능을 극대화한 것이 특징
- EfficientnetB4는 EfficientNet 시리즈의 중간 크기 모델
from torchvision.models import efficientnet_b4, EfficientNet_B4_Weights
from torchvision.models._api import WeightsEnum
from torch.hub import load_state_dict_from_url
def get_state_dict(self, *args, **kwargs):
kwargs.pop("check_hash")
return load_state_dict_from_url(self.url, *args, **kwargs)
WeightsEnum.get_state_dict = get_state_dict
model = efficientnet_b4(weights=EfficientNet_B4_Weights.IMAGENET1K_V1).to(device)
model
for param in model.parameters():
param.requires_grad = False
model.classifier = nn.Sequential(
nn.Linear(1792, 512),
nn.ReLU(),
nn.Linear(512, 149)
).to(device)
print(model)
optimizer = optim.Adam(model.classifier.parameters(), lr=0.001)
epochs = 10
for epoch in range(epochs):
for phase in ['train', 'validation']:
if phase == 'train':
model.train()
else:
model.eval()
sum_losses = 0
sum_accs = 0
for x_batch, y_batch in dataloaders[phase]:
x_batch = x_batch.to(device)
y_batch = y_batch.to(device)
y_pred = model(x_batch)
loss = nn.CrossEntropyLoss()(y_pred, y_batch)
if phase == 'train':
optimizer.zero_grad()
loss.backward()
optimizer.step()
sum_losses = sum_losses + loss
y_prob = nn.Softmax(1)(y_pred)
y_pred_index = torch.argmax(y_prob, axis=1)
acc = (y_batch == y_pred_index).float().sum() / len(y_batch) * 100
sum_accs = sum_accs + acc
avg_loss = sum_losses / len(dataloaders[phase])
avg_acc = sum_accs / len(dataloaders[phase])
print(f'{phase:10s}: Epoch {epoch+1:4d}/{epochs} Loss: {avg_loss:.4f} Accuracy: {avg_acc:.2f}%')
# 학습된 모델 파일 저장
torch.save(model.state_dict(), 'model.pth') # mode.h5
# pth는 파이토치 확장자
model = models.efficientnet_b4().to(device)
model.classifier = nn.Sequential(
nn.Linear(1792, 512),
nn.ReLU(),
nn.Linear(512, 149)
).to(device)
print(model)
model.load_state_dict(torch.load('model.pth'))
model.eval()
from PIL import Image
img1 = Image.open('validation/Snorlax/4.jpg')
img2 = Image.open('validation/Diglett/0.jpg')
fig, axes = plt.subplots(1, 2, figsize=(12, 6))
axes[0].imshow(img1)
axes[0].axis('off')
axes[1].imshow(img2)
axes[1].axis('off')
plt.show()
img1_input = data_transforms['validation'](img1)
img2_input = data_transforms['validation'](img2)
print(img1_input.shape)
print(img2_input.shape)
test_batch = torch.stack([img1_input, img2_input])
test_batch = test_batch.to(device)
test_batch.shape
y_pred = model(test_batch)
y_pred
y_prob = nn.Softmax(1)(y_pred)
y_prob
probs, idx = torch.topk(y_prob, k=3)
print(probs)
print(idx)
fig, axes = plt.subplots(1, 2, figsize=(15, 6))
axes[0].set_title('{:.2f}% {}, {:.2f}% {}, {:.2f}% {}'.format(
probs[0, 0] * 100,
image_datasets['validation'].classes[idx[0, 0]],
probs[0, 1] * 100,
image_datasets['validation'].classes[idx[0, 1]],
probs[0, 2] * 100,
image_datasets['validation'].classes[idx[0, 2]],
))
axes[0].imshow(img1)
axes[0].axis('off')
axes[1].set_title('{:.2f}% {}, {:.2f}% {}, {:.2f}% {}'.format(
probs[1, 0] * 100,
image_datasets['validation'].classes[idx[1, 0]],
probs[1, 1] * 100,
image_datasets['validation'].classes[idx[1, 1]],
probs[1, 2] * 100,
image_datasets['validation'].classes[idx[1, 2]],
))
axes[1].imshow(img2)
axes[1].axis('off')
'Python > 딥러닝(DL)' 카테고리의 다른 글
Python(35)- 전이학습 (0) | 2024.06.21 |
---|---|
Python(34)- CNN 모델링 (0) | 2024.06.20 |
Python(33)- 비선형 활성화 함수 (0) | 2024.06.20 |
Python(32)- 딥러닝 (0) | 2024.06.20 |