Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GKD #2624

Open
Tracked by #2617
qgallouedec opened this issue Jan 23, 2025 · 1 comment
Open
Tracked by #2617

GKD #2624

qgallouedec opened this issue Jan 23, 2025 · 1 comment
Assignees

Comments

@qgallouedec
Copy link
Member

No description provided.

@qgallouedec
Copy link
Member Author

from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import GKDTrainer, GKDConfig

tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")

batch_size = 4
gradient_accumulation_steps = 2
output_dir = f"GKD-bsz{batch_size}-grad_acc{gradient_accumulation_steps}-fixed"


model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
model = AutoModelForCausalLM.from_pretrained(model_id)
teacher_model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token


training_args = GKDConfig(
    output_dir=output_dir,
    per_device_train_batch_size=batch_size,
    dataloader_drop_last=True,
    gradient_accumulation_steps=gradient_accumulation_steps,
    logging_steps=2, 
)
dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling")

trainer = GKDTrainer(
    model=model_id,
    teacher_model=model_id,
    args=training_args,
    train_dataset=dummy_dataset["train"],
    processing_class=tokenizer,
)

trainer.train()

I'm not sure how to show that the problem arises with GKD. What I do know is that it should happen. @kashif can you take a look?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants