Source code for interpretdl.interpreter.transition_attention_maps

import numpy as np
import re

from .abc_interpreter import Interpreter
from ..data_processor.readers import images_transform_pipeline, preprocess_save_path
from ..data_processor.visualizer import explanation_to_vis, show_vis_explanation, save_image


[docs]class TAMInterpreter(Interpreter): """ TODO: Inherit from a subabstract interpreter. Transition Attention Maps Interpreter. This is a specific interpreter for Transformers models. TAMInterpreter assumes that the information flowing inside the Transformer model follows the Markov Chain. Within this supposition, TAMInterpreter considers the attention maps as transition matrices and computes the explanation by multiplying the initial state with the attention maps, with integrated gradients. More details regarding the Transition_Attention_Maps method can be found in the original paper: https://openreview.net/forum?id=TT-cf6QSDaQ. """ def __init__(self, model: callable, device: str = 'gpu:0') -> None: """ Args: model (callable): A model with :py:func:`forward` and possibly :py:func:`backward` functions. device (str): The device used for running ``model``, options: ``"cpu"``, ``"gpu:0"``, ``"gpu:1"`` etc. """ Interpreter.__init__(self, model, device) self.paddle_prepared = False
[docs] def interpret(self, inputs: str or list(str) or np.ndarray, start_layer: int = 4, steps: int = 20, label: int or None = None, resize_to: int = 224, crop_to: int or None = None, visual: bool = True, save_path: str or None = None): """ Given ``inputs``, TAMInterpreter obtains all attention maps (of layers whose name matches ``attention_layer_pattern``) and calculates their matrix multiplication. The ``start_layer`` controls the number of involved layers. The order of involving attention maps (from last layer to the first) is different from Rollout (from first to last). Then, an integrated gradients with ``steps`` is computed and multiplied to the attention result. Args: inputs (str or list of strs or numpy.ndarray): The input image filepath or a list of filepaths or numpy array of read images. start_layer (int, optional): Compute the state from the start layer. Default: ``4``. steps (int, optional): number of steps in the Riemann approximation of the integral. Default: ``50``. labels (list or tuple or numpy.ndarray, optional): The target labels to analyze. The number of labels should be equal to the number of images. If None, the most likely label for each image will be used. Default: ``None``. resize_to (int, optional): Images will be rescaled with the shorter edge being ``resize_to``. Defaults to ``224``. crop_to (int, optional): After resize, images will be center cropped to a square image with the size ``crop_to``. If None, no crop will be performed. Defaults to ``None``. visual (bool, optional): Whether or not to visualize the processed image. Default: ``True``. save_path (str, optional): The filepath(s) to save the processed image(s). If None, the image will not be saved. Default: ``None``. Returns: [numpy.ndarray]: interpretations/heatmap for images """ imgs, data = images_transform_pipeline(inputs, resize_to, crop_to) bsz = len(data) # batch size save_path = preprocess_save_path(save_path, bsz) if not self.paddle_prepared: self._paddle_prepare() attns, _, preds = self.predict_fn(data) assert start_layer < len(attns), "start_layer should be in the range of [0, num_block-1]" if label is None: label = preds b, h, s, _ = attns[0].shape num_blocks = len(attns) states = np.mean(attns[-1], axis=1)[:, 0, :].reshape((b, 1, s)) for i in range(start_layer, num_blocks - 1)[::-1]: attn = np.mean(attns[i], 1) states_ = states states = states @ attn states += states_ total_gradients = np.zeros((b, h, s, s)) for alpha in np.linspace(0, 1, steps): # forward propagation data_scaled = data * alpha _, gradients, _ = self.predict_fn(data_scaled, label=label) total_gradients += gradients W_state = np.mean((total_gradients / steps).clip(min=0), axis=1)[:, 0, :].reshape((b, 1, s)) tam_explanation = (states * W_state)[:, 0, 1:].reshape((-1, 14, 14)) # visualization and save image. for i in range(bsz): vis_explanation = explanation_to_vis(imgs[i], tam_explanation[i], style='overlay_heatmap') if visual: show_vis_explanation(vis_explanation) if save_path[i] is not None: save_image(save_path[i], vis_explanation) return tam_explanation
def _paddle_prepare(self, predict_fn=None): if predict_fn is None: import paddle paddle.set_device(self.device) # to get gradients, the ``train`` mode must be set. # we cannot set v.training = False for the same reason. self.model.train() for n, v in self.model.named_sublayers(): if "batchnorm" in v.__class__.__name__.lower(): v._use_global_stats = True if "dropout" in v.__class__.__name__.lower(): v.p = 0 # Report issues or pull requests if more layers need to be added. def predict_fn(data, label=None): data = paddle.to_tensor(data) data.stop_gradient = False attns = [] def hook(layer, input, output): attns.append(output) hooks = [] for n, v in self.model.named_sublayers(): if re.match('^blocks.*.attn.attn_drop$', n): h = v.register_forward_post_hook(hook) hooks.append(h) out = self.model(data) for h in hooks: h.remove() out = paddle.nn.functional.softmax(out, axis=1) preds = paddle.argmax(out, axis=1) if label is None: label = preds.numpy() label_onehot = paddle.nn.functional.one_hot(paddle.to_tensor(label), num_classes=out.shape[1]) target = paddle.sum(out * label_onehot, axis=1) target.backward() gradients = attns[-1].grad target.clear_gradient() if isinstance(gradients, paddle.Tensor): gradients = gradients.numpy() a = [] for attn in attns: a.append(attn.numpy()) return a, gradients, label self.predict_fn = predict_fn self.paddle_prepared = True