Skip to content

Commit

Permalink
Treescope viz working for attribute_context
Browse files Browse the repository at this point in the history
  • Loading branch information
gsarti committed Aug 12, 2024
1 parent 0c2eb0c commit 00a504a
Show file tree
Hide file tree
Showing 4 changed files with 343 additions and 119 deletions.
4 changes: 3 additions & 1 deletion inseq/commands/attribute_context/attribute_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ def attribute_context(args: AttributeContextArgs) -> AttributeContextOutput:
model_kwargs=deepcopy(args.model_kwargs),
tokenizer_kwargs=deepcopy(args.tokenizer_kwargs),
)
if not isinstance(args.model_name_or_path, str):
args.model_name_or_path = model.model_name
return attribute_context_with_model(args, model)


Expand Down Expand Up @@ -241,7 +243,7 @@ def attribute_context_with_model(args: AttributeContextArgs, model: HuggingfaceM
)
output.cci_scores.append(cci_out)
if args.show_viz or args.viz_path:
visualize_attribute_context(output, model, cti_threshold)
visualize_attribute_context(output, model, cti_threshold, args.show_viz, args.viz_path)
if not args.add_output_info:
output.info = None
if args.save_path:
Expand Down
141 changes: 57 additions & 84 deletions inseq/commands/attribute_context/attribute_context_helpers.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,16 @@
import logging
import re
from collections.abc import Callable
from dataclasses import dataclass, field, fields
from typing import Any

import treescope as ts
import treescope.figures as fg
import treescope.rendering_parts as rp
from rich import print as rprint
from rich.prompt import Confirm, Prompt
from torch import tensor

from ...data import FeatureAttributionSequenceOutput
from ...data.viz import get_tokens_heatmap_treescope
from ...models import HuggingfaceModel
from ...utils import pretty_dict
from ...utils.alignment_utils import compute_word_aligns
from ...utils.misc import clean_tokens
from ...utils.viz_utils import treescope_cmap
from .attribute_context_args import AttributeContextArgs, HandleOutputContextSetting

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -59,6 +52,15 @@ def maximum(self) -> float:
scores.extend(self.output_context_scores)
return max(scores)

@property
def all_scores(self) -> list[float]:
scores = []
if self.input_context_scores:
scores.extend(self.input_context_scores)
if self.output_context_scores:
scores.extend(self.output_context_scores)
return scores


@dataclass
class AttributeContextOutput:
Expand All @@ -77,83 +79,12 @@ class AttributeContextOutput:
def __repr__(self):
return f"{self.__class__.__name__}({pretty_dict(self.__dict__)})"

def __treescope_repr__(
self,
path: str,
subtree_renderer: Callable[[Any, str | None], ts.rendering_parts.Rendering],
) -> ts.rendering_parts.Rendering:
cmap_cti = treescope_cmap("greens")
cmap_cci = treescope_cmap("blues")
parts = [
fg.treescope_part_from_display_object(
fg.text_on_color("Context-sensitive tokens", value=1, colormap=cmap_cti)
),
rp.text(" in the generated output can be expanded to visualize the "),
fg.treescope_part_from_display_object(fg.text_on_color("contextual cues", value=1, colormap=cmap_cci)),
rp.text(" motivating their prediction.\n\n"),
]
if self.output_current is not None:
parts += [rp.custom_style(rp.text("Current output:\n\n"), css_style="font-weight: bold;")]
replace_chars = {}
cleaned_output_current = clean_tokens(self.output_current_tokens, replace_chars=replace_chars)
if self.input_context_tokens is not None:
cleaned_input_context = clean_tokens(self.input_context_tokens, replace_chars=replace_chars)
if self.output_context_tokens is not None:
cleaned_output_context = clean_tokens(self.output_context_tokens, replace_chars=replace_chars)
cci_idx_map = {cci.cti_idx: cci for cci in self.cci_scores} if self.cci_scores is not None else {}
for curr_tok_idx, curr_tok in enumerate(cleaned_output_current):
curr_tok_part = fg.treescope_part_from_display_object(
fg.text_on_color(
curr_tok,
value=round(self.cti_scores[curr_tok_idx], 4),
vmax=self.max_cti,
colormap=cmap_cti,
)
)
if curr_tok_idx in cci_idx_map:
cci_parts = [rp.text("\n")]
cci = cci_idx_map[curr_tok_idx]
if cci.input_context_scores is not None:
cci_parts.append(
get_tokens_heatmap_treescope(
tokens=cleaned_input_context,
scores=cci.input_context_scores,
title=f'Input context CCI scores for "{cci.cti_token}"',
title_style="font-style: italic; color: #888888;",
min_val=self.min_cci,
max_val=self.max_cci,
colormap=cmap_cci,
)
)
cci_parts.append(rp.text("\n\n"))
if cci.output_context_scores is not None:
cci_parts.append(
get_tokens_heatmap_treescope(
tokens=cleaned_output_context,
scores=cci.output_context_scores,
title=f'Output context CCI scores for "{cci.cti_token}"',
title_style="font-style: italic; color: #888888;",
min_val=self.min_cci,
max_val=self.max_cci,
colormap=cmap_cci,
)
)
curr_tok_part_final = rp.custom_style(
rp.build_full_line_with_annotations(
rp.build_custom_foldable_tree_node(
label=curr_tok_part,
contents=rp.fold_condition(
collapsed=rp.empty_part(),
expanded=rp.indented_children([rp.siblings(*cci_parts)]),
),
)
),
css_style="margin-left: 0.7em;",
)
else:
curr_tok_part_final = curr_tok_part
parts.append(rp.build_full_line_with_annotations(curr_tok_part_final))
return rp.custom_style(rp.siblings(*parts), css_style="white-space: pre-wrap")
def __treescope_repr__(self, *args, **kwargs):
from inseq.commands.attribute_context.attribute_context_viz_helpers import (
visualize_attribute_context_treescope,
)

return visualize_attribute_context_treescope(self)

def to_dict(self) -> dict[str, Any]:
out_dict = {k: v for k, v in self.__dict__.items() if k not in ["cci_scores", "info"]}
Expand Down Expand Up @@ -186,6 +117,18 @@ def max_cti(self) -> float:
return -1
return max(self.cti_scores)

@property
def mean_cti(self) -> float:
if self.cti_scores is None:
return 0
return sum(self.cti_scores) / len(self.cti_scores)

@property
def std_cti(self) -> float:
if self.cti_scores is None:
return 0
return tensor(self.cti_scores).std().item()

@property
def min_cci(self) -> float:
if self.cci_scores is None:
Expand All @@ -198,6 +141,36 @@ def max_cci(self) -> float:
return -1
return max(cci.maximum for cci in self.cci_scores)

@property
def cci_all_scores(self) -> list[float]:
if self.cci_scores is None:
return []
return [score for cci in self.cci_scores for score in cci.all_scores]

@property
def mean_cci(self) -> float:
if self.cci_scores is None:
return 0
return sum(self.cci_all_scores) / len(self.cci_all_scores)

@property
def std_cci(self) -> float:
if self.cci_scores is None:
return 0
return tensor(self.cci_all_scores).std().item()

@property
def input_context_scores(self) -> list[float] | None:
if self.cci_scores is None or self.cci_scores[0].input_context_scores is None:
return None
return [cci.input_context_scores for cci in self.cci_scores]

@property
def output_context_scores(self) -> list[float] | None:
if self.cci_scores is None or self.cci_scores[0].output_context_scores is None:
return None
return [cci.output_context_scores for cci in self.cci_scores]


def concat_with_sep(s1: str, s2: str, sep: str) -> bool:
"""Adds separator between two strings if needed."""
Expand Down
Loading

0 comments on commit 00a504a

Please sign in to comment.