Source code for interpretdl.interpreter._normlime_base

import numpy as np
import os, sys
from tqdm import tqdm
import paddle
from paddle.vision.transforms import functional as F

from ..common.paddle_utils import FeatureExtractor, extract_superpixel_features, get_pre_models
from ..data_processor.readers import load_pickle_file
from .lime import LIMECVInterpreter, LIMENLPInterpreter


[docs]class NormLIMECVInterpreter(LIMECVInterpreter): """ NormLIME Interpreter for CV tasks. (TODO) Some technical details will be complete soon. More details regarding the NormLIME method can be found in the original paper: https://arxiv.org/abs/1909.04200. """ def __init__(self, model: callable, device: str = 'gpu:0'): """ Args: model (callable): A model with :py:func:`forward` and possibly :py:func:`backward` functions. device (str): The device used for running ``model``, options: ``"cpu"``, ``"gpu:0"``, ``"gpu:1"`` etc. """ LIMECVInterpreter.__init__(self, model, device) self.lime_interpret = super().interpret def _get_lime_weights(self, data, num_samples, batch_size, save=False): if data in self.all_lime_weights: return self.lime_base.segments = None # set to None to force recompute the segmentation. lime_weights = self.lime_interpret(data, num_samples=num_samples, batch_size=batch_size, visual=False) sp_seg = self.lime_results['segmentation'] data_instance = self.lime_results['input'] self.all_lime_weights[data] = {'lime_weights': lime_weights, 'segmentation': sp_seg, 'input': data_instance} if save and self.filepath_to_save is not None: np.savez(self.filepath_to_save, **self.all_lime_weights) # load: dict(np.load(filepath_to_load, allow_pickle=true)) return
[docs] def interpret(self, image_paths, num_samples=1000, batch_size=50, save_path='normlime_weights.npy', temp_data_file='all_lime_weights.npz'): """ Main function of the interpreter. (TODO) Some technical details will be complete soon. Args: image_paths (list of strs): A list of image filepaths. num_samples (int, optional): LIME sampling numbers. Larger number of samples usually gives more accurate interpretation. Default: 1000 batch_size (int, optional): Number of samples to forward each time. Default: 50 save_path (str, optional): The .npy path to save the normlime weights. It is a dictionary where the key is label and value is segmentation ids with their importance. Default: 'normlime_weights.npy' temp_data_file (str, optional): The path to save the intermediate lime weights to avoid repeating computations. Default: 'all_lime_weights.npz'. Set to None will not save the intermediate lime weights. Returns: [dict] NormLIME weights: Global feature importance as a dict {label_i: weights on features} """ # Check `save_path`. Saving NormLIME results is necessary. if os.path.exists(save_path): print(f'{save_path} exists.') n = 0 tmp = save_path.split('.npy')[0] while os.path.exists(f'{tmp}-{n}.npy'): n += 1 save_path = f'{tmp}-{n}.npy' print(f'NormLIME results will be saved to {save_path}.') # Check `temp_data_file` and load computed results. self.all_lime_weights = {} if temp_data_file is None: self.filepath_to_save = None print("Intermediate LIME results will not be saved.") else: self.filepath_to_save = temp_data_file if temp_data_file.endswith('.npz') else temp_data_file + '.npz' if os.path.exists(self.filepath_to_save): self.all_lime_weights = dict(np.load(self.filepath_to_save, allow_pickle=True)) # compute lime weights and put in self.all_lime_weights for i in tqdm(range(len(image_paths)), leave=True, position=0): image_path = image_paths[i] self._get_lime_weights(image_path, num_samples, batch_size, save=(i % 10 == 0)) if self.filepath_to_save is not None: np.savez(self.filepath_to_save, **self.all_lime_weights) # convert superpixel indexes to cluster indexes. _, h_pre_models_kmeans = get_pre_models() kmeans_model = load_pickle_file(h_pre_models_kmeans) normlime_weights_all_labels = {} fextractor = FeatureExtractor() for i, image_path in enumerate(image_paths): lime_explanation_i = self.all_lime_weights[image_path] if isinstance(lime_explanation_i, np.ndarray): lime_explanation_i = lime_explanation_i.item() # lime_explanation_i is a dict of { # 'input': , # 'segmentation':, # 'lime_weights': # } img_to_show = lime_explanation_i['input'][np.newaxis, ...] # static model paddle.enable_static() f = fextractor.forward(img_to_show).transpose((1, 2, 0)) paddle.disable_static() img_size = (img_to_show.shape[1], img_to_show.shape[2]) f = F.resize(f, img_size) # compute clusters according to the sp features. X = extract_superpixel_features(f, lime_explanation_i['segmentation']) try: cluster_labels = kmeans_model.predict(X) # a list. len = number of sp. except AttributeError: from sklearn.metrics import pairwise_distances_argmin_min cluster_labels, _ = pairwise_distances_argmin_min(X, kmeans_model.cluster_centers_) lime_weights = lime_explanation_i['lime_weights'] pred_labels = lime_weights.keys() for y in pred_labels: normlime_weights_label_y = normlime_weights_all_labels.get(y, {}) w_f_y = [abs(w[1]) for w in lime_weights[y]] w_f_y_l1norm = sum(w_f_y) for w in lime_weights[y]: seg_label = w[0] weight = w[1] * w[1] / w_f_y_l1norm tmp = normlime_weights_label_y.get(cluster_labels[seg_label], []) tmp.append(weight) normlime_weights_label_y[cluster_labels[seg_label]] = tmp normlime_weights_all_labels[y] = normlime_weights_label_y # compute normlime weights. for y in normlime_weights_all_labels: normlime_weights = normlime_weights_all_labels.get(y, {}) for k in normlime_weights: normlime_weights[k] = sum(normlime_weights[k]) / len(normlime_weights[k]) # check normlime if len(normlime_weights_all_labels.keys()) < max(normlime_weights_all_labels.keys()) + 1: print( "\n" + \ "Warning: !!! \n" + \ "There are at least {} classes, ".format(max(normlime_weights_all_labels.keys()) + 1) + \ "but the NormLIME has results of only {} classes. \n".format(len(normlime_weights_all_labels.keys())) + \ "It may have cause unstable results in the later computation" + \ " but can be improved by computing more test samples." + \ "\n" ) # Saving NormLIME results is necessary. np.save(save_path, normlime_weights_all_labels) return normlime_weights_all_labels
[docs]class NormLIMENLPInterpreter(LIMENLPInterpreter): """ NormLIME Interpreter for NLP tasks. More details regarding the NormLIME method can be found in the original paper: https://arxiv.org/abs/1909.04200. """ def __init__(self, model: callable, device: str = 'gpu:0'): """ Args: model (callable): A model with :py:func:`forward` and possibly :py:func:`backward` functions. device (str): The device used for running ``model``, options: ``"cpu"``, ``"gpu:0"``, ``"gpu:1"`` etc. """ LIMENLPInterpreter.__init__(self, model, device) self.lime_interpret = super().interpret def _get_lime_weights(self, data, preprocess_fn, num_samples, batch_size, unk_id, pad_id, lod_levels, save=False): if isinstance(data, str): dict_key = data elif isinstance(data, dict): dict_key = '-'.join([str(v) for v in list(data.values())]) else: # TODO: open an issue if there are more cases that are not covered. dict_key = str(data) # dict_key = '_'.join(str(i) for i in data) # dict_key = data if dict_key in self.all_lime_weights: print(f"'{dict_key}' has been computed before. Check it if this is NOT expected.") return lime_weights = self.lime_interpret(data, text_to_input_fn=preprocess_fn, unk_id=unk_id, pad_id=pad_id, num_samples=num_samples, batch_size=batch_size) self.all_lime_weights[dict_key] = { 'lime_weights': lime_weights, } if save and self.filepath_to_save is not None: np.savez(self.filepath_to_save, **self.all_lime_weights) # load: dict(np.load(filepath_to_load, allow_pickle=True)) return
[docs] def interpret(self, list_of_raw_text, preprocess_fn: callable, num_samples: int, batch_size: int, unk_id: int = 0, pad_id: int = 0, lod_levels: int or None = None, save_path: str = 'normlime_weights.npy', temp_data_file: str = 'all_lime_weights.npz'): """ NormLIMENLPInterpreter computes the LIME results of each sample of ``data``, normalizes and averages the LIME results. ``preprocess_fn`` is used for coping with texts, see the tutorials for an example. ``num_samples`` and ``batch_size`` are LIME arguments, for the generated samples and the batch size of each pass. Args: list_of_raw_text (str): The raw string for analysis. preprocess_fn (Callable): A user-defined function that input raw string and outputs the a tuple of inputs to feed into the NLP model. num_samples (int, optional): LIME sampling numbers. Larger number of samples usually gives more accurate interpretation. Default: 1000 batch_size (int, optional): Number of samples to forward each time. Default: 50 unk_id (int): The word id to replace occluded words. Typical choices include "", <unk>, and <pad>. pad_id (int or None): The word id used to pad the sequences. If None, it means there is no padding. Default: None. lod_levels (list or tuple or numpy.ndarray or None, optional): The lod levels for model inputs. It should have the length equal to number of outputs given by preprocess_fn. If None, lod levels are all zeros. Default: None. save_path (str, optional): The .npy path to save the normlime weights. It is a dictionary where the key is label and value is segmentation ids with their importance. Default: 'normlime_weights.npy' temp_data_file (str, optional): The .npz path to save the temporal LIME results, to avoid repeating the computations. Default: 'all_lime_weights.npz' Returns: [dict] NormLIME weights: {label_i: weights on features} """ # Check `save_path`. Saving NormLIME results is necessary. if os.path.exists(save_path): print(f'{save_path} exists.') n = 0 tmp = save_path.split('.npy')[0] while os.path.exists(f'{tmp}-{n}.npy'): n += 1 save_path = f'{tmp}-{n}.npy' print(f'NormLIME results will be saved to {save_path}.') # Check `temp_data_file` and load computed results. self.all_lime_weights = {} if temp_data_file is None: self.filepath_to_save = None print("Intermediate LIME results will not be saved.") else: self.filepath_to_save = temp_data_file if temp_data_file.endswith('.npz') else temp_data_file + '.npz' if os.path.exists(self.filepath_to_save): self.all_lime_weights = dict(np.load(self.filepath_to_save, allow_pickle=True)) # compute lime weights and put in self.all_lime_weights for i in tqdm(range(len(list_of_raw_text)), leave=True, position=0): self._get_lime_weights(list_of_raw_text[i], preprocess_fn=preprocess_fn, unk_id=unk_id, pad_id=pad_id, num_samples=num_samples, batch_size=batch_size, lod_levels=lod_levels, save=(i % 10) == 0) if self.filepath_to_save is not None: np.savez(self.filepath_to_save, **self.all_lime_weights) # Gather LIME weights and Normalize. normlime_weights_all_labels = {} for dict_key, lime_explanation_i in self.all_lime_weights.items(): if isinstance(lime_explanation_i, np.ndarray): lime_explanation_i = lime_explanation_i.item() # lime_explanation_i is a dict lime_weights = lime_explanation_i['lime_weights'] pred_labels = lime_weights.keys() for y in pred_labels: normlime_weights_label_y = normlime_weights_all_labels.get(y, {}) w_f_y = [abs(w[1]) for w in lime_weights[y]] w_f_y_l1norm = sum(w_f_y) for w in lime_weights[y]: word_id = w[0] if w[1] > 0: weight = w[1] * w[1] / w_f_y_l1norm tmp = normlime_weights_label_y.get(word_id, []) tmp.append(weight) normlime_weights_label_y[word_id] = tmp normlime_weights_all_labels[y] = normlime_weights_label_y # compute normlime weights. for y in normlime_weights_all_labels: normlime_weights = normlime_weights_all_labels.get(y, {}) for k in normlime_weights: normlime_weights[k] = (sum(normlime_weights[k]) / len(normlime_weights[k]), len(normlime_weights[k])) # check normlime if len(normlime_weights_all_labels.keys()) < max(normlime_weights_all_labels.keys()) + 1: print( "\n" + \ "Warning: !!! \n" + \ "There are at least {} classes, ".format(max(normlime_weights_all_labels.keys()) + 1) + \ "but the NormLIME has results of only {} classes. \n".format(len(normlime_weights_all_labels.keys())) + \ "It may have cause unstable results in the later computation" + \ " but can be improved by computing more test samples." + \ "\n" ) # Saving NormLIME results is necessary. np.save(save_path, normlime_weights_all_labels) return normlime_weights_all_labels