This guide covers advanced topics for extending MatterTune with custom components. You'll learn how to create custom backbones (model architectures) and datasets.
To implement a custom backbone, you need to create two classes:
- A configuration class inheriting from
FinetuneModuleBaseConfig
- A model class inheriting from
FinetuneModuleBase
Here's the basic template:
from typing import Literal
from typing_extensions import override
import mattertune as mt
@mt.backbone_registry.register
class MyBackboneConfig(mt.FinetuneModuleBaseConfig):
name: Literal["my_backbone"] = "my_backbone"
# Add your configuration parameters
hidden_size: int
num_layers: int
@override
def create_model(self):
return MyBackboneModule(self)
@override
@classmethod
def ensure_dependencies(cls):
# Check for required packages
pass
class MyBackboneModule(mt.FinetuneModuleBase["MyData", "MyBatch", MyBackboneConfig]):
@override
@classmethod
def hparams_cls(cls):
return MyBackboneConfig
# Your backbone module must implement these abstract methods:
@override
def create_model(self):
"""Initialize your model architecture here"""
pass
@override
def model_forward(self, batch, return_backbone_output=False):
"""Forward pass implementation"""
pass
@override
def pretrained_backbone_parameters(self):
"""Return backbone parameters"""
pass
@override
def output_head_parameters(self):
"""Return output head parameters"""
pass
@override
def cpu_data_transform(self, data):
"""Transform data on CPU before batching"""
pass
@override
def collate_fn(self, data_list):
"""Combine data samples into a batch"""
pass
@override
def gpu_batch_transform(self, batch):
"""Transform batch on GPU before forward pass"""
pass
@override
def batch_to_labels(self, batch):
"""Extract ground truth labels from batch"""
pass
@override
def atoms_to_data(self, atoms, has_labels):
"""Convert ASE Atoms to your data format"""
pass
@override
def create_normalization_context_from_batch(self, batch):
"""Create context for property normalization"""
pass
For detailed examples of backbone implementations, please refer to the existing backbone implementations in the MatterTune source code:
- JMP backbone:
mattertune/backbones/jmp/model.py
- EquiformerV2 backbone:
mattertune/backbones/eqV2/model.py
- M3GNet backbone:
mattertune/backbones/m3gnet/model.py
- ORB backbone:
mattertune/backbones/orb/model.py
Custom datasets require two classes:
- A configuration class inheriting from
DatasetConfigBase
- A dataset class inheriting from
Dataset[ase.Atoms]
Basic template:
from typing import Literal
from typing_extensions import override
import mattertune as mt
from torch.utils.data import Dataset
from ase import Atoms
@mt.data_registry.register
class MyDatasetConfig(mt.DatasetConfigBase):
type: Literal["my_dataset"] = "my_dataset"
# Add your configuration parameters
data_path: str
@override
def create_dataset(self):
return MyDataset(self)
class MyDataset(Dataset[Atoms]):
def __init__(self, config: MyDatasetConfig):
self.config = config
# Initialize your dataset
def __len__(self):
return len(self.data)
def __getitem__(self, idx) -> Atoms:
# Return an ASE Atoms object
return self.data[idx]
For detailed examples of dataset implementations, please refer to the existing dataset implementations in the MatterTune source code:
- XYZ dataset:
mattertune/data/xyz.py
- ASE database dataset:
mattertune/data/db.py
- Materials Project dataset:
mattertune/data/mp.py
- Matbench dataset:
mattertune/data/matbench.py
After implementing your custom components, you can use them in your training configuration:
config = mt.configs.MatterTunerConfig(
model=MyBackboneConfig(
hidden_size=256,
num_layers=4,
properties=[
mt.configs.EnergyPropertyConfig(...)
]
),
data=mt.configs.AutoSplitDataModuleConfig(
dataset=MyDatasetConfig(
data_path="path/to/data"
),
train_split=0.8,
batch_size=32
)
)
tuner = mt.MatterTuner(config)
model, trainer = tuner.tune()
- Type Hints: Always use proper type hints to catch errors early
- Documentation: Document your custom components thoroughly
- Error Handling: Provide clear error messages for configuration issues
- Testing: Write tests for your custom components
- Dependencies: Clearly document any additional dependencies
For more examples, check the source code of the built-in backbones and datasets in the MatterTune repository.