From 6c4d1df8e4713d4e3ad9983f30378cc4941c59ca Mon Sep 17 00:00:00 2001 From: Gabriele Sarti Date: Tue, 27 Feb 2024 07:53:24 +0100 Subject: [PATCH] VZ working for encoder-decoder --- CHANGELOG.md | 6 +- inseq/attr/feat/attribution_utils.py | 8 ++- inseq/attr/feat/ops/value_zeroing.py | 74 +++++++++++++++------ inseq/attr/feat/perturbation_attribution.py | 22 ++++-- inseq/data/attribution.py | 9 +-- inseq/models/model_config.py | 14 ++-- inseq/models/model_config.yaml | 53 ++++++++------- tests/attr/feat/test_feature_attribution.py | 2 +- tests/data/test_aggregator.py | 34 +++++----- tests/fixtures/aggregator.json | 2 + tests/models/test_huggingface_model.py | 4 +- 11 files changed, 151 insertions(+), 77 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index faeca60e..4c8fe04c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,9 @@ - Support for multi-GPU attribution ([#238](https://github.com/inseq-team/inseq/pull/238)) - Added `inseq attribute-context` CLI command to support the [PECoRe framework] for detecting and attributing context reliance in generative LMs ([#237](https://github.com/inseq-team/inseq/pull/237)) +- Added `value_zeroing` (`inseq.attr.feat.perturbation_attribution.ValueZeroingAttribution`) attribution method ([#173](https://github.com/inseq-team/inseq/pull/173)) +- Added `rollout` (`inseq.data.aggregation_functions.RolloutAggregationFunction`) aggregation function for `SequenceAttributionAggregator` class ([#173](https://github.com/inseq-team/inseq/pull/173)). +- `value_zeroing` and `attention` use scores from the last generation step to produce outputs more efficiently (`is_final_step_method = True`) ([#173](https://github.com/inseq-team/inseq/pull/173)). ## šŸ”§ Fixes & Refactoring @@ -26,4 +29,5 @@ ## šŸ’„ Breaking Changes -*No changes* +- If `attention` is used as attribution method in `model.attribute`, `step_scores` cannot be extracted at the same time since the method does not require iterating over the full sequence anymore. ([#173](https://github.com/inseq-team/inseq/pull/173)) As an alternative, step scores can be extracted separately using the `dummy` attribution method (i.e. no attribution). +- BOS is always included in target-side attribution and generated sequences if present. ([#173](https://github.com/inseq-team/inseq/pull/173)) diff --git a/inseq/attr/feat/attribution_utils.py b/inseq/attr/feat/attribution_utils.py index 8da4f899..a9679845 100644 --- a/inseq/attr/feat/attribution_utils.py +++ b/inseq/attr/feat/attribution_utils.py @@ -144,11 +144,15 @@ def extract_args( def get_source_target_attributions( attr: Union[StepAttributionTensor, tuple[StepAttributionTensor, StepAttributionTensor]], is_encoder_decoder: bool, + has_sequence_scores: bool = False, ) -> tuple[Optional[StepAttributionTensor], Optional[StepAttributionTensor]]: if isinstance(attr, tuple): if is_encoder_decoder: - return (attr[0], attr[1]) if len(attr) > 1 else (attr[0], None) + if has_sequence_scores: + return (attr[0], attr[1], attr[2]) + else: + return (attr[0], attr[1]) if len(attr) > 1 else (attr[0], None) else: - return (None, attr[0]) + return (None, None, attr[0]) if has_sequence_scores else (None, attr[0]) else: return (attr, None) if is_encoder_decoder else (None, attr) diff --git a/inseq/attr/feat/ops/value_zeroing.py b/inseq/attr/feat/ops/value_zeroing.py index a75dcc86..0ee2f86a 100644 --- a/inseq/attr/feat/ops/value_zeroing.py +++ b/inseq/attr/feat/ops/value_zeroing.py @@ -45,7 +45,6 @@ class ValueZeroingSimilarityMetric(Enum): class ValueZeroingModule(Enum): DECODER = "decoder" ENCODER = "encoder" - CROSS = "cross" class ValueZeroing(InseqAttribution): @@ -155,10 +154,13 @@ def compute_modules_post_zeroing_similarity( inputs: TensorOrTupleOfTensorsGeneric, additional_forward_args: TensorOrTupleOfTensorsGeneric, hidden_states: MultiLayerEmbeddingsTensor, + attention_module_name: str, + attributed_seq_len: Optional[int] = None, similarity_metric: str = ValueZeroingSimilarityMetric.COSINE.value, mode: str = ValueZeroingModule.DECODER.value, zeroed_units_indices: Optional[OneOrMoreIndicesDict] = None, - threshold: float = 1e-5, + min_score_threshold: float = 1e-5, + use_causal_mask: bool = False, ) -> MultiLayerScoreTensor: """Given a ``nn.ModuleList``, computes the similarity between the clean and corrupted states for each block. @@ -166,9 +168,12 @@ def compute_modules_post_zeroing_similarity( modules (:obj:`nn.ModuleList`): The list of modules to compute the similarity for. hidden_states (:obj:`MultiLayerEmbeddingsTensor`): The cached hidden states of the modules to use as clean counterparts when computing the similarity. - similarity_scores_shape (:obj:`torch.Size`): The shape of the similarity scores tensor to be returned. + attention_module_name (:obj:`str`): The name of the attention module to zero the values for. + attributed_seq_len (:obj:`int`): The length of the sequence to attribute. If not specified, it is assumed + to be the same as the length of the hidden states. similarity_metric (:obj:`str`): The name of the similarity metric used. Default: "cosine". mode (:obj:`str`): The mode of the model to compute the similarity for. Default: "decoder". + zeroed_units_indices (:obj:`Union[int, tuple[int, int], list[int]]` or :obj:`dict` with :obj:`int` keys and `Union[int, tuple[int, int], list[int]]` values, optional): The indices of the attention heads that should be zeroed to compute corrupted states. @@ -179,6 +184,9 @@ def compute_modules_post_zeroing_similarity( - If a dictionary, the keys are the layer indices and the values are the zeroed attention heads for the corresponding layer. Any missing layer will not be zeroed. Default: None. + min_score_threshold (:obj:`float`, optional): The minimum score threshold to consider when computing the + similarity. Default: 1e-5. + use_causal_mask (:obj:`bool`, optional): Whether a causal mask is applied to zeroing scores Default: False. Returns: :obj:`MultiLayerScoreTensor`: A tensor of shape ``[batch_size, seq_len, num_layer]`` containing distances @@ -186,11 +194,15 @@ def compute_modules_post_zeroing_similarity( """ if mode == ValueZeroingModule.DECODER.value: modules: nn.ModuleList = find_block_stack(self.forward_func.get_decoder()) - batch_size = hidden_states.size(0) - num_layers = len(modules) - sequence_length = hidden_states.size(2) + elif mode == ValueZeroingModule.ENCODER.value: + modules: nn.ModuleList = find_block_stack(self.forward_func.get_encoder()) else: raise NotImplementedError(f"Mode {mode} not implemented for value zeroing.") + if attributed_seq_len is None: + attributed_seq_len = hidden_states.size(2) + batch_size = hidden_states.size(0) + generated_seq_len = hidden_states.size(2) + num_layers = len(modules) # Store clean hidden states for later use. Starts at 1 since the first element of the modules stack is the # embedding layer, and we are only interested in the transformer blocks outputs. @@ -199,7 +211,7 @@ def compute_modules_post_zeroing_similarity( } # Scores for every layer of the model all_scores = torch.ones( - batch_size, num_layers, sequence_length, sequence_length, device=hidden_states.device + batch_size, num_layers, generated_seq_len, attributed_seq_len, device=hidden_states.device ) * float("nan") # Hooks: @@ -218,11 +230,11 @@ def compute_modules_post_zeroing_similarity( modules[block_idx].register_forward_hook(states_extract_and_patch_hook) ) # Zeroing is done for every token in the sequence separately (O(n) complexity) - for token_idx in range(sequence_length): + for token_idx in range(attributed_seq_len): value_zeroing_hook_handles: list[RemovableHandle] = [] # Value zeroing hooks are registered for every token separately since they are token-dependent for block_idx, block in enumerate(modules): - attention_module = block.get_submodule(self.forward_func.config.attention_module) + attention_module = block.get_submodule(attention_module_name) if isinstance(zeroed_units_indices, dict): if block_idx not in zeroed_units_indices: continue @@ -259,19 +271,22 @@ def compute_modules_post_zeroing_similarity( for block_idx in range(len(modules)): similarity_scores = self.SIMILARITY_METRICS[similarity_metric]( self.clean_block_output_states[block_idx].float(), self.corrupted_block_output_states[block_idx] - )[:, token_idx:] - all_scores[:, block_idx, token_idx:, token_idx] = 1 - similarity_scores + ) + if use_causal_mask: + all_scores[:, block_idx, token_idx:, token_idx] = 1 - similarity_scores[:, token_idx:] + else: + all_scores[:, block_idx, :, token_idx] = 1 - similarity_scores self.corrupted_block_output_states = {} for handle in states_extraction_hook_handles: handle.remove() self.clean_block_output_states = {} - all_scores = torch.where(all_scores < threshold, torch.zeros_like(all_scores), all_scores) + all_scores = torch.where(all_scores < min_score_threshold, torch.zeros_like(all_scores), all_scores) # Normalize scores to sum to 1 - per_token_sum_score = all_scores.sum(dim=-1, keepdim=True) + per_token_sum_score = all_scores.nansum(dim=-1, keepdim=True) per_token_sum_score[per_token_sum_score == 0] = 1 all_scores = all_scores / per_token_sum_score - # Final shape: [batch_size, seq_len, seq_len, num_layers] + # Final shape: [batch_size, attributed_seq_len, generated_seq_len, num_layers] return all_scores.permute(0, 3, 2, 1) def attribute( @@ -312,18 +327,39 @@ def attribute( f"Similarity metric {similarity_metric} not available." f"Available metrics: {','.join(self.SIMILARITY_METRICS.keys())}" ) + decoder_scores = self.compute_modules_post_zeroing_similarity( inputs=inputs, additional_forward_args=additional_forward_args, hidden_states=decoder_hidden_states, + attention_module_name=self.forward_func.config.self_attention_module, similarity_metric=similarity_metric, mode=ValueZeroingModule.DECODER.value, zeroed_units_indices=zeroed_units_indices, + use_causal_mask=True, ) - return decoder_scores # Encoder-decoder models also perform zeroing on the encoder self-attention and cross-attention values # Adapted from https://github.com/hmohebbi/ContextMixingASR/blob/master/scoring/valueZeroing.py - # if is_encoder_decoder: - # encoder_hidden_states = torch.stack(outputs.encoder_hidden_states) - # encoder = self.forward_func.get_encoder() - # encoder_stack = find_block_stack(encoder) + if self.forward_func.is_encoder_decoder: + # TODO: Enable different encoder/decoder/cross zeroing indices + encoder_scores = self.compute_modules_post_zeroing_similarity( + inputs=inputs, + additional_forward_args=additional_forward_args, + hidden_states=encoder_hidden_states, + attention_module_name=self.forward_func.config.self_attention_module, + similarity_metric=similarity_metric, + mode=ValueZeroingModule.ENCODER.value, + zeroed_units_indices=zeroed_units_indices, + ) + cross_scores = self.compute_modules_post_zeroing_similarity( + inputs=inputs, + additional_forward_args=additional_forward_args, + hidden_states=decoder_hidden_states, + attributed_seq_len=encoder_hidden_states.size(2), + attention_module_name=self.forward_func.config.cross_attention_module, + similarity_metric=similarity_metric, + mode=ValueZeroingModule.DECODER.value, + zeroed_units_indices=zeroed_units_indices, + ) + return encoder_scores, cross_scores, decoder_scores + return (decoder_scores,) diff --git a/inseq/attr/feat/perturbation_attribution.py b/inseq/attr/feat/perturbation_attribution.py index b18d3d0d..40da111c 100644 --- a/inseq/attr/feat/perturbation_attribution.py +++ b/inseq/attr/feat/perturbation_attribution.py @@ -144,11 +144,25 @@ def attribute_step( attribution_args: dict[str, Any] = {}, ) -> MultiDimensionalFeatureAttributionStepOutput: attr = self.method.attribute(**attribute_fn_main_args, **attribution_args) - source_attributions, target_attributions = get_source_target_attributions( - attr, self.attribution_model.is_encoder_decoder + encoder_self_scores, decoder_cross_scores, decoder_self_scores = get_source_target_attributions( + attr, self.attribution_model.is_encoder_decoder, has_sequence_scores=True ) + sequence_scores = {} + if self.attribution_model.is_encoder_decoder: + if len(attribute_fn_main_args["inputs"]) > 1: + target_attributions = decoder_self_scores.to("cpu") + else: + target_attributions = None + sequence_scores["decoder_self_scores"] = decoder_self_scores.to("cpu") + sequence_scores["encoder_self_scores"] = encoder_self_scores.to("cpu") + return MultiDimensionalFeatureAttributionStepOutput( + source_attributions=decoder_cross_scores.to("cpu"), + target_attributions=target_attributions, + sequence_scores=sequence_scores, + _num_dimensions=1, # num_layers + ) return MultiDimensionalFeatureAttributionStepOutput( - source_attributions=source_attributions, - target_attributions=target_attributions, + source_attributions=None, + target_attributions=decoder_self_scores, _num_dimensions=1, # num_layers ) diff --git a/inseq/data/attribution.py b/inseq/data/attribution.py index 5dcea071..5c58a414 100644 --- a/inseq/data/attribution.py +++ b/inseq/data/attribution.py @@ -209,8 +209,11 @@ def from_step_attributions( curr_target = [a.target[seq_idx][0] for a in attributions] targets.append(drop_padding(curr_target, pad_token)) if has_bos_token: - tokenized_target_sentences[seq_idx] = tokenized_target_sentences[seq_idx][1:] - tokenized_target_sentences[seq_idx] = drop_padding(tokenized_target_sentences[seq_idx], pad_token) + tokenized_target_sentences[seq_idx] = tokenized_target_sentences[seq_idx][:1] + drop_padding( + tokenized_target_sentences[seq_idx], pad_token + ) + else: + tokenized_target_sentences[seq_idx] = drop_padding(tokenized_target_sentences[seq_idx], pad_token) if attr_pos_end is None: attr_pos_end = max(len(t) for t in tokenized_target_sentences) for seq_idx in range(num_sequences): @@ -238,8 +241,6 @@ def from_step_attributions( [att.target_attributions for att in attributions], padding_dims=[1] ) for seq_id in range(num_sequences): - if has_bos_token: - target_attributions[seq_id] = target_attributions[seq_id][1:, ...] start_idx = max(pos_start) - pos_start[seq_id] end_idx = start_idx + len(tokenized_target_sentences[seq_id]) target_attributions[seq_id] = target_attributions[seq_id][ diff --git a/inseq/models/model_config.py b/inseq/models/model_config.py index a86d5fb3..52d9d47b 100644 --- a/inseq/models/model_config.py +++ b/inseq/models/model_config.py @@ -1,6 +1,7 @@ import logging from dataclasses import dataclass from pathlib import Path +from typing import Optional import yaml @@ -12,18 +13,23 @@ class ModelConfig: """Configuration used by the methods for which the attribute ``use_model_config=True``. Args: - attention_module (:obj:`str`): - The name of the module performing the attention computation (e.g.``attn`` for the GPT-2 model in - transformers). Can be identified by looking at the name of the attribute instantiating the attention module + self_attention_module (:obj:`str`): + The name of the module performing the self-attention computation (e.g.``attn`` for the GPT-2 model in + transformers). Can be identified by looking at the name of the self-attention module attribute in the model's transformer block class (e.g. :obj:`transformers.models.gpt2.GPT2Block` for GPT-2). + cross_attention_module (:obj:`str`): + The name of the module performing the cross-attention computation (e.g.``encoder_attn`` for MarianMT models + in transformers). Can be identified by looking at the name of the cross-attention module attribute + in the model's transformer block class (e.g. :obj:`transformers.models.marian.MarianDecoderLayer`). value_vector (:obj:`str`): The name of the variable in the forward pass of the attention module containing the value vector (e.g. ``value`` for the GPT-2 model in transformers). Can be identified by looking at the forward pass of the attention module (e.g. :obj:`transformers.models.gpt2.modeling_gpt2.GPT2Attention.forward` for GPT-2). """ - attention_module: str + self_attention_module: str value_vector: str + cross_attention_module: Optional[str] = None MODEL_CONFIGS = { diff --git a/inseq/models/model_config.yaml b/inseq/models/model_config.yaml index 68288930..afd564a1 100644 --- a/inseq/models/model_config.yaml +++ b/inseq/models/model_config.yaml @@ -1,46 +1,53 @@ +# AutoModelForCausalLM +BloomForCausalLM: + self_attention_module: "self_attention" + value_vector: "value_layer" GPT2LMHeadModel: - attention_module: "attn" + self_attention_module: "attn" value_vector: "value" OpenAIGPTLMHeadModel: - attention_module: "attn" + self_attention_module: "attn" value_vector: "value" GPTNeoXForCausalLM: - attention_module: "attention" + self_attention_module: "attention" value_vector: "value" -BloomForCausalLM: - attention_module: "self_attention" - value_vector: "value_layer" LlamaForCausalLM: - attention_module: "self_attn" + self_attention_module: "self_attn" value_vector: "value_states" GPTBigCodeForCausalLM: - attention_module: "attn" + self_attention_module: "attn" value_vector: "value" CodeGenForCausalLM: - attention_module: "attn" + self_attention_module: "attn" value_vector: "value" - -# TODO ForCausalLM +# TODO +# BioGptForCausalLM +# GemmaForCausalLM # GPTNeoForCausalLM # GPTJForCausalLM +# MistralForCausalLM +# MixtralForCausalLM +# MptForCausalLM +# OpenLlamaForCausalLM # OPTForCausalLM +# PhiForCausalLM +# StableLmForCausalLM # XGLMForCausalLM -# BioGptForCausalLM -# XLNetLMHeadModel + +# AutoModelForSeq2SeqLM +MarianMTModel: + self_attention_module: "self_attn" + cross_attention_module: "encoder_attn" + value_vector: "value_states" # TODO ForConditionalGeneration # BartForConditionalGeneration -# BlenderbotForConditionalGeneration -# T5ForConditionalGeneration -# MarianMTModel -# LongT5ForConditionalGeneration # FSMTForConditionalGeneration +# LongT5ForConditionalGeneration # M2M100ForConditionalGeneration # MBartForConditionalGeneration -# PegasusForConditionalGeneration -# ProphetNetForConditionalGeneration -# LEDForConditionalGeneration -# BigBirdPegasusForConditionalGeneration -# PLBartForConditionalGeneration -# SwitchTransformerForConditionalGeneration +# MT5ForConditionalGeneration # NllbMoeForConditionalGeneration +# SeamlessM4TForTextToText +# SeamlessM4Tv2ForTextToText +# T5ForConditionalGeneration diff --git a/tests/attr/feat/test_feature_attribution.py b/tests/attr/feat/test_feature_attribution.py index 841ea95d..f08d9a09 100644 --- a/tests/attr/feat/test_feature_attribution.py +++ b/tests/attr/feat/test_feature_attribution.py @@ -75,7 +75,7 @@ def test_contrastive_attribution_seq2seq_alignments(saliency_mt_model_larger: Hu "orig_tgt": "I soldati della pace ONU", "contrast_tgt": "Le forze militari di pace delle Nazioni Unite", "alignments": [[(0, 0), (1, 1), (2, 2), (3, 4), (4, 5), (5, 7), (6, 9)]], - "aligned_tgts": ["ā–Le ā†’ ā–I", "ā–forze ā†’ ā–soldati", "ā–di ā†’ ā–della", "ā–pace", "ā–Nazioni ā†’ ā–ONU", ""], + "aligned_tgts": ["", "ā–Le ā†’ ā–I", "ā–forze ā†’ ā–soldati", "ā–di ā†’ ā–della", "ā–pace", "ā–Nazioni ā†’ ā–ONU", ""], } out = saliency_mt_model_larger.attribute( aligned["src"], diff --git a/tests/data/test_aggregator.py b/tests/data/test_aggregator.py index eb5086ca..f7e7c3e5 100644 --- a/tests/data/test_aggregator.py +++ b/tests/data/test_aggregator.py @@ -39,14 +39,14 @@ def test_sequence_attribution_aggregator(saliency_mt_model: HuggingfaceEncoderDe ) seqattr = out.sequence_attributions[0] assert seqattr.source_attributions.shape == (6, 7, 512) - assert seqattr.target_attributions.shape == (7, 7, 512) + assert seqattr.target_attributions.shape == (8, 7, 512) assert seqattr.step_scores["probability"].shape == (7,) for i, step in enumerate(out.step_attributions): assert step.source_attributions.shape == (1, 6, 512) assert step.target_attributions.shape == (1, i + 1, 512) out_agg = seqattr.aggregate() assert out_agg.source_attributions.shape == (6, 7) - assert out_agg.target_attributions.shape == (7, 7) + assert out_agg.target_attributions.shape == (8, 7) assert out_agg.step_scores["probability"].shape == (7,) @@ -56,9 +56,9 @@ def test_continuous_span_aggregator(saliency_mt_model: HuggingfaceEncoderDecoder ) seqattr = out.sequence_attributions[0] out_agg = seqattr.aggregate(ContiguousSpanAggregator, source_spans=(3, 5), target_spans=[(0, 3), (4, 6)]) - assert out_agg.source_attributions.shape == (5, 4, 512) - assert out_agg.target_attributions.shape == (4, 4, 512) - assert out_agg.step_scores["probability"].shape == (4,) + assert out_agg.source_attributions.shape == (5, 5, 512) + assert out_agg.target_attributions.shape == (5, 5, 512) + assert out_agg.step_scores["probability"].shape == (5,) def test_span_aggregator_with_prefix(saliency_gpt_model: HuggingfaceDecoderOnlyModel): @@ -76,14 +76,14 @@ def test_aggregator_pipeline(saliency_mt_model: HuggingfaceEncoderDecoderModel): seqattr = out.sequence_attributions[0] squeezesum = AggregatorPipeline([ContiguousSpanAggregator, SequenceAttributionAggregator]) out_agg_squeezesum = seqattr.aggregate(squeezesum, source_spans=(3, 5), target_spans=[(0, 3), (4, 6)]) - assert out_agg_squeezesum.source_attributions.shape == (5, 4) - assert out_agg_squeezesum.target_attributions.shape == (4, 4) - assert out_agg_squeezesum.step_scores["probability"].shape == (4,) + assert out_agg_squeezesum.source_attributions.shape == (5, 5) + assert out_agg_squeezesum.target_attributions.shape == (5, 5) + assert out_agg_squeezesum.step_scores["probability"].shape == (5,) sumsqueeze = AggregatorPipeline([SequenceAttributionAggregator, ContiguousSpanAggregator]) out_agg_sumsqueeze = seqattr.aggregate(sumsqueeze, source_spans=(3, 5), target_spans=[(0, 3), (4, 6)]) - assert out_agg_sumsqueeze.source_attributions.shape == (5, 4) - assert out_agg_sumsqueeze.target_attributions.shape == (4, 4) - assert out_agg_sumsqueeze.step_scores["probability"].shape == (4,) + assert out_agg_sumsqueeze.source_attributions.shape == (5, 5) + assert out_agg_sumsqueeze.target_attributions.shape == (5, 5) + assert out_agg_sumsqueeze.step_scores["probability"].shape == (5,) assert not torch.allclose(out_agg_squeezesum.source_attributions, out_agg_sumsqueeze.source_attributions) assert not torch.allclose(out_agg_squeezesum.target_attributions, out_agg_sumsqueeze.target_attributions) # Named indexing version @@ -91,12 +91,12 @@ def test_aggregator_pipeline(saliency_mt_model: HuggingfaceEncoderDecoderModel): named_sumsqueeze = ["scores", "spans"] out_agg_squeezesum_named = seqattr.aggregate(named_squeezesum, source_spans=(3, 5), target_spans=[(0, 3), (4, 6)]) out_agg_sumsqueeze_named = seqattr.aggregate(named_sumsqueeze, source_spans=(3, 5), target_spans=[(0, 3), (4, 6)]) - assert out_agg_squeezesum_named.source_attributions.shape == (5, 4) - assert out_agg_squeezesum_named.target_attributions.shape == (4, 4) - assert out_agg_squeezesum_named.step_scores["probability"].shape == (4,) - assert out_agg_sumsqueeze_named.source_attributions.shape == (5, 4) - assert out_agg_sumsqueeze_named.target_attributions.shape == (4, 4) - assert out_agg_sumsqueeze_named.step_scores["probability"].shape == (4,) + assert out_agg_squeezesum_named.source_attributions.shape == (5, 5) + assert out_agg_squeezesum_named.target_attributions.shape == (5, 5) + assert out_agg_squeezesum_named.step_scores["probability"].shape == (5,) + assert out_agg_sumsqueeze_named.source_attributions.shape == (5, 5) + assert out_agg_sumsqueeze_named.target_attributions.shape == (5, 5) + assert out_agg_sumsqueeze_named.step_scores["probability"].shape == (5,) assert not torch.allclose( out_agg_squeezesum_named.source_attributions, out_agg_sumsqueeze_named.source_attributions ) diff --git a/tests/fixtures/aggregator.json b/tests/fixtures/aggregator.json index fc029eec..53123526 100644 --- a/tests/fixtures/aggregator.json +++ b/tests/fixtures/aggregator.json @@ -36,6 +36,7 @@ ], "target": "Inseq \u00e8 un framework per l'attribuzione automatica di modelli sequenziali.", "target_subwords": [ + "", "\u2581In", "se", "q", @@ -58,6 +59,7 @@ "" ], "target_merged": [ + "", "\u2581Inseq", "\u2581\u00e8", "\u2581un", diff --git a/tests/models/test_huggingface_model.py b/tests/models/test_huggingface_model.py index f81e635f..72da4a2f 100644 --- a/tests/models/test_huggingface_model.py +++ b/tests/models/test_huggingface_model.py @@ -274,8 +274,8 @@ def test_attribute_slice_seq2seq(saliency_mt_model): assert ex2.attr_pos_start == len(ex2.target) assert ex2.attr_pos_end == len(ex2.target) assert ex2.source_attributions.shape[1] == 0 and ex2.target_attributions.shape[1] == 0 - assert ex3.attr_pos_start == 12 - assert ex3.attr_pos_end == 15 + assert ex3.attr_pos_start == 13 + assert ex3.attr_pos_end == 16 assert ex1.source_attributions.shape[1] == ex1.attr_pos_end - ex1.attr_pos_start assert ex1.target_attributions.shape[1] == ex1.attr_pos_end - ex1.attr_pos_start assert ex1.target_attributions.shape[0] == ex1.attr_pos_end