<aside> 📢 MS COCO(Microsoft Common Objects in Context) 데이터세트를 활용해 Faster R-CNN 모델을 미세 조정해 이미지를 분류
</aside>
#이미지 정보 및 어노테이션 정보
{
"images": [
{
#라이선스 ID: 이미지가 어떠한 라이선스를 사용하는지 의미
"license": 1,
"file_name": "000000440755.jpg",
"coco_url": "<http://images.cocodataset.org/train2017/000000440755.jpg>",
"height": 480,
"width": 640,
"date_captured": "2013-11-21 00:59:51",
"flickr_url": "<http://farm9.staticflickr.com/8436/7904535454_19b70862de_z.jpg>",
"id": 440755
},
...
],
"annotations": [
{
#분할 마스크 좌표: x1,y1,...xn,yn 구조
"segmentation": [
[
203.08,
379.43,
...
279.31,
371,66
]
],
"area": 93351.66034999999,
#군집 객체 여부: 1할당은 픽셀 수준으로 분리가 어려운 경우
#0할당: 군집 객체 여부를 포함하지 않는다.
"iscrowd": 0,
#이미지 ID
"image_id": 440755,
#경계 상자 좌표: 왼쪽 상단 x, y 좌표와 너비, 높이로 구성
"bbox": [
200.16,
72.39,
...
439.84,
394.49
],
#카테고리: 어노테이션마다 다른 클래스라는 정보를 할당
"category_id": 2,
#어노테이션 ID: 해당 어노테이션의 고유 ID: 이미지 ID와 매핑해 사용
"id": 13555
},
...
]
}
#데이터세트 클래스 선언
import os
import torch
from PIL import Image
from pycocotools.coco import COCO
from torch.utils.data import Dataset
class COCODataset(Dataset):
#root: MS COCO 데이터세트의 경로
#train: 학습 데이터세트 불러오기 여부: 거짓으로 하면 검증용 데이터세트
def __init__(self, root, train, transform=None):
super().__init__()
directory = "train" if train else "val"
#annotations: annotations 디렉터리에 있는 어노테이션 JSON 파일 경로 설정
annotations = os.path.join(root, "annotations", f"{directory}_annotations.json")
#이미지와 어노테이션 정보를 불러오기 전에 학습에 사용되는 카테고리 정보를 불러옴
self.coco = COCO(annotations)
self.iamge_path = os.path.join(root, directory)
self.transform = transform
#카테고리 정보를 불러오기
self.categories = self._get_categories()
#이미지와 어노테이션 정보를 불러오기
self.data = self._load_data()
#self.coco 인스턴스의 cats 속성에서 카테고리 정보를 불러옴
#cats: 딕셔너리 구조: 상위 카테고리, 카테고리 ID, 카테고리 이름 포함
def _get_categories(self):
#categories: 모델 추론 시 카테고리 정보를 확인하기 위해 사용
#0: 배경을 의미
categories = {0: "background"}
for category in self.coco.cats.values():
categories[category["id"]] = category["name"]
return categories
#COCO 데이터세트 불러오기
def _load_data(self):
data = []
#imgs 속성: 어노테이션 JSON 파일의 이미지 정보(images)를 순차적으로 반환
#어노테이션 정보는 이미지 ID와 매핑될 수 있으므로 이미지 ID(_id)를 추출
for _id in self.coco.imgs:
#입력된 이미지 ID를 받아 어노테이션 정보를 반환
#한 번에 여러 ID를 입력받을 수 있어 리스트 형식으로 반환
#현재 하나의 ID만 전달하므로 첫 번째 어노테이션 정보를 가져와 파일 이름을 추출하고 이미지를 불러옴
file_name = self.coco.loadImgs(_id)[0]["file_name"]
image_path = os.path.join(self.iamge_path, file_name)
image = Image.open(image_path).convert("RGB")
boxes = []
labels = []
#self.coco.loadAnns: 어노테이션 정보를 불러온다. (어노테이션 ID를 가져온다.)
#self.coco.getAnnIds: 이미지 ID를 입력했을 때 어노테이션 ID를 반환한다.
anns = self.coco.loadAnns(self.coco.getAnnIds(_id))
#이미지 안에 여러 객체가 존재할 수 있으므로 다수의 어노테이션 정보가 포함될 수 있다.
#반복문을 활용해 카테고리 ID와 경계 상자 정보를 추출한다.
for ann in anns:
x, y, w, h = ann["bbox"]
#Faster R-CNN은 x(min),y(min),x(max),y(max)의 구조를 사용하므로 경계 상자 데이터 구조를 변경
boxes.append([x, y, x + w, y + h])
labels.append(ann["category_id"])
#target 딕셔너리: 이미지 ID, 경계 상자, 레이블 저장
#적합한 텐서 형식으로 변환
target = {
#이미지 ID: 모델 학습에는 사용되지 않지만, 모델 평가 과정에서 사용
"image_id": torch.LongTensor([_id]),
"boxes": torch.FloatTensor(boxes),
"labels": torch.LongTensor(labels)
}
data.append([image, target])
return data
#호출 및 길이 반환 메서드
def __getitem__(self, index):
image, target = self.data[index]
#호출 메서드는 이미지 변환이 적용될 수 있으므로 self.transform속성이 존재하면 변환을 적용
if self.transform:
image = self.transform(image)
return image, target
#저장한 데이터의 길이를 반환
def __len__(self):
return len(self.data)
#데이터로더
from torchvision import transforms
from torch.utils.data import DataLoader
def collator(batch):
return tuple(zip(*batch))
transform = transforms.Compose(
[
#PIL 이미지를 텐서로 변환
transforms.PILToTensor(),
#텐서 이미지를 다시 float 형식으로 변환
#Faster R-CNN이 float 형식의 [0.0, 1.0] 범위를 갖는 이미지 텐서를 사용하기 때문
transforms.ConvertImageDtype(dtype=torch.float)
]
)
train_dataset = COCODataset("../datasets/coco", train=True, transform=transform)
test_dataset = COCODataset("../datasets/coco", train=False, transform=transform)
#COCO 데이터세트는 이미지 내에 여러 객체 정보가 담길 수 있으므로 데이터의 길이가 다를 수 있다.
#집합 함수(collate_fn)를 적용해 데이터를 패딩한다.
#collator: 데이터로더에 데이터 패딩을 적용한다.
train_dataloader = DataLoader(
train_dataset, batch_size=4, shuffle=True, drop_last=True, collate_fn=collator
)
test_dataloader = DataLoader(
test_dataset, batch_size=1, shuffle=True, drop_last=True, collate_fn=collator
)
Faster R-CNN 모델의 백본으로 사용하려는 VGG-16 모델과 영역 제안 네트워크, 관심 영역 풀링을 적용한다.
#백본 및 모델 구조 정의
from torchvision import models
from torchvision import ops
from torchvision.models.detection import rpn
from torchvision.models.detection import FasterRCNN
#백본 모델은 VGG-16, 마지막 분류 계층을 제외해 특징 추출 모델로 사용
backbone = models.vgg16(weights="VGG16_Weights.IMAGENET1K_V1").features
#출력 채널 수를 지정: 512채널 반환
backbone.out_channels = 512
#영역 제안 네트워크와 관심 영역 풀링으로 구성된 2-stage 객체 탐지 모델
#영역 제안 네트워크: 입력 이미지에서 객체 위치 후보군을 생성
#앵커 생성기: 입력 이미지의 각 픽셀에 대해 앵커 박스를 생성: 서로 다른 크기와 종횡비
anchor_generator = rpn.AnchorGenerator(
#매개변수 형식: Tuple[Tuple[int]]
sizes=((32, 64, 128, 256, 512),), #크기
aspect_ratios=((0.5, 1.0, 2.0),) #종횡비
)
#관심 영역 풀링: 영역 제안 네트워크에서 생성한 객체 후보군을 입력으로 받아 후보군 내의 특징 맵 영역을 일정한 크기의 고정된 영역으로 샘플링
#다중 스케일 관심 영역 정렬: 관심 영역 정렬(ROI Align)기능이 포함됨, 다중 스케일 이미지에서 관심영역 풀링을 수행
#다양한 스케일의 특징 맵을 입력으로 받아, 각 관심 영역 후보군을 해당 스케일의 특징 맵에 맞게 샘플링해 고정된 크기의 관심 영역 특징 맵을 생성
#생성된 특징 맵은 분류 계층의 입력으로 사용됨
roi_pooler = ops.MultiScaleRoIAlign(
#특징 맵 이름: 관심 영역 풀링에 사용할 특징 맵의 이름
#VGG-16 모델의 특징 추출 계층은 "0"으로 정의돼 있다.
featmap_names=["0"],
#출력 크기: 관심 영역 풀링을 통해 추출된 특징 맵의 크기 (height, width)
output_size=(7, 7), #7X7 크기의 관심 영역 특징 맵 생성
#샘플링 비율: 관심 영역 특징 맵 사용 시 원본 특징 맵 영역을 샘플링하는 데 사용
sampling_ratio=2 #관심 영역을 샘플링하기 위해 2X2 크기의 그리드를 사용
)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = FasterRCNN(
backbone=backbone,
#배경도 클래스에 포함되기에 클래스 개수는 3으로 적용
num_classes=3,
rpn_anchor_generator=anchor_generator,
box_roi_pool=roi_pooler
).to(device)
#최적화 함수 및 학습률 스케줄러
from torch import optim
#학습이 가능한 매개변수만 params 변수에 저장해 확률적 경사 하강법을 적용
params = [p for p in model.parameters() if p.requires_grad]
#학습률 = 0.001, 모멘텀 = 0.9, 가중치 감쇠 = 0.0005
optimizer = optim.SGD(params, lr=0.001, momentum=0.9, weight_decay=0.0005)
#학습률 스케줄러: 지정된 주기마다 학습률을 감소시킴
#5 에폭마다 학습률이 0.1씩 줄어든다.
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
#[?]스케줄러도 step 메서드로 학습률을 갱신 가능: 한 에폭이 완료된 후에 호출[?]
Faster R-CNN 모델을 미세 조정
#Faster R-CNN 모델 미세 조정
for epoch in range(5):
cost = 0.0
for idx, (images, targets) in enumerate(train_dataloader):
#배치 크기로 데이터가 묶여 있으므로 리스트 간소화를 통해 장치 설정
images = list(image.to(device) for image in images)
#targets 변수는 딕셔너리 이므로 딕셔너리 간소화를 통해 적용
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
#반환되는 손실값: 분류 손실, 박스 회귀 손실, 객체 유무 손실, 영역 제안 네트워크 손실
loss_dict = model(images, targets)
#학습 모드일 때 모든 손실값을 출력
#모델은 네 개의 손실이 모두 최소가 되는 방향으로 학습돼야 하므로 손실값을 모두 더해 역전파를 계산
losses = sum(loss for loss in loss_dict.values())
optimizer.zero_grad()
losses.backward()
optimizer.step()
cost += losses
lr_scheduler.step()
cost = cost / len(train_dataloader)
print(f"Epoch : {epoch+1:4d}, Cost : {cost:.3f}")
#모델 추론 및 시각화
import numpy as np
from PIL import Image
from matplotlib import pyplot as plt
from torchvision.transforms.functional import to_pil_image
#Pillow 라이브러리로 사각형과 텍스트를 이미지 위에 그리는 함수
def draw_bbox(ax, box, text, color):
ax.add_patch(
plt.Rectangle(
xy=(box[0], box[1]),
width=box[2] - box[0],
height=box[3] - box[1],
fill=False,
edgecolor=color,
linewidth=2,
)
)
ax.annotate(
text=text,
xy=(box[0] - 5, box[1] - 5),
color=color,
weight="bold",
fontsize=13,
)
#임곗값을 0.5로 설정해 50% 이상의 객체만 표시한다.
threshold = 0.5
categories = test_dataset.categories
with torch.no_grad():
model.eval()
for images, targets in test_dataloader:
images = [image.to(device) for image in images]
outputs = model(images)
boxes = outputs[0]["boxes"].to("cpu").numpy()
labels = outputs[0]["labels"].to("cpu").numpy()
scores = outputs[0]["scores"].to("cpu").numpy()
boxes = boxes[scores >= threshold].astype(np.int32)
labels = labels[scores >= threshold]
scores = scores[scores >= threshold]
fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(1, 1, 1)
plt.imshow(to_pil_image(images[0]))
for box, label, score in zip(boxes, labels, scores):
draw_bbox(ax, box, f"{categories[label]} - {score:.4f}", "red")
tboxes = targets[0]["boxes"].numpy()
tlabels = targets[0]["labels"].numpy()
for box, label in zip(tboxes, tlabels):
draw_bbox(ax, box, f"{categories[label]}", "blue")
plt.show()