-
Notifications
You must be signed in to change notification settings - Fork 100
Open
Description
I tried to use the TokenClassificationExplainer
for my fine-tuned BERT model. It returns a dictionary where the key is the tokenized inputs.
When I process the returned dict manually, there was a missing token. Turned out, it has appeared once in the dict. A dictionary cannot have a duplicate key. Therefore, it did not show up in the final returned value. For those who use this class, I recommend to modify the return value so that all the tokenized inputs are preserved.
Here is the original implementation:
@property
def word_attributions(self) -> Dict:
"Returns the word attributions for model and the text provided. Raises error if attributions not calculated."
if self.attributions is not None:
word_attr = dict()
tokens = [token.replace("Ġ", "") for token in self.decode(self.input_ids)]
labels = self.predicted_class_names
for index, attr in self.attributions.items():
try:
predicted_class = self.id2label[torch.argmax(self.pred_probs[index]).item()]
except KeyError:
predicted_class = torch.argmax(self.pred_probs[index]).item()
word_attr[tokens[index]] = {
"label": predicted_class,
"attribution_scores": attr.word_attributions,
}
return word_attr
else:
raise ValueError("Attributions have not yet been calculated. Please call the explainer on text first.")
Below are my modifications to the word_attributions
property.
@property
def word_attributions(self) -> List:
"Returns the word attributions for model and the text provided. Raises error if attributions not calculated."
if self.attributions is not None:
word_attr = []
tokens = [token.replace("Ġ", "") for token in self.decode(self.input_ids)]
labels = self.predicted_class_names
for index in self._selected_indexes:
try:
predicted_class = self.id2label[torch.argmax(self.pred_probs[index]).item()]
except KeyError:
predicted_class = torch.argmax(self.pred_probs[index]).item()
word_attr.append({
"index": index,
"token": tokens[index],
"label": predicted_class,
"attribution_scores": self.attributions[index].word_attributions,
})
return word_attr
else:
raise ValueError("Attributions have not yet been calculated. Please call the explainer on text first.")
Notes:
- I prefer using
index
inself._selected_indexes:
as the iterator for consistency with other methods within the Class. - I have checked that the
labels
fromself.predicted_class_names
is consistent to the inferedlabels
in thetry...except
statement. I think it is better to use the pre-infered labels for consistency. However, please do re-check and verify when you try it.
Metadata
Metadata
Assignees
Labels
No labels