<aside> 📢 사전 학습된 ResNet-18과 후크(Hook)를 활용해 실습을 진행한다.
</aside>
후크: 특정 이벤트가 발생했을 때 다른 코드를 실행하는 기술
Grad-CAM은 마지막 합성곱 계층의 순전파와 역전파를 활용하므로 해당 계층이 실행될 때 이벤트를 실행시켜 기울기 값을 받아온다.
#순전파와 역전파 후크 등록
import torch
class GradCAM:
#활성화 맵을 확인하려는 모델(model)과 마지막 합성곱 계층을 확인하기 위한 main과 sub를 받아옴
def __init__(self, model, main, sub):
self.model = model.eval()
#ResNet-18 기준: main = layer4, sub = conv2
self.register_hook(main, sub)
def register_hook(self, main, sub):
#named_children = 모듈의 이름과 모듈을 방향
for name, module in self.model.named_children():
#main 동일
if name == main:
for sub_name, sub_module in module[-1].named_children():
#sub 동일
if sub_name == sub:
#순방향 후크와 역방향 후크를 등록
sub_module.register_forward_hook(self.forward_hook)
sub_module.register_full_backward_hook(self.backward_hook)
#module: 입력된 모듈, input: 순방향 연산의 입력 데이터, output: 순방향 연산의 출력 데이터
def forward_hook(self, module, input, output):
#마지막 계층의 특징 맵을 알기 위해 순방향 연산의 출력값 저장
self.feature_map = output
#입력 모듈, 기울기 입력값, 기울기 출력값 제공
def backward_hook(self, module, grad_input, grad_output):
#기울기 출력값은 튜플로 감싸인 텐서를 갖고 있으므로 첫 번째 텐서만 반환
self.gradient = grad_output[0]
def __call__(self, x):
output = self.model(x)
#츨력값에 해당하는 클래스 색인 값을 추출한다.
#가장 높은 클래스로 할당된 색인을 추출한다.
index = output.argmax(axis=1)
#원-핫 인코딩을 적용
one_hot = torch.zeros_like(output)
#원-핫 인코딩이 적용된 배열에서 최댓값 색인 위치에만 1을 부여
for i in range(output.size(0)):
one_hot[i][index[i]] = 1
self.model.zero_grad()
#역전파 연산을 통해 후크를 호출
#역전파 메서드의 기울기 매개변수에 원-핫 배열 전달
#메모리 절약 및 이전 결과 재사용 가능
#retain_graph=True: 하나의 이미지에서 여러 클래스의 Grad-CAM을 보려고 하거나, 역전파 연산이 자주 일어나는 경우
#기울기 유지는 역전파 연산이 여러 번 일어날 때 발생하는 오류를 억제
output.backward(gradient=one_hot, retain_graph=True)
#self.gradient: [N, 512, 7, 7] 차원
#평균 값 계산 시 세 번째(dim = 2)와 네 번째(dim = 3)차원을 따라 평균을 계산
#keepdim=True: 평균 계산 시 차원 유지
#a_k: [N, 512, 1, 1] 차원
a_k = torch.mean(self.gradient, dim=(2, 3), keepdim=True)
#self.feature_map: fk(i,j) 저장되어 있음
#grad_cam: [N, 7, 7] 차원
#하나의 이미지만 사용해 연산하면 [7, 7]차원: 클래스 활성화 맵과 동일
grad_cam = torch.sum(a_k * self.feature_map, dim=1)
grad_cam = torch.relu(grad_cam)
return grad_cam
from PIL import Image
from torch.nn import functional as F
from torchvision import models
from torchvision import transforms
from matplotlib import pyplot as plt
transform = transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
]
)
files = [
"../datasets/images/airplane.jpg", "../datasets/images/bus.jpg",
"../datasets/images/dog.jpg", "../datasets/images/african_hunting_dog.jpg"
]
images, tensors = [], []
for file in files:
image = Image.open(file)
images.append(image)
tensors.append(transform(image))
tensors = torch.stack(tensors)
# GradCAM 모델을 초기화합니다. 기본값으로는 ResNet18을 사용하며, 주요 레이어와 서브 레이어를 지정합니다.
model = GradCAM(
model=models.resnet18(weights="ResNet18_Weights.IMAGENET1K_V1"),
main="layer4",
sub="conv2"
)
grad_cams = model(tensors)
for idx, image in enumerate(images):
grad_cam = F.interpolate(
input=grad_cams[idx].unsqueeze(0).unsqueeze(0),
size=(image.size[1], image.size[0]),
mode="bilinear",
).squeeze().detach().numpy()
plt.imshow(image)
plt.imshow(grad_cam, cmap="jet", alpha=0.5)
plt.axis("off")
plt.show()




대량의 이미지에 대한 클래스 활성화 맵을 확인할 수 있는 구조이다.
클래스 활성화 맵: 전역 평균 풀링 계층을 통과한 완전 연결 계층의 기울기를 사용
Grad-CAM: 마지막 합성곱 계층만 활용하므로 모든 합성곱 신경망에 적용할 수 있으며, 전역 평균 풀링 계층을 직접 사용하지 않아 공간 정보를 보존할 수 있음