diff --git a/CHANGELOG.md b/CHANGELOG.md
index faeca60e..3fade279 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -6,6 +6,8 @@
- Support for multi-GPU attribution ([#238](https://github.com/inseq-team/inseq/pull/238))
- Added `inseq attribute-context` CLI command to support the [PECoRe framework] for detecting and attributing context reliance in generative LMs ([#237](https://github.com/inseq-team/inseq/pull/237))
+- Added `value_zeroing` (`inseq.attr.feat.perturbation_attribution.ValueZeroingAttribution`) attribution method ([#173](https://github.com/inseq-team/inseq/pull/173))
+- `value_zeroing` and `attention` use scores from the last generation step to produce outputs more efficiently (`is_final_step_method = True`) ([#173](https://github.com/inseq-team/inseq/pull/173)).
## š§ Fixes & Refactoring
@@ -26,4 +28,5 @@
## š„ Breaking Changes
-*No changes*
+- If `attention` is used as attribution method in `model.attribute`, `step_scores` cannot be extracted at the same time since the method does not require iterating over the full sequence anymore. ([#173](https://github.com/inseq-team/inseq/pull/173)) As an alternative, step scores can be extracted separately using the `dummy` attribution method (i.e. no attribution).
+- BOS is always included in target-side attribution and generated sequences if present. ([#173](https://github.com/inseq-team/inseq/pull/173))
diff --git a/README.md b/README.md
index 79aa0563..5e09d734 100644
--- a/README.md
+++ b/README.md
@@ -147,6 +147,8 @@ Use the `inseq.list_feature_attribution_methods` function to list all available
- `lime`: ["Why Should I Trust You?": Explaining the Predictions of Any Classifier](https://arxiv.org/abs/1602.04938) (Ribeiro et al., 2016)
+- `value_zeroing`: [Quantifying Context Mixing in Transformers](https://aclanthology.org/2023.eacl-main.245/) (Mohebbi et al. 2023)
+
#### Step functions
Step functions are used to extract custom scores from the model at each step of the attribution process with the `step_scores` argument in `model.attribute`. They can also be used as targets for attribution methods relying on model outputs (e.g. gradient-based methods) by passing them as the `attributed_fn` argument. The following step functions are currently supported:
diff --git a/docs/source/main_classes/feature_attribution.rst b/docs/source/main_classes/feature_attribution.rst
index 174b405c..1f282626 100644
--- a/docs/source/main_classes/feature_attribution.rst
+++ b/docs/source/main_classes/feature_attribution.rst
@@ -17,7 +17,7 @@ Attribution Methods
.. autoclass:: inseq.attr.FeatureAttribution
:members:
-Gradient Attribution Methods
+Gradient-based Attribution Methods
-----------------------------------------------------------------------------------------------------------------------
.. autoclass:: inseq.attr.feat.GradientAttributionRegistry
@@ -67,7 +67,7 @@ Layer Attribution Methods
:members:
-Attention Attribution Methods
+Internals-based Attribution Methods
-----------------------------------------------------------------------------------------------------------------------
.. autoclass:: inseq.attr.feat.InternalsAttributionRegistry
@@ -76,3 +76,18 @@ Attention Attribution Methods
.. autoclass:: inseq.attr.feat.AttentionWeightsAttribution
:members:
+
+Perturbation-based Attribution Methods
+-----------------------------------------------------------------------------------------------------------------------
+
+.. autoclass:: inseq.attr.feat.PerturbationAttributionRegistry
+ :members:
+
+.. autoclass:: inseq.attr.feat.OcclusionAttribution
+ :members:
+
+.. autoclass:: inseq.attr.feat.LimeAttribution
+ :members:
+
+.. autoclass:: inseq.attr.feat.ValueZeroingAttribution
+ :members:
\ No newline at end of file
diff --git a/inseq/attr/feat/__init__.py b/inseq/attr/feat/__init__.py
index cc07f530..2b25778a 100644
--- a/inseq/attr/feat/__init__.py
+++ b/inseq/attr/feat/__init__.py
@@ -17,6 +17,8 @@
from .perturbation_attribution import (
LimeAttribution,
OcclusionAttribution,
+ PerturbationAttributionRegistry,
+ ValueZeroingAttribution,
)
__all__ = [
@@ -39,4 +41,6 @@
"OcclusionAttribution",
"LimeAttribution",
"SequentialIntegratedGradientsAttribution",
+ "ValueZeroingAttribution",
+ "PerturbationAttributionRegistry",
]
diff --git a/inseq/attr/feat/attribution_utils.py b/inseq/attr/feat/attribution_utils.py
index 8da4f899..a9679845 100644
--- a/inseq/attr/feat/attribution_utils.py
+++ b/inseq/attr/feat/attribution_utils.py
@@ -144,11 +144,15 @@ def extract_args(
def get_source_target_attributions(
attr: Union[StepAttributionTensor, tuple[StepAttributionTensor, StepAttributionTensor]],
is_encoder_decoder: bool,
+ has_sequence_scores: bool = False,
) -> tuple[Optional[StepAttributionTensor], Optional[StepAttributionTensor]]:
if isinstance(attr, tuple):
if is_encoder_decoder:
- return (attr[0], attr[1]) if len(attr) > 1 else (attr[0], None)
+ if has_sequence_scores:
+ return (attr[0], attr[1], attr[2])
+ else:
+ return (attr[0], attr[1]) if len(attr) > 1 else (attr[0], None)
else:
- return (None, attr[0])
+ return (None, None, attr[0]) if has_sequence_scores else (None, attr[0])
else:
return (attr, None) if is_encoder_decoder else (None, attr)
diff --git a/inseq/attr/feat/feature_attribution.py b/inseq/attr/feat/feature_attribution.py
index 250cd700..ce0e9300 100644
--- a/inseq/attr/feat/feature_attribution.py
+++ b/inseq/attr/feat/feature_attribution.py
@@ -114,6 +114,7 @@ def __init__(self, attribution_model: "AttributionModel", hook_to_model: bool =
self.use_hidden_states: bool = False
self.use_predicted_target: bool = True
self.use_model_config: bool = False
+ self.is_final_step_method: bool = False
if hook_to_model:
self.hook(**kwargs)
@@ -272,6 +273,35 @@ def _run_compatibility_checks(self, attributed_fn) -> None:
" method."
)
+ @staticmethod
+ def _build_multistep_output_from_single_step(
+ single_step_output: FeatureAttributionStepOutput,
+ attr_pos_start: int,
+ attr_pos_end: int,
+ ) -> list[FeatureAttributionStepOutput]:
+ if single_step_output.step_scores:
+ raise ValueError("step_scores are not supported for final step attribution methods.")
+ num_seq = len(single_step_output.prefix)
+ steps = []
+ for pos_idx in range(attr_pos_start, attr_pos_end):
+ step_output = single_step_output.clone_empty()
+ step_output.source = single_step_output.source
+ step_output.prefix = [single_step_output.prefix[seq_idx][:pos_idx] for seq_idx in range(num_seq)]
+ step_output.target = (
+ single_step_output.target
+ if pos_idx == attr_pos_end - 1
+ else [[single_step_output.prefix[seq_idx][pos_idx]] for seq_idx in range(num_seq)]
+ )
+ if single_step_output.source_attributions is not None:
+ step_output.source_attributions = single_step_output.source_attributions[:, :, pos_idx - 1]
+ if single_step_output.target_attributions is not None:
+ step_output.target_attributions = single_step_output.target_attributions[:, :pos_idx, pos_idx - 1]
+ single_step_output.step_scores = {}
+ if single_step_output.sequence_scores is not None:
+ step_output.sequence_scores = single_step_output.sequence_scores
+ steps.append(step_output)
+ return steps
+
def format_contrastive_targets(
self,
target_sequences: TextSequences,
@@ -416,9 +446,9 @@ def attribute(
target_lengths=targets_lengths,
method_name=self.method_name,
show=show_progress,
- pretty=pretty_progress,
+ pretty=False if self.is_final_step_method else pretty_progress,
attr_pos_start=attr_pos_start,
- attr_pos_end=attr_pos_end,
+ attr_pos_end=1 if self.is_final_step_method else attr_pos_end,
)
whitespace_indexes = find_char_indexes(sequences.targets, " ")
attribution_outputs = []
@@ -427,6 +457,8 @@ def attribute(
# Attribution loop for generation
for step in range(attr_pos_start, iter_pos_end):
+ if self.is_final_step_method and step != iter_pos_end - 1:
+ continue
tgt_ids, tgt_mask = batch.get_step_target(step, with_attention=True)
step_output = self.filtered_attribute_step(
batch[:step],
@@ -450,7 +482,7 @@ def attribute(
contrast_targets_alignments=contrast_targets_alignments,
)
attribution_outputs.append(step_output)
- if pretty_progress:
+ if pretty_progress and not self.is_final_step_method:
tgt_tokens = batch.target_tokens
skipped_prefixes = tok2string(self.attribution_model, tgt_tokens, end=attr_pos_start)
attributed_sentences = tok2string(self.attribution_model, tgt_tokens, attr_pos_start, step + 1)
@@ -471,12 +503,17 @@ def attribute(
end = datetime.now()
close_progress_bar(pbar, show=show_progress, pretty=pretty_progress)
batch.detach().to("cpu")
+ if self.is_final_step_method:
+ attribution_outputs = self._build_multistep_output_from_single_step(
+ attribution_outputs[0],
+ attr_pos_start=attr_pos_start,
+ attr_pos_end=iter_pos_end,
+ )
out = FeatureAttributionOutput(
sequence_attributions=FeatureAttributionSequenceOutput.from_step_attributions(
attributions=attribution_outputs,
tokenized_target_sentences=target_tokens_with_ids,
- pad_id=self.attribution_model.pad_token,
- has_bos_token=self.attribution_model.is_encoder_decoder,
+ pad_token=self.attribution_model.pad_token,
attr_pos_end=attr_pos_end,
),
step_attributions=attribution_outputs if output_step_attributions else None,
@@ -593,7 +630,7 @@ def filtered_attribute_step(
step_output.step_scores[score] = get_step_scores(score, step_fn_args, step_fn_extra_args).to("cpu")
# Reinsert finished sentences
if target_attention_mask is not None and is_filtered:
- step_output.remap_from_filtered(target_attention_mask, orig_batch)
+ step_output.remap_from_filtered(target_attention_mask, orig_batch, self.is_final_step_method)
step_output = step_output.detach().to("cpu")
return step_output
diff --git a/inseq/attr/feat/internals_attribution.py b/inseq/attr/feat/internals_attribution.py
index 9c6e8923..003c1869 100644
--- a/inseq/attr/feat/internals_attribution.py
+++ b/inseq/attr/feat/internals_attribution.py
@@ -16,12 +16,12 @@
import logging
from typing import Any, Optional
+import torch
from captum._utils.typing import TensorOrTupleOfTensorsGeneric
-from captum.attr._utils.attribution import Attribution
from ...data import MultiDimensionalFeatureAttributionStepOutput
from ...utils import Registry
-from ...utils.typing import MultiLayerMultiUnitScoreTensor
+from ...utils.typing import InseqAttribution, MultiLayerMultiUnitScoreTensor
from .feature_attribution import FeatureAttribution
logger = logging.getLogger(__name__)
@@ -38,7 +38,7 @@ class AttentionWeightsAttribution(InternalsAttributionRegistry):
method_name = "attention"
- class AttentionWeights(Attribution):
+ class AttentionWeights(InseqAttribution):
@staticmethod
def has_convergence_delta() -> bool:
return False
@@ -74,9 +74,14 @@ def attribute(
:class:`~inseq.data.MultiDimensionalFeatureAttributionStepOutput`: A step output containing attention
weights for each layer and head, with shape :obj:`(batch_size, seq_len, n_layers, n_heads)`.
"""
- # We adopt the format [batch_size, sequence_length, num_layers, num_heads]
+ # We adopt the format [batch_size, sequence_length, sequence_length, num_layers, num_heads]
# for consistency with other multi-unit methods (e.g. gradient attribution)
- decoder_self_attentions = decoder_self_attentions[..., -1, :].to("cpu").clone().permute(0, 3, 1, 2)
+ decoder_self_attentions = decoder_self_attentions.to("cpu").clone().permute(0, 4, 3, 1, 2)
+ decoder_self_attentions = torch.where(
+ decoder_self_attentions == 0,
+ (torch.ones_like(decoder_self_attentions) * float("nan")),
+ decoder_self_attentions,
+ )
if self.forward_func.is_encoder_decoder:
sequence_scores = {}
if len(inputs) > 1:
@@ -85,10 +90,11 @@ def attribute(
target_attributions = None
sequence_scores["decoder_self_attentions"] = decoder_self_attentions
sequence_scores["encoder_self_attentions"] = (
- encoder_self_attentions.to("cpu").clone().permute(0, 3, 4, 1, 2)
+ encoder_self_attentions.to("cpu").clone().permute(0, 4, 3, 1, 2)
)
+ cross_attentions = cross_attentions.to("cpu").clone().permute(0, 4, 3, 1, 2)
return MultiDimensionalFeatureAttributionStepOutput(
- source_attributions=cross_attentions[..., -1, :].to("cpu").clone().permute(0, 3, 1, 2),
+ source_attributions=cross_attentions,
target_attributions=target_attributions,
sequence_scores=sequence_scores,
_num_dimensions=2, # num_layers, num_heads
@@ -106,6 +112,8 @@ def __init__(self, attribution_model, **kwargs):
self.use_attention_weights = True
# Does not rely on predicted output (i.e. decoding strategy agnostic)
self.use_predicted_target = False
+ # Needs only the final generation step to extract scores
+ self.is_final_step_method = True
self.method = self.AttentionWeights(attribution_model)
def attribute_step(
diff --git a/inseq/attr/feat/ops/__init__.py b/inseq/attr/feat/ops/__init__.py
index 388ab042..7d86167a 100644
--- a/inseq/attr/feat/ops/__init__.py
+++ b/inseq/attr/feat/ops/__init__.py
@@ -2,10 +2,12 @@
from .lime import Lime
from .monotonic_path_builder import MonotonicPathBuilder
from .sequential_integrated_gradients import SequentialIntegratedGradients
+from .value_zeroing import ValueZeroing
__all__ = [
"DiscretetizedIntegratedGradients",
"MonotonicPathBuilder",
+ "ValueZeroing",
"Lime",
"SequentialIntegratedGradients",
]
diff --git a/inseq/attr/feat/ops/value_zeroing.py b/inseq/attr/feat/ops/value_zeroing.py
new file mode 100644
index 00000000..c2afdbc7
--- /dev/null
+++ b/inseq/attr/feat/ops/value_zeroing.py
@@ -0,0 +1,394 @@
+# Copyright 2023 The Inseq Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from enum import Enum
+from typing import TYPE_CHECKING, Callable, Optional
+
+import torch
+from captum._utils.typing import TensorOrTupleOfTensorsGeneric
+from torch import nn
+from torch.utils.hooks import RemovableHandle
+
+from ....utils import (
+ StackFrame,
+ find_block_stack,
+ get_post_variable_assignment_hook,
+ recursive_get_submodule,
+ validate_indices,
+)
+from ....utils.typing import (
+ EmbeddingsTensor,
+ InseqAttribution,
+ MultiLayerEmbeddingsTensor,
+ MultiLayerScoreTensor,
+ OneOrMoreIndices,
+ OneOrMoreIndicesDict,
+)
+
+if TYPE_CHECKING:
+ from ....models import HuggingfaceModel
+
+logger = logging.getLogger(__name__)
+
+
+class ValueZeroingSimilarityMetric(Enum):
+ COSINE = "cosine"
+ EUCLIDEAN = "euclidean"
+
+
+class ValueZeroingModule(Enum):
+ DECODER = "decoder"
+ ENCODER = "encoder"
+
+
+class ValueZeroing(InseqAttribution):
+ """Value Zeroing method for feature attribution.
+
+ Introduced by `Mohebbi et al. (2023) `__ to quantify context mixing inside
+ Transformer models. The method is based on the observation that context mixing is regulated by the value vectors
+ of the attention mechanism. The method consists of two steps:
+
+ 1. Zeroing the value vectors of the attention mechanism for a given token index at a given layer of the model.
+ 2. Computing the similarity between hidden states produced with and without the zeroing operation, and using it
+ as a measure of context mixing for the given token at the given layer.
+
+ The method is converted into a feature attribution method by allowing for extraction of value zeroing scores at
+ specific layers, or by aggregating them across layers.
+
+ Attributes:
+ SIMILARITY_METRICS (:obj:`Dict[str, Callable]`):
+ Dictionary of available similarity metrics to be used forvcomputing the distance between hidden states
+ produced with and without the zeroing operation. Converted to distances as 1 - produced values.
+ forward_func (:obj:`AttributionModel`):
+ The attribution model to be used for value zeroing.
+ clean_block_output_states (:obj:`Dict[int, torch.Tensor]`):
+ Dictionary to store the hidden states produced by the model without the zeroing operation.
+ corrupted_block_output_states (:obj:`Dict[int, torch.Tensor]`):
+ Dictionary to store the hidden states produced by the model with the zeroing operation.
+ """
+
+ SIMILARITY_METRICS = {
+ "cosine": nn.CosineSimilarity(dim=-1),
+ "euclidean": lambda x, y: torch.cdist(x, y, p=2),
+ }
+
+ def __init__(self, forward_func: "HuggingfaceModel") -> None:
+ super().__init__(forward_func)
+ self.clean_block_output_states: dict[int, EmbeddingsTensor] = {}
+ self.corrupted_block_output_states: dict[int, EmbeddingsTensor] = {}
+
+ @staticmethod
+ def get_value_zeroing_hook(varname: str = "value") -> Callable[..., None]:
+ """Returns a hook to zero the value vectors of the attention mechanism.
+
+ Args:
+ varname (:obj:`str`, optional): The name of the variable containing the value vectors. The variable
+ is expected to be a 3D tensor of shape (batch_size, num_heads, seq_len) and is retrieved from the
+ local variables of the execution frame during the forward pass.
+ """
+
+ def value_zeroing_forward_mid_hook(
+ frame: StackFrame,
+ zeroed_token_index: Optional[int] = None,
+ zeroed_units_indices: Optional[OneOrMoreIndices] = None,
+ batch_size: int = 1,
+ ) -> None:
+ if varname not in frame.f_locals:
+ raise ValueError(
+ f"Variable {varname} not found in the local frame."
+ f"Other variable names: {', '.join(frame.f_locals.keys())}"
+ )
+ # Zeroing value vectors corresponding to the given token index
+ if zeroed_token_index is not None:
+ values_size = frame.f_locals[varname].size()
+ if len(values_size) == 3: # Assume merged shape (bsz * num_heads, seq_len, hidden_size) e.g. Whisper
+ values = frame.f_locals[varname].view(batch_size, -1, *values_size[1:])
+ elif len(values_size) == 4: # Assume per-head shape (bsz, num_heads, seq_len, hidden_size) e.g. GPT-2
+ values = frame.f_locals[varname].clone()
+ else:
+ raise ValueError(
+ f"Value vector shape {frame.f_locals[varname].size()} not supported. "
+ "Supported shapes: (batch_size, num_heads, seq_len, hidden_size) or "
+ "(batch_size * num_heads, seq_len, hidden_size)"
+ )
+ zeroed_units_indices = validate_indices(values, 1, zeroed_units_indices).to(values.device)
+ zeroed_token_index = torch.tensor(zeroed_token_index, device=values.device)
+ # Mask heads corresponding to zeroed units and tokens corresponding to zeroed tokens
+ values[:, zeroed_units_indices, zeroed_token_index] = 0
+ if len(values_size) == 3:
+ frame.f_locals[varname] = values.view(-1, *values_size[1:])
+ elif len(values_size) == 4:
+ frame.f_locals[varname] = values
+
+ return value_zeroing_forward_mid_hook
+
+ def get_states_extract_and_patch_hook(self, block_idx: int, hidden_state_idx: int = 0) -> Callable[..., None]:
+ """Returns a hook to extract the produced hidden states (corrupted by value zeroing)
+ and patch them with pre-computed clean states that will be passed onwards in the model forward.
+
+ Args:
+ block_idx (:obj:`int`): The idx of the block at which the hook is applied, used to store extracted states.
+ hidden_state_idx (:obj:`int`, optional): The index of the hidden state in the model output tuple.
+ """
+
+ def states_extract_and_patch_forward_hook(module, args, output) -> None:
+ self.corrupted_block_output_states[block_idx] = output[hidden_state_idx].clone().float().detach().cpu()
+
+ # Rebuild the output tuple patching the clean states at the place of the corrupted ones
+ output = (
+ output[:hidden_state_idx]
+ + (self.clean_block_output_states[block_idx].to(output[hidden_state_idx].device),)
+ + output[hidden_state_idx + 1 :]
+ )
+ return output
+
+ return states_extract_and_patch_forward_hook
+
+ @staticmethod
+ def has_convergence_delta() -> bool:
+ return False
+
+ def compute_modules_post_zeroing_similarity(
+ self,
+ inputs: TensorOrTupleOfTensorsGeneric,
+ additional_forward_args: TensorOrTupleOfTensorsGeneric,
+ hidden_states: MultiLayerEmbeddingsTensor,
+ attention_module_name: str,
+ attributed_seq_len: Optional[int] = None,
+ similarity_metric: str = ValueZeroingSimilarityMetric.COSINE.value,
+ mode: str = ValueZeroingModule.DECODER.value,
+ zeroed_units_indices: Optional[OneOrMoreIndicesDict] = None,
+ min_score_threshold: float = 1e-5,
+ use_causal_mask: bool = False,
+ ) -> MultiLayerScoreTensor:
+ """Given a ``nn.ModuleList``, computes the similarity between the clean and corrupted states for each block.
+
+ Args:
+ modules (:obj:`nn.ModuleList`): The list of modules to compute the similarity for.
+ hidden_states (:obj:`MultiLayerEmbeddingsTensor`): The cached hidden states of the modules to use as clean
+ counterparts when computing the similarity.
+ attention_module_name (:obj:`str`): The name of the attention module to zero the values for.
+ attributed_seq_len (:obj:`int`): The length of the sequence to attribute. If not specified, it is assumed
+ to be the same as the length of the hidden states.
+ similarity_metric (:obj:`str`): The name of the similarity metric used. Default: "cosine".
+ mode (:obj:`str`): The mode of the model to compute the similarity for. Default: "decoder".
+ zeroed_units_indices (:obj:`Union[int, tuple[int, int], list[int]]` or :obj:`dict` with :obj:`int` keys and
+ `Union[int, tuple[int, int], list[int]]` values, optional): The indices of the attention heads
+ that should be zeroed to compute corrupted states.
+ - If None, all attention heads across all layers are zeroed.
+ - If an integer, the same attention head is zeroed across all layers.
+ - If a tuple of two integers, the attention heads in the range are zeroed across all layers.
+ - If a list of integers, the attention heads in the list are zeroed across all layers.
+ - If a dictionary, the keys are the layer indices and the values are the zeroed attention heads for
+ the corresponding layer. Any missing layer will not be zeroed.
+ Default: None.
+ min_score_threshold (:obj:`float`, optional): The minimum score threshold to consider when computing the
+ similarity. Default: 1e-5.
+ use_causal_mask (:obj:`bool`, optional): Whether a causal mask is applied to zeroing scores Default: False.
+
+ Returns:
+ :obj:`MultiLayerScoreTensor`: A tensor of shape ``[batch_size, seq_len, num_layer]`` containing distances
+ (1 - similarity score) between original and corrupted states for each layer.
+ """
+ if mode == ValueZeroingModule.DECODER.value:
+ modules: nn.ModuleList = find_block_stack(self.forward_func.get_decoder())
+ elif mode == ValueZeroingModule.ENCODER.value:
+ modules: nn.ModuleList = find_block_stack(self.forward_func.get_encoder())
+ else:
+ raise NotImplementedError(f"Mode {mode} not implemented for value zeroing.")
+ if attributed_seq_len is None:
+ attributed_seq_len = hidden_states.size(2)
+ batch_size = hidden_states.size(0)
+ generated_seq_len = hidden_states.size(2)
+ num_layers = len(modules)
+
+ # Store clean hidden states for later use. Starts at 1 since the first element of the modules stack is the
+ # embedding layer, and we are only interested in the transformer blocks outputs.
+ self.clean_block_output_states = {
+ block_idx: hidden_states[:, block_idx + 1, ...].clone().detach().cpu() for block_idx in range(len(modules))
+ }
+ # Scores for every layer of the model
+ all_scores = torch.ones(
+ batch_size, num_layers, generated_seq_len, attributed_seq_len, device=hidden_states.device
+ ) * float("nan")
+
+ # Hooks:
+ # 1. states_extract_and_patch_hook on the transformer block stores corrupted states and force clean states
+ # as the output of the block forward pass, i.e. the zeroing is done independently across layers.
+ # 2. value_zeroing_hook on the attention module performs the value zeroing by replacing the "value" tensor
+ # during the forward (name is config-dependent) with a zeroed version for the specified token index.
+ #
+ # State extraction hooks can be registered only once since they are token-independent
+ # Skip last block since its states are not used raw, but may have further transformations applied to them
+ # (e.g. LayerNorm, Dropout). These are extracted separately from the model outputs.
+ states_extraction_hook_handles: list[RemovableHandle] = []
+ for block_idx in range(len(modules) - 1):
+ states_extract_and_patch_hook = self.get_states_extract_and_patch_hook(block_idx, hidden_state_idx=0)
+ states_extraction_hook_handles.append(
+ modules[block_idx].register_forward_hook(states_extract_and_patch_hook)
+ )
+ # Zeroing is done for every token in the sequence separately (O(n) complexity)
+ for token_idx in range(attributed_seq_len):
+ value_zeroing_hook_handles: list[RemovableHandle] = []
+ # Value zeroing hooks are registered for every token separately since they are token-dependent
+ for block_idx, block in enumerate(modules):
+ attention_module = recursive_get_submodule(block, attention_module_name)
+ if attention_module is None:
+ raise ValueError(f"Attention module {attention_module_name} not found in block {block_idx}.")
+ if isinstance(zeroed_units_indices, dict):
+ if block_idx not in zeroed_units_indices:
+ continue
+ zeroed_units_indices_block = zeroed_units_indices[block_idx]
+ else:
+ zeroed_units_indices_block = zeroed_units_indices
+ value_zeroing_hook = get_post_variable_assignment_hook(
+ module=attention_module,
+ varname=self.forward_func.config.value_vector,
+ hook_fn=self.get_value_zeroing_hook(self.forward_func.config.value_vector),
+ zeroed_token_index=token_idx,
+ zeroed_units_indices=zeroed_units_indices_block,
+ batch_size=batch_size,
+ )
+ value_zeroing_hook_handle = attention_module.register_forward_pre_hook(value_zeroing_hook)
+ value_zeroing_hook_handles.append(value_zeroing_hook_handle)
+
+ # Run forward pass with hooks. Fills self.corrupted_hidden_states with corrupted states across layers
+ # when zeroing the specified token index.
+ with torch.no_grad():
+ output = self.forward_func.forward_with_output(
+ *inputs, *additional_forward_args, output_hidden_states=True
+ )
+ # Extract last layer states directly from the model outputs
+ # This allows us to handle the presence of additional transformations (e.g. LayerNorm, Dropout)
+ # in the last layer automatically.
+ corrupted_states_dict = self.forward_func.get_hidden_states_dict(output)
+ corrupted_decoder_last_hidden_state = (
+ corrupted_states_dict[f"{mode}_hidden_states"][:, -1, ...].clone().detach().cpu()
+ )
+ self.corrupted_block_output_states[len(modules) - 1] = corrupted_decoder_last_hidden_state
+ for handle in value_zeroing_hook_handles:
+ handle.remove()
+ for block_idx in range(len(modules)):
+ similarity_scores = self.SIMILARITY_METRICS[similarity_metric](
+ self.clean_block_output_states[block_idx].float(), self.corrupted_block_output_states[block_idx]
+ )
+ if use_causal_mask:
+ all_scores[:, block_idx, token_idx:, token_idx] = 1 - similarity_scores[:, token_idx:]
+ else:
+ all_scores[:, block_idx, :, token_idx] = 1 - similarity_scores
+ self.corrupted_block_output_states = {}
+ for handle in states_extraction_hook_handles:
+ handle.remove()
+ self.clean_block_output_states = {}
+ all_scores = torch.where(all_scores < min_score_threshold, torch.zeros_like(all_scores), all_scores)
+ # Normalize scores to sum to 1
+ per_token_sum_score = all_scores.nansum(dim=-1, keepdim=True)
+ per_token_sum_score[per_token_sum_score == 0] = 1
+ all_scores = all_scores / per_token_sum_score
+
+ # Final shape: [batch_size, attributed_seq_len, generated_seq_len, num_layers]
+ return all_scores.permute(0, 3, 2, 1)
+
+ def attribute(
+ self,
+ inputs: TensorOrTupleOfTensorsGeneric,
+ additional_forward_args: TensorOrTupleOfTensorsGeneric,
+ similarity_metric: str = ValueZeroingSimilarityMetric.COSINE.value,
+ encoder_zeroed_units_indices: Optional[OneOrMoreIndicesDict] = None,
+ decoder_zeroed_units_indices: Optional[OneOrMoreIndicesDict] = None,
+ cross_zeroed_units_indices: Optional[OneOrMoreIndicesDict] = None,
+ encoder_hidden_states: Optional[MultiLayerEmbeddingsTensor] = None,
+ decoder_hidden_states: Optional[MultiLayerEmbeddingsTensor] = None,
+ output_decoder_self_scores: bool = True,
+ output_encoder_self_scores: bool = True,
+ ) -> TensorOrTupleOfTensorsGeneric:
+ """Perform attribution using the Value Zeroing method.
+
+ Args:
+ similarity_metric (:obj:`str`, optional): The similarity metric to use for computing the distance between
+ hidden states produced with and without the zeroing operation. Default: cosine similarity.
+ zeroed_units_indices (:obj:`Union[int, tuple[int, int], list[int]]` or :obj:`dict` with :obj:`int` keys and
+ `Union[int, tuple[int, int], list[int]]` values, optional): The indices of the attention heads
+ that should be zeroed to compute corrupted states.
+ - If None, all attention heads across all layers are zeroed.
+ - If an integer, the same attention head is zeroed across all layers.
+ - If a tuple of two integers, the attention heads in the range are zeroed across all layers.
+ - If a list of integers, the attention heads in the list are zeroed across all layers.
+ - If a dictionary, the keys are the layer indices and the values are the zeroed attention heads for
+ the corresponding layer.
+
+ Default: None (all heads are zeroed for every layer).
+ encoder_hidden_states (:obj:`torch.Tensor`, optional): A tensor of shape ``[batch_size, num_layers + 1,
+ source_seq_len, hidden_size]`` containing hidden states of the encoder. Available only for
+ encoder-decoders models. Default: None.
+ decoder_hidden_states (:obj:`torch.Tensor`, optional): A tensor of shape ``[batch_size, num_layers + 1,
+ target_seq_len, hidden_size]`` containing hidden states of the decoder.
+ output_decoder_self_scores (:obj:`bool`, optional): Whether to produce scores derived from zeroing the
+ decoder self-attention value vectors in encoder-decoder models. Cannot be false for decoder-only, or
+ if target-side attribution is requested using `attribute_target=True`. Default: True.
+ output_encoder_self_scores (:obj:`bool`, optional): Whether to produce scores derived from zeroing the
+ encoder self-attention value vectors in encoder-decoder models. Default: True.
+
+ Returns:
+ `TensorOrTupleOfTensorsGeneric`: Attribution outputs for source-only or source + target feature attribution
+ """
+ if similarity_metric not in self.SIMILARITY_METRICS:
+ raise ValueError(
+ f"Similarity metric {similarity_metric} not available."
+ f"Available metrics: {','.join(self.SIMILARITY_METRICS.keys())}"
+ )
+ decoder_scores = None
+ if not self.forward_func.is_encoder_decoder or output_decoder_self_scores or len(inputs) > 1:
+ decoder_scores = self.compute_modules_post_zeroing_similarity(
+ inputs=inputs,
+ additional_forward_args=additional_forward_args,
+ hidden_states=decoder_hidden_states,
+ attention_module_name=self.forward_func.config.self_attention_module,
+ similarity_metric=similarity_metric,
+ mode=ValueZeroingModule.DECODER.value,
+ zeroed_units_indices=decoder_zeroed_units_indices,
+ use_causal_mask=True,
+ )
+ # Encoder-decoder models also perform zeroing on the encoder self-attention and cross-attention values
+ # Adapted from https://github.com/hmohebbi/ContextMixingASR/blob/master/scoring/valueZeroing.py
+ if self.forward_func.is_encoder_decoder:
+ encoder_scores = None
+ if output_encoder_self_scores:
+ encoder_scores = self.compute_modules_post_zeroing_similarity(
+ inputs=inputs,
+ additional_forward_args=additional_forward_args,
+ hidden_states=encoder_hidden_states,
+ attention_module_name=self.forward_func.config.self_attention_module,
+ similarity_metric=similarity_metric,
+ mode=ValueZeroingModule.ENCODER.value,
+ zeroed_units_indices=encoder_zeroed_units_indices,
+ )
+ cross_scores = self.compute_modules_post_zeroing_similarity(
+ inputs=inputs,
+ additional_forward_args=additional_forward_args,
+ hidden_states=decoder_hidden_states,
+ attributed_seq_len=encoder_hidden_states.size(2),
+ attention_module_name=self.forward_func.config.cross_attention_module,
+ similarity_metric=similarity_metric,
+ mode=ValueZeroingModule.DECODER.value,
+ zeroed_units_indices=cross_zeroed_units_indices,
+ )
+ return encoder_scores, cross_scores, decoder_scores
+ elif encoder_zeroed_units_indices is not None or cross_zeroed_units_indices is not None:
+ logger.warning(
+ "Zeroing indices for encoder and cross-attentions were specified, but the model is not an "
+ "encoder-decoder. Use `decoder_zeroed_units_indices` to parametrize zeroing for the decoder module."
+ )
+ return (decoder_scores,)
diff --git a/inseq/attr/feat/perturbation_attribution.py b/inseq/attr/feat/perturbation_attribution.py
index fbebb780..c3eb0211 100644
--- a/inseq/attr/feat/perturbation_attribution.py
+++ b/inseq/attr/feat/perturbation_attribution.py
@@ -6,11 +6,12 @@
from ...data import (
CoarseFeatureAttributionStepOutput,
GranularFeatureAttributionStepOutput,
+ MultiDimensionalFeatureAttributionStepOutput,
)
from ...utils import Registry
from .attribution_utils import get_source_target_attributions
from .gradient_attribution import FeatureAttribution
-from .ops import Lime
+from .ops import Lime, ValueZeroing
logger = logging.getLogger(__name__)
@@ -117,3 +118,101 @@ def attribute_step(
target_attributions=out.target_attributions,
sequence_scores=out.sequence_scores,
)
+
+
+class ValueZeroingAttribution(PerturbationAttributionRegistry):
+ """Value Zeroing method for feature attribution.
+
+ Introduced by `Mohebbi et al. (2023) `__ to quantify context mixing
+ in Transformer models. The method is based on the observation that context mixing is regulated by the value vectors
+ of the attention mechanism. The method consists of two steps:
+
+ 1. Zeroing the value vectors of the attention mechanism for a given token index at a given layer of the model.
+ 2. Computing the similarity between hidden states produced with and without the zeroing operation, and using it
+ as a measure of context mixing for the given token at the given layer.
+
+ The method is converted into a feature attribution method by allowing for extraction of value zeroing scores at
+ specific layers, or by aggregating them across layers.
+
+ Reference implementations:
+ - Original implementation: `hmohebbi/ValueZeroing `__
+ - Encoder-decoder implementation: `hmohebbi/ContextMixingASR `__
+
+ Args:
+ similarity_metric (:obj:`str`, optional): The similarity metric to use for computing the distance between
+ hidden states produced with and without the zeroing operation. Options: cosine, euclidean. Default: cosine.
+ encoder_zeroed_units_indices (:obj:`Union[int, tuple[int, int], list[int], dict]`, optional): The indices of
+ the attention heads that should be zeroed to compute corrupted states in the encoder self-attention module.
+ Not used for decoder-only models, or if ``output_encoder_self_scores`` is False. Format
+
+ - None: all attention heads across all layers are zeroed.
+ - int: the same attention head is zeroed across all layers.
+ - tuple of two integers: the attention heads in the range are zeroed across all layers.
+ - list of integers: the attention heads in the list are zeroed across all layers.
+ - dictionary: the keys are the layer indices and the values are the zeroed attention heads for the corresponding layer.
+
+ Default: None (all heads are zeroed for every encoder layer).
+ decoder_zeroed_units_indices (:obj:`Union[int, tuple[int, int], list[int], dict]`, optional): Same as
+ ``encoder_zeroed_units_indices`` but for the decoder self-attention module. Not used for encoder-decoder
+ models or if ``output_decoder_self_scores`` is False. Default: None (all heads are zeroed for every decoder layer).
+ cross_zeroed_units_indices (:obj:`Union[int, tuple[int, int], list[int], dict]`, optional): Same as
+ ``encoder_zeroed_units_indices`` but for the cross-attention module in encoder-decoder models. Not used
+ if the model is decoder-only. Default: None (all heads are zeroed for every layer).
+ output_decoder_self_scores (:obj:`bool`, optional): Whether to produce scores derived from zeroing the
+ decoder self-attention value vectors in encoder-decoder models. Cannot be false for decoder-only, or
+ if target-side attribution is requested using `attribute_target=True`. Default: True.
+ output_encoder_self_scores (:obj:`bool`, optional): Whether to produce scores derived from zeroing the
+ encoder self-attention value vectors in encoder-decoder models. Default: True.
+
+ Returns:
+ :class:`~inseq.data.MultiDimensionalFeatureAttributionStepOutput`: The final dimension returned by the method
+ is ``[attributed_seq_len, generated_seq_len, num_layers]``. If ``output_decoder_self_scores`` and
+ ``output_encoder_self_scores`` are True, the respective scores are returned in the ``sequence_scores``
+ output dictionary.
+ """
+
+ method_name = "value_zeroing"
+
+ def __init__(self, attribution_model, **kwargs):
+ super().__init__(attribution_model, hook_to_model=False)
+ # Hidden states will be passed to the attribute_step method
+ self.use_hidden_states = True
+ # Does not rely on predicted output (i.e. decoding strategy agnostic)
+ self.use_predicted_target = False
+ # Uses model configuration to access attention module and value vector variable
+ self.use_model_config = True
+ # Needs only the final generation step to extract scores
+ self.is_final_step_method = True
+ self.method = ValueZeroing(attribution_model)
+ self.hook(**kwargs)
+
+ def attribute_step(
+ self,
+ attribute_fn_main_args: dict[str, Any],
+ attribution_args: dict[str, Any] = {},
+ ) -> MultiDimensionalFeatureAttributionStepOutput:
+ attr = self.method.attribute(**attribute_fn_main_args, **attribution_args)
+ encoder_self_scores, decoder_cross_scores, decoder_self_scores = get_source_target_attributions(
+ attr, self.attribution_model.is_encoder_decoder, has_sequence_scores=True
+ )
+ sequence_scores = {}
+ if self.attribution_model.is_encoder_decoder:
+ if len(attribute_fn_main_args["inputs"]) > 1:
+ target_attributions = decoder_self_scores.to("cpu")
+ else:
+ target_attributions = None
+ if decoder_self_scores is not None:
+ sequence_scores["decoder_self_scores"] = decoder_self_scores.to("cpu")
+ if encoder_self_scores is not None:
+ sequence_scores["encoder_self_scores"] = encoder_self_scores.to("cpu")
+ return MultiDimensionalFeatureAttributionStepOutput(
+ source_attributions=decoder_cross_scores.to("cpu"),
+ target_attributions=target_attributions,
+ sequence_scores=sequence_scores,
+ _num_dimensions=1, # num_layers
+ )
+ return MultiDimensionalFeatureAttributionStepOutput(
+ source_attributions=None,
+ target_attributions=decoder_self_scores,
+ _num_dimensions=1, # num_layers
+ )
diff --git a/inseq/attr/step_functions.py b/inseq/attr/step_functions.py
index d9cd6092..83aa8d6e 100644
--- a/inseq/attr/step_functions.py
+++ b/inseq/attr/step_functions.py
@@ -462,8 +462,7 @@ def register_step_function(
attribution targets by gradient-based feature attribution methods.
Args:
- fn (:obj:`callable`): The function to be used to compute step scores. Default parameters (use kwargs to capture
- unused ones when defining your function):
+ fn (:obj:`callable`): The function to be used to compute step scores. Default parameters (use kwargs to capture unused ones when defining your function):
- :obj:`attribution_model`: an :class:`~inseq.models.AttributionModel` instance, corresponding to the model
used for computing the score.
diff --git a/inseq/commands/commands_utils.py b/inseq/commands/commands_utils.py
index 7701409a..dbfb8ac4 100644
--- a/inseq/commands/commands_utils.py
+++ b/inseq/commands/commands_utils.py
@@ -18,5 +18,4 @@ def command_args_docstring(cls):
field_help = field.metadata.get("help", "")
docstring += textwrap.dedent(f"\n**{field.name}** (``{field_type}``): {field_help}\n")
cls.__doc__ = docstring
- print(docstring)
return cls
diff --git a/inseq/data/aggregator.py b/inseq/data/aggregator.py
index bb475707..dbae5352 100644
--- a/inseq/data/aggregator.py
+++ b/inseq/data/aggregator.py
@@ -12,9 +12,10 @@
aggregate_token_sequence,
available_classes,
extract_signature_args,
+ validate_indices,
)
from ..utils import normalize as normalize_fn
-from ..utils.typing import IndexSpan, TokenWithId
+from ..utils.typing import IndexSpan, OneOrMoreIndices, TokenWithId
from .aggregation_functions import AggregationFunction
from .data_utils import TensorWrapper
@@ -305,7 +306,7 @@ def _process_attribution_scores(
cls,
attr: "FeatureAttributionSequenceOutput",
aggregate_fn: AggregationFunction,
- select_idx: Union[int, tuple[int, int], list[int], None] = None,
+ select_idx: Optional[OneOrMoreIndices] = None,
normalize: bool = True,
**kwargs,
):
@@ -366,7 +367,7 @@ def aggregate_source_attributions(
cls,
attr: "FeatureAttributionSequenceOutput",
aggregate_fn: AggregationFunction,
- select_idx: Union[int, tuple[int, int], list[int], None] = None,
+ select_idx: Optional[OneOrMoreIndices] = None,
normalize: bool = True,
**kwargs,
):
@@ -380,7 +381,7 @@ def aggregate_target_attributions(
cls,
attr: "FeatureAttributionSequenceOutput",
aggregate_fn: AggregationFunction,
- select_idx: Union[int, tuple[int, int], list[int], None] = None,
+ select_idx: Optional[OneOrMoreIndices] = None,
normalize: bool = True,
**kwargs,
):
@@ -398,7 +399,7 @@ def aggregate_sequence_scores(
cls,
attr: "FeatureAttributionSequenceOutput",
aggregate_fn: AggregationFunction,
- select_idx: Union[int, tuple[int, int], list[int], None] = None,
+ select_idx: Optional[OneOrMoreIndices] = None,
**kwargs,
):
if aggregate_fn.takes_sequence_scores:
@@ -439,46 +440,12 @@ def is_compatible(attr: "FeatureAttributionSequenceOutput"):
def _filter_scores(
scores: torch.Tensor,
dim: int = -1,
- indices: Union[int, tuple[int, int], list[int], None] = None,
+ indices: Optional[OneOrMoreIndices] = None,
) -> torch.Tensor:
- n_units = scores.shape[dim]
-
- if hasattr(indices, "__iter__"):
- if len(indices) == 0:
- raise RuntimeError("At least two indices must be specified for aggregation.")
- if len(indices) == 1:
- indices = indices[0]
-
+ indexed = scores.index_select(dim, validate_indices(scores, dim, indices).to(scores.device))
if isinstance(indices, int):
- if indices not in range(-n_units, n_units):
- raise IndexError(f"Index out of range. Scores only have {n_units} units.")
- indices = indices if indices >= 0 else n_units + indices
- return scores.select(dim, torch.tensor(indices, device=scores.device))
- else:
- if indices is None:
- indices = (0, n_units)
- logger.info("No indices specified for extraction. Using all units by default.")
-
- # Convert negative indices to positive indices
- if hasattr(indices, "__iter__"):
- indices = type(indices)([h_idx if h_idx >= 0 else n_units + h_idx for h_idx in indices])
- if not hasattr(indices, "__iter__") or (
- len(indices) == 2 and isinstance(indices, tuple) and indices[0] >= indices[1]
- ):
- raise RuntimeError(
- "A (start, end) tuple of indices representing a span, a list of individual indices"
- " or a single index must be specified for select_idx."
- )
- max_idx_val = n_units if isinstance(indices, list) else n_units + 1
- if not all(h in range(-n_units, max_idx_val) for h in indices):
- raise IndexError("One or more index out of range. Scores only have {n_units} units.")
- if len(set(indices)) != len(indices):
- raise IndexError("Duplicate indices are not allowed.")
- if isinstance(indices, tuple):
- scores = scores.index_select(dim, torch.arange(indices[0], indices[1], device=scores.device))
- else:
- scores = scores.index_select(dim, torch.tensor(indices, device=scores.device))
- return scores
+ return indexed.squeeze(dim)
+ return indexed
@staticmethod
def _aggregate_scores(
diff --git a/inseq/data/attribution.py b/inseq/data/attribution.py
index f3671244..7841cf7c 100644
--- a/inseq/data/attribution.py
+++ b/inseq/data/attribution.py
@@ -12,6 +12,7 @@
get_sequences_from_batched_steps,
json_advanced_dump,
json_advanced_load,
+ pad_with_nan,
pretty_dict,
remap_from_filtered,
)
@@ -178,9 +179,8 @@ def get_remove_pad_fn(attr: "FeatureAttributionStepOutput", name: str) -> Callab
def from_step_attributions(
cls,
attributions: list["FeatureAttributionStepOutput"],
- tokenized_target_sentences: Optional[list[list[TokenWithId]]] = None,
- pad_id: Optional[Any] = None,
- has_bos_token: bool = True,
+ tokenized_target_sentences: list[list[TokenWithId]],
+ pad_token: Optional[Any] = None,
attr_pos_end: Optional[int] = None,
) -> list["FeatureAttributionSequenceOutput"]:
"""Converts a list of :class:`~inseq.data.attribution.FeatureAttributionStepOutput` objects containing multiple
@@ -198,36 +198,35 @@ def from_step_attributions(
num_sequences = len(attr.prefix)
if not all(len(attr.prefix) == num_sequences for attr in attributions):
raise ValueError("All the attributions must include the same number of sequences.")
- seq_attributions = []
- sources = None
- if attr.source_attributions is not None:
- sources = [drop_padding(attr.source[seq_id], pad_id) for seq_id in range(num_sequences)]
- targets = [
- drop_padding([a.target[seq_id][0] for a in attributions], pad_id) for seq_id in range(num_sequences)
- ]
- if tokenized_target_sentences is None:
- tokenized_target_sentences = targets
- if has_bos_token:
- tokenized_target_sentences = [tok_seq[1:] for tok_seq in tokenized_target_sentences]
- tokenized_target_sentences = [
- drop_padding(tokenized_target_sentences[seq_id], pad_id) for seq_id in range(num_sequences)
- ]
+ seq_attributions: list[FeatureAttributionSequenceOutput] = []
+ sources = []
+ targets = []
+ pos_start = []
+ for seq_idx in range(num_sequences):
+ if attr.source_attributions is not None:
+ sources.append(drop_padding(attr.source[seq_idx], pad_token))
+ curr_target = [a.target[seq_idx][0] for a in attributions]
+ targets.append(drop_padding(curr_target, pad_token))
+ if all(attr.prefix[seq_idx][0] == pad_token for seq_idx in range(num_sequences)):
+ tokenized_target_sentences[seq_idx] = tokenized_target_sentences[seq_idx][:1] + drop_padding(
+ tokenized_target_sentences[seq_idx][1:], pad_token
+ )
+ else:
+ tokenized_target_sentences[seq_idx] = drop_padding(tokenized_target_sentences[seq_idx], pad_token)
if attr_pos_end is None:
attr_pos_end = max(len(t) for t in tokenized_target_sentences)
- pos_start = [
- min(len(tokenized_target_sentences[seq_id]), attr_pos_end) - len(targets[seq_id])
- for seq_id in range(num_sequences)
- ]
- for seq_id in range(num_sequences):
- source = tokenized_target_sentences[seq_id][: pos_start[seq_id]] if sources is None else sources[seq_id]
- seq_attributions.append(
- attr.get_sequence_cls(
- source=source,
- target=tokenized_target_sentences[seq_id],
- attr_pos_start=pos_start[seq_id],
- attr_pos_end=attr_pos_end,
- )
+ for seq_idx in range(num_sequences):
+ # If the model is decoder-only, the source is the input prefix
+ curr_pos_start = min(len(tokenized_target_sentences[seq_idx]), attr_pos_end) - len(targets[seq_idx])
+ pos_start.append(curr_pos_start)
+ source = tokenized_target_sentences[seq_idx][:curr_pos_start] if not sources else sources[seq_idx]
+ curr_seq_attribution: FeatureAttributionSequenceOutput = attr.get_sequence_cls(
+ source=source,
+ target=tokenized_target_sentences[seq_idx],
+ attr_pos_start=pos_start[seq_idx],
+ attr_pos_end=attr_pos_end,
)
+ seq_attributions.append(curr_seq_attribution)
if attr.source_attributions is not None:
source_attributions = get_sequences_from_batched_steps([att.source_attributions for att in attributions])
for seq_id in range(num_sequences):
@@ -241,18 +240,13 @@ def from_step_attributions(
[att.target_attributions for att in attributions], padding_dims=[1]
)
for seq_id in range(num_sequences):
- if has_bos_token:
- target_attributions[seq_id] = target_attributions[seq_id][1:, ...]
start_idx = max(pos_start) - pos_start[seq_id]
end_idx = start_idx + len(tokenized_target_sentences[seq_id])
target_attributions[seq_id] = target_attributions[seq_id][
start_idx:end_idx, : len(targets[seq_id]), ... # noqa: E203
]
if target_attributions[seq_id].shape[0] != len(tokenized_target_sentences[seq_id]):
- empty_final_row = torch.ones(
- 1, *target_attributions[seq_id].shape[1:], device=target_attributions[seq_id].device
- ) * float("nan")
- target_attributions[seq_id] = torch.cat([target_attributions[seq_id], empty_final_row], dim=0)
+ target_attributions[seq_id] = pad_with_nan(target_attributions[seq_id], dim=0, pad_size=1)
seq_attributions[seq_id].target_attributions = target_attributions[seq_id]
if attr.step_scores is not None:
step_scores = [{} for _ in range(num_sequences)]
@@ -427,47 +421,51 @@ def remap_from_filtered(
self,
target_attention_mask: TargetIdsTensor,
batch: Union[DecoderOnlyBatch, EncoderDecoderBatch],
+ is_final_step_method: bool = False,
) -> None:
"""Remaps the attributions to the original shape of the input sequence."""
+ batch_size = (
+ len(batch.sources.input_tokens) if self.source_attributions is not None else len(batch.target_tokens)
+ )
+ source_len = len(batch.sources.input_tokens[0])
+ target_len = len(batch.target_tokens[0])
+ # Normal per-step attribution outputs have shape (batch_size, seq_len, ...)
+ other_dims_start_idx = 2
+ # Final step attribution outputs have shape (batch_size, seq_len, seq_len, ...)
+ if is_final_step_method:
+ other_dims_start_idx += 1
+ other_dims = (
+ self.source_attributions.shape[other_dims_start_idx:]
+ if self.source_attributions is not None
+ else self.target_attributions.shape[other_dims_start_idx:]
+ )
if self.source_attributions is not None:
self.source_attributions = remap_from_filtered(
- original_shape=(len(batch.sources.input_tokens), *self.source_attributions.shape[1:]),
+ original_shape=(batch_size, *self.source_attributions.shape[1:]),
mask=target_attention_mask,
filtered=self.source_attributions,
)
if self.target_attributions is not None:
self.target_attributions = remap_from_filtered(
- original_shape=(len(batch.target_tokens), *self.target_attributions.shape[1:]),
+ original_shape=(batch_size, *self.target_attributions.shape[1:]),
mask=target_attention_mask,
filtered=self.target_attributions,
)
if self.step_scores is not None:
for score_name, score_tensor in self.step_scores.items():
self.step_scores[score_name] = remap_from_filtered(
- original_shape=(len(batch.target_tokens), 1),
+ original_shape=(batch_size, 1),
mask=target_attention_mask,
filtered=score_tensor.unsqueeze(-1),
).squeeze(-1)
if self.sequence_scores is not None:
for score_name, score_tensor in self.sequence_scores.items():
if score_name.startswith("decoder"):
- original_shape = (
- len(batch.target_tokens),
- self.target_attributions.shape[1],
- *self.target_attributions.shape[1:],
- )
+ original_shape = (batch_size, target_len, target_len, *other_dims)
elif score_name.startswith("encoder"):
- original_shape = (
- len(batch.sources.input_tokens),
- self.source_attributions.shape[1],
- *self.source_attributions.shape[1:],
- )
+ original_shape = (batch_size, source_len, source_len, *other_dims)
else: # default case: cross-attention
- original_shape = (
- len(batch.sources.input_tokens),
- self.target_attributions.shape[1],
- *self.source_attributions.shape[1:],
- )
+ original_shape = (batch_size, source_len, target_len, *other_dims)
self.sequence_scores[score_name] = remap_from_filtered(
original_shape=original_shape,
mask=target_attention_mask,
diff --git a/inseq/data/data_utils.py b/inseq/data/data_utils.py
index b907d627..d0f90203 100644
--- a/inseq/data/data_utils.py
+++ b/inseq/data/data_utils.py
@@ -112,7 +112,7 @@ def _torch(attr):
def _eq(self_attr: TensorClass, other_attr: TensorClass) -> bool:
try:
if isinstance(self_attr, torch.Tensor):
- return torch.allclose(self_attr, other_attr, equal_nan=True)
+ return torch.allclose(self_attr, other_attr, equal_nan=True, atol=1e-5)
elif isinstance(self_attr, dict):
return all(TensorWrapper._eq(self_attr[k], other_attr[k]) for k in self_attr.keys())
else:
@@ -175,6 +175,10 @@ def clone(self: TensorClass) -> TensorClass:
out_params[field.name] = None
return self.__class__(**out_params)
+ def clone_empty(self: TensorClass) -> TensorClass:
+ out_params = {k: v for k, v in self.__dict__.items() if k.startswith("_") and v is not None}
+ return self.__class__(**out_params)
+
def to_dict(self: TensorClass) -> dict[str, Any]:
return {k: v for k, v in self.__dict__.items() if not k.startswith("_")}
diff --git a/inseq/models/attribution_model.py b/inseq/models/attribution_model.py
index c74e82c8..b96db5c2 100644
--- a/inseq/models/attribution_model.py
+++ b/inseq/models/attribution_model.py
@@ -386,6 +386,24 @@ def attribute(
original_device = self.device
if device is not None:
self.device = device
+ attribution_method = self.get_attribution_method(method, override_default_attribution)
+ attributed_fn = self.get_attributed_fn(attributed_fn)
+ attribution_args, attributed_fn_args, step_scores_args = extract_args(
+ attribution_method,
+ attributed_fn,
+ step_scores,
+ default_args=self.formatter.get_step_function_reserved_args(),
+ **kwargs,
+ )
+ if isnotebook():
+ logger.debug("Pretty progress currently not supported in notebooks, falling back to tqdm.")
+ pretty_progress = False
+ if attribution_method.is_final_step_method:
+ if step_scores:
+ raise ValueError(
+ "Step scores are not supported for final step methods since they do not iterate over the full"
+ " sequence. Please remove the step scores and compute them separatly passing method='dummy'."
+ )
input_texts, generated_texts = format_input_texts(input_texts, generated_texts)
has_generated_texts = generated_texts is not None
if not self.is_encoder_decoder:
@@ -411,36 +429,30 @@ def attribute(
f"Generation arguments {generation_args} are provided, but will be ignored (constrained decoding)."
)
logger.debug(f"reference_texts={generated_texts}")
- attribution_method = self.get_attribution_method(method, override_default_attribution)
- attributed_fn = self.get_attributed_fn(attributed_fn)
- attribution_args, attributed_fn_args, step_scores_args = extract_args(
- attribution_method,
- attributed_fn,
- step_scores,
- default_args=self.formatter.get_step_function_reserved_args(),
- **kwargs,
- )
- if isnotebook():
- logger.debug("Pretty progress currently not supported in notebooks, falling back to tqdm.")
- pretty_progress = False
if not self.is_encoder_decoder:
assert all(
generated_texts[idx].startswith(input_texts[idx]) for idx in range(len(input_texts))
), "Forced generations with decoder-only models must start with the input texts."
if has_generated_texts and len(input_texts) > 1:
- logger.info(
+ logger.warning(
"Batched constrained decoding is currently not supported for decoder-only models."
" Using batch size of 1."
)
batch_size = 1
if len(input_texts) > 1 and (attr_pos_start is not None or attr_pos_end is not None):
- logger.info(
+ logger.warning(
"Custom attribution positions are currently not supported when batching generations for"
" decoder-only models. Using batch size of 1."
)
batch_size = 1
+ elif attribution_method.is_final_step_method and len(input_texts) > 1:
+ logger.warning(
+ "Batched attribution with encoder-decoder models currently not supported for final-step methods."
+ " Using batch size of 1."
+ )
+ batch_size = 1
if attribution_method.method_name == "lime":
- logger.info("Batched attribution currently not supported for LIME. Using batch size of 1.")
+ logger.warning("Batched attribution currently not supported for LIME. Using batch size of 1.")
batch_size = 1
attribution_outputs = attribution_method.prepare_and_attribute(
input_texts,
diff --git a/inseq/models/model_config.py b/inseq/models/model_config.py
index 05b8a468..52d9d47b 100644
--- a/inseq/models/model_config.py
+++ b/inseq/models/model_config.py
@@ -1,6 +1,7 @@
import logging
from dataclasses import dataclass
from pathlib import Path
+from typing import Optional
import yaml
@@ -10,14 +11,25 @@
@dataclass
class ModelConfig:
"""Configuration used by the methods for which the attribute ``use_model_config=True``.
+
Args:
- attention_module (:obj:`str`):
- The name of the module performing the attention computation (e.g.``attn`` for the GPT-2 model in
- transformers). Can be identified by looking at the name of the attribute instantiating the attention module
+ self_attention_module (:obj:`str`):
+ The name of the module performing the self-attention computation (e.g.``attn`` for the GPT-2 model in
+ transformers). Can be identified by looking at the name of the self-attention module attribute
in the model's transformer block class (e.g. :obj:`transformers.models.gpt2.GPT2Block` for GPT-2).
+ cross_attention_module (:obj:`str`):
+ The name of the module performing the cross-attention computation (e.g.``encoder_attn`` for MarianMT models
+ in transformers). Can be identified by looking at the name of the cross-attention module attribute
+ in the model's transformer block class (e.g. :obj:`transformers.models.marian.MarianDecoderLayer`).
+ value_vector (:obj:`str`):
+ The name of the variable in the forward pass of the attention module containing the value vector
+ (e.g. ``value`` for the GPT-2 model in transformers). Can be identified by looking at the forward pass of
+ the attention module (e.g. :obj:`transformers.models.gpt2.modeling_gpt2.GPT2Attention.forward` for GPT-2).
"""
- attention_module: str
+ self_attention_module: str
+ value_vector: str
+ cross_attention_module: Optional[str] = None
MODEL_CONFIGS = {
diff --git a/inseq/models/model_config.yaml b/inseq/models/model_config.yaml
index b48ed209..bcc32e41 100644
--- a/inseq/models/model_config.yaml
+++ b/inseq/models/model_config.yaml
@@ -1,2 +1,111 @@
+# Decoder-only models
+BioGptForCausalLM:
+ self_attention_module: "self_attn"
+ value_vector: "value_states"
+BloomForCausalLM:
+ self_attention_module: "self_attention"
+ value_vector: "value_layer"
+CodeGenForCausalLM:
+ self_attention_module: "attn"
+ value_vector: "value"
+FalconForCausalLM:
+ self_attention_module: "self_attention"
+ value_vector: "value_layer"
+GemmaForCausalLM:
+ self_attention_module: "self_attn"
+ value_vector: "value_states"
+GPTBigCodeForCausalLM:
+ self_attention_module: "attn"
+ value_vector: "value"
+GPTJForCausalLM:
+ self_attention_module: "attn"
+ value_vector: "value"
GPT2LMHeadModel:
- attention_module: "attn"
\ No newline at end of file
+ self_attention_module: "attn"
+ value_vector: "value"
+GPTNeoForCausalLM:
+ self_attention_module: "attn"
+ value_vector: "value"
+GPTNeoXForCausalLM:
+ self_attention_module: "attention"
+ value_vector: "value"
+LlamaForCausalLM:
+ self_attention_module: "self_attn"
+ value_vector: "value_states"
+MistralForCausalLM:
+ self_attention_module: "self_attn"
+ value_vector: "value_states"
+MixtralForCausalLM:
+ self_attention_module: "self_attn"
+ value_vector: "value_states"
+MptForCausalLM:
+ self_attention_module: "attn"
+ value_vector: "value_states"
+OpenAIGPTLMHeadModel:
+ self_attention_module: "attn"
+ value_vector: "value"
+OPTForCausalLM:
+ self_attention_module: "self_attn"
+ value_vector: "value_states"
+PhiForCausalLM:
+ self_attention_module: "self_attn"
+ value_vector: "value_states"
+Qwen2ForCausalLM:
+ self_attention_module: "self_attn"
+ value_vector: "value_states"
+StableLmForCausalLM:
+ self_attention_module: "self_attn"
+ value_vector: "value_states"
+XGLMForCausalLM:
+ self_attention_module: "self_attn"
+ value_vector: "value_states"
+
+# Encoder-decoder models
+BartForConditionalGeneration:
+ self_attention_module: "self_attn"
+ cross_attention_module: "encoder_attn"
+ value_vector: "value_states"
+MarianMTModel:
+ self_attention_module: "self_attn"
+ cross_attention_module: "encoder_attn"
+ value_vector: "value_states"
+FSMTForConditionalGeneration:
+ self_attention_module: "self_attn"
+ cross_attention_module: "encoder_attn"
+ value_vector: "v"
+M2M100ForConditionalGeneration:
+ self_attention_module: "self_attn"
+ cross_attention_module: "encoder_attn"
+ value_vector: "value_states"
+MBartForConditionalGeneration:
+ self_attention_module: "self_attn"
+ cross_attention_module: "encoder_attn"
+ value_vector: "value_states"
+MT5ForConditionalGeneration:
+ self_attention_module: "SelfAttention"
+ cross_attention_module: "EncDecAttention"
+ value_vector: "value_states"
+NllbMoeForConditionalGeneration:
+ self_attention_module: "self_attn"
+ cross_attention_module: "cross_attention"
+ value_vector: "value_states"
+PegasusForConditionalGeneration:
+ self_attention_module: "self_attn"
+ cross_attention_module: "encoder_attn"
+ value_vector: "value_states"
+SeamlessM4TForTextToText:
+ self_attention_module: "self_attn"
+ cross_attention_module: "cross_attention"
+ value_vector: "value"
+SeamlessM4Tv2ForTextToText:
+ self_attention_module: "self_attn"
+ cross_attention_module: "cross_attention"
+ value_vector: "value"
+T5ForConditionalGeneration:
+ self_attention_module: "SelfAttention"
+ cross_attention_module: "EncDecAttention"
+ value_vector: "value_states"
+UMT5ForConditionalGeneration:
+ self_attention_module: "SelfAttention"
+ cross_attention_module: "EncDecAttention"
+ value_vector: "value_states"
\ No newline at end of file
diff --git a/inseq/utils/__init__.py b/inseq/utils/__init__.py
index 29f81615..69d9d1ad 100644
--- a/inseq/utils/__init__.py
+++ b/inseq/utils/__init__.py
@@ -8,6 +8,7 @@
MissingAttributionMethodError,
UnknownAttributionMethodError,
)
+from .hooks import StackFrame, get_post_variable_assignment_hook
from .import_utils import (
is_captum_available,
is_datasets_available,
@@ -49,12 +50,16 @@
check_device,
euclidean_distance,
filter_logits,
+ find_block_stack,
get_default_device,
get_front_padding,
get_sequences_from_batched_steps,
normalize,
+ pad_with_nan,
+ recursive_get_submodule,
remap_from_filtered,
top_p_logits_mask,
+ validate_indices,
)
__all__ = [
@@ -118,4 +123,9 @@
"top_p_logits_mask",
"filter_logits",
"cli_arg",
+ "get_post_variable_assignment_hook",
+ "StackFrame",
+ "validate_indices",
+ "pad_with_nan",
+ "recursive_get_submodule",
]
diff --git a/inseq/utils/hooks.py b/inseq/utils/hooks.py
new file mode 100644
index 00000000..02472f4e
--- /dev/null
+++ b/inseq/utils/hooks.py
@@ -0,0 +1,110 @@
+import re
+from inspect import getsourcelines
+from sys import gettrace, settrace
+from typing import Callable, Optional, TypeVar
+
+from torch import nn
+
+from .misc import get_left_padding
+
+StackFrame = TypeVar("StackFrame")
+
+
+def get_last_variable_assignment_position(
+ module: nn.Module,
+ varname: str,
+ fname: str = "forward",
+) -> Optional[int]:
+ """Extract the code line number of the last variable assignment for a variable of interest in the specified method
+ of a `nn.Module` object.
+
+ Args:
+ module (`nn.Module`):
+ A PyTorch module containing a method with a variable assignment after which the hook should be executed.
+ varname (`str`):
+ The name of the variable to use as anchor for the hook.
+ fname (`str`, *optional*, defaults to "forward"):
+ The name of the method in which the variable assignment should be searched.
+
+ Returns:
+ `Optional[int]`: Returns the line number in the file (not relative to the method) of the last variable
+ assignment. Returns None if no assignment to the variable was found.
+ """
+ # Matches any assignment of variable varname
+ pattern = rf"^\s*(?:\w+\s*,\s*)*\b{varname}\b\s*(?:,.+\s*)*=\s*[^\W=]+.*$"
+ code, startline = getsourcelines(getattr(module, fname))
+ line_numbers = []
+ i = 0
+ while i < len(code):
+ line = code[i]
+ # Handles multi-line assignments
+ if re.match(pattern, line):
+ parentheses_count = line.count("(") - line.count(")")
+ ends_with_newline = lambda l: l.strip().endswith("\\")
+ follow_indent = lambda l, i: len(code) > i + 1 and get_left_padding(code[i + 1]) > get_left_padding(l)
+ while (ends_with_newline(line) or follow_indent(line, i) or parentheses_count > 0) and len(code) > i + 1:
+ i += 1
+ line = code[i]
+ parentheses_count += line.count("(") - line.count(")")
+ line_numbers.append(i)
+ i += 1
+ if len(line_numbers) == 0:
+ return None
+ return line_numbers[-1] + startline + 1
+
+
+def get_post_variable_assignment_hook(
+ module: nn.Module,
+ varname: str,
+ fname: str = "forward",
+ hook_fn: Callable[[StackFrame], None] = lambda **kwargs: None,
+ **kwargs,
+) -> Callable[[], None]:
+ """Creates a hook that is called after the last variable assignment in the specified method of a `nn.Module`.
+
+ This is a hacky method using the ``sys.settrace()`` function to circumvent the limited hook points of Pytorch hooks
+ and set a custom hook point dynamically. This approach is preferred to ensure a broader compatibility with Hugging
+ Face transformers models that do not provide hook points in their architectures for the moment.
+
+ Args:
+ module (`nn.Module`):
+ A PyTorch module containing a method with a variable assignment after which the hook should be executed.
+ varname (`str`):
+ The name of the variable to use as anchor for the hook.
+ fname (`str`, *optional*, defaults to "forward"):
+ The name of the method in which the variable assignment should be searched.
+ hook_fn (`Callable[[FrameType], None]`, *optional*, defaults to lambdaframe):
+ A custom hook function that is called after the last variable assignment in the specified method. The first
+ parameter is the current frame in the execution at the hook point, and any additional arguments can be
+ passed when creating the hook. ``frame.f_locals`` is a dictionary containing all local variables.
+
+ Returns:
+ The hook function that can be registered with the module. If hooking the module's ``forward()`` method, the
+ hook can be registered with Pytorch native hook methods.
+ """
+ hook_line_num = get_last_variable_assignment_position(module, varname, fname)
+ curr_trace_fn = gettrace()
+ if hook_line_num is None:
+ raise ValueError(f"Could not find assignment to {varname} in {module}'s {fname}() method")
+
+ def var_tracer(frame, event, arg=None):
+ curr_line_num = frame.f_lineno
+ curr_func_name = frame.f_code.co_name
+
+ # Matches the first executable line after hook_line_num in the same function of the same module
+ if (
+ event == "line"
+ and curr_line_num >= hook_line_num
+ and curr_func_name == fname
+ and isinstance(frame.f_locals.get("self"), nn.Module)
+ and frame.f_locals.get("self")._get_name() == module._get_name()
+ ):
+ # Call the custom hook providing the current frame and any additional arguments as context
+ hook_fn(frame, **kwargs)
+ settrace(curr_trace_fn)
+ return var_tracer
+
+ def hook(*args, **kwargs):
+ settrace(var_tracer)
+
+ return hook
diff --git a/inseq/utils/misc.py b/inseq/utils/misc.py
index e09e5df7..628995bc 100644
--- a/inseq/utils/misc.py
+++ b/inseq/utils/misc.py
@@ -10,7 +10,6 @@
from functools import wraps
from importlib import import_module
from inspect import signature
-from itertools import dropwhile
from numbers import Number
from os import PathLike, fsync
from typing import Any, Callable, Optional, Union
@@ -171,10 +170,10 @@ def pad(seq: Sequence[Sequence[Any]], pad_id: Any):
return seq
-def drop_padding(seq: Sequence[Any], pad_id: Any):
+def drop_padding(seq: Sequence[TokenWithId], pad_id: str):
if pad_id is None:
return seq
- return list(reversed(list(dropwhile(lambda x: x == pad_id, reversed(seq)))))
+ return [x for x in seq if x.token != pad_id]
def isnotebook():
@@ -435,3 +434,8 @@ def clean_tokens(tokens: list[str], remove_tokens: list[str]) -> tuple[list[str]
else:
removed_token_idxs += [idx]
return clean_tokens, removed_token_idxs
+
+
+def get_left_padding(text: str):
+ """Returns the number of spaces at the beginning of a string."""
+ return len(text) - len(text.lstrip())
diff --git a/inseq/utils/torch_utils.py b/inseq/utils/torch_utils.py
index 88e807cc..86acd635 100644
--- a/inseq/utils/torch_utils.py
+++ b/inseq/utils/torch_utils.py
@@ -5,11 +5,14 @@
import torch
import torch.nn.functional as F
from jaxtyping import Int, Num
+from torch import nn
from torch.backends.cuda import is_built as is_cuda_built
from torch.backends.mps import is_available as is_mps_available
from torch.backends.mps import is_built as is_mps_built
from torch.cuda import is_available as is_cuda_available
+from .typing import OneOrMoreIndices
+
if TYPE_CHECKING:
pass
@@ -244,3 +247,118 @@ def get_default_device() -> str:
return "cpu"
else:
return "cpu"
+
+
+def find_block_stack(module):
+ """Recursively searches for the first instance of a `nn.ModuleList` submodule within a given `torch.nn.Module`.
+
+ Args:
+ module (:obj:`torch.nn.Module`): A Pytorch :obj:`nn.Module` object.
+
+ Returns:
+ :obj:`torch.nn.ModuleList`: The first instance of a :obj:`nn.Module` submodule found within the given object.
+ None: If no `nn.ModuleList` submodule is found within the given `nn.Module` object.
+ """
+ # Check if the current module is an instance of nn.ModuleList
+ if isinstance(module, nn.ModuleList):
+ return module
+
+ # Recursively search for nn.ModuleList in the submodules of the current module
+ for submodule in module.children():
+ module_list = find_block_stack(submodule)
+ if module_list is not None:
+ return module_list
+
+ # If nn.ModuleList is not found in any submodules, return None
+ return None
+
+
+def validate_indices(
+ scores: torch.Tensor,
+ dim: int = -1,
+ indices: Optional[OneOrMoreIndices] = None,
+) -> OneOrMoreIndices:
+ """Validates a set of indices for a given dimension of a tensor of scores. Supports single indices, spans and lists
+ of indices, including negative indices to specify positions relative to the end of the tensor.
+
+ Args:
+ scores (torch.Tensor): The tensor of scores.
+ dim (int, optional): The dimension of the tensor that will be indexed. Defaults to -1.
+ indices (Union[int, tuple[int, int], list[int], None], optional):
+ - If an integer, it is interpreted as a single index for the dimension.
+ - If a tuple of two integers, it is interpreted as a span of indices for the dimension.
+ - If a list of integers, it is interpreted as a list of individual indices for the dimension.
+
+ Returns:
+ ``Union[int, tuple[int, int], list[int]]``: The validated list of positive indices for indexing the dimension.
+ """
+ if dim >= scores.ndim:
+ raise IndexError(f"Dimension {dim} is greater than tensor dimension {scores.ndim}")
+ n_units = scores.shape[dim]
+ if not isinstance(indices, (int, tuple, list)) and indices is not None:
+ raise TypeError(
+ "Indices must be an integer, a (start, end) tuple of indices representing a span, a list of individual"
+ " indices or a single index."
+ )
+ if hasattr(indices, "__iter__"):
+ if len(indices) == 0:
+ raise RuntimeError("An empty sequence of indices is not allowed.")
+ if len(indices) == 1:
+ indices = indices[0]
+
+ if isinstance(indices, int):
+ if indices not in range(-n_units, n_units):
+ raise IndexError(f"Index out of range. Scores only have {n_units} units.")
+ indices = indices if indices >= 0 else n_units + indices
+ return torch.tensor(indices)
+ else:
+ if indices is None:
+ indices = (0, n_units)
+ logger.info("No indices specified. Using all indices by default.")
+
+ # Convert negative indices to positive indices
+ if hasattr(indices, "__iter__"):
+ indices = type(indices)([h_idx if h_idx >= 0 else n_units + h_idx for h_idx in indices])
+ if not hasattr(indices, "__iter__") or (
+ len(indices) == 2 and isinstance(indices, tuple) and indices[0] >= indices[1]
+ ):
+ raise RuntimeError(
+ "A (start, end) tuple of indices representing a span, a list of individual indices"
+ " or a single index must be specified."
+ )
+ max_idx_val = n_units if isinstance(indices, list) else n_units + 1
+ if not all(h in range(-n_units, max_idx_val) for h in indices):
+ raise IndexError(f"One or more index out of range. Scores only have {n_units} units.")
+ if len(set(indices)) != len(indices):
+ raise IndexError("Duplicate indices are not allowed.")
+ if isinstance(indices, tuple):
+ return torch.arange(indices[0], indices[1])
+ else:
+ return torch.tensor(indices)
+
+
+def pad_with_nan(t: torch.Tensor, dim: int, pad_size: int, front: bool = False) -> torch.Tensor:
+ """Utility to pad a tensor with nan values along a given dimension."""
+ nan_tensor = torch.ones(
+ *t.shape[:dim],
+ pad_size,
+ *t.shape[dim + 1 :],
+ device=t.device,
+ ) * float("nan")
+ if front:
+ return torch.cat([nan_tensor, t], dim=dim)
+ return torch.cat([t, nan_tensor], dim=dim)
+
+
+def recursive_get_submodule(parent: nn.Module, target: str) -> Optional[nn.Module]:
+ if target == "":
+ return parent
+ mod = None
+ if hasattr(parent, target):
+ mod = getattr(parent, target)
+ else:
+ for submodule in parent.children():
+ mod = recursive_get_submodule(submodule, target)
+ if mod is not None:
+ break
+ return mod
diff --git a/inseq/utils/typing.py b/inseq/utils/typing.py
index 7599bbc7..4eec4a5b 100644
--- a/inseq/utils/typing.py
+++ b/inseq/utils/typing.py
@@ -1,13 +1,17 @@
from collections.abc import Sequence
from dataclasses import dataclass
-from typing import Optional, Union
+from typing import TYPE_CHECKING, Callable, Optional, Union
import torch
+from captum.attr._utils.attribution import Attribution
from jaxtyping import Float, Float32, Int64
from transformers import PreTrainedModel
TextInput = Union[str, Sequence[str]]
+if TYPE_CHECKING:
+ from inseq.models import AttributionModel
+
@dataclass
class TokenWithId:
@@ -28,6 +32,34 @@ def __eq__(self, other: Union[str, int, "TokenWithId"]):
return False
+class InseqAttribution(Attribution):
+ """A wrapper class for the Captum library's Attribution class to type hint the ``forward_func`` attribute
+ as an :class:`~inseq.models.AttributionModel`.
+ """
+
+ def __init__(self, forward_func: "AttributionModel") -> None:
+ r"""
+ Args:
+ forward_func (:class:`~inseq.models.AttributionModel`): The model hooker to the attribution method.
+ """
+ self.forward_func = forward_func
+
+ attribute: Callable
+
+ @property
+ def multiplies_by_inputs(self):
+ return False
+
+ def has_convergence_delta(self) -> bool:
+ return False
+
+ compute_convergence_delta: Callable
+
+ @classmethod
+ def get_name(cls: type["InseqAttribution"]) -> str:
+ return "".join([char if char.islower() or idx == 0 else " " + char for idx, char in enumerate(cls.__name__)])
+
+
@dataclass
class TextSequences:
targets: TextInput
@@ -40,6 +72,8 @@ class TextSequences:
OneOrMoreAttributionSequences = Sequence[Sequence[float]]
IndexSpan = Union[tuple[int, int], Sequence[tuple[int, int]]]
+OneOrMoreIndices = Union[int, list[int], tuple[int, int]]
+OneOrMoreIndicesDict = dict[int, OneOrMoreIndices]
IdsTensor = Int64[torch.Tensor, "batch_size seq_len"]
TargetIdsTensor = Int64[torch.Tensor, "batch_size"]
diff --git a/requirements-dev.txt b/requirements-dev.txt
index 91a4d3f2..92a9ca95 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -361,7 +361,7 @@ traitlets==5.14.1
# jupyter-client
# jupyter-core
# matplotlib-inline
-transformers==4.37.2
+transformers==4.38.1
typeguard==2.13.3
# via jaxtyping
typer==0.9.0
diff --git a/requirements.txt b/requirements.txt
index 9f392d72..a0a99e61 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -93,7 +93,7 @@ tqdm==4.66.2
# captum
# huggingface-hub
# transformers
-transformers==4.37.2
+transformers==4.38.1
typeguard==2.13.3
# via jaxtyping
typing-extensions==4.9.0
diff --git a/tests/attr/feat/test_feature_attribution.py b/tests/attr/feat/test_feature_attribution.py
index 07b2045c..80856176 100644
--- a/tests/attr/feat/test_feature_attribution.py
+++ b/tests/attr/feat/test_feature_attribution.py
@@ -1,8 +1,14 @@
+from typing import Any, Optional
+
import torch
+from captum._utils.typing import TensorOrTupleOfTensorsGeneric
from pytest import fixture
import inseq
+from inseq.attr.feat.internals_attribution import InternalsAttributionRegistry
+from inseq.data import MultiDimensionalFeatureAttributionStepOutput
from inseq.models import HuggingfaceDecoderOnlyModel, HuggingfaceEncoderDecoderModel
+from inseq.utils.typing import InseqAttribution, MultiLayerMultiUnitScoreTensor
@fixture(scope="session")
@@ -69,7 +75,7 @@ def test_contrastive_attribution_seq2seq_alignments(saliency_mt_model_larger: Hu
"orig_tgt": "I soldati della pace ONU",
"contrast_tgt": "Le forze militari di pace delle Nazioni Unite",
"alignments": [[(0, 0), (1, 1), (2, 2), (3, 4), (4, 5), (5, 7), (6, 9)]],
- "aligned_tgts": ["āLe ā āI", "āforze ā āsoldati", "ādi ā ādella", "āpace", "āNazioni ā āONU", ""],
+ "aligned_tgts": ["", "āLe ā āI", "āforze ā āsoldati", "ādi ā ādella", "āpace", "āNazioni ā āONU", ""],
}
out = saliency_mt_model_larger.attribute(
aligned["src"],
@@ -129,3 +135,122 @@ def test_mcd_weighted_attribution_gpt(saliency_gpt_model):
)
attribution_scores = out.sequence_attributions[0].target_attributions
assert isinstance(attribution_scores, torch.Tensor)
+
+
+class MultiStepAttentionWeights(InseqAttribution):
+ """Variant of the AttentionWeights class with is_final_step_method = False.
+ As a result, the attention matrix is computed and sliced at every generation step.
+ We define it here to test consistency with the final step method.
+ """
+
+ def attribute(
+ self,
+ inputs: TensorOrTupleOfTensorsGeneric,
+ additional_forward_args: TensorOrTupleOfTensorsGeneric,
+ encoder_self_attentions: Optional[MultiLayerMultiUnitScoreTensor] = None,
+ decoder_self_attentions: Optional[MultiLayerMultiUnitScoreTensor] = None,
+ cross_attentions: Optional[MultiLayerMultiUnitScoreTensor] = None,
+ ) -> MultiDimensionalFeatureAttributionStepOutput:
+ # We adopt the format [batch_size, sequence_length, num_layers, num_heads]
+ # for consistency with other multi-unit methods (e.g. gradient attribution)
+ decoder_self_attentions = decoder_self_attentions[..., -1, :].to("cpu").clone().permute(0, 3, 1, 2)
+ if self.forward_func.is_encoder_decoder:
+ sequence_scores = {}
+ if len(inputs) > 1:
+ target_attributions = decoder_self_attentions
+ else:
+ target_attributions = None
+ sequence_scores["decoder_self_attentions"] = decoder_self_attentions
+ sequence_scores["encoder_self_attentions"] = (
+ encoder_self_attentions.to("cpu").clone().permute(0, 4, 3, 1, 2)
+ )
+ return MultiDimensionalFeatureAttributionStepOutput(
+ source_attributions=cross_attentions[..., -1, :].to("cpu").clone().permute(0, 3, 1, 2),
+ target_attributions=target_attributions,
+ sequence_scores=sequence_scores,
+ _num_dimensions=2, # num_layers, num_heads
+ )
+ else:
+ return MultiDimensionalFeatureAttributionStepOutput(
+ source_attributions=None,
+ target_attributions=decoder_self_attentions,
+ _num_dimensions=2, # num_layers, num_heads
+ )
+
+
+class MultiStepAttentionWeightsAttribution(InternalsAttributionRegistry):
+ """Variant of the basic attention attribution method computing attention weights at every generation step."""
+
+ method_name = "per_step_attention"
+
+ def __init__(self, attribution_model, **kwargs):
+ super().__init__(attribution_model)
+ # Attention weights will be passed to the attribute_step method
+ self.use_attention_weights = True
+ # Does not rely on predicted output (i.e. decoding strategy agnostic)
+ self.use_predicted_target = False
+ self.method = MultiStepAttentionWeights(attribution_model)
+
+ def attribute_step(
+ self,
+ attribute_fn_main_args: dict[str, Any],
+ attribution_args: dict[str, Any],
+ ) -> MultiDimensionalFeatureAttributionStepOutput:
+ return self.method.attribute(**attribute_fn_main_args, **attribution_args)
+
+
+def test_seq2seq_final_step_per_step_conformity(saliency_mt_model_larger: HuggingfaceEncoderDecoderModel):
+ out_per_step = saliency_mt_model_larger.attribute(
+ "Hello ladies and badgers!",
+ method="per_step_attention",
+ attribute_target=True,
+ show_progress=False,
+ output_step_attributions=True,
+ )
+ out_final_step = saliency_mt_model_larger.attribute(
+ "Hello ladies and badgers!",
+ method="attention",
+ attribute_target=True,
+ show_progress=False,
+ output_step_attributions=True,
+ )
+ assert out_per_step[0] == out_final_step[0]
+
+
+def test_gpt_final_step_per_step_conformity(saliency_gpt_model_larger: HuggingfaceDecoderOnlyModel):
+ out_per_step = saliency_gpt_model_larger.attribute(
+ "Hello ladies and badgers!",
+ method="per_step_attention",
+ show_progress=False,
+ output_step_attributions=True,
+ )
+ out_final_step = saliency_gpt_model_larger.attribute(
+ "Hello ladies and badgers!",
+ method="attention",
+ show_progress=False,
+ output_step_attributions=True,
+ )
+ assert out_per_step[0] == out_final_step[0]
+
+
+# Batching for Seq2Seq models is not supported when using is_final_step methods
+# Passing several sentences will attributed them one by one under the hood
+# def test_seq2seq_multi_step_attention_weights_batched_full_match(saliency_mt_model: HuggingfaceEncoderDecoderModel):
+
+
+def test_gpt_multi_step_attention_weights_batched_full_match(saliency_gpt_model_larger: HuggingfaceDecoderOnlyModel):
+ out_per_step = saliency_gpt_model_larger.attribute(
+ ["Hello world!", "Colorless green ideas sleep furiously."],
+ method="per_step_attention",
+ show_progress=False,
+ )
+ out_final_step = saliency_gpt_model_larger.attribute(
+ ["Hello world!", "Colorless green ideas sleep furiously."],
+ method="attention",
+ show_progress=False,
+ )
+ for i in range(2):
+ assert out_per_step[i].target_attributions.shape == out_final_step[i].target_attributions.shape
+ assert torch.allclose(
+ out_per_step[i].target_attributions, out_final_step[i].target_attributions, equal_nan=True, atol=1e-5
+ )
diff --git a/tests/data/test_aggregator.py b/tests/data/test_aggregator.py
index eb5086ca..f7e7c3e5 100644
--- a/tests/data/test_aggregator.py
+++ b/tests/data/test_aggregator.py
@@ -39,14 +39,14 @@ def test_sequence_attribution_aggregator(saliency_mt_model: HuggingfaceEncoderDe
)
seqattr = out.sequence_attributions[0]
assert seqattr.source_attributions.shape == (6, 7, 512)
- assert seqattr.target_attributions.shape == (7, 7, 512)
+ assert seqattr.target_attributions.shape == (8, 7, 512)
assert seqattr.step_scores["probability"].shape == (7,)
for i, step in enumerate(out.step_attributions):
assert step.source_attributions.shape == (1, 6, 512)
assert step.target_attributions.shape == (1, i + 1, 512)
out_agg = seqattr.aggregate()
assert out_agg.source_attributions.shape == (6, 7)
- assert out_agg.target_attributions.shape == (7, 7)
+ assert out_agg.target_attributions.shape == (8, 7)
assert out_agg.step_scores["probability"].shape == (7,)
@@ -56,9 +56,9 @@ def test_continuous_span_aggregator(saliency_mt_model: HuggingfaceEncoderDecoder
)
seqattr = out.sequence_attributions[0]
out_agg = seqattr.aggregate(ContiguousSpanAggregator, source_spans=(3, 5), target_spans=[(0, 3), (4, 6)])
- assert out_agg.source_attributions.shape == (5, 4, 512)
- assert out_agg.target_attributions.shape == (4, 4, 512)
- assert out_agg.step_scores["probability"].shape == (4,)
+ assert out_agg.source_attributions.shape == (5, 5, 512)
+ assert out_agg.target_attributions.shape == (5, 5, 512)
+ assert out_agg.step_scores["probability"].shape == (5,)
def test_span_aggregator_with_prefix(saliency_gpt_model: HuggingfaceDecoderOnlyModel):
@@ -76,14 +76,14 @@ def test_aggregator_pipeline(saliency_mt_model: HuggingfaceEncoderDecoderModel):
seqattr = out.sequence_attributions[0]
squeezesum = AggregatorPipeline([ContiguousSpanAggregator, SequenceAttributionAggregator])
out_agg_squeezesum = seqattr.aggregate(squeezesum, source_spans=(3, 5), target_spans=[(0, 3), (4, 6)])
- assert out_agg_squeezesum.source_attributions.shape == (5, 4)
- assert out_agg_squeezesum.target_attributions.shape == (4, 4)
- assert out_agg_squeezesum.step_scores["probability"].shape == (4,)
+ assert out_agg_squeezesum.source_attributions.shape == (5, 5)
+ assert out_agg_squeezesum.target_attributions.shape == (5, 5)
+ assert out_agg_squeezesum.step_scores["probability"].shape == (5,)
sumsqueeze = AggregatorPipeline([SequenceAttributionAggregator, ContiguousSpanAggregator])
out_agg_sumsqueeze = seqattr.aggregate(sumsqueeze, source_spans=(3, 5), target_spans=[(0, 3), (4, 6)])
- assert out_agg_sumsqueeze.source_attributions.shape == (5, 4)
- assert out_agg_sumsqueeze.target_attributions.shape == (4, 4)
- assert out_agg_sumsqueeze.step_scores["probability"].shape == (4,)
+ assert out_agg_sumsqueeze.source_attributions.shape == (5, 5)
+ assert out_agg_sumsqueeze.target_attributions.shape == (5, 5)
+ assert out_agg_sumsqueeze.step_scores["probability"].shape == (5,)
assert not torch.allclose(out_agg_squeezesum.source_attributions, out_agg_sumsqueeze.source_attributions)
assert not torch.allclose(out_agg_squeezesum.target_attributions, out_agg_sumsqueeze.target_attributions)
# Named indexing version
@@ -91,12 +91,12 @@ def test_aggregator_pipeline(saliency_mt_model: HuggingfaceEncoderDecoderModel):
named_sumsqueeze = ["scores", "spans"]
out_agg_squeezesum_named = seqattr.aggregate(named_squeezesum, source_spans=(3, 5), target_spans=[(0, 3), (4, 6)])
out_agg_sumsqueeze_named = seqattr.aggregate(named_sumsqueeze, source_spans=(3, 5), target_spans=[(0, 3), (4, 6)])
- assert out_agg_squeezesum_named.source_attributions.shape == (5, 4)
- assert out_agg_squeezesum_named.target_attributions.shape == (4, 4)
- assert out_agg_squeezesum_named.step_scores["probability"].shape == (4,)
- assert out_agg_sumsqueeze_named.source_attributions.shape == (5, 4)
- assert out_agg_sumsqueeze_named.target_attributions.shape == (4, 4)
- assert out_agg_sumsqueeze_named.step_scores["probability"].shape == (4,)
+ assert out_agg_squeezesum_named.source_attributions.shape == (5, 5)
+ assert out_agg_squeezesum_named.target_attributions.shape == (5, 5)
+ assert out_agg_squeezesum_named.step_scores["probability"].shape == (5,)
+ assert out_agg_sumsqueeze_named.source_attributions.shape == (5, 5)
+ assert out_agg_sumsqueeze_named.target_attributions.shape == (5, 5)
+ assert out_agg_sumsqueeze_named.step_scores["probability"].shape == (5,)
assert not torch.allclose(
out_agg_squeezesum_named.source_attributions, out_agg_sumsqueeze_named.source_attributions
)
diff --git a/tests/fixtures/aggregator.json b/tests/fixtures/aggregator.json
index fc029eec..53123526 100644
--- a/tests/fixtures/aggregator.json
+++ b/tests/fixtures/aggregator.json
@@ -36,6 +36,7 @@
],
"target": "Inseq \u00e8 un framework per l'attribuzione automatica di modelli sequenziali.",
"target_subwords": [
+ "",
"\u2581In",
"se",
"q",
@@ -58,6 +59,7 @@
""
],
"target_merged": [
+ "",
"\u2581Inseq",
"\u2581\u00e8",
"\u2581un",
diff --git a/tests/inference_commons.py b/tests/inference_commons.py
index 3da21068..19810018 100644
--- a/tests/inference_commons.py
+++ b/tests/inference_commons.py
@@ -1,3 +1,6 @@
+import json
+import os
+
from inseq.data import EncoderDecoderBatch
from inseq.utils import json_advanced_load
@@ -9,3 +12,8 @@ def get_example_batches():
dict_batches["batches"] = [batch.torch() for batch in dict_batches["batches"]]
assert all(isinstance(batch, EncoderDecoderBatch) for batch in dict_batches["batches"])
return dict_batches
+
+
+def load_examples() -> dict:
+ file = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/huggingface_model.json")
+ return json.load(open(file))
diff --git a/tests/models/test_huggingface_model.py b/tests/models/test_huggingface_model.py
index 993c07ac..72da4a2f 100644
--- a/tests/models/test_huggingface_model.py
+++ b/tests/models/test_huggingface_model.py
@@ -2,8 +2,6 @@
since it is bugged is not very elegant, this will need to be refactored.
"""
-import json
-import os
import pytest
import torch
@@ -15,8 +13,9 @@
from inseq.data import FeatureAttributionOutput, FeatureAttributionSequenceOutput
from inseq.utils import get_default_device
-EXAMPLES_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../fixtures/huggingface_model.json")
-EXAMPLES = json.load(open(EXAMPLES_FILE))
+from ..inference_commons import load_examples
+
+EXAMPLES = load_examples()
USE_REFERENCE_TEXT = [True, False]
ATTRIBUTE_TARGET = [True, False]
@@ -275,8 +274,8 @@ def test_attribute_slice_seq2seq(saliency_mt_model):
assert ex2.attr_pos_start == len(ex2.target)
assert ex2.attr_pos_end == len(ex2.target)
assert ex2.source_attributions.shape[1] == 0 and ex2.target_attributions.shape[1] == 0
- assert ex3.attr_pos_start == 12
- assert ex3.attr_pos_end == 15
+ assert ex3.attr_pos_start == 13
+ assert ex3.attr_pos_end == 16
assert ex1.source_attributions.shape[1] == ex1.attr_pos_end - ex1.attr_pos_start
assert ex1.target_attributions.shape[1] == ex1.attr_pos_end - ex1.attr_pos_start
assert ex1.target_attributions.shape[0] == ex1.attr_pos_end
@@ -303,12 +302,12 @@ def test_attribute_decoder(saliency_gpt2_model):
assert ex1.target_attributions.shape[1] == ex1.attr_pos_end - ex1.attr_pos_start
assert ex1.target_attributions.shape[0] == ex1.attr_pos_end
# Empty attributions outputs have start and end set to seq length
- assert ex2.attr_pos_start == 17
- assert ex2.attr_pos_end == 22
+ assert ex2.attr_pos_start == 9
+ assert ex2.attr_pos_end == 14
assert ex2.target_attributions.shape[1] == ex2.attr_pos_end - ex2.attr_pos_start
assert ex2.target_attributions.shape[0] == ex2.attr_pos_end
- assert ex3.attr_pos_start == 17
- assert ex3.attr_pos_end == 22
+ assert ex3.attr_pos_start == 12
+ assert ex3.attr_pos_end == 17
assert ex3.target_attributions.shape[1] == ex3.attr_pos_end - ex3.attr_pos_start
assert ex3.target_attributions.shape[0] == ex3.attr_pos_end
assert out.info["attr_pos_start"] == 17