<aside> 📢 사전 학습된 ResNet-18과 후크(Hook)를 활용해 실습을 진행한다.

</aside>

후크: 특정 이벤트가 발생했을 때 다른 코드를 실행하는 기술

Grad-CAM은 마지막 합성곱 계층의 순전파와 역전파를 활용하므로 해당 계층이 실행될 때 이벤트를 실행시켜 기울기 값을 받아온다.

#순전파와 역전파 후크 등록
import torch

class GradCAM:
		#활성화 맵을 확인하려는 모델(model)과 마지막 합성곱 계층을 확인하기 위한 main과 sub를 받아옴
    def __init__(self, model, main, sub):
        self.model = model.eval()
        #ResNet-18 기준: main = layer4, sub = conv2
        self.register_hook(main, sub)

    def register_hook(self, main, sub):
		    #named_children = 모듈의 이름과 모듈을 방향
        for name, module in self.model.named_children():
		        #main 동일
            if name == main:
                for sub_name, sub_module in module[-1].named_children():
                    #sub 동일
                    if sub_name == sub:
		                    #순방향 후크와 역방향 후크를 등록
                        sub_module.register_forward_hook(self.forward_hook)
                        sub_module.register_full_backward_hook(self.backward_hook)
		#module: 입력된 모듈, input: 순방향 연산의 입력 데이터, output: 순방향 연산의 출력 데이터
    def forward_hook(self, module, input, output):
		    #마지막 계층의 특징 맵을 알기 위해 순방향 연산의 출력값 저장
        self.feature_map = output  
		#입력 모듈, 기울기 입력값, 기울기 출력값 제공
    def backward_hook(self, module, grad_input, grad_output):
		    #기울기 출력값은 튜플로 감싸인 텐서를 갖고 있으므로 첫 번째 텐서만 반환
        self.gradient = grad_output[0]
        
    def __call__(self, x):
        output = self.model(x)
				#츨력값에 해당하는 클래스 색인 값을 추출한다.
				#가장 높은 클래스로 할당된 색인을 추출한다.
        index = output.argmax(axis=1)
        #원-핫 인코딩을 적용
        one_hot = torch.zeros_like(output)
        #원-핫 인코딩이 적용된 배열에서 최댓값 색인 위치에만 1을 부여                
        for i in range(output.size(0)):
            one_hot[i][index[i]] = 1
            
        self.model.zero_grad()
        #역전파 연산을 통해 후크를 호출
        #역전파 메서드의 기울기 매개변수에 원-핫 배열 전달
        #메모리 절약 및 이전 결과 재사용 가능
        #retain_graph=True: 하나의 이미지에서 여러 클래스의 Grad-CAM을 보려고 하거나, 역전파 연산이 자주 일어나는 경우
        #기울기 유지는 역전파 연산이 여러 번 일어날 때 발생하는 오류를 억제
        output.backward(gradient=one_hot, retain_graph=True)
				#self.gradient: [N, 512, 7, 7] 차원
				#평균 값 계산 시 세 번째(dim = 2)와 네 번째(dim = 3)차원을 따라 평균을 계산
				#keepdim=True: 평균 계산 시 차원 유지
				#a_k: [N, 512, 1, 1] 차원
        a_k = torch.mean(self.gradient, dim=(2, 3), keepdim=True)
        #self.feature_map: fk(i,j) 저장되어 있음
        #grad_cam: [N, 7, 7] 차원
        #하나의 이미지만 사용해 연산하면 [7, 7]차원: 클래스 활성화 맵과 동일
        grad_cam = torch.sum(a_k * self.feature_map, dim=1)
        grad_cam = torch.relu(grad_cam)
        return grad_cam
from PIL import Image
from torch.nn import functional as F
from torchvision import models
from torchvision import transforms
from matplotlib import pyplot as plt

transform = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        ),
    ]
)

files = [
    "../datasets/images/airplane.jpg", "../datasets/images/bus.jpg",
    "../datasets/images/dog.jpg", "../datasets/images/african_hunting_dog.jpg"
]
images, tensors = [], []
for file in files:
    image = Image.open(file)
    images.append(image)
    tensors.append(transform(image))
tensors = torch.stack(tensors)
# GradCAM 모델을 초기화합니다. 기본값으로는 ResNet18을 사용하며, 주요 레이어와 서브 레이어를 지정합니다.
model = GradCAM(
    model=models.resnet18(weights="ResNet18_Weights.IMAGENET1K_V1"),
    main="layer4",
    sub="conv2"
)
grad_cams = model(tensors)

for idx, image in enumerate(images):
    grad_cam = F.interpolate(
        input=grad_cams[idx].unsqueeze(0).unsqueeze(0),
        size=(image.size[1], image.size[0]),
        mode="bilinear",
    ).squeeze().detach().numpy()

    plt.imshow(image)
    plt.imshow(grad_cam, cmap="jet", alpha=0.5)
    plt.axis("off")
    plt.show()

Untitled

Untitled

Untitled

Untitled

대량의 이미지에 대한 클래스 활성화 맵을 확인할 수 있는 구조이다.

클래스 활성화 맵: 전역 평균 풀링 계층을 통과한 완전 연결 계층의 기울기를 사용

Grad-CAM: 마지막 합성곱 계층만 활용하므로 모든 합성곱 신경망에 적용할 수 있으며, 전역 평균 풀링 계층을 직접 사용하지 않아 공간 정보를 보존할 수 있음