Skip to content

Commit

Permalink
add flag for results of congr, fix token for diar
Browse files Browse the repository at this point in the history
  • Loading branch information
Lameus committed Dec 11, 2023
1 parent b186c02 commit f69a314
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 12 deletions.
27 changes: 17 additions & 10 deletions expert/core/congruence/congruence_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(
sr: int = 44100,
device: torch.device | None = None,
output_dir: str | PathLike | None = None,
return_path: bool = False,
):
if lang not in ["en", "ru"]:
raise NotImplementedError("'lang' must be 'en' or 'ru'.")
Expand Down Expand Up @@ -92,6 +93,8 @@ def __init__(
if not os.path.exists(self.temp_path):
os.makedirs(self.temp_path)

self.return_path = return_path

@property
def device(self) -> torch.device:
"""Check the device type.
Expand Down Expand Up @@ -194,15 +197,19 @@ def get_congruence(self):
emotions_data["audio"] = audio_data.to_dict(orient="records")
emotions_data["text"] = text_data.to_dict(orient="records")

with open(
os.path.join(self.temp_path, "emotions.json"), "w"
) as filename:
json.dump(emotions_data, filename)
if self.return_path:
with open(
os.path.join(self.temp_path, "emotions.json"), "w"
) as filename:
json.dump(emotions_data, filename)

cong_data[["video_path", "time_sec", "congruence"]].to_json(
os.path.join(self.temp_path, "congruence.json"), orient="records"
)
cong_data[["video_path", "time_sec", "congruence"]].to_json(
os.path.join(self.temp_path, "congruence.json"),
orient="records",
)

return os.path.join(self.temp_path, "emotions.json"), os.path.join(
self.temp_path, "congruence.json"
)
return os.path.join(self.temp_path, "emotions.json"), os.path.join(
self.temp_path, "congruence.json"
)
else:
return {"emotions": emotions_data, "congruence": cong_data}
2 changes: 1 addition & 1 deletion expert/data/annotation/speech_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
def transcribe_video(
video_path: Union[str, PathLike],
lang: Optional[str] = "en",
model: Optional[str] = "server",
model: Optional[str] = "local",
device: Optional[Union[torch.device, None]] = None,
) -> Dict:
"""Speech recognition module from video.
Expand Down
2 changes: 1 addition & 1 deletion expert/data/diarization/speaker_diarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(
if device is not None:
self._device = device

token = "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" # FIXME убрать харкод пароля # nosec
token = "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"
self.pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization@2.1", use_auth_token=token
)
Expand Down

0 comments on commit f69a314

Please sign in to comment.