diff --git a/nemo/collections/multimodal/data/clip/augmentations/augmentations.py b/nemo/collections/multimodal/data/clip/augmentations/augmentations.py index d1de22f687e5..79c762d16610 100644 --- a/nemo/collections/multimodal/data/clip/augmentations/augmentations.py +++ b/nemo/collections/multimodal/data/clip/augmentations/augmentations.py @@ -46,6 +46,8 @@ @dataclass class AugmentationCfg: + """Augmentation Config""" + scale: Tuple[float, float] = (0.9, 1.0) ratio: Optional[Tuple[float, float]] = None color_jitter: Optional[Union[float, Tuple[float, float, float]]] = None @@ -56,6 +58,8 @@ class AugmentationCfg: class ResizeMaxSize(nn.Module): + """Resize module""" + def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0): super().__init__() if not isinstance(max_size, int): @@ -66,6 +70,7 @@ def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', self.fill = fill def forward(self, img): + # pylint: disable=C0116 if isinstance(img, torch.Tensor): height, width = img.shape[:2] else: @@ -82,6 +87,7 @@ def forward(self, img): def _convert_to_rgb(image): + # pylint: disable=C0116 return image.convert('RGB') @@ -94,7 +100,9 @@ def image_transform( fill_color: int = 0, aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, ): + # pylint: disable=C0116 assert TORCHVISION_AVAILABLE, "Torchvision imports failed but they are required." + mean = mean or OPENAI_DATASET_MEAN if not isinstance(mean, (list, tuple)): mean = (mean,) * 3 @@ -139,7 +147,9 @@ def image_transform( train_transform = Compose( [ RandomResizedCrop( - image_size, scale=aug_cfg_dict.pop('scale'), interpolation=InterpolationMode.BICUBIC, + image_size, + scale=aug_cfg_dict.pop('scale'), + interpolation=InterpolationMode.BICUBIC, ), _convert_to_rgb, ToTensor(), @@ -160,6 +170,10 @@ def image_transform( CenterCrop(image_size), ] transforms.extend( - [_convert_to_rgb, ToTensor(), normalize,] + [ + _convert_to_rgb, + ToTensor(), + normalize, + ] ) return Compose(transforms) diff --git a/nemo/collections/multimodal/data/clip/clip_dataset.py b/nemo/collections/multimodal/data/clip/clip_dataset.py index 448efba4b8ba..22798dda4290 100644 --- a/nemo/collections/multimodal/data/clip/clip_dataset.py +++ b/nemo/collections/multimodal/data/clip/clip_dataset.py @@ -72,6 +72,31 @@ def tokenize(texts: Union[str, List[str]], tokenizer: Any, context_length: int = return result +# pylint: disable=C0116 +def get_preprocess_fns_params( + img_h, img_w, img_mean=None, img_std=None, is_train=True, max_position_embedding=None, tokenizer=None +): + + # This is equivalent to `get_preprocess_fns` but does not need the whole config to get the functions. This is + # Particularly used in Nemo2 + # Define transforms + img_size = (img_h, img_w) + img_transform = image_transform( + img_size, + is_train=is_train, + mean=img_mean, + std=img_std, + ) + text_transform = lambda x: x + if tokenizer is not None: + text_transform = partial( + tokenize, + tokenizer=tokenizer, + context_length=max_position_embedding, + ) + return img_transform, text_transform + + def get_preprocess_fns(model_cfg, tokenizer=None, is_train=True): # Define transforms img_size = (model_cfg.vision.get("img_h"), model_cfg.vision.get("img_w")) @@ -104,7 +129,8 @@ def tuple_to_dict(inp): def transform_fn(sample, img_transform, text_transform): image, text = sample["jpg"], sample["txt"] - return img_transform(image), text_transform(text) + img_transformed, text_transformed = img_transform(image), text_transform(text) + return img_transformed, text_transformed def build_train_valid_datasets( @@ -144,8 +170,79 @@ def custom_collate(batch): return default_collate(batch) +def build_imagenet_validation_dataloader_params( + imagenet_val, + img_h, + img_w, + mbs, + gbs, + num_workers=0, + pin_memory=True, + img_mean=None, + img_std=None, + is_train=False, + max_position_embedding=None, + tokenizer=None, +): + # This is equivalent to `build_imagenet_validation_dataloader` but does not need the whole config. + # Particularly used in Nemo2 + val_image_transform, text_transform = get_preprocess_fns_params( + img_h, + img_w, + img_mean, + img_std, + is_train=is_train, + max_position_embedding=max_position_embedding, + tokenizer=tokenizer, + ) + + imagenet_val_data = {} + + imagenet_path = imagenet_val + if imagenet_path is None: + return None + + image_dataset = ImageFolder( + root=imagenet_path, + transform=val_image_transform, + ) + + image_batch_sampler = MegatronPretrainingSampler( + total_samples=len(image_dataset), + consumed_samples=0, + micro_batch_size=mbs, + global_batch_size=gbs, + data_parallel_rank=parallel_state.get_data_parallel_rank(), + data_parallel_size=parallel_state.get_data_parallel_world_size(), + drop_last=False, + ) + + imagenet_val_data["images"] = torch.utils.data.DataLoader( + image_dataset, + batch_sampler=image_batch_sampler, + num_workers=num_workers, + collate_fn=custom_collate, + pin_memory=pin_memory, + persistent_workers=True, + ) + text_dataset = ImagenetClassnameDataset(imagenet_classnames, openai_imagenet_template, text_transform) + + imagenet_val_data["texts"] = torch.utils.data.DataLoader( + text_dataset, + batch_size=text_dataset.num_templates, + num_workers=0, + pin_memory=True, + persistent_workers=False, + drop_last=False, + ) + + return imagenet_val_data + + +# pylint: enable=C0116 # For zero-shot imagenet validation def build_imagenet_validation_dataloader(model_cfg, tokenizer=None): + """Build dataloaders""" val_image_transform, text_transform = get_preprocess_fns(model_cfg, tokenizer, is_train=False) data_cfg = model_cfg.data @@ -192,7 +289,10 @@ def build_imagenet_validation_dataloader(model_cfg, tokenizer=None): class ImagenetClassnameDataset(Dataset): + """Imagenet class dataset""" + def __init__(self, classnames, templates, text_transform): + # pylint: disable=C0116 self.num_templates = len(templates) self.samples = [] for classname in classnames: @@ -200,7 +300,9 @@ def __init__(self, classnames, templates, text_transform): self.samples.extend(text_transform(texts)) def __getitem__(self, index): + # pylint: disable=C0116 return self.samples[index] def __len__(self): + # pylint: disable=C0116 return len(self.samples) diff --git a/nemo/collections/multimodal/data/energon/base.py b/nemo/collections/multimodal/data/energon/base.py index c29935880889..4aeb142ffa23 100644 --- a/nemo/collections/multimodal/data/energon/base.py +++ b/nemo/collections/multimodal/data/energon/base.py @@ -64,11 +64,16 @@ def __init__( micro_batch_size: int = 1, global_batch_size: int = 1, num_workers: int = 1, + num_val_workers: int | None = None, pin_memory: bool = True, + shuffle_buffer_size: int = 100, + max_samples_per_sequence: int | None = None, multimodal_sample_config: Optional[MultiModalSampleConfig] = MultiModalSampleConfig(), task_encoder: Optional[MultiModalTaskEncoder] = None, decoder_seq_length: Optional[int] = None, packing_buffer_size: Optional[int] = None, + validation_task_encoder: Optional[MultiModalTaskEncoder] = None, + **kwargs, ) -> None: """ Initialize the EnergonMultiModalDataModule. @@ -80,13 +85,20 @@ def __init__( seq_length (int, optional): The maximum sequence length for tokenized text. Defaults to 2048. micro_batch_size (int, optional): The batch size for training and validation. Defaults to 1. num_workers (int, optional): Number of workers for data loading. Defaults to 1. + num_val_workers (int, optional): Number of workers for validation data loading. Defaults to num_workers. pin_memory (bool, optional): Whether to pin memory in the DataLoader. Defaults to True. multimodal_sample_config (MultiModalSampleConfig, optional): Configuration object for multimodal samples. Defaults to MultiModalSampleConfig(). + shuffle_buffer_size (int, optional): Size of the shuffle buffer. Defaults to 100. + max_samples_per_sequence (int, optional): Maximum number of samples per sequence to load from memory. + Defaults to None (loads the whole tar file at once). task_encoder (MultiModalTaskEncoder, optional): Encoder responsible for encoding and batching samples. If not provided, a default (MultimodalTaskEncoder) encoder will be created. Defaults to None. - decoder_seq_length (int, optional): The maximum sequence length for the decoder. Used in encoder-decoder models. + decoder_seq_length (int, optional): The max sequence length for the decoder. Used in encoder-decoder models packing_buffer_size (int, optional): Size of the packing buffer for batched samples. Defaults to None. + validation_task_encoder (MultiModalTaskEncoder, optional): Encoder responsible for encoding + and batching samples for validation. Defaults to None and will be the same as task_encoder. + **kwargs: Additional keyword arguments. Will be passed to get_train_dataset() of Energon """ super().__init__() @@ -102,6 +114,8 @@ def __init__( self.num_workers = num_workers self.pin_memory = pin_memory self.multimodal_sample_config = multimodal_sample_config + self.shuffle_buffer_size = shuffle_buffer_size + self.max_samples_per_sequence = max_samples_per_sequence self.task_encoder = task_encoder or MultiModalTaskEncoder( tokenizer=self.tokenizer, image_processor=self.image_processor, @@ -117,10 +131,17 @@ def __init__( self.train_dataloader_object = None self.val_dataloader_object = None self.packing_buffer_size = packing_buffer_size + self.validation_task_encoder = validation_task_encoder or self.task_encoder + self.num_val_workers = num_val_workers or self.num_workers + self.kwargs = kwargs def io_init(self, **kwargs) -> fdl.Config[Self]: - cfg_kwargs = {k: deepcopy(v) for k, v in kwargs.items() if k not in ['image_processor', 'task_encoder']} + cfg_kwargs = { + k: deepcopy(v) + for k, v in kwargs.items() + if k not in ['image_processor', 'task_encoder', 'validation_task_encoder'] + } for val in cfg_kwargs.values(): if not serialization.find_node_traverser(type(val)): @@ -142,18 +163,27 @@ def datasets_provider(self, worker_config, split: Literal['train', 'val'] = 'val Returns: Dataset: The dataset configured for the specified split. """ + if split not in {'train', 'val'}: raise ValueError("Invalid value for split. Allowed values are 'train' or 'val'.") + + if split == "train": + task_encoder = self.task_encoder + else: + task_encoder = self.validation_task_encoder + _dataset = get_train_dataset( self.path, batch_size=self.micro_batch_size, - task_encoder=self.task_encoder, + task_encoder=task_encoder, worker_config=worker_config, - max_samples_per_sequence=None, packing_buffer_size=self.packing_buffer_size, - shuffle_buffer_size=100, split_part=split, + shuffle_buffer_size=self.shuffle_buffer_size, + max_samples_per_sequence=self.max_samples_per_sequence, + **self.kwargs, ) + return _dataset def train_dataloader(self) -> TRAIN_DATALOADERS: @@ -216,9 +246,9 @@ def val_dataloader(self) -> EVAL_DATALOADERS: if not parallel_state.is_initialized(): logging.info( f"Muiltimodal val data loader parallel state is not initialized," - "using default worker config with no_workers {self.num_workers}" + f"using default worker config with no_workers {self.num_workers}" ) - worker_config = WorkerConfig.default_worker_config(self.num_workers) + worker_config = WorkerConfig.default_worker_config(self.num_val_workers) else: rank = parallel_state.get_data_parallel_rank() world_size = parallel_state.get_data_parallel_world_size() @@ -248,7 +278,7 @@ def test_dataloader(self) -> None: Returns: None """ - logging.warning(f"Multimodal dataloader test dataset split does not exist") + logging.warning("Multimodal dataloader test dataset split does not exist") return None def state_dict(self) -> Dict[str, Any]: @@ -264,7 +294,7 @@ def state_dict(self) -> Dict[str, Any]: if self.trainer: dataloader_obj = self.trainer.train_dataloader - state = dataloader_obj.save_state() + state = dataloader_obj.save_state_global(dst_rank=0) consumed_samples = self.data_sampler.compute_consumed_samples( self.trainer.global_step - self.init_global_step ) @@ -286,7 +316,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: """ if not 'dataloader_state' in state_dict: logging.warning( - f"Data loader state cannot be resumed from state_dict," + f"Data loader state cannot be resumed from state_dict, " f"it does not have the required key dataloader_state. It has {state_dict.keys()}" ) return @@ -294,16 +324,19 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: state = state_dict['dataloader_state'] try: if self.trainer: - self.trainer.datamodule.train_dataloader().restore_state(state) - logging.info(f" Multimodal dataloader state restored") + self.trainer.datamodule.train_dataloader().restore_state(state, src_rank=0) + logging.info("Multimodal dataloader state restored") else: logging.error(f"Cannot restore state from state_dict {state_dict}") raise ValueError( - f"Cannot restore state from state_dict: " - f"Is the trainer object is initialized and attached to datamodule???" + "Cannot restore state from state_dict: " + "Is the trainer object is initialized and attached to datamodule???" ) except Exception as e: - raise RuntimeError(f"Failed to dataloader restore state due to: {e}") + logging.warning( + f"Failed to dataloader restore state due to [Please ensure you are using same version " + f"of energon while saving and loading, Continuing without restoring data loader] : {e}" + ) try: from megatron.core.num_microbatches_calculator import update_num_microbatches diff --git a/nemo/collections/multimodal/losses/clip_loss.py b/nemo/collections/multimodal/losses/clip_loss.py index 694f29a86a9e..9811cafc7ee6 100644 --- a/nemo/collections/multimodal/losses/clip_loss.py +++ b/nemo/collections/multimodal/losses/clip_loss.py @@ -30,7 +30,10 @@ def gather_features( - image_features, text_features, local_loss=False, gather_with_grad=False, + image_features, + text_features, + local_loss=False, + gather_with_grad=False, ): """ Gathers image and text features across multiple data parallel processes. @@ -109,8 +112,12 @@ class ClipLoss(nn.Module): """ def __init__( - self, local_loss=False, gather_with_grad=False, cache_labels=False, + self, + local_loss=False, + gather_with_grad=False, + cache_labels=False, ): + """Init""" super().__init__() self.local_loss = local_loss self.gather_with_grad = gather_with_grad @@ -124,6 +131,7 @@ def __init__( self.rank = parallel_state.get_data_parallel_rank() def forward(self, output_tensor): + """Forward for loss""" image_features, text_features, logit_scale = output_tensor device = image_features.device if self.world_size > 1: diff --git a/nemo/collections/vlm/__init__.py b/nemo/collections/vlm/__init__.py index 3e9eebe47cbe..ffa6ae124cc1 100644 --- a/nemo/collections/vlm/__init__.py +++ b/nemo/collections/vlm/__init__.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from nemo.collections.vlm.clip.data import ClipMockDataModule +from nemo.collections.vlm.clip.model import ClipConfigB32, ClipConfigL14, CLIPModel from nemo.collections.vlm.hf.data.hf_dataset import HFDatasetDataModule from nemo.collections.vlm.hf.model.hf_auto_model_for_image_text_to_text import HFAutoModelForImageTextToText from nemo.collections.vlm.llava_next.data import LlavaNextMockDataModule, LlavaNextTaskEncoder @@ -88,9 +90,15 @@ "mllama_11b", "mllama_90b", "llava_next_7b", + "LlavaNextConfig", "LlavaNextConfig7B", "LlavaNextConfig13B", "LlavaNextModel", "LlavaNextMockDataModule", "LlavaNextTaskEncoder", + "CLIPModel", + "LoRA", + "ClipConfigL14", + "ClipConfigB32", + "ClipMockDataModule", ] diff --git a/nemo/collections/vlm/clip/__init__.py b/nemo/collections/vlm/clip/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/nemo/collections/vlm/clip/data/__init__.py b/nemo/collections/vlm/clip/data/__init__.py new file mode 100644 index 000000000000..58185aad1fb6 --- /dev/null +++ b/nemo/collections/vlm/clip/data/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nemo.collections.vlm.clip.data.mock import MockDataModule as ClipMockDataModule + +__all__ = ['ClipMockDataModule'] diff --git a/nemo/collections/vlm/clip/data/clip_data_module.py b/nemo/collections/vlm/clip/data/clip_data_module.py new file mode 100644 index 000000000000..8dca7b286d67 --- /dev/null +++ b/nemo/collections/vlm/clip/data/clip_data_module.py @@ -0,0 +1,115 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +from megatron.energon import Cooker, DefaultTaskEncoder, SkipSample, basic_sample_keys +from torchvision import transforms + +from nemo.collections.multimodal.data.clip.augmentations.augmentations import image_transform +from nemo.collections.multimodal.data.clip.clip_dataset import tokenize +from nemo.lightning.io.mixin import IOMixin +from nemo.utils import logging + + +def cook_raw_iamges(sample: dict) -> dict: + """ + Processes a raw sample dictionary from energon dataset and returns a new dictionary with specific keys. + + Args: + sample (dict): The input dictionary containing the raw sample data. + + Returns: + dict: A new dictionary containing the processed sample data with the following keys: + - All keys from the result of `basic_sample_keys(sample)` + - 'jpg': original images + - 'txt': contains raw text + """ + if "jpg" not in sample or "txt" not in sample: + logging.info(f"Raw sample {sample} does not contain a jpg or txt file") + raise SkipSample + + return dict( + **basic_sample_keys(sample), + image=sample['jpg'], + txt=sample['txt'], + ) + + +class ClipTaskEncoder(DefaultTaskEncoder, IOMixin): + """ + A simple task encoder for CLIP. The input sample is expected to have a 'jpg' key and a 'txt' key. + Cookers are used to process raw sample data into a dictionary with specific keys. raw_sample -> (jpg, txt) + + Args: + img_h: height of the image + img_w: width of the image + img_mean: mean of the image + img_std: standard deviation of the image + max_length: max length of the text + tokenizer: tokenizer for text + image_processor: image processor for image + is_train: whether it is training or not + + This class augments and tokenizes the sample using the provided tokenizer and image processor + and returns a dictionary. + """ + + cookers = [Cooker(cook_raw_iamges)] + + def __init__( + self, + img_h: int = 224, + img_w: int = 224, + img_mean: int = None, + img_std: int = None, + max_length: int = 77, + tokenizer: Optional = None, + image_processor: Optional = None, + is_train: bool = True, + ): + super().__init__() + + self.tokenizer = tokenizer + self.image_processor = image_processor + + if image_processor is None or tokenizer is None: + logging.warning("Processor or tokenizer are not provided! Fall back to `openai/clip-vit-large-patch14`.") + from transformers import AutoProcessor + + from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer + + processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14") + self.tokenizer = AutoTokenizer("openai/clip-vit-large-patch14") + self.image_processor = processor.image_processor + + img_size = (img_h, img_w) + self.img_size = img_size + + self.img_transform = image_transform( + img_size, + is_train=is_train, + mean=img_mean, + std=img_std, + ) + self.toPIL = transforms.ToPILImage() + self.max_length = max_length + + def encode_sample(self, sample: dict) -> dict: + """ + Encodes a sample dictionary into a dictionary with specific keys. Applied the augmenters and tokenizers. + """ + sample_new = {} + sample_new["images"] = self.img_transform(sample["image"]) + sample_new["captions"] = tokenize(sample["txt"], self.tokenizer, context_length=self.max_length) + return sample_new diff --git a/nemo/collections/vlm/clip/data/mock.py b/nemo/collections/vlm/clip/data/mock.py new file mode 100644 index 000000000000..4df3e3e6034a --- /dev/null +++ b/nemo/collections/vlm/clip/data/mock.py @@ -0,0 +1,233 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Optional + +import lightning.pytorch as pl +import numpy as np +import torch +from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS +from torch.utils import data +from torch.utils.data import DataLoader, Dataset + +from nemo.lightning.pytorch.plugins import MegatronDataSampler +from nemo.utils import logging + + +class MockDataModule(pl.LightningDataModule): + """ + Mock data module with data sampling and preprocessing configurations. + """ + + def __init__( + self, + seq_length: int = 77, + decoder_seq_length: Optional[int] = None, + tokenizer: Optional = None, + image_processor: Optional = None, + micro_batch_size: int = 4, + global_batch_size: int = 8, + rampup_batch_size: Optional[List[int]] = None, + num_train_samples: int = 10_000_000, + num_val_samples: int = 10_000_000, + num_test_samples: int = 10_000_000, + num_workers: int = 8, + pin_memory: bool = True, + persistent_workers: bool = False, + task_encoder: Optional[Any] = None, + ): + """ + Initializes the mock data module with data sampling and preprocessing configurations. + task_encoder: This Mock data module uses Energon Task encoder if provided. + + Args: + seq_length (int): Maximum sequence length for tokens. + decoder_seq_length (Optional[int]): Sequence length for the decoder. Used by Megatron Sampler. + tokenizer: Tokenizer for text processing. + image_processor: Processor for image preprocessing. + micro_batch_size (int): Batch size for training and validation. + global_batch_size (int): Total batch size across GPUs. + rampup_batch_size (Optional[List[int]]): Batch size ramp-up schedule. Used by Megatron Sampler. + num_train_samples (int): Number of training samples. + num_val_samples (int): Number of validation samples. + num_test_samples (int): Number of testing samples. + num_workers (int): Number of workers for data loading. + pin_memory (bool): Whether to pin memory for data loading. + persistent_workers (bool): Whether workers should remain persistent. + task_encoder: Task encoder for Energon tasks. + """ + super().__init__() + self.seq_length = seq_length + self.decoder_seq_length = decoder_seq_length + self.micro_batch_size = micro_batch_size + self.global_batch_size = global_batch_size + self.num_train_samples = num_train_samples + self.num_val_samples = num_val_samples + self.num_test_samples = num_test_samples + self.num_workers = num_workers + self.pin_memory = pin_memory + self.persistent_workers = persistent_workers + self.task_encoder = task_encoder + self.tokenizer = tokenizer + self.image_processor = image_processor + + if tokenizer is None or image_processor is None: + logging.warning("Processor or tokenizer are not provided! Fall back to `openai/clip-vit-large-patch14`.") + from transformers import AutoProcessor + + from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer + + processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14") + self.tokenizer = tokenizer or AutoTokenizer("openai/clip-vit-large-patch14") + self.image_processor = image_processor or processor.image_processor + self.data_sampler = MegatronDataSampler( + seq_len=self.seq_length, + decoder_seq_len=self.decoder_seq_length, + micro_batch_size=micro_batch_size, + global_batch_size=global_batch_size, + rampup_batch_size=rampup_batch_size, + ) + + def setup(self, stage: str = "") -> None: + # pylint: disable=C0116 + self._train_ds = _MockClipDataset( + self.tokenizer, + self.image_processor, + "train", + self.num_train_samples, + self.seq_length, + task_encoder=self.task_encoder, + ) + self._validation_ds = _MockClipDataset( + self.tokenizer, self.image_processor, "valid", self.num_val_samples, self.seq_length + ) + self._test_ds = _MockClipDataset( + self.tokenizer, self.image_processor, "test", self.num_test_samples, self.seq_length + ) + + def train_dataloader(self) -> TRAIN_DATALOADERS: + # pylint: disable=C0116 + if not hasattr(self, "_train_ds"): + self.setup() + return self._create_dataloader(self._train_ds) + + def val_dataloader(self) -> EVAL_DATALOADERS: + # pylint: disable=C0116 + if not hasattr(self, "_validation_ds"): + self.setup() + return self._create_dataloader(self._validation_ds) + + def test_dataloader(self) -> EVAL_DATALOADERS: + # pylint: disable=C0116 + if not hasattr(self, "_test_ds"): + self.setup() + return self._create_dataloader(self._test_ds) + + def _create_dataloader(self, dataset, **kwargs) -> DataLoader: + # pylint: disable=C0116 + return DataLoader( + dataset, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + persistent_workers=self.persistent_workers, + collate_fn=dataset.collate_fn, + **kwargs, + ) + + def state_dict(self) -> Dict[str, Any]: + """ + Save the state of the data module. + + This method is called when saving a checkpoint. It generates and saves the state of the data module, + including the state of the dataloader and the number of consumed samples. + + Returns: + Dict[str, Any]: A dictionary containing the state of the data module. + """ + + logging.warning("trainer object not connected to data module object returning empty state") + return {} + + +class _MockClipDataset(Dataset): + def __init__( + self, + tokenizer, + image_processor, + name: str, + num_samples: int, + seq_length: int, + seed: int = 42, + task_encoder=None, + ) -> None: + super().__init__() + self.name = name + self.seq_length = seq_length + + self.vocab_size = tokenizer.vocab_size + + crop_size = image_processor.crop_size + self.image_height, self.image_width = crop_size["height"], crop_size["width"] + + self.length = num_samples + self.seed = seed + self.task_encoder = task_encoder + + def __len__(self) -> int: + return self.length + + def __getitem__(self, idx) -> Dict[str, torch.Tensor]: + # Generate data of the expected size and datatype (based on GPTDataset). + + np_gen = np.random.default_rng(seed=(self.seed + idx)) + tokens = torch.from_numpy(np_gen.integers(self.vocab_size, size=[self.seq_length], dtype=np.int64)) + images = torch.from_numpy(np_gen.random(size=[3, self.image_height, self.image_width], dtype=np.float32)) + + if self.task_encoder is not None: + # Use energon task encoder if provided + return self.task_encoder.encode_sample({"image": images, "txt": "This is Random Mock Text"}) + + return { + "images": images, + "captions": tokens, + } + + def _collate_fn(self, batch): + """ + A default implementation of a collation function. + Users should override this method to define custom data loaders. + """ + collated_batch = data.dataloader.default_collate(batch) + collated_batch["attention_mask"] = None + return collated_batch + + def collate_fn(self, batch): + """Method that user pass as functor to DataLoader. + + The method optionally performs neural type checking and add types to the outputs. + + Please note, subclasses of Dataset should not implement `input_types`. + + # Usage: + dataloader = torch.utils.data.DataLoader( + ...., + collate_fn=dataset.collate_fn, + .... + ) + + Returns + ------- + Collated batch, with or without types. + """ + return self._collate_fn(batch) diff --git a/nemo/collections/vlm/clip/loss/__init__.py b/nemo/collections/vlm/clip/loss/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/nemo/collections/vlm/clip/loss/clip_loss.py b/nemo/collections/vlm/clip/loss/clip_loss.py new file mode 100644 index 000000000000..01171934a5f4 --- /dev/null +++ b/nemo/collections/vlm/clip/loss/clip_loss.py @@ -0,0 +1,168 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, Tuple + +import torch +import torch.distributed.nn +from megatron.core import parallel_state +from torch import distributed as dist +from torch.nn import functional as F + +from nemo.collections.nlp.modules.common.megatron.utils import average_losses_across_data_parallel_group +from nemo.lightning.megatron_parallel import MegatronLossReduction + + +def gather_features( + image_features, + text_features, + local_loss=False, + gather_with_grad=False, +): + """ + Gathers image and text features across multiple data parallel processes. + + This function is designed to work in a distributed environment where multiple + processes are handling different portions of data. It gathers the image and text + features from all processes to form a complete set of features across the entire dataset. + This is crucial for calculating loss in models like CLIP, especially when the model is + trained in a data parallel fashion. + + Parameters: + image_features (Tensor): A tensor containing the image features. + text_features (Tensor): A tensor containing the text features. + local_loss (bool, optional): A flag to determine whether to use local loss calculation. + Defaults to False. + gather_with_grad (bool, optional): A flag to enable gathering with gradient computation. + This is not currently working in the latest PyTorch version. + Defaults to False. + + Returns: + Tuple[Tensor, Tensor]: A tuple containing the gathered image features and text features + across all processes. + """ + data_parallel_world_size = parallel_state.get_data_parallel_world_size() + data_parallel_rank = parallel_state.get_data_parallel_rank() + data_parallel_group = parallel_state.get_data_parallel_group() + + if gather_with_grad: + # https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/loss.py#L48 + all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0) + all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0) + + else: + gathered_image_features = [torch.zeros_like(image_features) for _ in range(data_parallel_world_size)] + gathered_text_features = [torch.zeros_like(text_features) for _ in range(data_parallel_world_size)] + dist.all_gather(gathered_image_features, image_features, group=data_parallel_group) + dist.all_gather(gathered_text_features, text_features, group=data_parallel_group) + if not local_loss: + # ensure grads for local rank when all_* features don't have a gradient + # https://amsword.medium.com/gradient-backpropagation-with-torch-distributed-all-gather-9f3941a381f8 + gathered_image_features[data_parallel_rank] = image_features + gathered_text_features[data_parallel_rank] = text_features + all_image_features = torch.cat(gathered_image_features, dim=0) + all_text_features = torch.cat(gathered_text_features, dim=0) + + return all_image_features, all_text_features + + +class ClipMegatronLoss(MegatronLossReduction): + """ + A custom loss module for CLIP (Contrastive Languageā€“Image Pretraining) training. + + This module is specifically designed for calculating the loss in CLIP model training, + supporting features like local loss calculation, gradient gathering, and label caching + for efficiency in a distributed training setup. + + Parameters: + local_loss (bool, optional): If True, calculates loss locally on each data parallel process. + Defaults to False. + gather_with_grad (bool, optional): If True, gathers gradients during loss calculation. + Currently not functional in the latest PyTorch version. + Defaults to False. + cache_labels (bool, optional): If True, caches labels for reuse in subsequent iterations, + improving performance. Defaults to False. + + Attributes: + world_size (int): The size of the data parallel group (number of processes). + rank (int): The rank of the current process within the data parallel group. + + Methods: + forward(output_tensor): Computes the loss given the model's output tensor. This involves + gathering features across processes, computing logits, and + calculating the final cross-entropy loss. + """ + + def __init__( + self, + local_loss=False, + gather_with_grad=True, + cache_labels=False, + ): + super().__init__() + self.local_loss = local_loss + self.gather_with_grad = gather_with_grad + self.cache_labels = cache_labels + + # cache state + self.prev_num_logits = 0 + self.labels = {} + + self.world_size = parallel_state.get_data_parallel_world_size() + self.rank = parallel_state.get_data_parallel_rank() + + def forward( + self, batch: Dict[str, torch.Tensor], forward_out: Tuple[torch.Tensor, torch.Tensor, torch.Tensor] + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + image_features, text_features, logit_scale = forward_out + device = image_features.device + if self.world_size > 1: + all_image_features, all_text_features = gather_features( + image_features, text_features, self.local_loss, self.gather_with_grad + ) + + if self.local_loss: + logits_per_image = logit_scale * image_features @ all_text_features.T + logits_per_text = logit_scale * text_features @ all_image_features.T + else: + logits_per_image = logit_scale * all_image_features @ all_text_features.T + logits_per_text = logits_per_image.T + else: + logits_per_image = logit_scale * image_features @ text_features.T + logits_per_text = logit_scale * text_features @ image_features.T + + # calculated ground-truth and cache if enabled + num_logits = logits_per_image.shape[0] + if self.prev_num_logits != num_logits or device not in self.labels: + labels = torch.arange(num_logits, device=device, dtype=torch.long) + if self.world_size > 1 and self.local_loss: + labels = labels + num_logits * self.rank + if self.cache_labels: + self.labels[device] = labels + self.prev_num_logits = num_logits + else: + labels = self.labels[device] + + total_loss = (F.cross_entropy(logits_per_image, labels) + F.cross_entropy(logits_per_text, labels)) / 2 + + reduced_loss = average_losses_across_data_parallel_group([total_loss]) + + return total_loss, {"avg": reduced_loss} + + def reduce(self, losses_reduced_per_micro_batch) -> torch.Tensor: + if losses_reduced_per_micro_batch: + if "avg" in losses_reduced_per_micro_batch[0]: + loss_tensors_list = [loss_reduced["avg"] for loss_reduced in losses_reduced_per_micro_batch] + loss_tensor = torch.concat(loss_tensors_list) + return loss_tensor.mean() + return torch.tensor(0.0, device=torch.cuda.current_device()) diff --git a/nemo/collections/vlm/clip/model/__init__.py b/nemo/collections/vlm/clip/model/__init__.py new file mode 100644 index 000000000000..0b18680fc53c --- /dev/null +++ b/nemo/collections/vlm/clip/model/__init__.py @@ -0,0 +1,32 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.collections.vlm.clip.model.base import ClipConfig, CLIPModel, CLIPTextModelConfig, CLIPViTConfig +from nemo.collections.vlm.clip.model.clip import ( + ClipConfigB32, + ClipConfigL14, + CLIPTextModelL_14_224_Config, + CLIPViTL_14_224_Config, +) + +__all__ = [ + "CLIPViTConfig", + "CLIPTextModelConfig", + "ClipConfig", + "CLIPViTL_14_224_Config", + "CLIPTextModelL_14_224_Config", + "ClipConfigL14", + "ClipConfigB32", + "CLIPModel", +] diff --git a/nemo/collections/vlm/clip/model/base.py b/nemo/collections/vlm/clip/model/base.py new file mode 100644 index 000000000000..518372f75b72 --- /dev/null +++ b/nemo/collections/vlm/clip/model/base.py @@ -0,0 +1,497 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Callable, Dict, Optional + +import lightning.pytorch as L +import numpy as np +import torch +import torch.distributed +import torch.nn.functional as F +from megatron.core.enums import ModelType +from megatron.core.extensions.transformer_engine import TENorm +from megatron.core.models.gpt import GPTModel as MCoreGPTModel +from megatron.core.models.vision.clip_vit_model import CLIPViTModel as MCoreCLIPViTModel +from megatron.core.optimizer import OptimizerConfig +from megatron.core.transformer import MegatronModule +from megatron.core.transformer.enums import AttnMaskType as MCoreAttnMaskType +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_config import TransformerConfig +from tqdm import tqdm + +from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec +from nemo.collections.llm import fn +from nemo.collections.llm.gpt.model import transformer_engine_layer_spec +from nemo.collections.llm.gpt.model.base import default_layer_spec +from nemo.collections.multimodal.data.clip.clip_dataset import build_imagenet_validation_dataloader_params +from nemo.collections.nlp.modules.common.megatron.utils import average_losses_across_data_parallel_group +from nemo.collections.vlm.clip.loss.clip_loss import ClipMegatronLoss +from nemo.lightning import MegatronOptimizerModule, OptimizerModule, get_vocab_size, io +from nemo.utils import logging + + +# pylint: disable=C0116 +def clip_forward_step(model, batch) -> torch.Tensor: + forward_args = {"images": batch["images"], "captions": batch["captions"]} + return model(**forward_args) + + +def clip_data_step(dataloader_iter) -> Dict[str, torch.Tensor]: + batch = next(dataloader_iter) + + if isinstance(batch, tuple) and len(batch) == 3: + _batch = batch[0] + else: + _batch = batch + + if "captions" in _batch and len(_batch["captions"].shape) == 3: + _batch["captions"] = _batch["captions"].squeeze() + + _batch = {key: val.cuda(non_blocking=True) if val is not None else None for key, val in _batch.items()} + return _batch + + +def set_input_tensor(self, tensor): + pass + + +# pylint: enable=C0116 +@dataclass +class CLIPViTConfig(TransformerConfig, io.IOMixin): + """Clip ViT model config""" + + output_dim: int = 512 + add_class_token: bool = True + class_token_len: int = 8 + + patch_dim: int = 16 + img_h: int = 224 + img_w: int = 224 + vision_model_type: str = "clip" # ["clip", "siglip"] + transformer_layer_spec: ModuleSpec = transformer_engine_layer_spec + gated_linear_unit: bool = False + attention_softmax_in_fp32: bool = False + + # Without these the init for transformer will give error + num_layers: int = 1 # Placeholder, NOT used! + num_attention_heads: int = 8 # Placeholder, NOT used! + + def configure_model(self) -> "CLIPViTModel": + # pylint: disable=C0116 + transformer_layer_spec = self.transformer_layer_spec + if not isinstance(transformer_layer_spec, ModuleSpec): + from nemo.collections.vlm.layer_specs import get_layer_spec_te + + transformer_layer_spec = get_layer_spec_te(is_vit=True) + + transformer_layer_spec.submodules.self_attention.params['attn_mask_type'] = MCoreAttnMaskType.no_mask + self.transformer_layer_spec = transformer_layer_spec + + return CLIPViTModel( + self, + transformer_layer_spec=transformer_layer_spec, + add_class_token=self.add_class_token, + class_token_len=self.class_token_len, + patch_dim=self.patch_dim, + img_h=self.img_h, + img_w=self.img_w, + model_subtype=self.vision_model_type, + output_dim=self.output_dim, + ) + + +class CLIPViTModel(MCoreCLIPViTModel): + """Clip ViT model""" + + def __init__( + self, + transformer_config: TransformerConfig, + transformer_layer_spec: ModuleSpec, + add_class_token: bool = True, + class_token_len: int = 8, + patch_dim: int = 16, + img_h: int = 224, + img_w: int = 224, + model_subtype: str = "clip", + output_dim: int = 1024, + ): + # pylint: disable=C0116 + # TODO (yuya): need to handle post_process correctly in order to enable PP + self.output_dim = output_dim + + super().__init__( + transformer_config=transformer_config, + transformer_layer_spec=transformer_layer_spec, + add_class_token=add_class_token, + class_token_len=class_token_len, + patch_dim=patch_dim, + img_h=img_h, + img_w=img_w, + model_subtype=model_subtype, + ) + + self.final_layernorm = TENorm( + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, + ) + + self.head = torch.nn.Linear( + self.config.hidden_size, + self.output_dim, + bias=False, + ) + + def set_input_tensor(self, tensor): + # pylint: disable=C0116 + pass + + def forward(self, x): + # pylint: disable=C0116 + x = super().forward( + x, + ) + x = self.final_layernorm(x) + x = x[:, 0] + x = self.head(x) + return x + + +@dataclass +class CLIPTextModelConfig(TransformerConfig, io.IOMixin): + """Clip text model config""" + + output_dim: int = 512 + make_vocab_size_divisible_by: int = 128 + max_seq_length: int = 1024 + + share_embeddings_and_output_weights: bool = False + + # Imported from gpt/base model + use_transformer_engine_full_layer_spec: bool = False + transformer_layer_spec: ModuleSpec = default_layer_spec + + # Without these the init for transformer will give error + + def configure_model(self, tokenizer, pre_process=None, post_process=None) -> "CLIPTextModel": + # pylint: disable=C0116 + transformer_layer_spec = self.transformer_layer_spec + if not isinstance(transformer_layer_spec, ModuleSpec): + transformer_layer_spec = transformer_layer_spec(self) + + if hasattr(self, 'vocab_size'): + vocab_size = self.vocab_size + if tokenizer is not None: + logging.info( + f"Use preset vocab_size: {vocab_size}, original vocab_size: {tokenizer.vocab_size}, dummy tokens:" + f" {vocab_size - tokenizer.vocab_size}." + ) + else: + vocab_size = get_vocab_size(self, tokenizer.vocab_size, self.make_vocab_size_divisible_by) + + return CLIPTextModel( + transformer_config=self, + transformer_layer_spec=transformer_layer_spec, + vocab_size=vocab_size, + max_sequence_length=self.max_seq_length, + output_dim=self.output_dim, + share_embeddings_and_output_weights=self.share_embeddings_and_output_weights, + ) + + +class CLIPTextModel(MCoreGPTModel): + """Clip text model""" + + def __init__( + self, + transformer_config: TransformerConfig, + transformer_layer_spec: ModuleSpec, + vocab_size: int, + max_sequence_length: int, + output_dim: int = 1024, + share_embeddings_and_output_weights: bool = False, + ): + # pylint: disable=C0116 + # TODO (yuya): need to handle post_process correctly in order to enable PP + self.output_dim = output_dim + + # We give post_process as false to get hidden states instead of logits as we have one more layer head + super().__init__( + transformer_config, + transformer_layer_spec, + vocab_size, + max_sequence_length, + True, + False, + share_embeddings_and_output_weights, + ) + self.final_layernorm = TENorm( + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, + ) + + self.head = torch.nn.Linear( + self.config.hidden_size, + self.output_dim, + bias=False, + ) + self.position_ids = None + if self.pre_process: + self.position_ids = torch.arange(max_sequence_length).expand(1, -1).cuda() + + def forward(self, input_ids): + # pylint: disable=C0116 + x = super().forward(input_ids, position_ids=self.position_ids, attention_mask=None) + x = self.final_layernorm(x) + x = x[input_ids.argmax(dim=-1), torch.arange(x.shape[1])] + x = self.head(x) + return x + + def set_input_tensor(self, tensor): + # pylint: disable=C0116 + pass + + +@dataclass +class ClipConfig(TransformerConfig, io.IOMixin): + """Clip model config""" + + text_transformer_config: Optional[CLIPTextModelConfig] = None + vision_transformer_config: Optional[CLIPViTConfig] = None + get_attention_mask_from_fusion: bool = True + forward_step_fn: Callable = clip_forward_step + data_step_fn: Callable = clip_data_step + + # Without these the init for transformer will give error + num_layers: int = 1 # Placeholder, NOT used! + num_attention_heads: int = 8 # Placeholder, NOT used! + hidden_size: int = 768 # Placeholder, NOT used! + + def configure_model(self, tokenizer, pre_process=True, post_process=True): + # pylint: disable=C0116 + print(self.kv_channels) + return MCoreClipModel( + self, + tokenizer=tokenizer, + pre_process=pre_process, + post_process=post_process, + ) + + +class MCoreClipModel(MegatronModule): + """Clip model""" + + def __init__(self, config: ClipConfig, tokenizer, pre_process=True, post_process=True) -> None: + # pylint: disable=C0116 + super().__init__(config=config) + self.pre_process = pre_process + self.post_process = post_process + vision_transformer_config = config.vision_transformer_config + text_transformer_config = config.text_transformer_config + self.output_dim = config.vision_transformer_config.output_dim + self.vision_model = vision_transformer_config.configure_model() + self.text_model = text_transformer_config.configure_model( + tokenizer=tokenizer, pre_process=pre_process, post_process=post_process + ) + self.logit_scale = torch.nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + self.model_type = ModelType.encoder_or_decoder + + def forward(self, images: torch.Tensor, captions: torch.Tensor): + # pylint: disable=C0116 + image_features = self.vision_model(images) + text_features = self.text_model(captions) + if self.post_process: + return F.normalize(image_features, dim=-1), F.normalize(text_features, dim=-1), self.logit_scale.exp() + + return image_features, text_features + + def set_input_tensor(self, tensor): + # pylint: disable=C0116 + pass + + +class CLIPModel(L.LightningModule, io.IOMixin, io.ConnectorMixin, fn.FNMixin): + """ + CLIPModel is the base class for all CLIP models. + + Args: + config: CLIPConfig. The configuration of the CLIP model. Please see the `CLIPConfig` for details. + optim: OptimizerModule. This module is just used for init and the actual optimizer is created via trainer API. + tokenizer: TokenizerSpec. This module is used for deciding the output length of the language model. + + # These parameters are just for imagenet validation + imagenet_val: Optional[str] = None: Optional path to imagenet validation dataset. + mbs: int = 8: Batch size for imagenet validation. + gbs: int = 8: Global Batch for imagenet validation. + max_workers: int = 4: Maximum number of workers used for imagenet validation. + + + """ + + def __init__( + self, + config: ClipConfig, + optim: Optional[OptimizerModule] = None, + tokenizer: Optional["TokenizerSpec"] = None, + imagenet_val: Optional[str] = None, + mbs: int = 8, + gbs: int = 8, + max_workers: int = 4, + ): + # pylint: disable=C0116 + super().__init__() + self.config = config + self.tokenizer = tokenizer + self.optim = optim or MegatronOptimizerModule(config=OptimizerConfig(lr=1e-4, use_distributed_optimizer=True)) + self.optim.connect(self) # This will bind the `configure_optimizers` method + self._training_loss_reduction = None + self._validation_loss_reduction = None + + # These parameters are just for imagenet validation + self.imagenet_val = imagenet_val + self.mbs = mbs + self.gbs = gbs + self.max_workers = max_workers + + def on_fit_start(self): + """Initialize the dataloader parameters for imagenet validation""" + if self.imagenet_val is not None: + self.imagenet_val = build_imagenet_validation_dataloader_params( + self.imagenet_val, + self.config.vision_transformer_config.img_h, + self.config.vision_transformer_config.img_w, + self.mbs, + self.gbs, + num_workers=self.max_workers, + max_position_embedding=self.config.text_transformer_config.max_seq_length, + tokenizer=self.tokenizer, + ) + + def configure_model(self) -> None: + """Configure the model""" + if not hasattr(self, "module"): + self.module = self.config.configure_model(self.tokenizer) + + def forward(self, images: torch.Tensor, captions: torch.Tensor): + # pylint: disable=C0116 + return self.module(images, captions) + + def data_step(self, dataloader_iter) -> Dict[str, torch.Tensor]: + # pylint: disable=C0116 + return self.config.data_step_fn(dataloader_iter) + + def forward_step(self, batch) -> torch.Tensor: + # pylint: disable=C0116 + return self.config.forward_step_fn(self, batch) + + def training_step(self, batch, batch_idx=None) -> torch.Tensor: + """In mcore the loss-function is part of the forward-pass (when labels are provided)""" + return self.forward_step(batch) + + def validation_step(self, batch, batch_idx=None) -> torch.Tensor: + """In mcore the loss-function is part of the forward-pass (when labels are provided)""" + return self.forward_step(batch) + + def zero_shot_classifier(self): + """Zero shot classifier for imagenet validation""" + text_encoder = self.module.module.module.text_model + with torch.no_grad(): + zeroshot_weights = [] + for texts in self.imagenet_val["texts"]: + texts = texts.cuda(non_blocking=True) + with torch.cuda.amp.autocast( + enabled=True, + dtype=torch.bfloat16, + ): + class_embeddings = text_encoder(texts) + class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0) + class_embedding /= class_embedding.norm() + zeroshot_weights.append(class_embedding) + zeroshot_weights = torch.stack(zeroshot_weights, dim=1) + return zeroshot_weights + + def zero_shot_eval(self): + """Zero shot evaluation for imagenet validation""" + + def accuracy(output, target, topk=(1,)): + pred = output.topk(max(topk), 1, True, True)[1].t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk] + + logging.info('Starting zero-shot imagenet.') + + logging.info('Building zero-shot classifier') + classifier = self.zero_shot_classifier() + + logging.info('Using classifier') + + vision_encoder = self.module.module.module.vision_model + + with torch.no_grad(): + top1, top5, n = 0.0, 0.0, 0.0 + for images, target in tqdm(self.imagenet_val["images"], desc="Imagenet Zero-shot Evaluation", leave=False): + if images is None or target is None: + continue + + images = images.cuda(non_blocking=True).to(torch.bfloat16) + target = target.cuda(non_blocking=True) + + # predict + with torch.cuda.amp.autocast( + enabled=True, + dtype=torch.bfloat16, + ): + + image_features = vision_encoder(images) + image_features = F.normalize(image_features, dim=-1) + logits = 100.0 * image_features @ classifier + + # measure accuracy + acc1, acc5 = accuracy(logits, target, topk=(1, 5)) + top1 += acc1 + top5 += acc5 + n += images.size(0) + + logging.info('Finished zero-shot imagenet.') + top1 = top1 / n + top5 = top5 / n + return top1, top5 + + def on_validation_epoch_end(self): + """Run zero shot evaluation for imagenet validation""" + if self.imagenet_val is not None: + imagenet_metric = torch.zeros(2).cuda() + imagenet_metric[0], imagenet_metric[1] = self.zero_shot_eval() + imagenet_metric = average_losses_across_data_parallel_group(imagenet_metric) + self.log('imagenet_top1', imagenet_metric[0], prog_bar=True, rank_zero_only=True, batch_size=1) + self.log('imagenet_top5', imagenet_metric[1], prog_bar=True, rank_zero_only=True, batch_size=1) + + @property + def training_loss_reduction(self) -> ClipMegatronLoss: + # pylint: disable=C0116 + if not self._training_loss_reduction: + self._training_loss_reduction = ClipMegatronLoss() + + return self._training_loss_reduction + + @property + def validation_loss_reduction(self) -> ClipMegatronLoss: + # pylint: disable=C0116 + if not self._validation_loss_reduction: + self._validation_loss_reduction = ClipMegatronLoss() + + return self._validation_loss_reduction diff --git a/nemo/collections/vlm/clip/model/clip.py b/nemo/collections/vlm/clip/model/clip.py new file mode 100644 index 000000000000..b7bb09cfe84c --- /dev/null +++ b/nemo/collections/vlm/clip/model/clip.py @@ -0,0 +1,437 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass, field +from pathlib import Path + +import torch +import torch.distributed +from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer + +from nemo.collections.nlp.modules.common.megatron.utils import ApproxGELUActivation +from nemo.collections.vlm.clip.model import ClipConfig, CLIPModel, CLIPTextModelConfig, CLIPViTConfig +from nemo.lightning import io, teardown + + +@dataclass +class CLIPViTL_14_224_Config(CLIPViTConfig): + """Clip vit large patch14 config""" + + # Will handle it later + vision_model_type: str = "clip" + patch_dim: int = 14 + img_h: int = 224 + img_w: int = 224 + num_layers: int = 12 + num_attention_heads: int = 12 + hidden_size: int = 768 + hidden_dropout: float = 0.0 + attention_dropout: float = 0.0 + ffn_hidden_size: int = 3072 + gated_linear_unit: bool = False + kv_channels: int = 64 + layernorm_zero_centered_gamma: bool = False + apply_query_key_layer_scaling: bool = False + bias_activation_fusion: bool = False + bias_dropout_fusion: bool = True + attention_softmax_in_fp32: bool = False + normalization: str = 'LayerNorm' + apply_rope_fusion: bool = False + masked_softmax_fusion: bool = True + persist_layer_norm: bool = True + + +@dataclass +class CLIPViTB_32_224_Config(CLIPViTConfig): + """Clip vit large patch14 config""" + + vision_model_type: str = "clip" + patch_dim: int = 32 + img_h: int = 224 + img_w: int = 224 + num_layers: int = 12 + num_attention_heads: int = 12 + hidden_size: int = 768 + hidden_dropout: float = 0.0 + attention_dropout: float = 0.0 + ffn_hidden_size: int = 3072 + gated_linear_unit: bool = False + kv_channels: int = None + class_token_len: int = 7 + init_method_std: float = 0.02 + layernorm_zero_centered_gamma: bool = False + apply_query_key_layer_scaling: bool = False + bias_activation_fusion: bool = False + bias_dropout_fusion: bool = True + attention_softmax_in_fp32: bool = False + normalization: str = 'LayerNorm' + apply_rope_fusion: bool = False + masked_softmax_fusion: bool = True + persist_layer_norm: bool = True + + +@dataclass +class CLIPTextModelB_32_224_Config(CLIPTextModelConfig): + """Clip text model Base config""" + + # model architecture + max_seq_length: int = 80 + max_position_embeddings: int = 80 + num_layers: int = 12 + hidden_size: int = 512 + ffn_hidden_size: int = 2048 # Transformer FFN hidden size. Usually 4 * hidden_size. + num_attention_heads: int = 8 + init_method_std: float = ( + 0.02 # Standard deviation of the zero mean normal distribution used for weight initialization.') + ) + use_scaled_init_method: bool = True # use scaled residuals initialization + hidden_dropout: float = 0.0 # Dropout probability for hidden state transformer. + attention_dropout: float = 0.0 + apply_query_key_layer_scaling: bool = False # scale Q * K^T by 1 / layer-number. + attention_softmax_in_fp32: bool = False + normalization: bool = "LayerNorm" + do_layer_norm_weight_decay: bool = False # True means weight decay on all params + + persist_layer_norm: bool = True # Use of persistent fused layer norm kernel. + masked_softmax_fusion: bool = True + bias_dropout_fusion: bool = True + bias_activation_fusion: False + + +@dataclass +class CLIPTextModelL_14_224_Config(CLIPTextModelConfig): + """Clip text model large config""" + + # model architecture + max_seq_length: int = 77 + max_position_embeddings: int = 77 + num_layers: int = 12 + hidden_size: int = 512 + ffn_hidden_size: int = 2048 # Transformer FFN hidden size. Usually 4 * hidden_size. + num_attention_heads: int = 8 + init_method_std: float = ( + 0.02 # Standard deviation of the zero mean normal distribution used for weight initialization.') + ) + use_scaled_init_method: bool = True # use scaled residuals initialization + hidden_dropout: float = 0.0 # Dropout probability for hidden state transformer. + attention_dropout: float = 0.0 + apply_query_key_layer_scaling: bool = False # scale Q * K^T by 1 / layer-number. + do_layer_norm_weight_decay: bool = False # True means weight decay on all params + + persist_layer_norm: bool = True # Use of persistent fused layer norm kernel. + masked_softmax_fusion: bool = True + bias_dropout_fusion: bool = True + + +@dataclass +class ClipConfigL14(ClipConfig): + """Main Clip config for Large model""" + + text_transformer_config: CLIPTextModelConfig = field(default_factory=lambda: CLIPTextModelL_14_224_Config()) + vision_transformer_config: CLIPViTConfig = field(default_factory=lambda: CLIPViTL_14_224_Config()) + + +@dataclass +class ClipConfigB32(ClipConfig): + """Main Clip config for Base model""" + + text_transformer_config: CLIPTextModelConfig = field(default_factory=lambda: CLIPTextModelB_32_224_Config()) + vision_transformer_config: CLIPViTConfig = field(default_factory=lambda: CLIPViTB_32_224_Config()) + + +@io.model_importer(CLIPModel, "hf") +class HFClipImporter(io.ModelConnector["CLIPModel", CLIPModel]): + """Import model from Hugging Face""" + + def init(self) -> CLIPModel: + # pylint: disable=C0116 + return CLIPModel(self.config, tokenizer=self.tokenizer) + + def apply(self, output_path: Path) -> Path: + # pylint: disable=C0116 + from transformers import CLIPModel + + # Get source model from HF + source = CLIPModel.from_pretrained(str(self)) + target = self.init() + trainer = self.nemo_setup(target) + + # Convert both to bfloat16 + target = target.to(torch.bfloat16) + source = source.to(torch.bfloat16) + self.convert_state(source, target) + + print(f"Converted Clip model to Nemo, saving to {output_path}") + + self.nemo_save(output_path, trainer) + + print(f"Converted Clip model saved to {output_path}") + + teardown(trainer, target) + del trainer, target + + return output_path + + def convert_state(self, source, target, image_newline=False): + # pylint: disable=C0116, line-too-long + # Start with the heads + mapping = { + 'text_projection.weight': "text_model.head.weight", + 'visual_projection.weight': 'vision_model.head.weight', + } + + mapping.update( + { + "text_model.embeddings.token_embedding.weight": "text_model.embedding.word_embeddings.weight", + "text_model.embeddings.position_embedding.weight": "text_model.embedding.position_embeddings.weight", + "text_model.final_layer_norm.weight": "text_model.final_layernorm.weight", + "text_model.final_layer_norm.bias": "text_model.final_layernorm.bias", + "vision_model.embeddings.class_embedding": "vision_model.class_token", + "vision_model.embeddings.patch_embedding.weight": "vision_model.conv1.weight", + "vision_model.embeddings.position_embedding.weight": "vision_model.position_embeddings.weight", + "vision_model.pre_layrnorm.weight": "vision_model.ln_pre.weight", + "vision_model.pre_layrnorm.bias": "vision_model.ln_pre.bias", + "vision_model.post_layernorm.weight": "vision_model.final_layernorm.weight", + "vision_model.post_layernorm.bias": "vision_model.final_layernorm.bias", + "text_model.encoder.layers.*.self_attn.out_proj.weight": "text_model.decoder.layers.*.self_attention.linear_proj.weight", + "text_model.encoder.layers.*.self_attn.out_proj.bias": "text_model.decoder.layers.*.self_attention.linear_proj.bias", + "text_model.encoder.layers.*.layer_norm1.weight": "text_model.decoder.layers.*.self_attention.linear_qkv.layer_norm_weight", + "text_model.encoder.layers.*.layer_norm1.bias": "text_model.decoder.layers.*.self_attention.linear_qkv.layer_norm_bias", + "text_model.encoder.layers.*.mlp.fc1.weight": "text_model.decoder.layers.*.mlp.linear_fc1.weight", + "text_model.encoder.layers.*.mlp.fc1.bias": "text_model.decoder.layers.*.mlp.linear_fc1.bias", + "text_model.encoder.layers.*.mlp.fc2.weight": "text_model.decoder.layers.*.mlp.linear_fc2.weight", + "text_model.encoder.layers.*.mlp.fc2.bias": "text_model.decoder.layers.*.mlp.linear_fc2.bias", + "text_model.encoder.layers.*.layer_norm2.weight": "text_model.decoder.layers.*.mlp.linear_fc1.layer_norm_weight", + "text_model.encoder.layers.*.layer_norm2.bias": "text_model.decoder.layers.*.mlp.linear_fc1.layer_norm_bias", + "vision_model.encoder.layers.*.self_attn.out_proj.weight": "vision_model.decoder.layers.*.self_attention.linear_proj.weight", + "vision_model.encoder.layers.*.self_attn.out_proj.bias": "vision_model.decoder.layers.*.self_attention.linear_proj.bias", + "vision_model.encoder.layers.*.layer_norm1.weight": "vision_model.decoder.layers.*.self_attention.linear_qkv.layer_norm_weight", + "vision_model.encoder.layers.*.layer_norm1.bias": "vision_model.decoder.layers.*.self_attention.linear_qkv.layer_norm_bias", + "vision_model.encoder.layers.*.mlp.fc1.weight": "vision_model.decoder.layers.*.mlp.linear_fc1.weight", + "vision_model.encoder.layers.*.mlp.fc1.bias": "vision_model.decoder.layers.*.mlp.linear_fc1.bias", + "vision_model.encoder.layers.*.mlp.fc2.weight": "vision_model.decoder.layers.*.mlp.linear_fc2.weight", + "vision_model.encoder.layers.*.mlp.fc2.bias": "vision_model.decoder.layers.*.mlp.linear_fc2.bias", + "vision_model.encoder.layers.*.layer_norm2.weight": "vision_model.decoder.layers.*.mlp.linear_fc1.layer_norm_weight", + "vision_model.encoder.layers.*.layer_norm2.bias": "vision_model.decoder.layers.*.mlp.linear_fc1.layer_norm_bias", + } + ) + # pylint: enable=line-too-long + + return io.apply_transforms( + source, + target, + mapping=mapping, + transforms=[ + _import_cls_token, + _import_vision_qkv_bias, + _import_vision_qkv, + _import_language_qkv_bias, + _import_language_qkv, + ], + ) + + @property + def tokenizer(self) -> "AutoTokenizer": + # pylint: disable=C0116 + + return AutoTokenizer(str(self)) + + @property + def config(self) -> ClipConfig: + # pylint: disable=C0116 + from transformers import CLIPConfig as HFCLIPConfig + + source = HFCLIPConfig.from_pretrained(str(self)) + + text_conifg = source.text_config + + def make_vocab_size_divisible_by(vocab_size): + base = 128 + while vocab_size % base != 0: + base //= 2 + return base + + language_transformer_config = CLIPTextModelConfig( + output_dim=text_conifg.projection_dim, + num_layers=text_conifg.num_hidden_layers, + hidden_size=text_conifg.hidden_size, + ffn_hidden_size=text_conifg.intermediate_size, + num_attention_heads=text_conifg.num_attention_heads, + init_method_std=text_conifg.initializer_range, + layernorm_epsilon=text_conifg.layer_norm_eps, + gated_linear_unit=False, + make_vocab_size_divisible_by=make_vocab_size_divisible_by(text_conifg.vocab_size), + share_embeddings_and_output_weights=False, + attention_dropout=text_conifg.attention_dropout, + hidden_dropout=text_conifg.dropout, + activation_func=ApproxGELUActivation, + max_seq_length=text_conifg.max_position_embeddings, + apply_query_key_layer_scaling=False, + ) + + vision_config = source.vision_config + + vision_transformer_config = CLIPViTConfig( + vision_model_type="clip", + patch_dim=vision_config.patch_size, + img_h=vision_config.image_size, + img_w=vision_config.image_size, + num_layers=vision_config.num_hidden_layers, + num_attention_heads=vision_config.num_attention_heads, + hidden_size=vision_config.hidden_size, + hidden_dropout=vision_config.dropout, + attention_dropout=vision_config.attention_dropout, + ffn_hidden_size=vision_config.intermediate_size, + gated_linear_unit=False, # TODO (ask Yao, This was False in the config) Does he knows if they use GLU? + apply_query_key_layer_scaling=False, + activation_func=ApproxGELUActivation, + output_dim=vision_config.projection_dim, + init_method_std=vision_config.initializer_range, + layernorm_epsilon=vision_config.layer_norm_eps, + # HF only uses one class token + class_token_len=1, + ) + + output = ClipConfig( + text_transformer_config=language_transformer_config, vision_transformer_config=vision_transformer_config + ) + + return output + + +def import_qkv(q, k, v, head_num, num_query_groups, heads_per_group, hidden_size, head_size): + # pylint: disable=C0116 + old_tensor_shape = q.size() + new_q_tensor_shape = (head_num, head_size) + old_tensor_shape[1:] + new_kv_tensor_shape = (num_query_groups, head_size) + old_tensor_shape[1:] + + q = q.view(*new_q_tensor_shape) + k = k.view(*new_kv_tensor_shape) + v = v.view(*new_kv_tensor_shape) + + qkv_weights_l = [] + for i in range(num_query_groups): + qkv_weights_l.append(q[i * heads_per_group : (i + 1) * heads_per_group, :, :]) + qkv_weights_l.append(k[i : i + 1, :, :]) + qkv_weights_l.append(v[i : i + 1, :, :]) + qkv_weights = torch.cat(qkv_weights_l) + assert qkv_weights.ndim == 3, qkv_weights.shape + assert qkv_weights.shape[0] == (heads_per_group + 2) * num_query_groups, qkv_weights.shape + assert qkv_weights.shape[1] == head_size, qkv_weights.shape + assert qkv_weights.shape[2] == old_tensor_shape[1], qkv_weights.shape + + qkv_weights = qkv_weights.reshape([head_size * (head_num + 2 * num_query_groups), hidden_size]) + + return qkv_weights + + +@io.state_transform( + source_key=( + "vision_model.encoder.layers.*.self_attn.q_proj.weight", + "vision_model.encoder.layers.*.self_attn.k_proj.weight", + "vision_model.encoder.layers.*.self_attn.v_proj.weight", + ), + target_key="vision_model.decoder.layers.*.self_attention.linear_qkv.weight", +) +def _import_vision_qkv(ctx: io.TransformCTX, q, k, v): + # pylint: disable=C0116 + megatron_config = ctx.target.config.vision_transformer_config + return import_qkv( + q, + k, + v, + head_num=megatron_config.num_attention_heads, + num_query_groups=megatron_config.num_query_groups, + heads_per_group=megatron_config.num_attention_heads // megatron_config.num_query_groups, + hidden_size=megatron_config.hidden_size, + head_size=megatron_config.kv_channels, + ) + + +@io.state_transform( + source_key=( + "vision_model.encoder.layers.*.self_attn.q_proj.bias", + "vision_model.encoder.layers.*.self_attn.k_proj.bias", + "vision_model.encoder.layers.*.self_attn.v_proj.bias", + ), + target_key="vision_model.decoder.layers.*.self_attention.linear_qkv.bias", +) +def _import_vision_qkv_bias(ctx: io.TransformCTX, q_bias, k_bias, v_bias): + # pylint: disable=C0116 + megatron_config = ctx.target.config.vision_transformer_config + return import_qkv( + q_bias.unsqueeze(-1), + k_bias.unsqueeze(-1), + v_bias.unsqueeze(-1), + head_num=megatron_config.num_attention_heads, + num_query_groups=megatron_config.num_query_groups, + heads_per_group=megatron_config.num_attention_heads // megatron_config.num_query_groups, + hidden_size=1, + head_size=megatron_config.kv_channels, + ).squeeze(-1) + + +@io.state_transform( + source_key=( + "text_model.encoder.layers.*.self_attn.q_proj.bias", + "text_model.encoder.layers.*.self_attn.k_proj.bias", + "text_model.encoder.layers.*.self_attn.v_proj.bias", + ), + target_key="text_model.decoder.layers.*.self_attention.linear_qkv.bias", +) +def _import_language_qkv_bias(ctx: io.TransformCTX, q_bias, k_bias, v_bias): + # pylint: disable=C0116 + megatron_config = ctx.target.config.text_transformer_config + return import_qkv( + q_bias.unsqueeze(-1), + k_bias.unsqueeze(-1), + v_bias.unsqueeze(-1), + head_num=megatron_config.num_attention_heads, + num_query_groups=megatron_config.num_query_groups, + heads_per_group=megatron_config.num_attention_heads // megatron_config.num_query_groups, + hidden_size=1, + head_size=megatron_config.kv_channels, + ).squeeze(-1) + + +@io.state_transform( + source_key=( + "text_model.encoder.layers.*.self_attn.q_proj.weight", + "text_model.encoder.layers.*.self_attn.k_proj.weight", + "text_model.encoder.layers.*.self_attn.v_proj.weight", + ), + target_key="text_model.decoder.layers.*.self_attention.linear_qkv.weight", +) +def _import_language_qkv(ctx: io.TransformCTX, q, k, v): + # pylint: disable=C0116 + megatron_config = ctx.target.config.text_transformer_config + + return import_qkv( + q, + k, + v, + head_num=megatron_config.num_attention_heads, + num_query_groups=megatron_config.num_query_groups, + heads_per_group=megatron_config.num_attention_heads // megatron_config.num_query_groups, + hidden_size=megatron_config.hidden_size, + head_size=megatron_config.kv_channels, + ) + + +@io.state_transform( + source_key=("vision_model.embeddings.class_embedding",), + target_key="vision_model.class_token", +) +def _import_cls_token(ctx: io.TransformCTX, cls_token): + # pylint: disable=C0116 + return cls_token.reshape(1, 1, -1) diff --git a/nemo/collections/vlm/recipes/__init__.py b/nemo/collections/vlm/recipes/__init__.py index e3225dec8c4f..fc1642224f59 100644 --- a/nemo/collections/vlm/recipes/__init__.py +++ b/nemo/collections/vlm/recipes/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. -from nemo.collections.vlm.recipes import llava15_7b, llava15_13b, llava_next_7b, mllama_11b, mllama_90b +from nemo.collections.vlm.recipes import clip_b32, llava15_7b, llava15_13b, llava_next_7b, mllama_11b, mllama_90b __all__ = [ "llava15_7b", @@ -21,4 +21,5 @@ "mllama_11b", "mllama_90b", "llava_next_7b", + "clip_b32", ] diff --git a/nemo/collections/vlm/recipes/clip_b32.py b/nemo/collections/vlm/recipes/clip_b32.py new file mode 100644 index 000000000000..f2f52043a618 --- /dev/null +++ b/nemo/collections/vlm/recipes/clip_b32.py @@ -0,0 +1,139 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Optional + +import lightning.pytorch as pl +import nemo_run as run +import torch + +from nemo import lightning as nl +from nemo.collections import llm, vlm +from nemo.collections.llm.recipes.log.default import tensorboard_logger +from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing +from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed +from nemo.utils.exp_manager import TimingCallback + +NAME = "clip_b32" + + +@run.cli.factory(name=NAME) +def model() -> run.Config[pl.LightningModule]: + """ + Factory function to create a Clip B32 model configuration. + + Returns: + run.Co nfig[pl.LightningModule]: Configuration for the Clip B32 model. + + Examples: + CLI usage: + $ nemo llm pretrain model=clip_b32 ... + + Python API usage: + >>> model_config = model() + >>> print(model_config) + """ + return run.Config(vlm.CLIPModel, config=run.Config(vlm.ClipConfigB32)) + + +@run.cli.factory(target=llm.pretrain, name=NAME) +def pretrain_recipe( + dir: Optional[str] = None, + name: str = "default", + num_nodes: int = 1, + num_gpus_per_node: int = 8, +) -> run.Partial: + """ + Create a fine-tuning recipe for Clip B32 model. + + This function sets up a complete configuration for fine-tuning, including + model, trainer, data, logging, optimization, and resumption settings. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the fine-tuning run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + + Returns: + run.Partial: Partial configuration for fine-tuning. + + Examples: + CLI usage: + $ nemo llm finetune --factory clip_b32 + + Python API usage: + >>> recipe = finetune_recipe(name="clip_b32", num_nodes=1) + >>> print(recipe) + + Note: + This recipe uses the Mock dataset for fine-tuning. For more information + on fine-tuning LLMs with NeMo, see the fine-tuning guide in the + `scripts/vlm/` directory. + """ + + strategy = run.Config( + nl.MegatronStrategy, + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + encoder_pipeline_model_parallel_size=0, + pipeline_dtype=torch.bfloat16, + ) + + trainer = run.Config( + nl.Trainer, + accelerator="gpu", + devices=num_gpus_per_node, + limit_val_batches=10, + log_every_n_steps=1, + max_steps=5000, + num_nodes=num_nodes, + plugins=bf16_mixed(), + strategy=strategy, + val_check_interval=1000, + callbacks=[ + run.Config(TimingCallback), + ], + ) + + recipe = run.Partial( + llm.pretrain, + model=model(), + trainer=trainer, + data=run.Config( + vlm.ClipMockDataModule, + seq_length=80, + global_batch_size=128, + micro_batch_size=2, + tokenizer=None, + image_processor=None, + num_workers=4, + ), + log=llm.default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), + optim=distributed_fused_adam_with_cosine_annealing( + max_lr=1e-3, + min_lr=1e-5, + warmup_steps=2000, + adam_beta1=0.9, + adam_beta2=0.98, + ), + resume=run.Config( + nl.AutoResume, + resume_if_exists=False, + resume_ignore_no_checkpoint=True, + ), + ) + + return recipe diff --git a/requirements/requirements_multimodal.txt b/requirements/requirements_multimodal.txt index 585e277be72a..a9c73870f644 100644 --- a/requirements/requirements_multimodal.txt +++ b/requirements/requirements_multimodal.txt @@ -6,7 +6,7 @@ diffusers>=0.19.3 einops_exts imageio kornia -megatron-energon==4.0.0 +megatron-energon==5.1.0 nerfacc>=0.5.3 open_clip_torch==2.24.0 PyMCubes diff --git a/scripts/vlm/clip_infer.py b/scripts/vlm/clip_infer.py new file mode 100644 index 000000000000..d532ce4ccd73 --- /dev/null +++ b/scripts/vlm/clip_infer.py @@ -0,0 +1,155 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage example: + python scripts/vlm/clip_infer.py \ + --image_url https://upload.wikimedia.org/wikipedia/commons/0/0f/1665_Girl_with_a_Pearl_Earring.jpg \ + --hf_path hf://openai/clip-vit-large-patch14 \ + --classes "a dog" "a boy" "a girl" + + + +It should generate a high probability for "a girl" tag, e.g. +Nemo: CLIP text probability: [('a dog', 0.0048940657), ('a boy', 0.002311793), ('a girl', 0.9927942)] +HF: CLIP text probability: [('a dog', 0.0048940657), ('a boy', 0.002311793), ('a girl', 0.9927942)] +""" +import argparse +import os + +import requests +import torch +from PIL import Image +from transformers import AutoProcessor +from transformers import CLIPModel as HFCLIPModel + +import nemo.lightning as nl +from nemo.collections.vlm import CLIPModel + + +def load_image(image_path: str) -> Image.Image: + """ + Load an image from a URL or local file path. + + Args: + image_path (str): The URL or local path to the image. + + Returns: + Image.Image: The loaded PIL image object, or None if loading fails. + """ + try: + if os.path.exists(image_path): # Check if it's a local file path + image = Image.open(image_path) + else: # Assume it's a remote URL + response = requests.get(image_path, stream=True) + response.raise_for_status() + image = Image.open(response.raw) + return image + except (requests.exceptions.RequestException, FileNotFoundError, IOError) as e: + print(f"Error loading image from {image_path}: {e}") + return None + + +def main(args) -> None: + # pylint: disable=C0115,C0116 + strategy = nl.MegatronStrategy( + tensor_model_parallel_size=1, + ckpt_include_optimizer=False, + ckpt_save_optimizer=False, + ) + + trainer = nl.Trainer( + devices=1, + max_steps=1000, + accelerator="gpu", + strategy=strategy, + plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"), + val_check_interval=1000, + limit_val_batches=50, + ) + + hf_repo = args.hf_path.split("//")[1] + processor = AutoProcessor.from_pretrained(hf_repo) + max_length = processor.tokenizer.model_max_length + + # Load the image + raw_image = load_image(args.image_url) + if raw_image is None: + return # Exit if the image can't be loaded + + fabric = trainer.to_fabric() + model = fabric.import_model(args.hf_path, CLIPModel) + model = model.module.cuda() + + # Freeze the models, We have a few nesting in the model + vision_model = model.module.module.module.vision_model.eval() + text_model = model.module.module.module.text_model.eval() + + with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16): + # %% Zero-shot classification + classes = args.classes + + inputs = processor( + text=classes, + images=[raw_image], + return_tensors="pt", + truncation=True, # Truncate if the sentence is longer than max_seq_length + padding='max_length', # Pad to max_seq_length + max_length=max_length, + ) + + inputs = {key: value.to("cuda") for key, value in inputs.items()} + + model_hf = HFCLIPModel.from_pretrained(hf_repo) + model_hf = model_hf.to("cuda") + output_hf = model_hf(**inputs) + + image_embeds_nemo = vision_model(inputs["pixel_values"].cuda().to(torch.bfloat16)) + image_embeds_hf = output_hf["image_embeds"] + + text_embeds_nemo = text_model(inputs["input_ids"].cuda()) + text_embeds_hf = output_hf["text_embeds"] + + image_embeds_nemo /= image_embeds_nemo.norm(dim=-1, keepdim=True) + text_embeds_nemo /= text_embeds_nemo.norm(dim=-1, keepdim=True) + + nemo_probs = (100.0 * image_embeds_nemo @ text_embeds_nemo.T).softmax(dim=-1) + hf_probs = (100.0 * image_embeds_hf @ text_embeds_hf.T).softmax(dim=-1) + + print(f"Nemo: CLIP text probability: ", list(zip(classes, nemo_probs[0].cpu().numpy()))) + print(f"HF: CLIP text probability: ", list(zip(classes, hf_probs[0].cpu().numpy()))) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Clip Verification Script") + parser.add_argument( + "--image_url", + type=str, + default="1665_Girl_with_a_Pearl_Earring.jpg", + help="URL of the image to use for inference.", + ) + + parser.add_argument( + "--hf_path", + type=str, + default="hf://openai/clip-vit-large-patch14", + help="Path to the Huggingface model.", + ) + + parser.add_argument( + '--classes', nargs='+', type=str, help="Classes for texts", default=["a dog", "a boy", "a girl"] + ) + args = parser.parse_args() + + main(args) diff --git a/scripts/vlm/clip_pretrain.py b/scripts/vlm/clip_pretrain.py new file mode 100644 index 000000000000..1e0dc3944c04 --- /dev/null +++ b/scripts/vlm/clip_pretrain.py @@ -0,0 +1,198 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Example: + python scripts/vlm/clip_pretrain.py \ + --data_type=mock +""" + +from PIL import ImageFile + +ImageFile.LOAD_TRUNCATED_IMAGES = True +import argparse +import os + +import torch +from megatron.core.optimizer import OptimizerConfig +from pytorch_lightning.loggers import WandbLogger + +from nemo import lightning as nl +from nemo.collections import llm, vlm +from nemo.collections.multimodal.data.energon.base import EnergonMultiModalDataModule +from nemo.collections.vlm.clip.data.clip_data_module import ClipTaskEncoder +from nemo.collections.vlm.clip.model import ClipConfigB32, CLIPModel +from nemo.lightning.pytorch.optim import CosineAnnealingScheduler +from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule +from nemo.utils.exp_manager import TimingCallback + + +def main(args): + max_steps = args.max_steps + + train_task_encoder = ClipTaskEncoder(max_length=args.decoder_seq_length) + valid_task_encoder = ClipTaskEncoder(max_length=args.decoder_seq_length, is_train=False) + if args.data_type == "energon": + data = EnergonMultiModalDataModule( + args.data_path, + seq_length=args.decoder_seq_length, + image_processor=None, + micro_batch_size=args.mbs, + global_batch_size=args.gbs, + num_workers=args.num_workers, + task_encoder=train_task_encoder, + tokenizer=train_task_encoder.tokenizer, + validation_task_encoder=valid_task_encoder, + image_decode="pil", + ignore_decoder_errors=True, + ) + elif args.data_type == "mock": + data = vlm.ClipMockDataModule( + seq_length=args.decoder_seq_length, + global_batch_size=args.gbs, + micro_batch_size=args.mbs, + tokenizer=None, + num_train_samples=10_000_000_000, + image_processor=None, + num_workers=8, + ) + else: + raise ValueError(f"Data type {args.data_type} not supported") + + model = CLIPModel( + ClipConfigB32(), + tokenizer=train_task_encoder.tokenizer, + imagenet_val=args.imagenet_val, + mbs=args.mbs, + gbs=args.gbs, + max_workers=8, + ) + + # Training strategy setup + strategy = nl.MegatronStrategy( + tensor_model_parallel_size=args.tp_size, + pipeline_model_parallel_size=args.pp_size, + encoder_pipeline_model_parallel_size=args.encoder_pp_size, + pipeline_dtype=torch.bfloat16, + ) + + # Checkpoint callback setup + checkpoint_callback = nl.ModelCheckpoint( + save_last="link", + monitor="reduced_train_loss", + save_top_k=2, + every_n_train_steps=2000, + dirpath=os.path.join(args.log_dir, args.name), + ) + + # Trainer setup + trainer = nl.Trainer( + num_nodes=args.num_nodes, + devices=args.devices, + max_steps=max_steps, + accelerator="gpu", + strategy=strategy, + plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"), + callbacks=[checkpoint_callback, TimingCallback()], + val_check_interval=2000, + check_val_every_n_epoch=None, + limit_val_batches=1, # We limit validation batches as we are using imagenet validation set + log_every_n_steps=10, + num_sanity_val_steps=0, + ) + + # Logger setup + nemo_logger = nl.NeMoLogger( + log_dir=args.log_dir, + name=args.name, + wandb=WandbLogger(project=args.wandb_project, name=args.name) if args.wandb_project is not None else None, + ) + + # Auto resume setup + resume = nl.AutoResume( + resume_if_exists=False, + resume_ignore_no_checkpoint=True, + resume_from_directory=os.path.join(args.log_dir, args.name), + restore_config=nl.RestoreConfig(path=args.restore_path) if args.restore_path is not None else None, + ) + + # Optimizer and scheduler setup + opt_config = OptimizerConfig( + optimizer='adam', + lr=1e-3, + adam_beta1=0.9, + adam_beta2=0.98, + weight_decay=0.2, + ) + + sched = CosineAnnealingScheduler( + max_steps=trainer.max_steps, + warmup_steps=2000, + constant_steps=0, + min_lr=1e-5, + ) + opt = MegatronOptimizerModule( + opt_config, + sched, + ) + + llm.pretrain( + model=model, + data=data, + trainer=trainer, + log=nemo_logger, + optim=opt, + resume=resume, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Clip Model Training Script") + + # Argument parsing + parser.add_argument("--data_type", type=str, required=False, default="energon", help="mock | energon") + parser.add_argument("--data_path", type=str, required=False, default=None, help="Path to the dataset director") + + parser.add_argument( + "--log_dir", type=str, required=False, default="/results", help="Directory for logging and checkpoints" + ) + + parser.add_argument( + "--restore_path", type=str, required=False, default=None, help="Path to restore model from checkpoint" + ) + + parser.add_argument("--mbs", type=int, required=False, default=32, help="Micro batch size") + parser.add_argument("--gbs", type=int, required=False, default=64, help="Global batch size") + parser.add_argument( + "--decoder_seq_length", + type=int, + required=False, + default=80, + help="Decoder" " sequence length for encoding the input text", + ) + parser.add_argument("--devices", type=int, required=False, default=1) + parser.add_argument("--num_nodes", type=int, required=False, default=1) + parser.add_argument("--num_workers", type=int, required=False, default=8) + + parser.add_argument("--max_steps", type=int, required=False, default=375000) + parser.add_argument("--tp_size", type=int, required=False, default=1) + parser.add_argument("--pp_size", type=int, required=False, default=1) + parser.add_argument("--encoder_pp_size", type=int, required=False, default=0) + parser.add_argument("--name", type=str, required=False, default="clip_pretrain") + parser.add_argument("--wandb_project", type=str, required=False, default=None) + parser.add_argument("--lr", type=float, required=False, default=2.0e-06, help="Learning rate") + parser.add_argument("--imagenet_val", type=str, required=False, default=None, help="imagenet path for val") + + args = parser.parse_args() + main(args)