Source code for interpretdl.interpreter.forgetting_events

import numpy as np
import os, sys
import pickle
import paddle

from .abc_interpreter import Interpreter


[docs]class ForgettingEventsInterpreter(Interpreter): """ Forgetting Events Interpreter computes the frequency of forgetting events for each training sample by running a normal training process. The training sample undergoes a forgetting event if it is misclassified at step t+1 after having been correctly classified at step t. A training sample would be more probable to be mislabeled or hard to learn if it has more forgetting events happened. More details regarding the Forgetting Events method can be found in the original paper: https://arxiv.org/abs/1812.05159. """ def __init__(self, model: callable, device: str = 'gpu:0'): """ Args: model (callable): A model with :py:func:`forward` and possibly :py:func:`backward` functions. device (str): The device used for running ``model``, options: ``"cpu"``, ``"gpu:0"``, ``"gpu:1"`` etc. """ Interpreter.__init__(self, model, device)
[docs] def interpret(self, train_reader: callable, optimizer: paddle.optimizer, batch_size: int, epochs: int, find_noisy_labels=False, save_path=None): """Run the training process and record the forgetting events statistics. Args: train_reader (callable): A training data generator. optimizer (paddle.optimizer): The paddle optimizer. batch_size (int): Number of samples to forward each time. epochs (int): The number of epochs to train the model. find_noisy_labels (bool, optional): whether to find noisy labels. Defaults to False. save_path (_type_, optional): The filepath to save the processed image. If None, the image will not be saved. Default: None Returns: (dict, dict): (count_forgotten, forgotten) where count_forgotten is {count of forgetting events: list of data indices with such count of forgetting events} and forgotten is {data index: numpy.ndarray of wrong predictions that follow true predictions in the training process}. """ stats = {} if save_path is None: save_path = 'assets' if not os.path.exists(save_path): os.makedirs(save_path) paddle.set_device(self.device) for i in range(epochs): counter = 0 correct = 0 total = 0 for step_id, data_train in enumerate(train_reader()): if isinstance(data_train[0][1], np.ndarray): x_train = [t[1] for t in data_train] else: x_train = [t[1].numpy() for t in data_train] y_train = [t[2] for t in data_train] x_train = paddle.to_tensor(x_train) y_train = paddle.to_tensor(np.array(y_train).reshape((-1, 1))) logits = self.model(x_train) predicted = paddle.argmax(logits, axis=1).numpy() bsz = len(predicted) loss = paddle.nn.functional.softmax_with_cross_entropy(logits, y_train) avg_loss = paddle.mean(loss) y_train = y_train.reshape((bsz, )).numpy() acc = (predicted == y_train).astype(int) for k in range(bsz): idx = data_train[k][0] # first list is acc, second list is predicted label index_stats = stats.get(idx, [[], []]) index_stats[0].append(acc[k]) index_stats[1].append(predicted[k]) stats[idx] = index_stats avg_loss.backward() optimizer.step() optimizer.clear_grad() correct += np.sum(acc) total += bsz sys.stdout.write('\r') sys.stdout.write('| Epoch [%3d/%3d] Iter[%3d]\t\tLoss: %.4f Acc@1: %.3f%%' % (i + 1, epochs, step_id + 1, avg_loss.numpy().item(), 100. * correct / total)) sys.stdout.flush() with open(os.path.join(save_path, "stats.pkl"), "wb") as f: pickle.dump(stats, f) if find_noisy_labels: noisy_samples = self.find_noisy_labels(stats) return stats, noisy_samples else: count_forgotten, forgotten = self.compute_and_order_forgetting_stats(stats, epochs, save_path) return stats, (count_forgotten, forgotten)
def compute_and_order_forgetting_stats(self, stats, epochs, save_path=None): unlearned_per_presentation = {} first_learned = {} forgotten = {} for example_id, example_stats in stats.items(): # accuracies presentation_acc = np.array(example_stats[0][:epochs]) # predicted labels presentation_predicted = np.array(example_stats[1][:epochs]) transitions = presentation_acc[1:] - presentation_acc[:-1] if len(np.where(transitions == -1)[0]) > 0: # forgotten epochs unlearned_per_presentation[example_id] = np.where(transitions == -1)[0] + 2 # forgotten indices forgotten[example_id] = presentation_predicted[np.where(transitions == -1)[0] + 1] else: unlearned_per_presentation[example_id] = [] forgotten[example_id] = np.array([]) if len(np.where(presentation_acc == 1)[0]) > 0: first_learned[example_id] = np.where(presentation_acc == 1)[0][0] else: first_learned[example_id] = np.nan forgotten[example_id] = presentation_predicted count_forgotten = {} for example_id, forgotten_epochs in unlearned_per_presentation.items(): if np.isnan(first_learned[example_id]): count = -1 else: count = len(forgotten_epochs) count_stats = count_forgotten.get(count, []) count_stats.append(example_id) count_forgotten[count] = count_stats if save_path is not None: with open(os.path.join(save_path, "count_forgotten.pkl"), "wb") as f: pickle.dump(count_forgotten, f) with open(os.path.join(save_path, "forgotten.pkl"), "wb") as f: pickle.dump(forgotten, f) return count_forgotten, forgotten def find_noisy_labels(self, stats): pairs = [] for example_id, example_stats in stats.items(): presentation_acc = np.array(example_stats[0]) if len(np.where(presentation_acc == 1)[0]) == 0: continue pairs.append([example_id, np.where(presentation_acc == 1)[0].mean()]) if len(pairs) == 0: return [] scores = [p[1] for p in pairs] thre = np.mean(scores) + 5 * np.std(scores) noisy_pairs = [p for p in pairs if p[1] > thre] sorted_noisy_pairs = sorted(noisy_pairs, key=lambda x: x[1], reverse=True) img_ids = [p[0] for p in sorted_noisy_pairs] return img_ids