Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Show only part of the input when calling show() method #281

Closed
RibinMTC opened this issue Jul 2, 2024 · 2 comments
Closed

Show only part of the input when calling show() method #281

RibinMTC opened this issue Jul 2, 2024 · 2 comments
Labels
question Further information is requested

Comments

@RibinMTC
Copy link

RibinMTC commented Jul 2, 2024

Question

Hi, thank you for the awesome library. How can I show only part of my input prompt in the heatmap? For example, my prompt has the structure: "instruction - {input_text}", I want to ignore the scores for the instruction but only show the heatmap for the input_text. I have used the following code:

model_path = "google/gemma-2b"
model = AutoModelForCausalLM.from_pretrained(model_path, load_in_8bit=True, device_map="auto")
attrib_model = inseq.load_model(
        model=model,
        attribution_method="attention"
    )
input_prompt = """Instruction: Summarize this article.
Input_text: In a quiet village nestled between rolling hills, an ancient tree whispered secrets to those who listened. One night, a curious child named Elara leaned close and heard tales of hidden treasures beneath the roots. As dawn broke, she unearthed a shimmering box, unlocking a forgotten world of wonder and magic."""

full_output_prompt = input_prompt + "Elara discovers a shimmering box under an ancient tree, unlocking a world of magic."

out = attrib_model.attribute(
    input_texts=input_prompt,
    generated_texts=full_output_prompt
)
subw_sqa_agg = out.aggregate(SubwordAggregator, special_chars="▁").aggregate()
subw_viz = subw_sqa_agg.show(return_html=True, do_aggregation=False)
@RibinMTC RibinMTC added the question Further information is requested label Jul 2, 2024
@gsarti
Copy link
Member

gsarti commented Jul 3, 2024

Hi @RibinMTC,

Thanks for reaching out! This was a notable missing option in the current version of Inseq, so I added a PR (#282) to introduce a SliceAggregator class to handle this behavior. You can try it out (pip install git+https://github.com/inseq-team/inseq.git@viz-slice) and let me know whether it addresses your concern:

import inseq

attrib_model = inseq.load_model("google/gemma-2b", "attention")
input_prompt = """Instruction: Summarize this article.
Input_text: In a quiet village nestled between rolling hills, an ancient tree whispered secrets to those who listened. One night, a curious child named Elara leaned close and heard tales of hidden treasures beneath the roots. As dawn broke, she unearthed a shimmering box, unlocking a forgotten world of wonder and magic.
Summary:"""

full_output_prompt = input_prompt + " Elara discovers a shimmering box under an ancient tree, unlocking a world of magic."
out = attrib_model.attribute(input_prompt, full_output_prompt)[0]

# Slice the summary -> aggregate subwords -> default attention aggregation (mean head, mean layer) + show
out[13:71].aggregate("subwords").show()

@RibinMTC
Copy link
Author

RibinMTC commented Jul 3, 2024

This was exactly what I was looking for, thank you very much :)

@RibinMTC RibinMTC closed this as completed Jul 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants