Skip to content

Commit

Permalink
Basic viz working
Browse files Browse the repository at this point in the history
  • Loading branch information
gsarti committed Aug 10, 2024
1 parent 904c893 commit 0c2eb0c
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 5 deletions.
127 changes: 127 additions & 0 deletions inseq/commands/attribute_context/attribute_context_helpers.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
import logging
import re
from collections.abc import Callable
from dataclasses import dataclass, field, fields
from typing import Any

import treescope as ts
import treescope.figures as fg
import treescope.rendering_parts as rp
from rich import print as rprint
from rich.prompt import Confirm, Prompt
from torch import tensor

from ...data import FeatureAttributionSequenceOutput
from ...data.viz import get_tokens_heatmap_treescope
from ...models import HuggingfaceModel
from ...utils import pretty_dict
from ...utils.alignment_utils import compute_word_aligns
from ...utils.misc import clean_tokens
from ...utils.viz_utils import treescope_cmap
from .attribute_context_args import AttributeContextArgs, HandleOutputContextSetting

logger = logging.getLogger(__name__)
Expand All @@ -34,6 +41,24 @@ def __repr__(self):
def to_dict(self) -> dict[str, Any]:
return dict(self.__dict__.items())

@property
def minimum(self) -> float:
scores = [0]
if self.input_context_scores:
scores.extend(self.input_context_scores)
if self.output_context_scores:
scores.extend(self.output_context_scores)
return min(scores)

@property
def maximum(self) -> float:
scores = [0]
if self.input_context_scores:
scores.extend(self.input_context_scores)
if self.output_context_scores:
scores.extend(self.output_context_scores)
return max(scores)


@dataclass
class AttributeContextOutput:
Expand All @@ -52,6 +77,84 @@ class AttributeContextOutput:
def __repr__(self):
return f"{self.__class__.__name__}({pretty_dict(self.__dict__)})"

def __treescope_repr__(
self,
path: str,
subtree_renderer: Callable[[Any, str | None], ts.rendering_parts.Rendering],
) -> ts.rendering_parts.Rendering:
cmap_cti = treescope_cmap("greens")
cmap_cci = treescope_cmap("blues")
parts = [
fg.treescope_part_from_display_object(
fg.text_on_color("Context-sensitive tokens", value=1, colormap=cmap_cti)
),
rp.text(" in the generated output can be expanded to visualize the "),
fg.treescope_part_from_display_object(fg.text_on_color("contextual cues", value=1, colormap=cmap_cci)),
rp.text(" motivating their prediction.\n\n"),
]
if self.output_current is not None:
parts += [rp.custom_style(rp.text("Current output:\n\n"), css_style="font-weight: bold;")]
replace_chars = {}
cleaned_output_current = clean_tokens(self.output_current_tokens, replace_chars=replace_chars)
if self.input_context_tokens is not None:
cleaned_input_context = clean_tokens(self.input_context_tokens, replace_chars=replace_chars)
if self.output_context_tokens is not None:
cleaned_output_context = clean_tokens(self.output_context_tokens, replace_chars=replace_chars)
cci_idx_map = {cci.cti_idx: cci for cci in self.cci_scores} if self.cci_scores is not None else {}
for curr_tok_idx, curr_tok in enumerate(cleaned_output_current):
curr_tok_part = fg.treescope_part_from_display_object(
fg.text_on_color(
curr_tok,
value=round(self.cti_scores[curr_tok_idx], 4),
vmax=self.max_cti,
colormap=cmap_cti,
)
)
if curr_tok_idx in cci_idx_map:
cci_parts = [rp.text("\n")]
cci = cci_idx_map[curr_tok_idx]
if cci.input_context_scores is not None:
cci_parts.append(
get_tokens_heatmap_treescope(
tokens=cleaned_input_context,
scores=cci.input_context_scores,
title=f'Input context CCI scores for "{cci.cti_token}"',
title_style="font-style: italic; color: #888888;",
min_val=self.min_cci,
max_val=self.max_cci,
colormap=cmap_cci,
)
)
cci_parts.append(rp.text("\n\n"))
if cci.output_context_scores is not None:
cci_parts.append(
get_tokens_heatmap_treescope(
tokens=cleaned_output_context,
scores=cci.output_context_scores,
title=f'Output context CCI scores for "{cci.cti_token}"',
title_style="font-style: italic; color: #888888;",
min_val=self.min_cci,
max_val=self.max_cci,
colormap=cmap_cci,
)
)
curr_tok_part_final = rp.custom_style(
rp.build_full_line_with_annotations(
rp.build_custom_foldable_tree_node(
label=curr_tok_part,
contents=rp.fold_condition(
collapsed=rp.empty_part(),
expanded=rp.indented_children([rp.siblings(*cci_parts)]),
),
)
),
css_style="margin-left: 0.7em;",
)
else:
curr_tok_part_final = curr_tok_part
parts.append(rp.build_full_line_with_annotations(curr_tok_part_final))
return rp.custom_style(rp.siblings(*parts), css_style="white-space: pre-wrap")

def to_dict(self) -> dict[str, Any]:
out_dict = {k: v for k, v in self.__dict__.items() if k not in ["cci_scores", "info"]}
out_dict["cci_scores"] = [cci_out.to_dict() for cci_out in self.cci_scores]
Expand All @@ -71,6 +174,30 @@ def from_dict(cls, out_dict: dict[str, Any]) -> "AttributeContextOutput":
out.info = AttributeContextArgs(**{k: v for k, v in out_dict["info"].items() if k in field_names})
return out

@property
def min_cti(self) -> float:
if self.cti_scores is None:
return -1
return min(self.cti_scores)

@property
def max_cti(self) -> float:
if self.cti_scores is None:
return -1
return max(self.cti_scores)

@property
def min_cci(self) -> float:
if self.cci_scores is None:
return -1
return min(cci.minimum for cci in self.cci_scores)

@property
def max_cci(self) -> float:
if self.cci_scores is None:
return -1
return max(cci.maximum for cci in self.cci_scores)


def concat_with_sep(s1: str, s2: str, sep: str) -> bool:
"""Adds separator between two strings if needed."""
Expand Down
20 changes: 15 additions & 5 deletions inseq/data/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,9 +299,11 @@ def show_token_attributions(
)
generated_token_parts = []
for attr in attributions:
cleaned_generated_tokens = clean_tokens(t.token for t in attr.target[attr.attr_pos_start : attr.attr_pos_end])
cleaned_input_tokens = clean_tokens(t.token for t in attr.source)
cleaned_target_tokens = clean_tokens(t.token for t in attr.target)
cleaned_generated_tokens = clean_tokens(
[t.token for t in attr.target[attr.attr_pos_start : attr.attr_pos_end]], replace_chars=replace_char
)
cleaned_input_tokens = clean_tokens([t.token for t in attr.source], replace_chars=replace_char)
cleaned_target_tokens = clean_tokens([t.token for t in attr.target], replace_chars=replace_char)
step_scores = None
title = "Generated text:\n\n"
if step_score_highlight is not None:
Expand Down Expand Up @@ -604,6 +606,7 @@ def get_tokens_heatmap_treescope(
min_val: float | None = None,
max_val: float | None = None,
wrap_after: int | str | list[str] | tuple[str] | None = None,
colormap: str | list[tuple[int, int, int]] | None = None,
):
parts = []
if title is not None:
Expand All @@ -613,11 +616,18 @@ def get_tokens_heatmap_treescope(
css_style=title_style,
)
)
if colormap is None:
colormap = treescope_cmap("blue_to_red")
elif isinstance(colormap, str):
colormap = treescope_cmap(colormap)
elif not isinstance(colormap, list):
raise ValueError("If specified, colormap must be a string or a list of RGB tuples.")

for idx, tok in enumerate(tokens):
if not np.isnan(scores[idx]):
if not np.isnan(scores[idx]) and tok != "":
parts.append(
fg.treescope_part_from_display_object(
fg.text_on_color(tok, value=round(scores[idx], 4), vmin=min_val, vmax=max_val)
fg.text_on_color(tok, value=round(scores[idx], 4), vmin=min_val, vmax=max_val, colormap=colormap)
)
)
parts += maybe_add_linebreak(tok, idx, wrap_after)
Expand Down

0 comments on commit 0c2eb0c

Please sign in to comment.