Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve permutations #183

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion hassil/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""Home Assistant Intent Language parser"""

from .expression import (
Alternative,
Group,
ListReference,
Permutation,
arturpragacz marked this conversation as resolved.
Show resolved Hide resolved
RuleReference,
Sentence,
Sequence,
SequenceType,
TextChunk,
)
from .intents import Intents
Expand Down
120 changes: 72 additions & 48 deletions hassil/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import re
from abc import ABC
from collections.abc import Iterable
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, Iterator, List, Optional


Expand All @@ -22,7 +22,7 @@ class TextChunk(Expression):
# Set in __post_init__
original_text: str = None # type: ignore

parent: "Optional[Sequence]" = None
parent: "Optional[Group]" = None

def __post_init__(self):
if self.original_text is None:
Expand All @@ -39,40 +39,22 @@ def empty() -> "TextChunk":
return TextChunk()


class SequenceType(str, Enum):
"""Type of a sequence. Optionals are alternatives with an empty option."""

# Sequence of expressions
GROUP = "group"

# Expressions where only one will be recognized
ALTERNATIVE = "alternative"

# Permutations of a set of expressions
PERMUTATION = "permutation"


@dataclass
class Sequence(Expression):
"""Ordered sequence of expressions. Supports groups, optionals, and alternatives."""
class Group(Expression):
"""Ordered group of expressions. Supports sequences, optionals, and alternatives."""

# Items in the sequence
# Items in the group
items: List[Expression] = field(default_factory=list)

# Group or alternative
type: SequenceType = SequenceType.GROUP

is_optional: bool = False

def text_chunk_count(self) -> int:
"""Return the number of TextChunk expressions in this sequence (recursive)."""
"""Return the number of TextChunk expressions in this group (recursive)."""
num_text_chunks = 0
for item in self.items:
if isinstance(item, TextChunk):
num_text_chunks += 1
elif isinstance(item, Sequence):
seq: Sequence = item
num_text_chunks += seq.text_chunk_count()
elif isinstance(item, Group):
grp: Group = item
num_text_chunks += grp.text_chunk_count()

return num_text_chunks

Expand All @@ -93,16 +75,41 @@ def _list_names(
if isinstance(item, ListReference):
list_ref: ListReference = item
yield list_ref.list_name
elif isinstance(item, Sequence):
seq: Sequence = item
yield from seq.list_names(expansion_rules)
elif isinstance(item, Group):
grp: Group = item
yield from grp.list_names(expansion_rules)
elif isinstance(item, RuleReference):
rule_ref: RuleReference = item
if expansion_rules and (rule_ref.rule_name in expansion_rules):
rule_body = expansion_rules[rule_ref.rule_name]
rule_body = expansion_rules[rule_ref.rule_name].exp
yield from self._list_names(rule_body, expansion_rules)


@dataclass
class Sequence(Group):
"""Sequence of expressions"""


@dataclass
class Alternative(Group):
"""Expressions where only one will be recognized"""

is_optional: bool = False


@dataclass
class Permutation(Group):
"""Permutations of a set of expressions"""

def iterate_permutations(self) -> Iterable[tuple[Expression, "Permutation"]]:
"""Iterate over all permutations."""
for i, item in enumerate(self.items):
items = self.items[:]
del items[i]
rest = Permutation(items=items)
yield (item, rest)


@dataclass
class RuleReference(Expression):
"""Reference to an expansion rule by <name>."""
Expand Down Expand Up @@ -135,49 +142,66 @@ def slot_name(self) -> str:


@dataclass
class Sentence(Sequence):
"""Sequence representing a complete sentence template."""
class Sentence:
"""A complete sentence template."""

exp: Expression
text: Optional[str] = None
pattern: Optional[re.Pattern] = None

def text_chunk_count(self) -> int:
"""Return the number of TextChunk expressions in this sentence."""
assert isinstance(self.exp, Group)
return self.exp.text_chunk_count() # pylint: disable=no-member

def list_names(
self,
expansion_rules: Optional[Dict[str, "Sentence"]] = None,
) -> Iterator[str]:
"""Return names of list references in this sentence."""
assert isinstance(self.exp, Group)
return self.exp.list_names(expansion_rules) # pylint: disable=no-member

def compile(self, expansion_rules: Dict[str, "Sentence"]) -> None:
if self.pattern is not None:
# Already compiled
return

pattern_chunks: List[str] = []
self._compile_expression(self, pattern_chunks, expansion_rules)

self._compile_expression(self.exp, pattern_chunks, expansion_rules)
pattern_str = "".join(pattern_chunks).replace(r"\ ", r"[ ]*")
self.pattern = re.compile(f"^{pattern_str}$", re.IGNORECASE)

def _compile_expression(
self, exp: Expression, pattern_chunks: List[str], rules: Dict[str, "Sentence"]
):
) -> None:
if isinstance(exp, TextChunk):
# Literal text
chunk: TextChunk = exp
if chunk.text:
escaped_text = re.escape(chunk.text)
pattern_chunks.append(escaped_text)
elif isinstance(exp, Sequence):
# Linear sequence or alternative choices
seq: Sequence = exp
if seq.type == SequenceType.GROUP:
# Linear sequence
for item in seq.items:
elif isinstance(exp, Group):
grp: Group = exp
if isinstance(grp, Sequence):
for item in grp.items:
self._compile_expression(item, pattern_chunks, rules)
elif seq.type == SequenceType.ALTERNATIVE:
# Alternative choices
if seq.items:
elif isinstance(grp, Alternative):
if grp.items:
pattern_chunks.append("(?:")
for item in seq.items:
for item in grp.items:
self._compile_expression(item, pattern_chunks, rules)
pattern_chunks.append("|")
pattern_chunks[-1] = ")"
elif isinstance(grp, Permutation):
arturpragacz marked this conversation as resolved.
Show resolved Hide resolved
if grp.items:
pattern_chunks.append("(?:")
for item in grp.items:
self._compile_expression(item, pattern_chunks, rules)
pattern_chunks.append("|")
pattern_chunks[-1] = f"){{{len(grp.items)}}}"
else:
raise ValueError(seq)
raise ValueError(grp)
elif isinstance(exp, ListReference):
# Slot list
pattern_chunks.append("(?:.+)")
Expand All @@ -189,6 +213,6 @@ def _compile_expression(
raise ValueError(rule_ref)

e_rule = rules[rule_ref.rule_name]
self._compile_expression(e_rule, pattern_chunks, rules)
self._compile_expression(e_rule.exp, pattern_chunks, rules)
else:
raise ValueError(exp)
2 changes: 1 addition & 1 deletion hassil/intents.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,6 @@ def _parse_data_settings(settings_dict: Dict[str, Any]) -> IntentDataSettings:
def _maybe_parse_template(text: str, allow_template: bool = True) -> Expression:
"""Parse string as a sentence template if it has template syntax."""
if allow_template and is_template(text):
return parse_sentence(text)
return parse_sentence(text).exp

return TextChunk(normalize_text(text))
Loading
Loading