Explainable AI: Interpreting BERT Model
Learn how explainable AI techniques like Integrated Gradients can be used to make a deep learning NLP model interpretable.
Join the DZone community and get the full member experience.
Join For FreeMotivation and Background
Why is it important to build interpretable AI models?
The future of AI is in enabling humans and machines to work together to solve complex problems. Organizations are attempting to improve process efficiency and transparency by combining AI/ML technology with human review.
In recent years with the advancement of AI, AI-specific regulations have emerged, for example, Good Machine Learning Practices (GMLP) in pharma and Model Risk Management (MRM) in finance industries, other broad-spectrum regulations addressing data privacy, EU’s GDPR and California’s CCPA. Similarly, internal compliance teams may also want to interpret a model’s behavior when validating decisions based on model predictions. For instance, underwriters want to learn why a specific loan application was tagged suspicious by an ML model.
Overview
What is interpretability?
In the ML context, interpretability refers to trying to backtrack what factors have contributed to an ML model for making a certain prediction. As shown in the graph below, simpler models are easier to interpret but may often produce lower accuracy compared to complex models like Deep Learning and transformer-based models that can understand non-linear relations in the data and often have high accuracy.
Loosely defined, there are two types of explanations:
- Global explanation: is explaining on an overall model level to understand what features have contributed the most to the output? For example, in a finance setting where the use case is to build an ML model to identify customers who are most likely to default, some of the most influential features for making that decision are the customer’s credit score, total no. of credit cards, revolving balance, etc.
- Local explanation: This can enable you to zoom in on a particular data point and observe the behavior of the model in that neighborhood. For example, for sentiment classification of a movie review use case, certain words in the review may have a higher impact on the outcomes vs the other. “I have never watched something as bad.”
What is a transformer model?
A transformer model is a neural network that tracks relationships in sequential input, such as the words in a sentence, to learn context and subsequent meaning. Transformer models use an evolving set of mathematical approaches, called attention or self-attention, to find minute relationships between even distance data elements in a series. Refer to Google’s publication for more information.
Integrated Gradients
Integrated Gradients (IG), is an Explainable AI technique introduced in the paper Axiomatic Attribution for Deep Networks. In this paper, an attempt is made to assign an attribution value to each input feature. This tells how much the input contributed to the final prediction.
IG is a local method that is a popular interpretability technique due to its broad applicability to any differentiable model (e.g., text, image, structured data), ease of implementation, computational efficiency relative to alternative approaches, and theoretical justifications. Integrated gradients represent the integral of gradients with respect to inputs along the path from a given baseline to input. The integral can be approximated using a Riemann Sum or Gauss Legendre quadrature rule. Formally, it can be described as follows:
The cornerstones of this approach are two fundamental axioms, namely sensitivity and implementation invariance. More information can be found in the original paper.
Use Case
Now let’s see in action how the Integrated Gradients method can be applied using the Captum package. We will be fine-tuning a question-answering BERT (Bidirectional Encoder Representations from Transformers) model, on the SQUAD dataset using the transformers library from HuggingFace, review notebook for a detailed walkthrough.
Steps
- Load the tokenizer and pre-trained BERT model, in this case,
bert-base-uncased
- Next is computing attributions w.r.t
BertEmbeddings
layer. To do so, define baseline/references and numericalize both the baselines and inputs.
def construct_whole_bert_embeddings(input_ids, ref_input_ids, \
token_type_ids=None, ref_token_type_ids=None, \
position_ids=None, ref_position_ids=None):
input_embeddings = model.bert.embeddings(input_ids, token_type_ids=token_type_ids, position_ids=position_ids)
ref_input_embeddings = model.bert.embeddings(ref_input_ids, token_type_ids=ref_token_type_ids, position_ids=ref_position_ids)
return input_embeddings, ref_input_embeddings
- Now, let's define the question-answer pair as an input to our BERT model
Question = “What is important to us?”
text = “It is important to us to include, empower and support humans of all kinds.”
- Generate corresponding baselines/references for question-answer pair
- The next step is to make predictions, one option is to use
LayerIntegratedGradients
and compute the attributions with respect toBertEmbedding
.LayerIntegratedGradients
represents the integral of gradients with respect to the layer inputs/outputs along the straight-line path from the layer activations at the given baseline to the layer activation at the input.
start_scores, end_scores = predict(input_ids, \
token_type_ids=token_type_ids, \
position_ids=position_ids, \
attention_mask=attention_mask)
print(‘Question: ‘, question)
print(‘Predicted Answer: ‘, ‘ ‘.join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores)+1]))
lig = LayerIntegratedGradients(squad_pos_forward_func, model.bert.embeddings)
- Output:
Question: What is important to us?
Predicted Answer: to include , em ##power and support humans of all kinds
- Visualize attributes for each word token in the input sequence using a helper function
# storing couple samples in an array for visualization purposes
start_position_vis =
viz.VisualizationDataRecord(
attributions_start_sum,
torch.max(torch.softmax(start_scores[0], dim=0)),
torch.argmax(start_scores),
torch.argmax(start_scores),
str(ground_truth_start_ind),
attributions_start_sum.sum(),
all_tokens,
delta_start)
print(‘\033[1m’, ‘Visualizations For Start Position’, ‘\033[0m’)
viz.visualize_text([start_position_vis])
print(‘\033[1m’, ‘Visualizations For End Position’, ‘\033[0m’)
viz.visualize_text([end_position_vis])
From the results above we can tell that for predicting the start position, our model is focusing more on the question side. More specifically on the tokens ‘what’ and ‘important’. It has also a slight focus on the token sequence ‘to us’ on the text side.
In contrast to that, for predicting end position, our model focuses more on the text side and has relatively high attribution on the last end position token ‘kinds’.
Conclusion
This blog describes how explainable AI techniques like Integrated Gradients can be used to make a deep learning NLP model interpretable by highlighting positive and negative word influences on the outcome of the model.
References
Published at DZone with permission of Sai Sharanya Nalla. See the original article here.
Opinions expressed by DZone contributors are their own.
Comments