Skip to content

Commit

Permalink
VZ working for encoder-decoder
Browse files Browse the repository at this point in the history
  • Loading branch information
gsarti committed Feb 27, 2024
1 parent ea78f50 commit 6c4d1df
Show file tree
Hide file tree
Showing 11 changed files with 151 additions and 77 deletions.
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

- Support for multi-GPU attribution ([#238](https://github.com/inseq-team/inseq/pull/238))
- Added `inseq attribute-context` CLI command to support the [PECoRe framework] for detecting and attributing context reliance in generative LMs ([#237](https://github.com/inseq-team/inseq/pull/237))
- Added `value_zeroing` (`inseq.attr.feat.perturbation_attribution.ValueZeroingAttribution`) attribution method ([#173](https://github.com/inseq-team/inseq/pull/173))
- Added `rollout` (`inseq.data.aggregation_functions.RolloutAggregationFunction`) aggregation function for `SequenceAttributionAggregator` class ([#173](https://github.com/inseq-team/inseq/pull/173)).
- `value_zeroing` and `attention` use scores from the last generation step to produce outputs more efficiently (`is_final_step_method = True`) ([#173](https://github.com/inseq-team/inseq/pull/173)).

## 🔧 Fixes & Refactoring

Expand All @@ -26,4 +29,5 @@

## 💥 Breaking Changes

*No changes*
- If `attention` is used as attribution method in `model.attribute`, `step_scores` cannot be extracted at the same time since the method does not require iterating over the full sequence anymore. ([#173](https://github.com/inseq-team/inseq/pull/173)) As an alternative, step scores can be extracted separately using the `dummy` attribution method (i.e. no attribution).
- BOS is always included in target-side attribution and generated sequences if present. ([#173](https://github.com/inseq-team/inseq/pull/173))
8 changes: 6 additions & 2 deletions inseq/attr/feat/attribution_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,15 @@ def extract_args(
def get_source_target_attributions(
attr: Union[StepAttributionTensor, tuple[StepAttributionTensor, StepAttributionTensor]],
is_encoder_decoder: bool,
has_sequence_scores: bool = False,
) -> tuple[Optional[StepAttributionTensor], Optional[StepAttributionTensor]]:
if isinstance(attr, tuple):
if is_encoder_decoder:
return (attr[0], attr[1]) if len(attr) > 1 else (attr[0], None)
if has_sequence_scores:
return (attr[0], attr[1], attr[2])
else:
return (attr[0], attr[1]) if len(attr) > 1 else (attr[0], None)
else:
return (None, attr[0])
return (None, None, attr[0]) if has_sequence_scores else (None, attr[0])
else:
return (attr, None) if is_encoder_decoder else (None, attr)
74 changes: 55 additions & 19 deletions inseq/attr/feat/ops/value_zeroing.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ class ValueZeroingSimilarityMetric(Enum):
class ValueZeroingModule(Enum):
DECODER = "decoder"
ENCODER = "encoder"
CROSS = "cross"


class ValueZeroing(InseqAttribution):
Expand Down Expand Up @@ -155,20 +154,26 @@ def compute_modules_post_zeroing_similarity(
inputs: TensorOrTupleOfTensorsGeneric,
additional_forward_args: TensorOrTupleOfTensorsGeneric,
hidden_states: MultiLayerEmbeddingsTensor,
attention_module_name: str,
attributed_seq_len: Optional[int] = None,
similarity_metric: str = ValueZeroingSimilarityMetric.COSINE.value,
mode: str = ValueZeroingModule.DECODER.value,
zeroed_units_indices: Optional[OneOrMoreIndicesDict] = None,
threshold: float = 1e-5,
min_score_threshold: float = 1e-5,
use_causal_mask: bool = False,
) -> MultiLayerScoreTensor:
"""Given a ``nn.ModuleList``, computes the similarity between the clean and corrupted states for each block.
Args:
modules (:obj:`nn.ModuleList`): The list of modules to compute the similarity for.
hidden_states (:obj:`MultiLayerEmbeddingsTensor`): The cached hidden states of the modules to use as clean
counterparts when computing the similarity.
similarity_scores_shape (:obj:`torch.Size`): The shape of the similarity scores tensor to be returned.
attention_module_name (:obj:`str`): The name of the attention module to zero the values for.
attributed_seq_len (:obj:`int`): The length of the sequence to attribute. If not specified, it is assumed
to be the same as the length of the hidden states.
similarity_metric (:obj:`str`): The name of the similarity metric used. Default: "cosine".
mode (:obj:`str`): The mode of the model to compute the similarity for. Default: "decoder".
zeroed_units_indices (:obj:`Union[int, tuple[int, int], list[int]]` or :obj:`dict` with :obj:`int` keys and
`Union[int, tuple[int, int], list[int]]` values, optional): The indices of the attention heads
that should be zeroed to compute corrupted states.
Expand All @@ -179,18 +184,25 @@ def compute_modules_post_zeroing_similarity(
- If a dictionary, the keys are the layer indices and the values are the zeroed attention heads for
the corresponding layer. Any missing layer will not be zeroed.
Default: None.
min_score_threshold (:obj:`float`, optional): The minimum score threshold to consider when computing the
similarity. Default: 1e-5.
use_causal_mask (:obj:`bool`, optional): Whether a causal mask is applied to zeroing scores Default: False.
Returns:
:obj:`MultiLayerScoreTensor`: A tensor of shape ``[batch_size, seq_len, num_layer]`` containing distances
(1 - similarity score) between original and corrupted states for each layer.
"""
if mode == ValueZeroingModule.DECODER.value:
modules: nn.ModuleList = find_block_stack(self.forward_func.get_decoder())
batch_size = hidden_states.size(0)
num_layers = len(modules)
sequence_length = hidden_states.size(2)
elif mode == ValueZeroingModule.ENCODER.value:
modules: nn.ModuleList = find_block_stack(self.forward_func.get_encoder())
else:
raise NotImplementedError(f"Mode {mode} not implemented for value zeroing.")
if attributed_seq_len is None:
attributed_seq_len = hidden_states.size(2)
batch_size = hidden_states.size(0)
generated_seq_len = hidden_states.size(2)
num_layers = len(modules)

# Store clean hidden states for later use. Starts at 1 since the first element of the modules stack is the
# embedding layer, and we are only interested in the transformer blocks outputs.
Expand All @@ -199,7 +211,7 @@ def compute_modules_post_zeroing_similarity(
}
# Scores for every layer of the model
all_scores = torch.ones(
batch_size, num_layers, sequence_length, sequence_length, device=hidden_states.device
batch_size, num_layers, generated_seq_len, attributed_seq_len, device=hidden_states.device
) * float("nan")

# Hooks:
Expand All @@ -218,11 +230,11 @@ def compute_modules_post_zeroing_similarity(
modules[block_idx].register_forward_hook(states_extract_and_patch_hook)
)
# Zeroing is done for every token in the sequence separately (O(n) complexity)
for token_idx in range(sequence_length):
for token_idx in range(attributed_seq_len):
value_zeroing_hook_handles: list[RemovableHandle] = []
# Value zeroing hooks are registered for every token separately since they are token-dependent
for block_idx, block in enumerate(modules):
attention_module = block.get_submodule(self.forward_func.config.attention_module)
attention_module = block.get_submodule(attention_module_name)
if isinstance(zeroed_units_indices, dict):
if block_idx not in zeroed_units_indices:
continue
Expand Down Expand Up @@ -259,19 +271,22 @@ def compute_modules_post_zeroing_similarity(
for block_idx in range(len(modules)):
similarity_scores = self.SIMILARITY_METRICS[similarity_metric](
self.clean_block_output_states[block_idx].float(), self.corrupted_block_output_states[block_idx]
)[:, token_idx:]
all_scores[:, block_idx, token_idx:, token_idx] = 1 - similarity_scores
)
if use_causal_mask:
all_scores[:, block_idx, token_idx:, token_idx] = 1 - similarity_scores[:, token_idx:]
else:
all_scores[:, block_idx, :, token_idx] = 1 - similarity_scores
self.corrupted_block_output_states = {}
for handle in states_extraction_hook_handles:
handle.remove()
self.clean_block_output_states = {}
all_scores = torch.where(all_scores < threshold, torch.zeros_like(all_scores), all_scores)
all_scores = torch.where(all_scores < min_score_threshold, torch.zeros_like(all_scores), all_scores)
# Normalize scores to sum to 1
per_token_sum_score = all_scores.sum(dim=-1, keepdim=True)
per_token_sum_score = all_scores.nansum(dim=-1, keepdim=True)
per_token_sum_score[per_token_sum_score == 0] = 1
all_scores = all_scores / per_token_sum_score

# Final shape: [batch_size, seq_len, seq_len, num_layers]
# Final shape: [batch_size, attributed_seq_len, generated_seq_len, num_layers]
return all_scores.permute(0, 3, 2, 1)

def attribute(
Expand Down Expand Up @@ -312,18 +327,39 @@ def attribute(
f"Similarity metric {similarity_metric} not available."
f"Available metrics: {','.join(self.SIMILARITY_METRICS.keys())}"
)

decoder_scores = self.compute_modules_post_zeroing_similarity(
inputs=inputs,
additional_forward_args=additional_forward_args,
hidden_states=decoder_hidden_states,
attention_module_name=self.forward_func.config.self_attention_module,
similarity_metric=similarity_metric,
mode=ValueZeroingModule.DECODER.value,
zeroed_units_indices=zeroed_units_indices,
use_causal_mask=True,
)
return decoder_scores
# Encoder-decoder models also perform zeroing on the encoder self-attention and cross-attention values
# Adapted from https://github.com/hmohebbi/ContextMixingASR/blob/master/scoring/valueZeroing.py
# if is_encoder_decoder:
# encoder_hidden_states = torch.stack(outputs.encoder_hidden_states)
# encoder = self.forward_func.get_encoder()
# encoder_stack = find_block_stack(encoder)
if self.forward_func.is_encoder_decoder:
# TODO: Enable different encoder/decoder/cross zeroing indices
encoder_scores = self.compute_modules_post_zeroing_similarity(
inputs=inputs,
additional_forward_args=additional_forward_args,
hidden_states=encoder_hidden_states,
attention_module_name=self.forward_func.config.self_attention_module,
similarity_metric=similarity_metric,
mode=ValueZeroingModule.ENCODER.value,
zeroed_units_indices=zeroed_units_indices,
)
cross_scores = self.compute_modules_post_zeroing_similarity(
inputs=inputs,
additional_forward_args=additional_forward_args,
hidden_states=decoder_hidden_states,
attributed_seq_len=encoder_hidden_states.size(2),
attention_module_name=self.forward_func.config.cross_attention_module,
similarity_metric=similarity_metric,
mode=ValueZeroingModule.DECODER.value,
zeroed_units_indices=zeroed_units_indices,
)
return encoder_scores, cross_scores, decoder_scores
return (decoder_scores,)
22 changes: 18 additions & 4 deletions inseq/attr/feat/perturbation_attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,25 @@ def attribute_step(
attribution_args: dict[str, Any] = {},
) -> MultiDimensionalFeatureAttributionStepOutput:
attr = self.method.attribute(**attribute_fn_main_args, **attribution_args)
source_attributions, target_attributions = get_source_target_attributions(
attr, self.attribution_model.is_encoder_decoder
encoder_self_scores, decoder_cross_scores, decoder_self_scores = get_source_target_attributions(
attr, self.attribution_model.is_encoder_decoder, has_sequence_scores=True
)
sequence_scores = {}
if self.attribution_model.is_encoder_decoder:
if len(attribute_fn_main_args["inputs"]) > 1:
target_attributions = decoder_self_scores.to("cpu")
else:
target_attributions = None
sequence_scores["decoder_self_scores"] = decoder_self_scores.to("cpu")
sequence_scores["encoder_self_scores"] = encoder_self_scores.to("cpu")
return MultiDimensionalFeatureAttributionStepOutput(
source_attributions=decoder_cross_scores.to("cpu"),
target_attributions=target_attributions,
sequence_scores=sequence_scores,
_num_dimensions=1, # num_layers
)
return MultiDimensionalFeatureAttributionStepOutput(
source_attributions=source_attributions,
target_attributions=target_attributions,
source_attributions=None,
target_attributions=decoder_self_scores,
_num_dimensions=1, # num_layers
)
9 changes: 5 additions & 4 deletions inseq/data/attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,11 @@ def from_step_attributions(
curr_target = [a.target[seq_idx][0] for a in attributions]
targets.append(drop_padding(curr_target, pad_token))
if has_bos_token:
tokenized_target_sentences[seq_idx] = tokenized_target_sentences[seq_idx][1:]
tokenized_target_sentences[seq_idx] = drop_padding(tokenized_target_sentences[seq_idx], pad_token)
tokenized_target_sentences[seq_idx] = tokenized_target_sentences[seq_idx][:1] + drop_padding(
tokenized_target_sentences[seq_idx], pad_token
)
else:
tokenized_target_sentences[seq_idx] = drop_padding(tokenized_target_sentences[seq_idx], pad_token)
if attr_pos_end is None:
attr_pos_end = max(len(t) for t in tokenized_target_sentences)
for seq_idx in range(num_sequences):
Expand Down Expand Up @@ -238,8 +241,6 @@ def from_step_attributions(
[att.target_attributions for att in attributions], padding_dims=[1]
)
for seq_id in range(num_sequences):
if has_bos_token:
target_attributions[seq_id] = target_attributions[seq_id][1:, ...]
start_idx = max(pos_start) - pos_start[seq_id]
end_idx = start_idx + len(tokenized_target_sentences[seq_id])
target_attributions[seq_id] = target_attributions[seq_id][
Expand Down
14 changes: 10 additions & 4 deletions inseq/models/model_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Optional

import yaml

Expand All @@ -12,18 +13,23 @@ class ModelConfig:
"""Configuration used by the methods for which the attribute ``use_model_config=True``.
Args:
attention_module (:obj:`str`):
The name of the module performing the attention computation (e.g.``attn`` for the GPT-2 model in
transformers). Can be identified by looking at the name of the attribute instantiating the attention module
self_attention_module (:obj:`str`):
The name of the module performing the self-attention computation (e.g.``attn`` for the GPT-2 model in
transformers). Can be identified by looking at the name of the self-attention module attribute
in the model's transformer block class (e.g. :obj:`transformers.models.gpt2.GPT2Block` for GPT-2).
cross_attention_module (:obj:`str`):
The name of the module performing the cross-attention computation (e.g.``encoder_attn`` for MarianMT models
in transformers). Can be identified by looking at the name of the cross-attention module attribute
in the model's transformer block class (e.g. :obj:`transformers.models.marian.MarianDecoderLayer`).
value_vector (:obj:`str`):
The name of the variable in the forward pass of the attention module containing the value vector
(e.g. ``value`` for the GPT-2 model in transformers). Can be identified by looking at the forward pass of
the attention module (e.g. :obj:`transformers.models.gpt2.modeling_gpt2.GPT2Attention.forward` for GPT-2).
"""

attention_module: str
self_attention_module: str
value_vector: str
cross_attention_module: Optional[str] = None


MODEL_CONFIGS = {
Expand Down
53 changes: 30 additions & 23 deletions inseq/models/model_config.yaml
Original file line number Diff line number Diff line change
@@ -1,46 +1,53 @@
# AutoModelForCausalLM
BloomForCausalLM:
self_attention_module: "self_attention"
value_vector: "value_layer"
GPT2LMHeadModel:
attention_module: "attn"
self_attention_module: "attn"
value_vector: "value"
OpenAIGPTLMHeadModel:
attention_module: "attn"
self_attention_module: "attn"
value_vector: "value"
GPTNeoXForCausalLM:
attention_module: "attention"
self_attention_module: "attention"
value_vector: "value"
BloomForCausalLM:
attention_module: "self_attention"
value_vector: "value_layer"
LlamaForCausalLM:
attention_module: "self_attn"
self_attention_module: "self_attn"
value_vector: "value_states"
GPTBigCodeForCausalLM:
attention_module: "attn"
self_attention_module: "attn"
value_vector: "value"
CodeGenForCausalLM:
attention_module: "attn"
self_attention_module: "attn"
value_vector: "value"

# TODO ForCausalLM
# TODO
# BioGptForCausalLM
# GemmaForCausalLM
# GPTNeoForCausalLM
# GPTJForCausalLM
# MistralForCausalLM
# MixtralForCausalLM
# MptForCausalLM
# OpenLlamaForCausalLM
# OPTForCausalLM
# PhiForCausalLM
# StableLmForCausalLM
# XGLMForCausalLM
# BioGptForCausalLM
# XLNetLMHeadModel

# AutoModelForSeq2SeqLM
MarianMTModel:
self_attention_module: "self_attn"
cross_attention_module: "encoder_attn"
value_vector: "value_states"

# TODO ForConditionalGeneration
# BartForConditionalGeneration
# BlenderbotForConditionalGeneration
# T5ForConditionalGeneration
# MarianMTModel
# LongT5ForConditionalGeneration
# FSMTForConditionalGeneration
# LongT5ForConditionalGeneration
# M2M100ForConditionalGeneration
# MBartForConditionalGeneration
# PegasusForConditionalGeneration
# ProphetNetForConditionalGeneration
# LEDForConditionalGeneration
# BigBirdPegasusForConditionalGeneration
# PLBartForConditionalGeneration
# SwitchTransformerForConditionalGeneration
# MT5ForConditionalGeneration
# NllbMoeForConditionalGeneration
# SeamlessM4TForTextToText
# SeamlessM4Tv2ForTextToText
# T5ForConditionalGeneration
2 changes: 1 addition & 1 deletion tests/attr/feat/test_feature_attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def test_contrastive_attribution_seq2seq_alignments(saliency_mt_model_larger: Hu
"orig_tgt": "I soldati della pace ONU",
"contrast_tgt": "Le forze militari di pace delle Nazioni Unite",
"alignments": [[(0, 0), (1, 1), (2, 2), (3, 4), (4, 5), (5, 7), (6, 9)]],
"aligned_tgts": ["▁Le → ▁I", "▁forze → ▁soldati", "▁di → ▁della", "▁pace", "▁Nazioni → ▁ONU", "</s>"],
"aligned_tgts": ["<pad>", "▁Le → ▁I", "▁forze → ▁soldati", "▁di → ▁della", "▁pace", "▁Nazioni → ▁ONU", "</s>"],
}
out = saliency_mt_model_larger.attribute(
aligned["src"],
Expand Down
Loading

0 comments on commit 6c4d1df

Please sign in to comment.