Source code for interpretdl.interpreter.lime

import warnings
import numpy as np
from collections.abc import Iterable

from ..data_processor.readers import preprocess_image, read_image, restore_image
from ..data_processor.visualizer import sp_weights_to_image_explanation, overlay_threshold, save_image, show_vis_explanation

from ._lime_base import LimeBase
from .abc_interpreter import Interpreter, InputOutputInterpreter


[docs]class LIMECVInterpreter(InputOutputInterpreter): """ LIME presents a locally explanation by fitting a set of perturbed samples near the target sample using an interpretable model, specifically a linear model. The implementation is based on https://github.com/marcotcr/lime. More details regarding the LIME method can be found in the original paper: https://arxiv.org/abs/1602.04938. """ def __init__(self, model: callable, device: str = 'gpu:0', random_seed: int or None = None): """ Args: model (callable): A model with :py:func:`forward` and possibly :py:func:`backward` functions. device (str): The device used for running ``model``, options: ``"cpu"``, ``"gpu:0"``, ``"gpu:1"`` etc. """ InputOutputInterpreter.__init__(self, model, device) # use the default LIME setting self.lime_base = LimeBase(random_state=random_seed) self.lime_results = {}
[docs] def interpret(self, data: str, interpret_class: int = None, top_k: int = 1, num_samples: int = 1000, batch_size: int = 50, resize_to: int = 224, crop_to: int = None, visual: bool = True, save_path: str = None): """ Main function of the interpreter. The implementation is based on https://github.com/marcotcr/lime. Args: data (str): The input file path. interpret_class (int, optional): The index of class to interpret. If None, the most likely label will be used. Default: ``None``. top_k (int, optional): Number of top classes to interpret. Will not be used if ``interpret_class`` is given. Default: ``1``. num_samples (int, optional): LIME sampling numbers. Larger number of samples usually gives more accurate interpretation. Default: ``1000``. batch_size (int, optional): Number of samples to forward each time. Default: ``50``. resize_to (int, optional): Images will be rescaled with the shorter edge being ``resize_to``. Defaults to ``224``. crop_to (int, optional): After resize, images will be center cropped to a square image with the size ``crop_to``. If None, no crop will be performed. Defaults to ``None``. visual (bool, optional): Whether or not to visualize the processed image. Default: ``True``. save_path (str, optional): The filepath(s) to save the processed image(s). If None, the image will not be saved. Default: ``None``. Returns: [dict]: LIME results: {interpret_label_i: weights on features} """ # preprocess_inputs if isinstance(data, str): img = read_image(data, resize_to, crop_to) else: if len(data.shape) == 3: data = np.expand_dims(data, axis=0) if np.issubdtype(data.dtype, np.integer): img = data else: # for later visualization img = restore_image(data.copy()) data = preprocess_image(img) data_type = np.array(data).dtype self.data_type = data_type self._build_predict_fn(output='probability') probability, _, _ = self.predict_fn(data, None) # only one example here probability = probability[0] if interpret_class is None: # only interpret top 1 if not provided. pred_label = np.argsort(probability) interpret_class = pred_label[-top_k:] interpret_class = np.array(interpret_class) elif isinstance(interpret_class, list): interpret_class = np.array(interpret_class) else: interpret_class = np.array([interpret_class]) def predict_fn_for_lime(_imgs): _data = preprocess_image(_imgs) # transpose to [N, 3, H, W], scaled to [0.0, 1.0] output, _, _ = self.predict_fn(_data, None) return output self.predict_fn_for_lime = predict_fn_for_lime lime_weights, r2_scores = self.lime_base.interpret_instance(img[0], self.predict_fn_for_lime, interpret_class, num_samples=num_samples, batch_size=batch_size) # visualization and save image. if save_path is None and not visual: # no need to visualize or save explanation results. pass else: explanation_mask = sp_weights_to_image_explanation(img[0], lime_weights, interpret_class[0], self.lime_base.segments) explanation_vis = overlay_threshold(img[0], explanation_mask) if visual: show_vis_explanation(explanation_vis) if save_path is not None: save_image(save_path, explanation_vis) # intermediate results, for possible further usages. self.lime_results['probability'] = {c: probability[c] for c in interpret_class.ravel()} self.lime_results['input'] = img[0] self.lime_results['segmentation'] = self.lime_base.segments self.lime_results['r2_scores'] = r2_scores self.lime_results['lime_weights'] = lime_weights return lime_weights
[docs]class LIMENLPInterpreter(InputOutputInterpreter): """ LIME Interpreter for NLP tasks. LIME presents a locally explanation by fitting a set of perturbed samples near the target sample using an interpretable model, specifically a linear model. The implementation is based on https://github.com/marcotcr/lime. More details regarding the LIME method can be found in the original paper: https://arxiv.org/abs/1602.04938. """ def __init__(self, model: callable, device: str = 'gpu:0', random_seed: int or None = None) -> None: """ Args: model (callable): A model with :py:func:`forward` and possibly :py:func:`backward` functions. device (str): The device used for running ``model``, options: ``"cpu"``, ``"gpu:0"``, ``"gpu:1"`` etc. random_seed (int): random seed. Defaults to None. """ InputOutputInterpreter.__init__(self, model, device) # use the default LIME setting self.lime_base = LimeBase(random_state=random_seed) self.lime_results = {}
[docs] def interpret(self, raw_text: str, tokenizer: callable = None, text_to_input_fn: callable = None, preprocess_fn: callable = None, unk_id: int = 0, pad_id: int = 0, classes_to_interpret: list or np.ndarray = None, num_samples: int = 1000, batch_size: int = 50, max_seq_len: int = 128, visual: bool = False): """ Main function of the interpreter. The implementation is based on https://github.com/marcotcr/lime. Args: data (str): The raw string for analysis. tokenizer (callable): text_to_input (callable): A user-defined function that convert raw text string to a tuple of inputs that can be fed into the NLP model. unk_id (int): The word id to replace occluded words. Typical choices include "", <unk>, and <pad>. pad_id (int or None): The word id used to pad the sequences. If None, it means there is no padding. Default: ``None``. classes_to_interpret (list or numpy.ndarray, optional): The index of class to interpret. If None, the most likely label will be used. can be Default: ``None``. num_samples (int, optional): LIME sampling numbers. Larger number of samples usually gives more accurate interpretation. Default: ``1000``. batch_size (int, optional): Number of samples to forward each time. Default: ``50``. visual (bool, optional): Whether or not to visualize. Default: ``True``. Returns: [dict]: LIME results: {interpret_label_i: weights on features} """ if preprocess_fn is not None: text_to_input_fn = preprocess_fn warnings.warn('``preprocess_fn`` would be deprecated soon. Use ``text_to_input`` directly.', stacklevel=2) assert (tokenizer is None) + (text_to_input_fn is None) == 1, "only one of them should be given." # tokenizer to text_to_input. if tokenizer is not None: if hasattr(tokenizer, 'pad_token_id'): pad_id = tokenizer.pad_token_id print("According to the tokenizer, pad_token_id is set to", pad_id) if hasattr(tokenizer, 'unk_token_id'): unk_id = tokenizer.unk_token_id print("According to the tokenizer, unk_token_id is set to", unk_id) def text_to_input_fn(raw_text): encoded_inputs = tokenizer(text=raw_text, max_seq_len=max_seq_len) # order is important. *_batched_and_to_tuple will be the input for the model. _batched_and_to_tuple = tuple([np.array([v]) for v in encoded_inputs.values()]) return _batched_and_to_tuple else: print("Warning: Visualization can not be supported if tokenizer is not given.") # from raw text string to token ids (and other terms that the user-defined function outputs). model_input = text_to_input_fn(raw_text) if isinstance(model_input, Iterable) and not hasattr(model_input, 'shape'): self.model_inputs = tuple(inp for inp in model_input) else: self.model_inputs = tuple(model_input, ) self._build_predict_fn(output='probability') def predict_fn_for_lime(*inputs): probability, _, _ = self.predict_fn(inputs, None) return probability probability, _, _ = self.predict_fn(self.model_inputs, classes_to_interpret) # only one example here probability = probability[0] # only interpret top 1 if classes_to_interpret is None: pred_label = np.argsort(probability) classes_to_interpret = pred_label[-1:] # this api is from LIME official repo: https://github.com/marcotcr/lime. lime_weights, r2_scores = self.lime_base.interpret_instance_text(self.model_inputs, classifier_fn=predict_fn_for_lime, interpret_labels=classes_to_interpret, unk_id=unk_id, pad_id=pad_id, num_samples=num_samples, batch_size=batch_size) # intermediate results, for possible further usages. self.predicted_proba = probability self.lime_results['probability'] = {c: probability[c] for c in classes_to_interpret.ravel()} self.lime_results['r2_scores'] = r2_scores self.lime_results['lime_weights'] = lime_weights if visual: # TODO: visualize if tokenizer is given. print("Visualization is not supported yet.") print("Currently please see the tutorial for the visualization:") print("https://github.com/PaddlePaddle/InterpretDL/blob/master/tutorials/ernie-2.0-en-sst-2.ipynb") return lime_weights