From d5ec9c93b4f66271306fb0e77de2fb6fa1b9cbb7 Mon Sep 17 00:00:00 2001 From: Artur Pragacz Date: Thu, 24 Oct 2024 00:26:50 +0200 Subject: [PATCH 1/8] Clarify difference between Group and Sequence --- hassil/__init__.py | 4 +- hassil/expression.py | 51 ++++++------ hassil/parse_expression.py | 158 +++++++++++++++++++------------------ hassil/sample.py | 18 ++--- hassil/string_matcher.py | 20 ++--- tests/test_expression.py | 42 +++++----- 6 files changed, 148 insertions(+), 145 deletions(-) diff --git a/hassil/__init__.py b/hassil/__init__.py index ec38dc6..8f24fff 100644 --- a/hassil/__init__.py +++ b/hassil/__init__.py @@ -1,11 +1,11 @@ """Home Assistant Intent Language parser""" from .expression import ( + Group, + GroupType, ListReference, RuleReference, Sentence, - Sequence, - SequenceType, TextChunk, ) from .intents import Intents diff --git a/hassil/expression.py b/hassil/expression.py index 9ffbfde..08977ba 100644 --- a/hassil/expression.py +++ b/hassil/expression.py @@ -22,7 +22,7 @@ class TextChunk(Expression): # Set in __post_init__ original_text: str = None # type: ignore - parent: "Optional[Sequence]" = None + parent: "Optional[Group]" = None def __post_init__(self): if self.original_text is None: @@ -39,11 +39,11 @@ def empty() -> "TextChunk": return TextChunk() -class SequenceType(str, Enum): - """Type of a sequence. Optionals are alternatives with an empty option.""" +class GroupType(str, Enum): + """Type of a group. Optionals are alternatives with an empty option.""" # Sequence of expressions - GROUP = "group" + SEQUENCE = "sequence" # Expressions where only one will be recognized ALTERNATIVE = "alternative" @@ -53,26 +53,25 @@ class SequenceType(str, Enum): @dataclass -class Sequence(Expression): - """Ordered sequence of expressions. Supports groups, optionals, and alternatives.""" +class Group(Expression): + """Ordered group of expressions. Supports sequences, optionals, and alternatives.""" - # Items in the sequence + # Items in the group items: List[Expression] = field(default_factory=list) - # Group or alternative - type: SequenceType = SequenceType.GROUP + type: GroupType = GroupType.SEQUENCE is_optional: bool = False def text_chunk_count(self) -> int: - """Return the number of TextChunk expressions in this sequence (recursive).""" + """Return the number of TextChunk expressions in this group (recursive).""" num_text_chunks = 0 for item in self.items: if isinstance(item, TextChunk): num_text_chunks += 1 - elif isinstance(item, Sequence): - seq: Sequence = item - num_text_chunks += seq.text_chunk_count() + elif isinstance(item, Group): + grp: Group = item + num_text_chunks += grp.text_chunk_count() return num_text_chunks @@ -93,9 +92,9 @@ def _list_names( if isinstance(item, ListReference): list_ref: ListReference = item yield list_ref.list_name - elif isinstance(item, Sequence): - seq: Sequence = item - yield from seq.list_names(expansion_rules) + elif isinstance(item, Group): + grp: Group = item + yield from grp.list_names(expansion_rules) elif isinstance(item, RuleReference): rule_ref: RuleReference = item if expansion_rules and (rule_ref.rule_name in expansion_rules): @@ -135,8 +134,8 @@ def slot_name(self) -> str: @dataclass -class Sentence(Sequence): - """Sequence representing a complete sentence template.""" +class Sentence(Group): + """Group representing a complete sentence template.""" text: Optional[str] = None pattern: Optional[re.Pattern] = None @@ -161,23 +160,23 @@ def _compile_expression( if chunk.text: escaped_text = re.escape(chunk.text) pattern_chunks.append(escaped_text) - elif isinstance(exp, Sequence): + elif isinstance(exp, Group): # Linear sequence or alternative choices - seq: Sequence = exp - if seq.type == SequenceType.GROUP: + grp: Group = exp + if grp.type == GroupType.SEQUENCE: # Linear sequence - for item in seq.items: + for item in grp.items: self._compile_expression(item, pattern_chunks, rules) - elif seq.type == SequenceType.ALTERNATIVE: + elif grp.type == GroupType.ALTERNATIVE: # Alternative choices - if seq.items: + if grp.items: pattern_chunks.append("(?:") - for item in seq.items: + for item in grp.items: self._compile_expression(item, pattern_chunks, rules) pattern_chunks.append("|") pattern_chunks[-1] = ")" else: - raise ValueError(seq) + raise ValueError(grp) elif isinstance(exp, ListReference): # Slot list pattern_chunks.append("(?:.+)") diff --git a/hassil/parse_expression.py b/hassil/parse_expression.py index 4be2937..a2cb647 100644 --- a/hassil/parse_expression.py +++ b/hassil/parse_expression.py @@ -5,11 +5,11 @@ from .expression import ( Expression, + Group, + GroupType, ListReference, RuleReference, Sentence, - Sequence, - SequenceType, TextChunk, ) from .parser import ( @@ -48,45 +48,45 @@ def __str__(self) -> str: return f"Error in chunk {self.chunk} at {self.metadata}" -def _ensure_alternative(seq: Sequence): - if seq.type != SequenceType.ALTERNATIVE: - seq.type = SequenceType.ALTERNATIVE +def _ensure_alternative(grp: Group): + if grp.type != GroupType.ALTERNATIVE: + grp.type = GroupType.ALTERNATIVE # Collapse items into a single group - seq.items = [ - Sequence( - type=SequenceType.GROUP, - items=seq.items, + grp.items = [ + Group( + type=GroupType.SEQUENCE, + items=grp.items, ) ] -def _ensure_permutation(seq: Sequence): - if seq.type != SequenceType.PERMUTATION: - seq.type = SequenceType.PERMUTATION +def _ensure_permutation(grp: Group): + if grp.type != GroupType.PERMUTATION: + grp.type = GroupType.PERMUTATION # Collapse items into a single group - seq.items = [ - Sequence( - type=SequenceType.GROUP, - items=seq.items, + grp.items = [ + Group( + type=GroupType.SEQUENCE, + items=grp.items, ) ] -def parse_group_or_alt_or_perm( - seq_chunk: ParseChunk, metadata: Optional[ParseMetadata] = None -) -> Sequence: - seq = Sequence(type=SequenceType.GROUP) - if seq_chunk.parse_type == ParseType.GROUP: - seq_text = _remove_delimiters(seq_chunk.text, GROUP_START, GROUP_END) - elif seq_chunk.parse_type == ParseType.OPT: - seq_text = _remove_delimiters(seq_chunk.text, OPT_START, OPT_END) +def parse_group( + grp_chunk: ParseChunk, metadata: Optional[ParseMetadata] = None +) -> Group: + grp = Group(type=GroupType.SEQUENCE) + if grp_chunk.parse_type == ParseType.GROUP: + grp_text = _remove_delimiters(grp_chunk.text, GROUP_START, GROUP_END) + elif grp_chunk.parse_type == ParseType.OPT: + grp_text = _remove_delimiters(grp_chunk.text, OPT_START, OPT_END) else: - raise ParseExpressionError(seq_chunk, metadata=metadata) + raise ParseExpressionError(grp_chunk, metadata=metadata) - item_chunk = next_chunk(seq_text) - last_seq_text = seq_text + item_chunk = next_chunk(grp_text) + last_grp_text = grp_text while item_chunk is not None: if item_chunk.parse_type in ( @@ -98,59 +98,60 @@ def parse_group_or_alt_or_perm( ): item = parse_expression(item_chunk, metadata=metadata) - if seq.type in (SequenceType.ALTERNATIVE, SequenceType.PERMUTATION): + if grp.type in (GroupType.ALTERNATIVE, GroupType.PERMUTATION): # Add to most recent group - if not seq.items: - seq.items.append(Sequence(type=SequenceType.GROUP)) + if not grp.items: + grp.items.append(Group(type=GroupType.SEQUENCE)) - # Must be group or alternative - last_item = seq.items[-1] - if not isinstance(last_item, Sequence): - raise ParseExpressionError(seq_chunk, metadata=metadata) + # Must be a group + last_item = grp.items[-1] + if not isinstance(last_item, Group): + raise ParseExpressionError(grp_chunk, metadata=metadata) last_item.items.append(item) else: # Add to parent group - seq.items.append(item) + grp.items.append(item) if isinstance(item, TextChunk): item_tc: TextChunk = item - item_tc.parent = seq + item_tc.parent = grp + elif item_chunk.parse_type == ParseType.ALT: - _ensure_alternative(seq) + _ensure_alternative(grp) # Begin new group - seq.items.append(Sequence(type=SequenceType.GROUP)) + grp.items.append(Group(type=GroupType.SEQUENCE)) elif item_chunk.parse_type == ParseType.PERM: - _ensure_permutation(seq) + _ensure_permutation(grp) # Begin new group - seq.items.append(Sequence(type=SequenceType.GROUP)) + grp.items.append(Group(type=GroupType.SEQUENCE)) else: - raise ParseExpressionError(seq_chunk, metadata=metadata) + raise ParseExpressionError(grp_chunk, metadata=metadata) # Next chunk - seq_text = seq_text[item_chunk.end_index :] + grp_text = grp_text[item_chunk.end_index :] - if seq_text == last_seq_text: + if grp_text == last_grp_text: # No change, unable to proceed - raise ParseExpressionError(seq_chunk, metadata=metadata) + raise ParseExpressionError(grp_chunk, metadata=metadata) - item_chunk = next_chunk(seq_text) - last_seq_text = seq_text + item_chunk = next_chunk(grp_text) + last_grp_text = grp_text - if seq.type == SequenceType.PERMUTATION: + if grp.type == GroupType.PERMUTATION: permuted_items: List[Expression] = [] - for permutation in permutations(seq.items): + for permutation in permutations(grp.items): permutation_with_spaces = _add_spaces_between_items(list(permutation)) permuted_items.append( - Sequence(type=SequenceType.GROUP, items=permutation_with_spaces) + Group(type=GroupType.SEQUENCE, items=permutation_with_spaces) ) - seq = Sequence(type=SequenceType.ALTERNATIVE, items=permuted_items) + grp = Group(type=GroupType.ALTERNATIVE, items=permuted_items) - return seq + return grp def parse_expression( @@ -162,14 +163,14 @@ def parse_expression( return TextChunk(text=text, original_text=original_text) if chunk.parse_type == ParseType.GROUP: - return parse_group_or_alt_or_perm(chunk, metadata=metadata) + return parse_group(chunk, metadata=metadata) if chunk.parse_type == ParseType.OPT: - seq = parse_group_or_alt_or_perm(chunk, metadata=metadata) - _ensure_alternative(seq) - seq.items.append(TextChunk(text="", parent=seq)) - seq.is_optional = True - return seq + grp = parse_group(chunk, metadata=metadata) + _ensure_alternative(grp) + grp.items.append(TextChunk(text="", parent=grp)) + grp.is_optional = True + return grp if chunk.parse_type == ParseType.LIST: text = _remove_escapes(chunk.text) @@ -192,7 +193,7 @@ def parse_sentence( text = text.strip() # text = fix_pattern_whitespace(text.strip()) - # Wrap in a group because sentences need to always be sequences. + # Wrap in a group because sentences need to always be groups. text = f"({text})" chunk = next_chunk(text) @@ -208,21 +209,21 @@ def parse_sentence( if chunk.end_index != len(text): raise ParseError(f"Expected chunk to end at index {chunk.end_index} in: {text}") - seq = parse_expression(chunk, metadata=metadata) - if not isinstance(seq, Sequence): - raise ParseError(f"Expected Sequence, got: {seq}") + grp = parse_expression(chunk, metadata=metadata) + if not isinstance(grp, Group): + raise ParseError(f"Expected Group, got: {grp}") - # Unpack redundant sequence - if len(seq.items) == 1: - first_item = seq.items[0] - if isinstance(first_item, Sequence): - seq = first_item + # Unpack redundant group + if len(grp.items) == 1: + first_item = grp.items[0] + if isinstance(first_item, Group): + grp = first_item return Sentence( - type=seq.type, - items=seq.items, + type=grp.type, + items=grp.items, text=original_text if keep_text else None, - is_optional=seq.is_optional, + is_optional=grp.is_optional, ) @@ -382,15 +383,16 @@ def _escape_text(text: str) -> str: def _add_spaces_between_items(items: List[Expression]) -> List[Expression]: - """Add spaces between each 2 items of a sequence, used for permutations""" + """Add spaces between each 2 items of a group, used for permutations""" + spaced_items: List[Expression] = [] # Unpack single item sequences to make pattern matching easier below unpacked_items: List[Expression] = [] for item in items: while ( - isinstance(item, Sequence) - and (item.type == SequenceType.GROUP) + isinstance(item, Group) + and (item.type == GroupType.SEQUENCE) and (len(item.items) == 1) ): item = item.items[0] @@ -401,10 +403,10 @@ def _add_spaces_between_items(items: List[Expression]) -> List[Expression]: for item_idx, item in enumerate(unpacked_items): if item_idx > 0: # Only add whitespace after the first item - if isinstance(previous_item, Sequence) and previous_item.is_optional: + if isinstance(previous_item, Group) and previous_item.is_optional: # Modify the previous optional to include a space at the end of # each item. - opt: Sequence = previous_item + opt: Group = previous_item fixed_items: List[Expression] = [] for opt_item in opt.items: fix_item = True @@ -420,16 +422,16 @@ def _add_spaces_between_items(items: List[Expression]) -> List[Expression]: if fix_item: fixed_items.append( - Sequence( - type=SequenceType.GROUP, + Group( + type=GroupType.SEQUENCE, items=[opt_item, TextChunk(" ")], ) ) else: fixed_items.append(opt_item) - spaced_items[-1] = Sequence( - type=SequenceType.ALTERNATIVE, is_optional=True, items=fixed_items + spaced_items[-1] = Group( + type=GroupType.ALTERNATIVE, is_optional=True, items=fixed_items ) else: # Add a space in front diff --git a/hassil/sample.py b/hassil/sample.py index c627250..0899044 100644 --- a/hassil/sample.py +++ b/hassil/sample.py @@ -15,11 +15,11 @@ from .errors import MissingListError, MissingRuleError from .expression import ( Expression, + Group, + GroupType, ListReference, RuleReference, Sentence, - Sequence, - SequenceType, TextChunk, ) from .intents import Intents, RangeSlotList, SlotList, TextSlotList, WildcardSlotList @@ -117,10 +117,10 @@ def sample_expression( if isinstance(expression, TextChunk): chunk: TextChunk = expression yield chunk.original_text - elif isinstance(expression, Sequence): - seq: Sequence = expression - if seq.type == SequenceType.ALTERNATIVE: - for item in seq.items: + elif isinstance(expression, Group): + grp: Group = expression + if grp.type == GroupType.ALTERNATIVE: + for item in grp.items: yield from sample_expression( item, slot_lists, @@ -129,7 +129,7 @@ def sample_expression( expand_lists=expand_lists, expand_ranges=expand_ranges, ) - elif seq.type == SequenceType.GROUP: + elif grp.type == GroupType.SEQUENCE: seq_sentences = map( partial( sample_expression, @@ -139,13 +139,13 @@ def sample_expression( expand_lists=expand_lists, expand_ranges=expand_ranges, ), - seq.items, + grp.items, ) sentence_texts = itertools.product(*seq_sentences) for sentence_words in sentence_texts: yield normalize_whitespace("".join(sentence_words)) else: - raise ValueError(f"Unexpected sequence type: {seq}") + raise ValueError(f"Unexpected group type: {grp}") elif isinstance(expression, ListReference): # {list} list_ref: ListReference = expression diff --git a/hassil/string_matcher.py b/hassil/string_matcher.py index 06becd5..3e235fd 100644 --- a/hassil/string_matcher.py +++ b/hassil/string_matcher.py @@ -11,11 +11,11 @@ from .errors import MissingListError, MissingRuleError from .expression import ( Expression, + Group, + GroupType, ListReference, RuleReference, Sentence, - Sequence, - SequenceType, TextChunk, ) from .intents import IntentData, RangeSlotList, SlotList, TextSlotList, WildcardSlotList @@ -398,19 +398,19 @@ def match_expression( else: # Match failed pass - elif isinstance(expression, Sequence): - seq: Sequence = expression - if seq.type == SequenceType.ALTERNATIVE: + elif isinstance(expression, Group): + grp: Group = expression + if grp.type == GroupType.ALTERNATIVE: # Any may match (words | in | alternative) # NOTE: [optional] = (optional | ) - for item in seq.items: + for item in grp.items: yield from match_expression(settings, context, item) - elif seq.type == SequenceType.GROUP: - if seq.items: + elif grp.type == GroupType.SEQUENCE: + if grp.items: # All must match (words in group) group_contexts = [context] - for item in seq.items: + for item in grp.items: # Next step group_contexts = [ item_context @@ -424,7 +424,7 @@ def match_expression( yield from group_contexts else: - raise ValueError(f"Unexpected sequence type: {seq}") + raise ValueError(f"Unexpected group type: {grp}") elif isinstance(expression, ListReference): # {list} diff --git a/tests/test_expression.py b/tests/test_expression.py index 15de448..b5f6e32 100644 --- a/tests/test_expression.py +++ b/tests/test_expression.py @@ -1,11 +1,11 @@ from unittest.mock import ANY from hassil.expression import ( + Group, + GroupType, ListReference, RuleReference, Sentence, - Sequence, - SequenceType, TextChunk, ) from hassil.parse_expression import parse_expression, parse_sentence @@ -19,13 +19,13 @@ def test_word(): def test_group_in_group(): - assert parse_expression(next_chunk("((test test2))")) == group( - items=[group(items=[t(text="test "), t(text="test2")])], + assert parse_expression(next_chunk("((test test2))")) == sequence( + items=[sequence(items=[t(text="test "), t(text="test2")])], ) def test_escapes(): - assert parse_expression(next_chunk(r"(test\<\>\{\}\)\( test2)")) == group( + assert parse_expression(next_chunk(r"(test\<\>\{\}\)\( test2)")) == sequence( items=[t(text="test<>{})( "), t(text="test2")], ) @@ -33,7 +33,7 @@ def test_escapes(): def test_optional(): assert parse_expression(next_chunk("[test test2]")) == alt( items=[ - group( + sequence( items=[t(text="test "), t(text="test2")], ), t(text=""), @@ -44,15 +44,15 @@ def test_optional(): def test_group_alternative(): assert parse_expression(next_chunk("(test | test2)")) == alt( - items=[group(items=[t(text="test ")]), group(items=[t(text=" test2")])], + items=[sequence(items=[t(text="test ")]), sequence(items=[t(text=" test2")])], ) def test_group_permutation(): assert parse_expression(next_chunk("(test; test2)")) == alt( items=[ - group(items=[t(text="test"), t(text=" "), t(text=" test2")]), - group(items=[t(text=" test2"), t(text=" "), t(text="test")]), + sequence(items=[t(text="test"), t(text=" "), t(text=" test2")]), + sequence(items=[t(text=" test2"), t(text=" "), t(text="test")]), ], ) @@ -79,9 +79,9 @@ def test_sentence_group(): def test_sentence_optional(): assert parse_sentence("[this is a test]") == Sentence( - type=SequenceType.ALTERNATIVE, + type=GroupType.ALTERNATIVE, items=[ - group( + sequence( items=[ t(text="this "), t(text="is "), @@ -97,9 +97,9 @@ def test_sentence_optional(): def test_sentence_optional_prefix(): assert parse_sentence("[t]est") == Sentence( - type=SequenceType.GROUP, + type=GroupType.SEQUENCE, items=[ - alt(items=[group(items=[t(text="t")]), t(text="")], is_optional=True), + alt(items=[sequence(items=[t(text="t")]), t(text="")], is_optional=True), t(text="est"), ], ) @@ -107,20 +107,22 @@ def test_sentence_optional_prefix(): def test_sentence_optional_suffix(): assert parse_sentence("test[s]") == Sentence( - type=SequenceType.GROUP, + type=GroupType.SEQUENCE, items=[ t(text="test"), - alt(items=[group(items=[t(text="s")]), t(text="")], is_optional=True), + alt(items=[sequence(items=[t(text="s")]), t(text="")], is_optional=True), ], ) def test_sentence_alternative_whitespace(): assert parse_sentence("test ( 1 | 2)") == Sentence( - type=SequenceType.GROUP, + type=GroupType.SEQUENCE, items=[ t(text="test "), - alt(items=[group(items=[t(text=" 1 ")]), group(items=[t(text=" 2")])]), + alt( + items=[sequence(items=[t(text=" 1 ")]), sequence(items=[t(text=" 2")])] + ), ], ) @@ -141,9 +143,9 @@ def t(**kwargs): return TextChunk(parent=ANY, **kwargs) -def group(**kwargs): - return Sequence(type=SequenceType.GROUP, **kwargs) +def sequence(**kwargs): + return Group(type=GroupType.SEQUENCE, **kwargs) def alt(**kwargs): - return Sequence(type=SequenceType.ALTERNATIVE, **kwargs) + return Group(type=GroupType.ALTERNATIVE, **kwargs) From a5377934d1ea99cf80e157f2ea63f72dce1092d0 Mon Sep 17 00:00:00 2001 From: Artur Pragacz Date: Thu, 24 Oct 2024 07:48:36 +0200 Subject: [PATCH 2/8] Improve permutations --- hassil/expression.py | 12 ++++-- hassil/parse_expression.py | 81 ++++---------------------------------- hassil/string_matcher.py | 14 +++++++ tests/test_expression.py | 10 +++-- 4 files changed, 38 insertions(+), 79 deletions(-) diff --git a/hassil/expression.py b/hassil/expression.py index 08977ba..6e62846 100644 --- a/hassil/expression.py +++ b/hassil/expression.py @@ -147,13 +147,12 @@ def compile(self, expansion_rules: Dict[str, "Sentence"]) -> None: pattern_chunks: List[str] = [] self._compile_expression(self, pattern_chunks, expansion_rules) - pattern_str = "".join(pattern_chunks).replace(r"\ ", r"[ ]*") self.pattern = re.compile(f"^{pattern_str}$", re.IGNORECASE) def _compile_expression( self, exp: Expression, pattern_chunks: List[str], rules: Dict[str, "Sentence"] - ): + ) -> None: if isinstance(exp, TextChunk): # Literal text chunk: TextChunk = exp @@ -161,7 +160,6 @@ def _compile_expression( escaped_text = re.escape(chunk.text) pattern_chunks.append(escaped_text) elif isinstance(exp, Group): - # Linear sequence or alternative choices grp: Group = exp if grp.type == GroupType.SEQUENCE: # Linear sequence @@ -175,6 +173,14 @@ def _compile_expression( self._compile_expression(item, pattern_chunks, rules) pattern_chunks.append("|") pattern_chunks[-1] = ")" + elif grp.type == GroupType.PERMUTATION: + # Permutation + if grp.items: + pattern_chunks.append("(?:") + for item in grp.items: + self._compile_expression(item, pattern_chunks, rules) + pattern_chunks.append("|") + pattern_chunks[-1] = f"){{{len(grp.items)}}}" else: raise ValueError(grp) elif isinstance(exp, ListReference): diff --git a/hassil/parse_expression.py b/hassil/parse_expression.py index a2cb647..7f24ab7 100644 --- a/hassil/parse_expression.py +++ b/hassil/parse_expression.py @@ -1,7 +1,6 @@ import re from dataclasses import dataclass -from itertools import permutations -from typing import List, Optional +from typing import Optional from .expression import ( Expression, @@ -100,10 +99,6 @@ def parse_group( if grp.type in (GroupType.ALTERNATIVE, GroupType.PERMUTATION): # Add to most recent group - if not grp.items: - grp.items.append(Group(type=GroupType.SEQUENCE)) - - # Must be a group last_item = grp.items[-1] if not isinstance(last_item, Group): raise ParseExpressionError(grp_chunk, metadata=metadata) @@ -141,15 +136,7 @@ def parse_group( last_grp_text = grp_text if grp.type == GroupType.PERMUTATION: - permuted_items: List[Expression] = [] - - for permutation in permutations(grp.items): - permutation_with_spaces = _add_spaces_between_items(list(permutation)) - permuted_items.append( - Group(type=GroupType.SEQUENCE, items=permutation_with_spaces) - ) - - grp = Group(type=GroupType.ALTERNATIVE, items=permuted_items) + _add_spaces_between_items(grp) return grp @@ -382,62 +369,10 @@ def _escape_text(text: str) -> str: return re.sub(r"([()\[\]{}<>])", r"\\\1", text) -def _add_spaces_between_items(items: List[Expression]) -> List[Expression]: +def _add_spaces_between_items(grp: Group) -> None: """Add spaces between each 2 items of a group, used for permutations""" - - spaced_items: List[Expression] = [] - - # Unpack single item sequences to make pattern matching easier below - unpacked_items: List[Expression] = [] - for item in items: - while ( - isinstance(item, Group) - and (item.type == GroupType.SEQUENCE) - and (len(item.items) == 1) - ): - item = item.items[0] - - unpacked_items.append(item) - - previous_item: Optional[Expression] = None - for item_idx, item in enumerate(unpacked_items): - if item_idx > 0: - # Only add whitespace after the first item - if isinstance(previous_item, Group) and previous_item.is_optional: - # Modify the previous optional to include a space at the end of - # each item. - opt: Group = previous_item - fixed_items: List[Expression] = [] - for opt_item in opt.items: - fix_item = True - if isinstance(opt_item, TextChunk): - opt_tc: TextChunk = opt_item - if not opt_tc.text: - # Don't fix empty text chunks - fix_item = False - else: - # Remove ending whitespace since we'll be adding a - # whitespace text chunk after. - opt_tc.text = opt_tc.text.rstrip() - - if fix_item: - fixed_items.append( - Group( - type=GroupType.SEQUENCE, - items=[opt_item, TextChunk(" ")], - ) - ) - else: - fixed_items.append(opt_item) - - spaced_items[-1] = Group( - type=GroupType.ALTERNATIVE, is_optional=True, items=fixed_items - ) - else: - # Add a space in front - spaced_items.append(TextChunk(text=" ")) - - spaced_items.append(item) - previous_item = item - - return spaced_items + for seq in grp.items: + assert isinstance(seq, Group), "Item is not a group" + assert seq.type == GroupType.SEQUENCE, "Item is not a sequence" + seq.items.insert(0, TextChunk(text=" ")) + seq.items.append(TextChunk(text=" ")) diff --git a/hassil/string_matcher.py b/hassil/string_matcher.py index 3e235fd..73b9667 100644 --- a/hassil/string_matcher.py +++ b/hassil/string_matcher.py @@ -423,6 +423,20 @@ def match_expression( break yield from group_contexts + + elif grp.type == GroupType.PERMUTATION: + if len(grp.items) == 1: + yield from match_expression(settings, context, grp.items[0]) + else: + # All must match (in arbitrary order) + for i, item in enumerate(grp.items): + items = grp.items[:] + del items[i] + perm = Group(type=GroupType.PERMUTATION, items=items) + + for item_context in match_expression(settings, context, item): + yield from match_expression(settings, item_context, perm) + else: raise ValueError(f"Unexpected group type: {grp}") diff --git a/tests/test_expression.py b/tests/test_expression.py index b5f6e32..5d02b60 100644 --- a/tests/test_expression.py +++ b/tests/test_expression.py @@ -49,10 +49,10 @@ def test_group_alternative(): def test_group_permutation(): - assert parse_expression(next_chunk("(test; test2)")) == alt( + assert parse_expression(next_chunk("(test; test2)")) == perm( items=[ - sequence(items=[t(text="test"), t(text=" "), t(text=" test2")]), - sequence(items=[t(text=" test2"), t(text=" "), t(text="test")]), + sequence(items=[t(text=" "), t(text="test"), t(text=" ")]), + sequence(items=[t(text=" "), t(text=" test2"), t(text=" ")]), ], ) @@ -149,3 +149,7 @@ def sequence(**kwargs): def alt(**kwargs): return Group(type=GroupType.ALTERNATIVE, **kwargs) + + +def perm(**kwargs): + return Group(type=GroupType.PERMUTATION, **kwargs) From 3647df47deac9932192b29574a07b6b0eed04e8f Mon Sep 17 00:00:00 2001 From: Artur Pragacz Date: Fri, 20 Dec 2024 00:28:33 +0100 Subject: [PATCH 3/8] Subgroup classes --- hassil/__init__.py | 4 +- hassil/expression.py | 68 ++++++++++++---------- hassil/intents.py | 2 +- hassil/parse_expression.py | 66 +++++++++------------ hassil/recognize.py | 4 +- hassil/sample.py | 24 ++++++-- hassil/string_matcher.py | 14 +++-- tests/test_expression.py | 116 +++++++++++++++++++------------------ tests/test_sample.py | 24 ++++---- 9 files changed, 169 insertions(+), 153 deletions(-) diff --git a/hassil/__init__.py b/hassil/__init__.py index 8f24fff..407c254 100644 --- a/hassil/__init__.py +++ b/hassil/__init__.py @@ -1,11 +1,13 @@ """Home Assistant Intent Language parser""" from .expression import ( + Alternative, Group, - GroupType, ListReference, + Permutation, RuleReference, Sentence, + Sequence, TextChunk, ) from .intents import Intents diff --git a/hassil/expression.py b/hassil/expression.py index 6e62846..ead21e4 100644 --- a/hassil/expression.py +++ b/hassil/expression.py @@ -3,7 +3,6 @@ import re from abc import ABC from dataclasses import dataclass, field -from enum import Enum from typing import Dict, Iterator, List, Optional @@ -39,19 +38,6 @@ def empty() -> "TextChunk": return TextChunk() -class GroupType(str, Enum): - """Type of a group. Optionals are alternatives with an empty option.""" - - # Sequence of expressions - SEQUENCE = "sequence" - - # Expressions where only one will be recognized - ALTERNATIVE = "alternative" - - # Permutations of a set of expressions - PERMUTATION = "permutation" - - @dataclass class Group(Expression): """Ordered group of expressions. Supports sequences, optionals, and alternatives.""" @@ -59,10 +45,6 @@ class Group(Expression): # Items in the group items: List[Expression] = field(default_factory=list) - type: GroupType = GroupType.SEQUENCE - - is_optional: bool = False - def text_chunk_count(self) -> int: """Return the number of TextChunk expressions in this group (recursive).""" num_text_chunks = 0 @@ -98,10 +80,27 @@ def _list_names( elif isinstance(item, RuleReference): rule_ref: RuleReference = item if expansion_rules and (rule_ref.rule_name in expansion_rules): - rule_body = expansion_rules[rule_ref.rule_name] + rule_body = expansion_rules[rule_ref.rule_name].exp yield from self._list_names(rule_body, expansion_rules) +@dataclass +class Sequence(Group): + """Sequence of expressions""" + + +@dataclass +class Alternative(Group): + """Expressions where only one will be recognized""" + + is_optional: bool = False + + +@dataclass +class Permutation(Group): + """Permutations of a set of expressions""" + + @dataclass class RuleReference(Expression): """Reference to an expansion rule by .""" @@ -134,19 +133,33 @@ def slot_name(self) -> str: @dataclass -class Sentence(Group): - """Group representing a complete sentence template.""" +class Sentence: + """A complete sentence template.""" + exp: Expression text: Optional[str] = None pattern: Optional[re.Pattern] = None + def text_chunk_count(self) -> int: + """Return the number of TextChunk expressions in this sentence.""" + assert isinstance(self.exp, Group) + return self.exp.text_chunk_count() # pylint: disable=no-member + + def list_names( + self, + expansion_rules: Optional[Dict[str, "Sentence"]] = None, + ) -> Iterator[str]: + """Return names of list references in this sentence.""" + assert isinstance(self.exp, Group) + return self.exp.list_names(expansion_rules) # pylint: disable=no-member + def compile(self, expansion_rules: Dict[str, "Sentence"]) -> None: if self.pattern is not None: # Already compiled return pattern_chunks: List[str] = [] - self._compile_expression(self, pattern_chunks, expansion_rules) + self._compile_expression(self.exp, pattern_chunks, expansion_rules) pattern_str = "".join(pattern_chunks).replace(r"\ ", r"[ ]*") self.pattern = re.compile(f"^{pattern_str}$", re.IGNORECASE) @@ -161,20 +174,17 @@ def _compile_expression( pattern_chunks.append(escaped_text) elif isinstance(exp, Group): grp: Group = exp - if grp.type == GroupType.SEQUENCE: - # Linear sequence + if isinstance(grp, Sequence): for item in grp.items: self._compile_expression(item, pattern_chunks, rules) - elif grp.type == GroupType.ALTERNATIVE: - # Alternative choices + elif isinstance(grp, Alternative): if grp.items: pattern_chunks.append("(?:") for item in grp.items: self._compile_expression(item, pattern_chunks, rules) pattern_chunks.append("|") pattern_chunks[-1] = ")" - elif grp.type == GroupType.PERMUTATION: - # Permutation + elif isinstance(grp, Permutation): if grp.items: pattern_chunks.append("(?:") for item in grp.items: @@ -194,6 +204,6 @@ def _compile_expression( raise ValueError(rule_ref) e_rule = rules[rule_ref.rule_name] - self._compile_expression(e_rule, pattern_chunks, rules) + self._compile_expression(e_rule.exp, pattern_chunks, rules) else: raise ValueError(exp) diff --git a/hassil/intents.py b/hassil/intents.py index 5e2a800..719f24a 100644 --- a/hassil/intents.py +++ b/hassil/intents.py @@ -454,6 +454,6 @@ def _parse_data_settings(settings_dict: Dict[str, Any]) -> IntentDataSettings: def _maybe_parse_template(text: str, allow_template: bool = True) -> Expression: """Parse string as a sentence template if it has template syntax.""" if allow_template and is_template(text): - return parse_sentence(text) + return parse_sentence(text).exp return TextChunk(normalize_text(text)) diff --git a/hassil/parse_expression.py b/hassil/parse_expression.py index 7f24ab7..f47da2c 100644 --- a/hassil/parse_expression.py +++ b/hassil/parse_expression.py @@ -3,12 +3,14 @@ from typing import Optional from .expression import ( + Alternative, Expression, Group, - GroupType, ListReference, + Permutation, RuleReference, Sentence, + Sequence, TextChunk, ) from .parser import ( @@ -47,36 +49,24 @@ def __str__(self) -> str: return f"Error in chunk {self.chunk} at {self.metadata}" -def _ensure_alternative(grp: Group): - if grp.type != GroupType.ALTERNATIVE: - grp.type = GroupType.ALTERNATIVE - - # Collapse items into a single group - grp.items = [ - Group( - type=GroupType.SEQUENCE, - items=grp.items, - ) - ] - +def _ensure_alternative(grp: Group, **kw) -> Alternative: + if isinstance(grp, Alternative): + return grp + # Collapse items into a single group + return Alternative(items=[Sequence(items=grp.items)], **kw) -def _ensure_permutation(grp: Group): - if grp.type != GroupType.PERMUTATION: - grp.type = GroupType.PERMUTATION - # Collapse items into a single group - grp.items = [ - Group( - type=GroupType.SEQUENCE, - items=grp.items, - ) - ] +def _ensure_permutation(grp: Group) -> Permutation: + if isinstance(grp, Permutation): + return grp + # Collapse items into a single group + return Permutation(items=[Sequence(items=grp.items)]) def parse_group( grp_chunk: ParseChunk, metadata: Optional[ParseMetadata] = None ) -> Group: - grp = Group(type=GroupType.SEQUENCE) + grp: Group = Sequence() if grp_chunk.parse_type == ParseType.GROUP: grp_text = _remove_delimiters(grp_chunk.text, GROUP_START, GROUP_END) elif grp_chunk.parse_type == ParseType.OPT: @@ -97,7 +87,7 @@ def parse_group( ): item = parse_expression(item_chunk, metadata=metadata) - if grp.type in (GroupType.ALTERNATIVE, GroupType.PERMUTATION): + if isinstance(grp, (Alternative, Permutation)): # Add to most recent group last_item = grp.items[-1] if not isinstance(last_item, Group): @@ -113,15 +103,15 @@ def parse_group( item_tc.parent = grp elif item_chunk.parse_type == ParseType.ALT: - _ensure_alternative(grp) + grp = _ensure_alternative(grp) # Begin new group - grp.items.append(Group(type=GroupType.SEQUENCE)) + grp.items.append(Sequence()) elif item_chunk.parse_type == ParseType.PERM: - _ensure_permutation(grp) + grp = _ensure_permutation(grp) # Begin new group - grp.items.append(Group(type=GroupType.SEQUENCE)) + grp.items.append(Sequence()) else: raise ParseExpressionError(grp_chunk, metadata=metadata) @@ -135,7 +125,7 @@ def parse_group( item_chunk = next_chunk(grp_text) last_grp_text = grp_text - if grp.type == GroupType.PERMUTATION: + if isinstance(grp, Permutation): _add_spaces_between_items(grp) return grp @@ -154,9 +144,8 @@ def parse_expression( if chunk.parse_type == ParseType.OPT: grp = parse_group(chunk, metadata=metadata) - _ensure_alternative(grp) + grp = _ensure_alternative(grp, is_optional=True) grp.items.append(TextChunk(text="", parent=grp)) - grp.is_optional = True return grp if chunk.parse_type == ParseType.LIST: @@ -207,10 +196,8 @@ def parse_sentence( grp = first_item return Sentence( - type=grp.type, - items=grp.items, + exp=grp, text=original_text if keep_text else None, - is_optional=grp.is_optional, ) @@ -369,10 +356,9 @@ def _escape_text(text: str) -> str: return re.sub(r"([()\[\]{}<>])", r"\\\1", text) -def _add_spaces_between_items(grp: Group) -> None: - """Add spaces between each 2 items of a group, used for permutations""" - for seq in grp.items: - assert isinstance(seq, Group), "Item is not a group" - assert seq.type == GroupType.SEQUENCE, "Item is not a sequence" +def _add_spaces_between_items(perm: Permutation) -> None: + """Add spaces between each 2 items of a permutation""" + for seq in perm.items: + assert isinstance(seq, Sequence), "Item is not a sequence" seq.items.insert(0, TextChunk(text=" ")) seq.items.append(TextChunk(text=" ")) diff --git a/hassil/recognize.py b/hassil/recognize.py index b614d39..0e8fbda 100644 --- a/hassil/recognize.py +++ b/hassil/recognize.py @@ -264,7 +264,7 @@ def recognize_all( intent_data=intent_data, ) maybe_match_contexts = match_expression( - match_settings, match_context, intent_sentence + match_settings, match_context, intent_sentence.exp ) yield from _process_match_contexts( maybe_match_contexts, @@ -432,7 +432,7 @@ def is_match( intent_sentence=sentence, ) - for maybe_match_context in match_expression(settings, match_context, sentence): + for maybe_match_context in match_expression(settings, match_context, sentence.exp): if maybe_match_context.is_match: return maybe_match_context diff --git a/hassil/sample.py b/hassil/sample.py index 0899044..d47a7bb 100644 --- a/hassil/sample.py +++ b/hassil/sample.py @@ -14,12 +14,13 @@ from .errors import MissingListError, MissingRuleError from .expression import ( + Alternative, Expression, Group, - GroupType, ListReference, RuleReference, Sentence, + Sequence, TextChunk, ) from .intents import Intents, RangeSlotList, SlotList, TextSlotList, WildcardSlotList @@ -81,7 +82,7 @@ def sample_intents( ): continue - sentence_texts = sample_expression( + sentence_texts = sample_sentence( intent_sentence, slot_lists, local_expansion_rules, @@ -105,6 +106,19 @@ def sample_intents( break +def sample_sentence( + sentence: Sentence, + slot_lists: Optional[Dict[str, SlotList]] = None, + expansion_rules: Optional[Dict[str, Sentence]] = None, + language: Optional[str] = None, + expand_lists: bool = True, + expand_ranges: bool = True, +) -> Iterable[str]: + return sample_expression( + sentence.exp, slot_lists, expansion_rules, language, expand_lists, expand_ranges + ) + + def sample_expression( expression: Expression, slot_lists: Optional[Dict[str, SlotList]] = None, @@ -119,7 +133,7 @@ def sample_expression( yield chunk.original_text elif isinstance(expression, Group): grp: Group = expression - if grp.type == GroupType.ALTERNATIVE: + if isinstance(grp, Alternative): for item in grp.items: yield from sample_expression( item, @@ -129,7 +143,7 @@ def sample_expression( expand_lists=expand_lists, expand_ranges=expand_ranges, ) - elif grp.type == GroupType.SEQUENCE: + elif isinstance(grp, Sequence): seq_sentences = map( partial( sample_expression, @@ -228,7 +242,7 @@ def sample_expression( if (not expansion_rules) or (rule_ref.rule_name not in expansion_rules): raise MissingRuleError(f"Missing expansion rule <{rule_ref.rule_name}>") - rule_body = expansion_rules[rule_ref.rule_name] + rule_body = expansion_rules[rule_ref.rule_name].exp yield from sample_expression( rule_body, slot_lists, diff --git a/hassil/string_matcher.py b/hassil/string_matcher.py index 73b9667..65e40e4 100644 --- a/hassil/string_matcher.py +++ b/hassil/string_matcher.py @@ -10,12 +10,14 @@ from .errors import MissingListError, MissingRuleError from .expression import ( + Alternative, Expression, Group, - GroupType, ListReference, + Permutation, RuleReference, Sentence, + Sequence, TextChunk, ) from .intents import IntentData, RangeSlotList, SlotList, TextSlotList, WildcardSlotList @@ -400,13 +402,13 @@ def match_expression( pass elif isinstance(expression, Group): grp: Group = expression - if grp.type == GroupType.ALTERNATIVE: + if isinstance(grp, Alternative): # Any may match (words | in | alternative) # NOTE: [optional] = (optional | ) for item in grp.items: yield from match_expression(settings, context, item) - elif grp.type == GroupType.SEQUENCE: + elif isinstance(grp, Sequence): if grp.items: # All must match (words in group) group_contexts = [context] @@ -424,7 +426,7 @@ def match_expression( yield from group_contexts - elif grp.type == GroupType.PERMUTATION: + elif isinstance(grp, Permutation): if len(grp.items) == 1: yield from match_expression(settings, context, grp.items[0]) else: @@ -432,7 +434,7 @@ def match_expression( for i, item in enumerate(grp.items): items = grp.items[:] del items[i] - perm = Group(type=GroupType.PERMUTATION, items=items) + perm = Permutation(items=items) for item_context in match_expression(settings, context, item): yield from match_expression(settings, item_context, perm) @@ -812,7 +814,7 @@ def match_expression( raise MissingRuleError(f"Missing expansion rule <{rule_ref.rule_name}>") yield from match_expression( - settings, context, settings.expansion_rules[rule_ref.rule_name] + settings, context, settings.expansion_rules[rule_ref.rule_name].exp ) else: raise ValueError(f"Unexpected expression: {expression}") diff --git a/tests/test_expression.py b/tests/test_expression.py index 5d02b60..3881d90 100644 --- a/tests/test_expression.py +++ b/tests/test_expression.py @@ -1,11 +1,12 @@ from unittest.mock import ANY from hassil.expression import ( - Group, - GroupType, + Alternative, ListReference, + Permutation, RuleReference, Sentence, + Sequence, TextChunk, ) from hassil.parse_expression import parse_expression, parse_sentence @@ -19,21 +20,21 @@ def test_word(): def test_group_in_group(): - assert parse_expression(next_chunk("((test test2))")) == sequence( - items=[sequence(items=[t(text="test "), t(text="test2")])], + assert parse_expression(next_chunk("((test test2))")) == Sequence( + items=[Sequence(items=[t(text="test "), t(text="test2")])], ) def test_escapes(): - assert parse_expression(next_chunk(r"(test\<\>\{\}\)\( test2)")) == sequence( + assert parse_expression(next_chunk(r"(test\<\>\{\}\)\( test2)")) == Sequence( items=[t(text="test<>{})( "), t(text="test2")], ) def test_optional(): - assert parse_expression(next_chunk("[test test2]")) == alt( + assert parse_expression(next_chunk("[test test2]")) == Alternative( items=[ - sequence( + Sequence( items=[t(text="test "), t(text="test2")], ), t(text=""), @@ -43,16 +44,16 @@ def test_optional(): def test_group_alternative(): - assert parse_expression(next_chunk("(test | test2)")) == alt( - items=[sequence(items=[t(text="test ")]), sequence(items=[t(text=" test2")])], + assert parse_expression(next_chunk("(test | test2)")) == Alternative( + items=[Sequence(items=[t(text="test ")]), Sequence(items=[t(text=" test2")])], ) def test_group_permutation(): - assert parse_expression(next_chunk("(test; test2)")) == perm( + assert parse_expression(next_chunk("(test; test2)")) == Permutation( items=[ - sequence(items=[t(text=" "), t(text="test"), t(text=" ")]), - sequence(items=[t(text=" "), t(text=" test2"), t(text=" ")]), + Sequence(items=[t(text=" "), t(text="test"), t(text=" ")]), + Sequence(items=[t(text=" "), t(text=" test2"), t(text=" ")]), ], ) @@ -67,63 +68,78 @@ def test_rule_reference(): def test_sentence_no_group(): assert parse_sentence("this is a test") == Sentence( - items=[t(text="this "), t(text="is "), t(text="a "), t(text="test")] + exp=Sequence( + items=[t(text="this "), t(text="is "), t(text="a "), t(text="test")] + ) ) def test_sentence_group(): assert parse_sentence("(this is a test)") == Sentence( - items=[t(text="this "), t(text="is "), t(text="a "), t(text="test")] + exp=Sequence( + items=[t(text="this "), t(text="is "), t(text="a "), t(text="test")] + ) ) def test_sentence_optional(): assert parse_sentence("[this is a test]") == Sentence( - type=GroupType.ALTERNATIVE, - items=[ - sequence( - items=[ - t(text="this "), - t(text="is "), - t(text="a "), - t(text="test"), - ] - ), - t(text=""), - ], - is_optional=True, + exp=Alternative( + items=[ + Sequence( + items=[ + t(text="this "), + t(text="is "), + t(text="a "), + t(text="test"), + ] + ), + t(text=""), + ], + is_optional=True, + ) ) def test_sentence_optional_prefix(): assert parse_sentence("[t]est") == Sentence( - type=GroupType.SEQUENCE, - items=[ - alt(items=[sequence(items=[t(text="t")]), t(text="")], is_optional=True), - t(text="est"), - ], + exp=Sequence( + items=[ + Alternative( + items=[Sequence(items=[t(text="t")]), t(text="")], is_optional=True + ), + t(text="est"), + ], + ) ) def test_sentence_optional_suffix(): assert parse_sentence("test[s]") == Sentence( - type=GroupType.SEQUENCE, - items=[ - t(text="test"), - alt(items=[sequence(items=[t(text="s")]), t(text="")], is_optional=True), - ], + exp=Sequence( + items=[ + t(text="test"), + Alternative( + items=[Sequence(items=[t(text="s")]), t(text="")], is_optional=True + ), + ], + ) ) def test_sentence_alternative_whitespace(): assert parse_sentence("test ( 1 | 2)") == Sentence( - type=GroupType.SEQUENCE, - items=[ - t(text="test "), - alt( - items=[sequence(items=[t(text=" 1 ")]), sequence(items=[t(text=" 2")])] - ), - ], + exp=Sequence( + items=[ + t(text="test "), + Alternative( + items=[ + Sequence(items=[t(text=" 1 ")]), + Sequence(items=[t(text=" 2")]), + ] + ), + ], + ) ) @@ -141,15 +157,3 @@ def test_sentence_alternative_whitespace(): def t(**kwargs): return TextChunk(parent=ANY, **kwargs) - - -def sequence(**kwargs): - return Group(type=GroupType.SEQUENCE, **kwargs) - - -def alt(**kwargs): - return Group(type=GroupType.ALTERNATIVE, **kwargs) - - -def perm(**kwargs): - return Group(type=GroupType.PERMUTATION, **kwargs) diff --git a/tests/test_sample.py b/tests/test_sample.py index 74ca027..6fd85aa 100644 --- a/tests/test_sample.py +++ b/tests/test_sample.py @@ -1,22 +1,20 @@ from hassil import parse_sentence from hassil.intents import RangeSlotList, TextSlotList -from hassil.sample import sample_expression +from hassil.sample import sample_sentence def test_text_chunk(): - assert set(sample_expression(parse_sentence("this is a test"))) == { - "this is a test" - } + assert set(sample_sentence(parse_sentence("this is a test"))) == {"this is a test"} def test_group(): - assert set(sample_expression(parse_sentence("this (is a) test"))) == { + assert set(sample_sentence(parse_sentence("this (is a) test"))) == { "this is a test" } def test_optional(): - assert set(sample_expression(parse_sentence("turn on [the] light[s]"))) == { + assert set(sample_sentence(parse_sentence("turn on [the] light[s]"))) == { "turn on light", "turn on lights", "turn on the light", @@ -25,7 +23,7 @@ def test_optional(): def test_double_optional(): - assert set(sample_expression(parse_sentence("turn [on] [the] light[s]"))) == { + assert set(sample_sentence(parse_sentence("turn [on] [the] light[s]"))) == { "turn light", "turn lights", "turn on light", @@ -38,7 +36,7 @@ def test_double_optional(): def test_alternative(): - assert set(sample_expression(parse_sentence("this is (the | a) test"))) == { + assert set(sample_sentence(parse_sentence("this is (the | a) test"))) == { "this is a test", "this is the test", } @@ -47,7 +45,7 @@ def test_alternative(): def test_list(): sentence = parse_sentence("turn off {area}") areas = TextSlotList.from_strings(["kitchen", "living room"]) - assert set(sample_expression(sentence, slot_lists={"area": areas})) == { + assert set(sample_sentence(sentence, slot_lists={"area": areas})) == { "turn off kitchen", "turn off living room", } @@ -56,7 +54,7 @@ def test_list(): def test_list_range(): sentence = parse_sentence("run test {num}") num_list = RangeSlotList(name=None, start=1, stop=3) - assert set(sample_expression(sentence, slot_lists={"num": num_list})) == { + assert set(sample_sentence(sentence, slot_lists={"num": num_list})) == { "run test 1", "run test 2", "run test 3", @@ -68,7 +66,7 @@ def test_list_range_missing_language(): num_list = RangeSlotList(name=None, start=1, stop=3, words=True) # Range slot digits cannot be converted to words without a language available. - assert set(sample_expression(sentence, slot_lists={"num": num_list})) == { + assert set(sample_sentence(sentence, slot_lists={"num": num_list})) == { "run test 1", "run test 2", "run test 3", @@ -79,7 +77,7 @@ def test_list_range_words(): sentence = parse_sentence("run test {num}") num_list = RangeSlotList(name=None, start=1, stop=3, words=True) assert set( - sample_expression(sentence, slot_lists={"num": num_list}, language="en") + sample_sentence(sentence, slot_lists={"num": num_list}, language="en") ) == { "run test 1", "run test one", @@ -93,7 +91,7 @@ def test_list_range_words(): def test_rule(): sentence = parse_sentence("turn off ") assert set( - sample_expression( + sample_sentence( sentence, expansion_rules={"area": parse_sentence("[the] kitchen")}, ) From d45f3a7f19144ae9269975a1a9bb4260d939cd04 Mon Sep 17 00:00:00 2001 From: Artur Pragacz Date: Fri, 20 Dec 2024 17:18:09 +0100 Subject: [PATCH 4/8] Iterate permutations --- hassil/expression.py | 9 +++++++++ hassil/string_matcher.py | 8 ++------ 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/hassil/expression.py b/hassil/expression.py index ead21e4..0776eaa 100644 --- a/hassil/expression.py +++ b/hassil/expression.py @@ -2,6 +2,7 @@ import re from abc import ABC +from collections.abc import Iterable from dataclasses import dataclass, field from typing import Dict, Iterator, List, Optional @@ -100,6 +101,14 @@ class Alternative(Group): class Permutation(Group): """Permutations of a set of expressions""" + def iterate_permutations(self) -> Iterable[tuple[Expression, "Permutation"]]: + """Iterate over all permutations.""" + for i, item in enumerate(self.items): + items = self.items[:] + del items[i] + rest = Permutation(items=items) + yield (item, rest) + @dataclass class RuleReference(Expression): diff --git a/hassil/string_matcher.py b/hassil/string_matcher.py index 65e40e4..fc2970b 100644 --- a/hassil/string_matcher.py +++ b/hassil/string_matcher.py @@ -431,13 +431,9 @@ def match_expression( yield from match_expression(settings, context, grp.items[0]) else: # All must match (in arbitrary order) - for i, item in enumerate(grp.items): - items = grp.items[:] - del items[i] - perm = Permutation(items=items) - + for item, rest in grp.iterate_permutations(): for item_context in match_expression(settings, context, item): - yield from match_expression(settings, item_context, perm) + yield from match_expression(settings, item_context, rest) else: raise ValueError(f"Unexpected group type: {grp}") From 07c3ca9303242600a4982de74ee54e6cce1e9891 Mon Sep 17 00:00:00 2001 From: Artur Pragacz Date: Fri, 20 Dec 2024 17:48:21 +0100 Subject: [PATCH 5/8] Sample permutations --- hassil/sample.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/hassil/sample.py b/hassil/sample.py index d47a7bb..ede8033 100644 --- a/hassil/sample.py +++ b/hassil/sample.py @@ -18,6 +18,7 @@ Expression, Group, ListReference, + Permutation, RuleReference, Sentence, Sequence, @@ -158,6 +159,22 @@ def sample_expression( sentence_texts = itertools.product(*seq_sentences) for sentence_words in sentence_texts: yield normalize_whitespace("".join(sentence_words)) + elif isinstance(grp, Permutation): + seq_sentences = map( + partial( + sample_expression, + slot_lists=slot_lists, + expansion_rules=expansion_rules, + language=language, + expand_lists=expand_lists, + expand_ranges=expand_ranges, + ), + grp.items, + ) + for perm_sentences in itertools.permutations(seq_sentences): + sentence_texts = itertools.product(*perm_sentences) + for sentence_words in sentence_texts: + yield normalize_whitespace("".join(sentence_words)) else: raise ValueError(f"Unexpected group type: {grp}") elif isinstance(expression, ListReference): From fbf21bc509d087eb103eb07a8b3108b20fefa91f Mon Sep 17 00:00:00 2001 From: Artur Pragacz Date: Fri, 20 Dec 2024 18:20:46 +0100 Subject: [PATCH 6/8] Fix tuple typing in older Python --- hassil/expression.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hassil/expression.py b/hassil/expression.py index 0776eaa..8d42dff 100644 --- a/hassil/expression.py +++ b/hassil/expression.py @@ -4,7 +4,7 @@ from abc import ABC from collections.abc import Iterable from dataclasses import dataclass, field -from typing import Dict, Iterator, List, Optional +from typing import Dict, Iterator, List, Optional, Tuple @dataclass @@ -101,7 +101,7 @@ class Alternative(Group): class Permutation(Group): """Permutations of a set of expressions""" - def iterate_permutations(self) -> Iterable[tuple[Expression, "Permutation"]]: + def iterate_permutations(self) -> Iterable[Tuple[Expression, "Permutation"]]: """Iterate over all permutations.""" for i, item in enumerate(self.items): items = self.items[:] From 5369291ea800cb364deaaac72fdd1ca3e8c493b9 Mon Sep 17 00:00:00 2001 From: Artur Pragacz Date: Fri, 20 Dec 2024 18:33:31 +0100 Subject: [PATCH 7/8] Fix typing once more --- hassil/expression.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/hassil/expression.py b/hassil/expression.py index 8d42dff..a40d7bc 100644 --- a/hassil/expression.py +++ b/hassil/expression.py @@ -1,10 +1,11 @@ """Classes for representing sentence templates.""" +from __future__ import annotations + import re from abc import ABC -from collections.abc import Iterable from dataclasses import dataclass, field -from typing import Dict, Iterator, List, Optional, Tuple +from typing import Dict, Iterable, Iterator, List, Optional, Tuple @dataclass @@ -34,7 +35,7 @@ def is_empty(self) -> bool: return self.text == "" @staticmethod - def empty() -> "TextChunk": + def empty() -> TextChunk: """Returns an empty text chunk""" return TextChunk() @@ -60,7 +61,7 @@ def text_chunk_count(self) -> int: def list_names( self, - expansion_rules: Optional[Dict[str, "Sentence"]] = None, + expansion_rules: Optional[Dict[str, Sentence]] = None, ) -> Iterator[str]: """Return names of list references (recursive).""" for item in self.items: @@ -69,7 +70,7 @@ def list_names( def _list_names( self, item: Expression, - expansion_rules: Optional[Dict[str, "Sentence"]] = None, + expansion_rules: Optional[Dict[str, Sentence]] = None, ) -> Iterator[str]: """Return names of list references (recursive).""" if isinstance(item, ListReference): @@ -101,7 +102,7 @@ class Alternative(Group): class Permutation(Group): """Permutations of a set of expressions""" - def iterate_permutations(self) -> Iterable[Tuple[Expression, "Permutation"]]: + def iterate_permutations(self) -> Iterable[Tuple[Expression, Permutation]]: """Iterate over all permutations.""" for i, item in enumerate(self.items): items = self.items[:] @@ -156,13 +157,13 @@ def text_chunk_count(self) -> int: def list_names( self, - expansion_rules: Optional[Dict[str, "Sentence"]] = None, + expansion_rules: Optional[Dict[str, Sentence]] = None, ) -> Iterator[str]: """Return names of list references in this sentence.""" assert isinstance(self.exp, Group) return self.exp.list_names(expansion_rules) # pylint: disable=no-member - def compile(self, expansion_rules: Dict[str, "Sentence"]) -> None: + def compile(self, expansion_rules: Dict[str, Sentence]) -> None: if self.pattern is not None: # Already compiled return @@ -173,7 +174,7 @@ def compile(self, expansion_rules: Dict[str, "Sentence"]) -> None: self.pattern = re.compile(f"^{pattern_str}$", re.IGNORECASE) def _compile_expression( - self, exp: Expression, pattern_chunks: List[str], rules: Dict[str, "Sentence"] + self, exp: Expression, pattern_chunks: List[str], rules: Dict[str, Sentence] ) -> None: if isinstance(exp, TextChunk): # Literal text From fcef7cf970727923f92926eaa7e2b06a4413060f Mon Sep 17 00:00:00 2001 From: Artur Pragacz Date: Fri, 20 Dec 2024 19:04:36 +0100 Subject: [PATCH 8/8] Fix optional --- hassil/parse_expression.py | 20 +++++++++++--------- tests/test_expression.py | 32 +++++++++++++++++++++++++++++--- 2 files changed, 40 insertions(+), 12 deletions(-) diff --git a/hassil/parse_expression.py b/hassil/parse_expression.py index f47da2c..9d9e138 100644 --- a/hassil/parse_expression.py +++ b/hassil/parse_expression.py @@ -49,18 +49,18 @@ def __str__(self) -> str: return f"Error in chunk {self.chunk} at {self.metadata}" -def _ensure_alternative(grp: Group, **kw) -> Alternative: +def _ensure_alternative(grp: Group) -> Alternative: if isinstance(grp, Alternative): return grp # Collapse items into a single group - return Alternative(items=[Sequence(items=grp.items)], **kw) + return Alternative(items=[grp]) def _ensure_permutation(grp: Group) -> Permutation: if isinstance(grp, Permutation): return grp # Collapse items into a single group - return Permutation(items=[Sequence(items=grp.items)]) + return Permutation(items=[grp]) def parse_group( @@ -88,9 +88,9 @@ def parse_group( item = parse_expression(item_chunk, metadata=metadata) if isinstance(grp, (Alternative, Permutation)): - # Add to most recent group + # Add to the most recent sequence last_item = grp.items[-1] - if not isinstance(last_item, Group): + if not isinstance(last_item, Sequence): raise ParseExpressionError(grp_chunk, metadata=metadata) last_item.items.append(item) @@ -105,12 +105,12 @@ def parse_group( elif item_chunk.parse_type == ParseType.ALT: grp = _ensure_alternative(grp) - # Begin new group + # Begin new sequence grp.items.append(Sequence()) elif item_chunk.parse_type == ParseType.PERM: grp = _ensure_permutation(grp) - # Begin new group + # Begin new sequence grp.items.append(Sequence()) else: raise ParseExpressionError(grp_chunk, metadata=metadata) @@ -144,8 +144,10 @@ def parse_expression( if chunk.parse_type == ParseType.OPT: grp = parse_group(chunk, metadata=metadata) - grp = _ensure_alternative(grp, is_optional=True) - grp.items.append(TextChunk(text="", parent=grp)) + alt = _ensure_alternative(grp) + alt.is_optional = True + alt.items.append(TextChunk(text="", parent=grp)) + grp = alt return grp if chunk.parse_type == ParseType.LIST: diff --git a/tests/test_expression.py b/tests/test_expression.py index 3881d90..adfc7c7 100644 --- a/tests/test_expression.py +++ b/tests/test_expression.py @@ -19,7 +19,7 @@ def test_word(): assert parse_expression(next_chunk("test")) == t(text="test") -def test_group_in_group(): +def test_sequence_in_sequence(): assert parse_expression(next_chunk("((test test2))")) == Sequence( items=[Sequence(items=[t(text="test "), t(text="test2")])], ) @@ -43,13 +43,13 @@ def test_optional(): ) -def test_group_alternative(): +def test_alternative(): assert parse_expression(next_chunk("(test | test2)")) == Alternative( items=[Sequence(items=[t(text="test ")]), Sequence(items=[t(text=" test2")])], ) -def test_group_permutation(): +def test_permutation(): assert parse_expression(next_chunk("(test; test2)")) == Permutation( items=[ Sequence(items=[t(text=" "), t(text="test"), t(text=" ")]), @@ -58,6 +58,32 @@ def test_group_permutation(): ) +def test_optional_alternative(): + assert parse_expression(next_chunk("[test | test2]")) == Alternative( + items=[ + Sequence(items=[t(text="test ")]), + Sequence(items=[t(text=" test2")]), + t(text=""), + ], + is_optional=True, + ) + + +def test_optional_permutation(): + assert parse_expression(next_chunk("[test; test2]")) == Alternative( + items=[ + Permutation( + items=[ + Sequence(items=[t(text=" "), t(text="test"), t(text=" ")]), + Sequence(items=[t(text=" "), t(text=" test2"), t(text=" ")]), + ], + ), + t(text=""), + ], + is_optional=True, + ) + + def test_slot_reference(): assert parse_expression(next_chunk("{test}")) == ListReference(list_name="test")