Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to support HuggingFace #2

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions .github/workflows/pre-commit.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
name: Pre-commit
on:
pull_request:
push:
branches:
- "*"
jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: "3.10"
- uses: pre-commit/action@v3.0.0
22 changes: 22 additions & 0 deletions .github/workflows/run-pytest.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
name: PyTest
on:
push:
branches:
- "*" # TODO: run on all for now (move to main later)
pull_request:
branches:
- "*"
jobs:
run-pytest:
name: python
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Install system dependencies
run: sudo apt-get update && sudo apt-get install -y libsndfile1
- name: Install uv
uses: astral-sh/setup-uv@v5
- name: Install the project
run: uv sync --all-extras --dev
- name: Run tests
run: uv run pytest test/
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
.gradio/
*.pt
*.pth

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down Expand Up @@ -161,4 +165,4 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
# .idea/
.vscode/
.ruff_cache/
.ruff_cache/
22 changes: 22 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
repos:
- repo: https://github.com/google/yamlfmt
rev: v0.16.0
hooks:
- id: yamlfmt
- repo: https://github.com/gitleaks/gitleaks
rev: v8.23.3
hooks:
- id: gitleaks
- repo: https://github.com/astral-sh/uv-pre-commit
# uv version.
rev: 0.5.30
hooks:
- id: uv-lock
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.9.6
hooks:
- id: ruff
types_or: [python, pyi]
args: [--fix]
- id: ruff-format
types_or: [python, pyi]
1 change: 1 addition & 0 deletions examples/example.jsonl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"path": "sample_audio/libritts_spk-84.wav"}
9 changes: 9 additions & 0 deletions examples/predict_from_jsonl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from audiobox_aesthetics.inference import AudioBoxAesthetics, AudioFileList

model = AudioBoxAesthetics.from_pretrained("thunnai/audiobox-aesthetics")
model.eval()


audio_file_list = AudioFileList.from_jsonl("examples/example.jsonl")
predictions = model.predict_from_files(audio_file_list)
print(predictions)
8 changes: 8 additions & 0 deletions examples/predict_single_file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from audiobox_aesthetics.inference import AudioBoxAesthetics

model = AudioBoxAesthetics.from_pretrained("thunnai/audiobox-aesthetics")
model.eval()

wav = model.load_audio("sample_audio/libritts_spk-84.wav")
predictions = model.predict_from_wavs(wav)
print(predictions)
48 changes: 25 additions & 23 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,38 +6,39 @@ build-backend = "setuptools.build_meta"
name = "audiobox_aesthetics"
version = "0.0.1"
authors = [
{name="Andros Tjandra", email="androstj@meta.com"},
{name="Yi-Chiao Wu"},
{name="Baishan Guo"},
{name="John Hoffman"},
{name="Brian Ellis"},
{name="Apoorv Vyas"},
{name="Bowen Shi"},
{name="Sanyuan Chen"},
{name="Matt Le"},
{name="Nick Zacharov"},
{name="Carleigh Wood"},
{name="Ann Lee"},
{name="Wei-ning Hsu"}
]
maintainers = [
{name="Andros Tjandra", email="androstj@meta.com"}
{ name = "Andros Tjandra", email = "androstj@meta.com" },
{ name = "Yi-Chiao Wu" },
{ name = "Baishan Guo" },
{ name = "John Hoffman" },
{ name = "Brian Ellis" },
{ name = "Apoorv Vyas" },
{ name = "Bowen Shi" },
{ name = "Sanyuan Chen" },
{ name = "Matt Le" },
{ name = "Nick Zacharov" },
{ name = "Carleigh Wood" },
{ name = "Ann Lee" },
{ name = "Wei-ning Hsu" },
]
maintainers = [{ name = "Andros Tjandra", email = "androstj@meta.com" }]
description = "Unified automatic quality assessment for speech, music, and sound."
requires-python = ">=3.9"
classifiers = [
"Programming Language :: Python :: 3",
"Operating System :: OS Independent",
]
readme = "README.md"
license = {file = "LICENSE"}
license = { file = "LICENSE" }

dependencies = [
"numpy",
"torch>=2.2.0",
"torchaudio",
"tqdm",
"submitit"
"numpy",
"torch>=2.2.0",
"torchaudio",
"tqdm",
"submitit",
"huggingface-hub>=0.28.1",
"pydantic>=2.10.6",
"safetensors>=0.5.2",
]

[project.scripts]
Expand All @@ -47,4 +48,5 @@ audio-aes = "audiobox_aesthetics.cli:app"
Homepage = "https://github.com/facebookresearch/audiobox-aesthetics"
Issues = "https://github.com/facebookresearch/audiobox-aesthetics/issues"


[dependency-groups]
dev = ["gradio>=4.44.1", "ipykernel>=6.29.5", "pytest>=8.3.4"]
Binary file added sample_audio/libritts_spk-3170.wav
Binary file not shown.
Binary file added sample_audio/libritts_spk-84.wav
Binary file not shown.
1 change: 1 addition & 0 deletions sample_audio/test.jsonl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"path": "sample_audio/libritts_spk-84.wav"}
2 changes: 1 addition & 1 deletion src/audiobox_aesthetics/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import submitit
from tqdm import tqdm
from .infer import load_dataset, main_predict
from audiobox_aesthetics.infer import load_dataset, main_predict

logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")

Expand Down
86 changes: 86 additions & 0 deletions src/audiobox_aesthetics/demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import gradio as gr
from audiobox_aesthetics.inference import (
AudioBoxAesthetics,
AudioFile,
AXIS_NAME_LOOKUP,
)

# Load the pre-trained model
model = AudioBoxAesthetics.from_pretrained("thunnai/audiobox-aesthetics")
model.eval()


def predict_aesthetics(audio_file):
# Create an AudioFile instance
audio_file_instance = AudioFile(path=audio_file)

# Predict using the model
predictions = model.predict_from_files(audio_file_instance)

single_prediction = predictions[0]

data_view = [
[AXIS_NAME_LOOKUP[key], value] for key, value in single_prediction.items()
]

return single_prediction, data_view


def create_demo():
# Create a Gradio Blocks interface
with gr.Blocks() as demo:
gr.Markdown("# AudioBox Aesthetics Prediction")
with gr.Group():
gr.Markdown("""Upload an audio file to predict its aesthetic scores.

This demo uses the AudioBox Aesthetics model to predict aesthetic scores for audio files along 4 axes:
- Content Enjoyment (CE)
- Content Usefulness (CU)
- Production Complexity (PC)
- Production Quality (PQ)

Scores range from 0 to 10.

For more details, see the [paper](https://arxiv.org/abs/2502.05139) or [code](https://github.com/facebookresearch/audiobox-aesthetics/tree/main).
""")

with gr.Row():
with gr.Group():
with gr.Column():
audio_input = gr.Audio(
sources="upload", type="filepath", label="Upload Audio"
)
submit_button = gr.Button("Predict", variant="primary")
with gr.Group():
with gr.Column():
output_data = gr.Dataframe(
headers=["Axes name", "Score"],
datatype=["str", "number"],
label="Aesthetic Scorest",
)
output_text = gr.Textbox(label="Raw prediction", interactive=False)

submit_button.click(
predict_aesthetics,
inputs=audio_input,
outputs=[output_text, output_data],
)

# Add examples
gr.Examples(
examples=[
"sample_audio/libritts_spk-84.wav",
"sample_audio/libritts_spk-3170.wav",
],
inputs=audio_input,
outputs=[output_text, output_data],
fn=predict_aesthetics,
cache_examples=True,
)

return demo


if __name__ == "__main__":
demo = create_demo()
demo.launch()
79 changes: 79 additions & 0 deletions src/audiobox_aesthetics/export_model_to_hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import requests
import os
import argparse
import torch

from audiobox_aesthetics.inference import AudioBoxAesthetics

if __name__ == "__main__":
# Set up argument parser
parser = argparse.ArgumentParser(
description="Download and test AudioBox Aesthetics model"
)
parser.add_argument(
"--checkpoint-url",
default="https://dl.fbaipublicfiles.com/audiobox-aesthetics/checkpoint.pt",
help="URL for the base checkpoint",
)
parser.add_argument(
"--model-name",
default="audiobox-aesthetics",
help="Name to save/load the pretrained model",
)
parser.add_argument(
"--push-to-hub",
action="store_true",
help="Push the model to the Hugging Face Hub",
)
args = parser.parse_args()

checkpoint_local_path = "base_checkpoint.pth"

if not os.path.exists(checkpoint_local_path):
print("Downloading base checkpoint")
response = requests.get(args.checkpoint_url)
with open(checkpoint_local_path, "wb") as f:
f.write(response.content)

# get model config from the base checkpoint
checkpoint = torch.load(
checkpoint_local_path, map_location="cpu", weights_only=True
)
model_cfg = checkpoint["model_cfg"]

# extract normalization params from the base checkpoint
target_transform = checkpoint["target_transform"]

target_transform = {
axis: {
"mean": checkpoint["target_transform"][axis]["mean"],
"std": checkpoint["target_transform"][axis]["std"],
}
for axis in target_transform.keys()
}
# force precision to be bfloat16 to match infer class
model_cfg["precision"] = "bf16"

model = AudioBoxAesthetics(
sample_rate=16_000, target_transform=target_transform, **model_cfg
)

model._load_base_checkpoint(checkpoint_local_path)
print("βœ… Loaded model from base checkpoint")

model.save_pretrained(args.model_name, push_to_hub=args.push_to_hub)
print(f"βœ… Saved model to {args.model_name}")
if args.push_to_hub:
model.push_to_hub(args.model_name)
print(f"βœ… Pushed model to Hub under {args.model_name}")

# test load from pretrained
model = AudioBoxAesthetics.from_pretrained(args.model_name)
model.eval()
print(f"βœ… Loaded model from pretrained {args.model_name}")

# test inference
wav = model.load_audio("sample_audio/libritts_spk-84.wav")
predictions = model.predict_from_wavs(wav)
print(predictions)
print("βœ… Inference test passed")
10 changes: 6 additions & 4 deletions src/audiobox_aesthetics/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import torchaudio
import torch.nn.functional as F

from .model.aes_wavlm import Normalize, WavlmAudioEncoderMultiOutput
from audiobox_aesthetics.model.aes_wavlm import Normalize, WavlmAudioEncoderMultiOutput

Batch = Dict[str, Any]

Expand Down Expand Up @@ -113,6 +113,8 @@ def setup_model(self):
"bf16": torch.bfloat16,
}.get(self.precision)

print("using precision", self.precision)

self.target_transform = {
axis: Normalize(
mean=ckpt["target_transform"][axis]["mean"],
Expand Down Expand Up @@ -205,8 +207,8 @@ def main_predict(input_file, ckpt, batch_size=10):
for ii in tqdm(range(0, len(metadata), batch_size)):
output = predictor.forward(metadata[ii : ii + batch_size])
outputs.extend(output)
assert len(outputs) == len(
metadata
), f"Output {len(outputs)} != input {len(metadata)} length"
assert len(outputs) == len(metadata), (
f"Output {len(outputs)} != input {len(metadata)} length"
)

return outputs
Loading