Skip to content

Commit

Permalink
Simple tool for computing nearest phonetic embeddings.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 720467516
  • Loading branch information
agutkin committed Jan 28, 2025
1 parent a693bde commit 84a2346
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@

"""Computes sorted distance vectors for all embeddings."""

from collections.abc import Sequence

from absl import app
from absl import flags
from protoscribe.language.embeddings import embedder
from protoscribe.language.phonology import phoible_segments
from protoscribe.language.phonology import phonetic_embeddings

Expand All @@ -43,16 +44,16 @@
)


def main(unused_argv):
phoible = phoible_segments.PhoibleSegments(
path=_PHOIBLE_PATH.value,
features_path=_PHOIBLE_FEATURES_PATH.value,
)
embeddings = phonetic_embeddings.PhoneticEmbeddings(
phoible_seg=phoible,
embedding_len=embedder.DEFAULT_EMBEDDING_DIM,
def main(argv: Sequence[str]) -> None:
if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.")

# Load phonetic embeddings.
embeddings = phonetic_embeddings.load_phonetic_embedder(
embeddings_file_path=_INPUT_EMBEDDINGS_FILE.value,
phoible_phonemes_path=_PHOIBLE_PATH.value,
phoible_features_path=_PHOIBLE_FEATURES_PATH.value
)
embeddings.read_embeddings(_INPUT_EMBEDDINGS_FILE.value)
embeddings.dump_all_distances(_OUTPUT_DISTANCES_FILE.value)


Expand Down
139 changes: 139 additions & 0 deletions protoscribe/language/phonology/phonetic_nearest_concepts_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright 2024 The Protoscribe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Computes k-NN for concepts using the phonetic embeddings.
This is somewhat similar in function to `phonetic_embeddings_distances` tool,
but supports lookups via and filtering by the category names.
"""

from collections.abc import Sequence
import csv
import itertools
import logging

from absl import app
from absl import flags
from protoscribe.language.phonology import phoible_segments
from protoscribe.language.phonology import phonetic_embeddings
from protoscribe.texts import generate_lib

import glob
import os

_INPUT_EMBEDDINGS_FILE = flags.DEFINE_string(
"input_embeddings_file", None,
"Path to the input phonetic embeddings file in TSV format.",
required=True
)

_PHOIBLE_PATH = flags.DEFINE_string(
"phoible_path", phoible_segments.PHOIBLE, "Path to PHOIBLE segments."
)

_PHOIBLE_FEATURES_PATH = flags.DEFINE_string(
"phoible_features_path",
phoible_segments.PHOIBLE_FEATURES,
"Path to PHOIBLE features.",
)

_TOP_K = flags.DEFINE_integer(
"top_k", 3,
"Keep best k candidates. If negative, compute for all entries."
)

_OUTPUT_TSV_FILE = flags.DEFINE_string(
"output_tsv_file", None,
"Path to the output file in TSV format containing all the closest "
"neighbors from the seen set.",
required=True
)

# Following will expose the category and lexicon command-line flags via FLAGS.
# In particular, we will need the main and number lexicons, and the
# administrative and non-administrative categories.
FLAGS = flags.FLAGS


def main(argv: Sequence[str]) -> None:
if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.")

# Load phonetic embeddings.
embeddings = phonetic_embeddings.load_phonetic_embedder(
embeddings_file_path=_INPUT_EMBEDDINGS_FILE.value,
phoible_phonemes_path=_PHOIBLE_PATH.value,
phoible_features_path=_PHOIBLE_FEATURES_PATH.value
)

# Load administrative (seen) and non-administrative (unseen) concepts.
# Make sure both are specified.
if not FLAGS.concepts or not FLAGS.unseen_concepts:
raise ValueError("Specify paths to both seens and unseen concepts!")

_, seen_concepts = generate_lib.load_concepts(FLAGS.concepts)
_, unseen_concepts = generate_lib.load_concepts(FLAGS.unseen_concepts)

# Load category and number lexicon.
if not FLAGS.main_lexicon or not FLAGS.number_lexicon:
raise ValueError("Specify --main_lexicon and --number_lexicon!")

lexicon, _ = generate_lib.load_phonetic_forms(
main_lexicon_file=FLAGS.main_lexicon,
number_lexicon_file=FLAGS.number_lexicon
)
logging.info("Loaded total of %d pronunciations.", len(lexicon))

# Cache the embeddings for seen concepts.
all_terms = embeddings.keys
seen_terms = []
for concept in seen_concepts:
concept = concept.split("_")[0] # POS kludge.
if concept not in lexicon:
raise ValueError(f"Concept {concept} not found in pronunciation lexicon!")
pron = " ".join(lexicon[concept])
if pron not in all_terms:
raise ValueError(f"No embedding found for pronunciation '{pron}'!")
seen_terms.append(pron)

# For each concept in unseen set compute its $k$-nearest neighbors.
logging.info("Saving results to %s ...", _OUTPUT_TSV_FILE.value)
with open(_OUTPUT_TSV_FILE.value, mode="wt") as f:
writer = csv.writer(f, delimiter="\t")
top_k_header = [
(f"Pron{k}", f"Dist{k}") for k in range(1, _TOP_K.value + 1)
]
writer.writerow(
["NewConcept", "NewPron"] + list(itertools.chain(*top_k_header))
)
for concept in unseen_concepts:
# Lookup the pronunciation.
concept = concept.split("_")[0] # POS kludge.
if concept not in lexicon:
raise ValueError(
f"Concept {concept} not found in pronunciation lexicon!")
pron = " ".join(lexicon[concept])

# Compute nearest K pronunciations.
nearest = embeddings.get_k_nearest_neighbors(
pron, _TOP_K.value, allowed_terms=seen_terms
)
nearest = [(other_p, float(d)) for other_p, d in nearest]
nearest = list(itertools.chain(*nearest))
writer.writerow([concept, pron] + nearest)
logging.info("Processed %d concepts.", len(unseen_concepts))


if __name__ == "__main__":
app.run(main)
11 changes: 11 additions & 0 deletions protoscribe/texts/generate_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,16 +241,27 @@ def load_phonetic_forms(
raise ValueError("Main lexicon file not specified!")

logging.info("Loading main lexicon from %s ...", main_lexicon_file)
pos_collisions = set()
with open(main_lexicon_file) as s:
for line in s:
conc, phon = line.strip("\n").split("\t")
# TODO: This will fail if we have the same term with two different
# parts of speech.
word = conc.split("_")[0]
if word in pronunciation_lexicon:
pos_collisions.add(word)
pronunciation_lexicon[word] = phon.split()
if seen_concepts and conc in seen_concepts:
seen_phonetic_forms.add(phon)

# Removing parts-of-speech results in key collisions. Print these words,
# if any.
if pos_collisions:
logging.warning(
"Removing POS from concepts results in pronunciation "
"collisions for: %s", pos_collisions
)

if not number_lexicon_file:
number_lexicon_file = _NUMBER_LEXICON.value
if not number_lexicon_file:
Expand Down

0 comments on commit 84a2346

Please sign in to comment.