import numpy as np
import paddle
from .lime import LIMECVInterpreter
from ._lime_base import compute_segments
from ._global_prior_base import get_cluster_label, cluster_global_weights_to_local_prior
from ..data_processor.readers import preprocess_image, read_image, restore_image
from ..data_processor.readers import load_npy_dict_file
from ..data_processor.visualizer import sp_weights_to_image_explanation, overlay_threshold, save_image, show_vis_explanation
[docs]class GLIMECVInterpreter(LIMECVInterpreter):
"""
G-LIME CV Interpreter. This method integrates the global information from NormLIME or Average to the local
explanation LIME.
More details can be found in this
[pdf link](https://github.com/PaddlePaddle/InterpretDL/files/10110787/glime-aij-paper.pdf).
"""
def __init__(self, model: callable, device: str = 'gpu:0') -> None:
"""
Args:
model (callable): A model with :py:func:`forward` and possibly :py:func:`backward` functions.
device (str): The device used for running ``model``, options: ``"cpu"``, ``"gpu:0"``, ``"gpu:1"``
etc.
"""
LIMECVInterpreter.__init__(self, model, device)
self.global_weights = None
[docs] def set_global_weights(self, global_weights_info: str or dict):
"""Set directly the global weights without any pre-computations.
Args:
global_weights_info (str or dict): A path of the file or the dict.
"""
if isinstance(global_weights_info, str):
self.global_weights = load_npy_dict_file(global_weights_info)
elif isinstance(global_weights_info, dict):
self.global_weights = global_weights_info
else:
print("Warning: Not set global weights. Unknown type.")
return
print(f"Set Global Weights from {global_weights_info}")
[docs] def compute_global_weights(self,
g_name: str = 'normlime',
list_of_lime_explanations: list = None,
list_file_paths: list = None,
save_path: str = None):
"""Compute the global weights, given the ``list_of_lime_explanations``. This is done by NormLIME or Average
Global Explanations, which are introduced in https://arxiv.org/abs/1909.04200 and
https://arxiv.org/abs/1907.03039 respectively.
Args:
g_name (str, optional): The method to aggregate local explanations. Defaults to ``'normlime'``.
list_of_lime_explanations (list, optional): The LIME results. Defaults to None.
list_file_paths (list, optional): This is not implemented currently. Defaults to None.
save_path (str, optional): A path to save the global weights, which can be directly used the next time,
and called by ``set_global_weights()``. Defaults to None.
Raises:
NotImplementedError: NotImplementedError.
Returns:
dict: Global Weights.
"""
if list_file_paths is not None:
raise NotImplementedError("Use scripts/benchmark.py to compute LIME explanations.")
# check the first one
assert 'input' in list_of_lime_explanations[0]
assert 'lime_weights' in list_of_lime_explanations[0]
assert 'segmentation' in list_of_lime_explanations[0]
global_weights_all_labels = {}
for lime_explanation in list_of_lime_explanations:
cluster_labels = get_cluster_label(lime_explanation['input'][np.newaxis, ...],
lime_explanation['segmentation'])
pred_labels = lime_explanation['lime_weights'].keys()
for y in pred_labels:
global_weights_y = global_weights_all_labels.get(y, {})
w_f_y = [abs(w[1]) for w in lime_explanation['lime_weights'][y]]
w_f_y_l1norm = sum(w_f_y)
for w in lime_explanation['lime_weights'][y]:
seg_label = w[0]
if g_name == 'normlime':
weight = w[1] * w[1] / w_f_y_l1norm
elif g_name == 'avg':
weight = abs(w[1])
else:
weight = w[1] * w[1]
tmp = global_weights_y.get(cluster_labels[seg_label], [])
tmp.append(weight)
global_weights_y[cluster_labels[seg_label]] = tmp
global_weights_all_labels[y] = global_weights_y
# compute global weights.
for y in global_weights_all_labels:
global_weights_y = global_weights_all_labels.get(y, {})
for k in global_weights_y:
global_weights_y[k] = sum(global_weights_y[k]) / len(global_weights_y[k])
if save_path is not None:
print(f"Saving Global Weights to {save_path}")
np.save(save_path, global_weights_all_labels)
self.global_weights = global_weights_all_labels
return self.global_weights
[docs] def interpret(self,
data: str,
interpret_class: int or None = None,
top_k: int = 1,
prior_method: str = 'none',
prior_reg_force: float = 1.0,
num_samples: int = 1000,
batch_size: int = 50,
resize_to: int = 224,
crop_to: int = None,
visual: bool = True,
save_path: str = None):
"""
Note that for GLIME interpreter, :py:func:`set_global_weights()` needs to be called before calling
:py:func:`interpret()`. Basically, the technical process of GLIME is similar to LIME. See the
`tutorial
<https://github.com/PaddlePaddle/InterpretDL/blob/master/tutorials/LIME_Variants_part2.ipynb>`_ for more
details.
Args:
data (str): The input file path.
interpret_class (int, optional): The index of class to interpret. If None, the most likely label will be
used. Default: ``None``.
top_k (int, optional): Number of top classes to interpret. Will not be used if ``interpret_class`` is
given. Default: ``1``.
prior_method: Prior method. Can be chosen from ``{"none", "ridge"}``. Defaults to ``"none"``, which is
equivalent to LIME. If ``none``, :py:func:`interpret()` will use zeros as prior; Otherwise, the loaded
prior will be used.
prior_reg_force (float, optional): The regularization force to apply. Default: ``1.0``.
num_samples (int, optional): LIME sampling numbers. Larger number of samples usually gives more accurate
interpretation. Default: ``1000``.
batch_size (int, optional): Number of samples to forward each time. Default: 50
resize_to (int, optional): Images will be rescaled with the shorter edge being ``resize_to``. Defaults to
``224``.
crop_to ([type], optional): After resize, images will be center cropped to a square image
with the size ``crop_to``. If None, no crop will be performed. Defaults to ``None``.
visual (bool, optional): Whether or not to visualize the processed image. Default: ``True``.
save_path (str, optional): The path to save the processed image. If None, the image will not be saved.
Default: ``None``.
Returns:
[dict]: LIME results: {interpret_label_i: weights on features}
"""
if self.global_weights is None and prior_method != "none":
raise ValueError(
"The interpreter is not prepared. Call compute_global_weights() or set_global_weights()"\
" before interpretation."
)
# preprocess_input
if isinstance(data, str):
img = read_image(data, resize_to, crop_to)
else:
if len(data.shape) == 3:
data = np.expand_dims(data, axis=0)
if np.issubdtype(data.dtype, np.integer):
img = data
else:
img = restore_image(data.copy()) # for later visualization
data = preprocess_image(img)
self._build_predict_fn(output='probability') # create self.predict_fn.
probability, _, _ = self.predict_fn(data, None)
probability = probability[0] # only one example here
if interpret_class is None:
# only interpret top 1 if not provided.
pred_label = np.argsort(probability)
interpret_class = pred_label[-top_k:]
interpret_class = np.array(interpret_class)
elif isinstance(interpret_class, list):
interpret_class = np.array(interpret_class)
else:
interpret_class = np.array([interpret_class])
def predict_fn_for_lime(_imgs):
_data = preprocess_image(_imgs) # transpose to [N, 3, H, W], scaled to [0.0, 1.0]
output, _, _ = self.predict_fn(_data, None)
return output
self.predict_fn_for_lime = predict_fn_for_lime
if self.lime_base.segments is None:
self.lime_base.segments = compute_segments(img[0])
segments = self.lime_base.segments
if prior_method == "none":
prior = np.zeros(len(np.unique(segments)))
else:
prior = cluster_global_weights_to_local_prior(img, segments, interpret_class[0], self.global_weights)
lime_weights, r2_scores = self.lime_base.interpret_instance(img[0],
self.predict_fn_for_lime,
interpret_class,
num_samples=num_samples,
batch_size=batch_size,
prior=prior,
reg_force=prior_reg_force)
# visualization and save image.
explanation_mask = sp_weights_to_image_explanation(img[0], lime_weights, interpret_class[0],
self.lime_base.segments)
explanation_vis = overlay_threshold(img[0], explanation_mask)
if visual:
show_vis_explanation(explanation_vis)
if save_path is not None:
save_image(save_path, explanation_vis)
self.lime_results['probability'] = {c: probability[c] for c in interpret_class.ravel()}
self.lime_results['input'] = img[0]
self.lime_results['segmentation'] = self.lime_base.segments
self.lime_results['r2_scores'] = r2_scores
self.lime_results['lime_weights'] = lime_weights
return lime_weights