Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve attribution viz and SliceAggregator #282

Merged
merged 4 commits into from
Jul 3, 2024
Merged

Improve attribution viz and SliceAggregator #282

merged 4 commits into from
Jul 3, 2024

Conversation

gsarti
Copy link
Member

@gsarti gsarti commented Jul 3, 2024

Description

This PR performs the following changes:

Visualization

  • Rows and columns in the visualization now have indices alongside tokens to facilitate index-based slicing, aggregation and alignment

  • The directions of generated/attributed tokens were clarified in the visualization using arrows instead of x/y

image

Slicing

  • A new SliceAggregator ("slices") is added to allow for slicing source (in encoder-decoder) or target (in decoder-only) tokens from a FeatureAttributionSequenceOutput object, using the same syntax of ContiguousSpanAggregator. The __getitem__ method of the FeatureAttributionSequenceOutput is a shortcut for this, allowing slicing with [start:stop] syntax.

⚠️ Important: In light of the FeatureAttributionSequenceOutput class design, the generated output will always be preserved in the sliced context (i.e. even if not explicitly included in target_spans). Analysis ignoring the generated output attributions will need to manually post-process attribution tensors to consider only the spans of interest.

import inseq
from inseq.data.aggregator import SliceAggregator

attrib_model = inseq.load_model("gpt2", "attention")
input_prompt = """Instruction: Summarize this article.
Input_text: In a quiet village nestled between rolling hills, an ancient tree whispered secrets to those who listened. One night, a curious child named Elara leaned close and heard tales of hidden treasures beneath the roots. As dawn broke, she unearthed a shimmering box, unlocking a forgotten world of wonder and magic.
Summary:"""

full_output_prompt = input_prompt + " Elara discovers a shimmering box under an ancient tree, unlocking a world of magic."

out = attrib_model.attribute(input_prompt, full_output_prompt)[0]

# These are all equivalent ways to slice only the input text contents
out_sliced = out.aggregate(SliceAggregator, target_spans=(13,73))
out_sliced = out.aggregate("slices", target_spans=(13,73))
out_sliced = out[13:73]

Other aggregation

  • The __sub__ method in FeatureAttributionSequenceOutput is now used as a shortcut for PairAggregator:
import inseq

attrib_model = inseq.load_model("gpt2", "saliency")

out_male = attrib_model.attribute(
    "The director went home because",
    "The director went home because he was tired",
    step_scores=["probability"]
)[0]
out_female = attrib_model.attribute(
    "The director went home because",
    "The director went home because she was tired",
    step_scores=["probability"]
)[0]
(out_male - out_female).show()

@gsarti gsarti merged commit 5a46d51 into main Jul 3, 2024
4 checks passed
@gsarti gsarti deleted the viz-slice branch July 3, 2024 12:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant