<aside> 📢 사전 학습된 ResNet-18의 특징 맵을 활용해 클래스 활성화 맵을 구현
</aside>
평균값 풀링을 적용하기 전의 특징 맵을 활용하기 위해 마지막 계층을 제외한 모든 계층을 추출한다.
from torch import nn
from torchvision import models
model = models.resnet18(weights="ResNet18_Weights.IMAGENET1K_V1").eval()
features = nn.Sequential(*list(model.children())[:-2])
레즈넷 모델을 평가 모드로 변경하여 마지막 계층을 제외한 모든 계층을 추출한다.
children 메서드: 모듈 클래스에 포함된 하위 모듈을 반환: model 변수에서 사용된 모든 계층을 순차적으로 반환
레즈넷 모델 구조
평균값 풀링과 완전 연결 계층: 분류기 : [:-2]를 통해 특징만 연산하는 계층을 features 변수로 저장
#특징 맵과 가중치 추출
from PIL import Image
from torchvision import transforms
# 이미지를 전처리하기 위한 변환 과정을 정의합니다.
transform = transforms.Compose(
[
# 이미지 크기를 (224, 224)로 조정합니다.
transforms.Resize((224, 224)),
# 이미지를 PyTorch Tensor로 변환합니다.
transforms.ToTensor(),
# 이미지를 정규화합니다. 평균(mean)과 표준편차(std)를 사용하여 RGB 채널별로 정규화합니다.
transforms.Normalize(
mean=[0.485, 0.456, 0.406], # RGB 각 채널의 평균값
std=[0.229, 0.224, 0.225] # RGB 각 채널의 표준편차
),
]
)
image = Image.open("C:\\\\Users\\\\yuhyu\\\\Desktop\\\\CODE\\\\pytorch\\\\datasets\\\\images/airplane.jpg")
target = transform(image).unsqueeze(0)
output = model(target)
#모델이 판단한 클래스 색인 값(ID) 추출
class_idx = int(output.argmax())
#완전 연결 계층 클래스 색인에 해당하는 가중치를 추출
#완전 연결 계층: 입력 차원은 512, 출력 차원은 1,000
#model.fc.weight[class_idx]: 512 차원
#차원 변경(reshape) 메서드로 확장
weights = model.fc.weight[class_idx].reshape(-1, 1, 1)
#features에 전처리된 이미지를 전달해 순방향 연산으로 특징 맵 추출
#[1, 512, 7, 7]의 차원 형태로 반환: 차원을 감소시키기 전 크기
#배치 크기를 의미하는 첫 차원은 불필요: squeeze 메서드로 제거
features_output = features(target).squeeze()
print(weights.shape)
print(features_output.shape)

#클래스 활성화 맵 생성
import torch
from torch.nn import functional as F
#class_idx에 해당하는 채널별 이미지 영역의 주요 특징이 계산
cam = features_output * weights
#[512, 7, 7]구조를 합 연산을 통해 [7, 7]크기로 변경
cam = torch.sum(cam, dim=0)
#cam: 이미지를 7등분했을 때 어떤 영역에서 가장 많은 영향을 미쳤는지 알려줌
#보간(interpolate)함수를 통해 입력 이미지 크기와 동일한 크기로 변경
cam = F.interpolate(
#4차원 배열을 unsqueeze 메서드를 통해 차원을 확장
input=cam.unsqueeze(0).unsqueeze(0),
#cam을 이미지 크기와 동일한 크기로 변경
size=(image.size[1], image.size[0]),
#보간 방법(mode)은 이중 선형(bilinear) 보간을 통해 부드럽게 확장
mode="bilinear",
#현재 cam 변수의 차원은 [1, 1, 7, 7]이므로 다시 스퀴즈 메스드를 통해 차원을 [이미지 너비, 이미지 높이]로 변경
#넘파이 배열로 변경
).squeeze().detach().numpy()
#클래스 활성화 맵 시각화
import matplotlib.pyplot as plt
plt.imshow(image)
plt.imshow(cam, cmap="jet", alpha=0.5)
plt.axis("off")
plt.show()

모델의 분류 과정에 가장 많은 영향을 준 영역은 붉은색 계열로 표시
파란색 계열로 표시된 영역은 분류 과정에 큰 영향을 주지 않았음
보간 함수의 보간 방법을 이웃 보간(nearest)로 설정한다면 7X7 블록 형태로 영향도 확인 가능
