Skip to content

Commit

Permalink
Merge pull request #8 from soran-ghaderi/new-feature
Browse files Browse the repository at this point in the history
feat: Enhance energy function visualization and calculations
  • Loading branch information
soran-ghaderi authored Dec 25, 2024
2 parents adee159 + 58df859 commit ba93f8b
Show file tree
Hide file tree
Showing 9 changed files with 602 additions and 105 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tag-release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ jobs:
git tag $new_tag
git push origin $new_tag
# =========== it works until here ==================
# =========== it works until here ================== # fixme: this part does not release the package on pypi
release:
needs: update-tag
if: contains(github.event.head_commit.message, '#release')
Expand Down
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ __pycache__/
build/
develop-eggs/
dist/
dist
downloads/
eggs/
.eggs/
Expand Down Expand Up @@ -159,4 +160,5 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
#.idea/
/scope.txt
39 changes: 39 additions & 0 deletions examples/energy_fn_visualization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import numpy as np
import torch
from matplotlib import pyplot as plt

from torchebm.core.energy_function import RosenbrockEnergy, AckleyEnergy, RastriginEnergy, DoubleWellEnergy, \
GaussianEnergy, HarmonicEnergy

def plot_energy_function(energy_fn, x_range, y_range, title):
x = np.linspace(x_range[0], x_range[1], 100)
y = np.linspace(y_range[0], y_range[1], 100)
X, Y = np.meshgrid(x, y)
Z = np.zeros_like(X)

for i in range(X.shape[0]):
for j in range(X.shape[1]):
point = torch.tensor([X[i, j], Y[i, j]], dtype=torch.float32).unsqueeze(0)
Z[i, j] = energy_fn(point).item()

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(X, Y, Z, cmap='viridis')
ax.set_title(title)
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('Energy')
plt.show()

energy_functions = [
(RosenbrockEnergy(), [-2, 2], [-1, 3], 'Rosenbrock Energy Function'),
(AckleyEnergy(), [-5, 5], [-5, 5], 'Ackley Energy Function'),
(RastriginEnergy(), [-5, 5], [-5, 5], 'Rastrigin Energy Function'),
(DoubleWellEnergy(), [-2, 2], [-2, 2], 'Double Well Energy Function'),
(GaussianEnergy(torch.tensor([0.0, 0.0]), torch.tensor([[1.0, 0.0], [0.0, 1.0]])), [-3, 3], [-3, 3], 'Gaussian Energy Function'),
(HarmonicEnergy(), [-3, 3], [-3, 3], 'Harmonic Energy Function')
]

# Plot each energy function
for energy_fn, x_range, y_range, title in energy_functions:
plot_energy_function(energy_fn, x_range, y_range, title)
248 changes: 206 additions & 42 deletions examples/langevin_dynamics_sampling.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,225 @@

"""
Examples for using the Langevin Dynamics sampler.
"""

import torch
from torchebm.core.energy_function import EnergyFunction
import matplotlib.pyplot as plt
import numpy as np
from typing import Tuple

from torch.xpu import device

from torchebm.samplers.langevin_dynamics import LangevinDynamics
from matplotlib import pyplot as plt

class QuadraticEnergy(EnergyFunction):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return 0.5 * torch.sum(x**2, dim=-1)

def gradient(self, x: torch.Tensor) -> torch.Tensor:
return x
# ===================== Example 1 =====================

def basic_example():
"""
simple Langevin dynamics sampling from a 2D Gaussian distribution.
"""

# Define a simple 2D Gaussian energy function
class GaussianEnergy:
def __init__(self, mean: torch.Tensor, cov: torch.Tensor, device=None):
self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
self.mean = mean.to(self.device)
self.precision = torch.inverse(cov).to(self.device)

def __call__(self, x: torch.Tensor) -> torch.Tensor:
x = x.to(self.device)
delta = x - self.mean
return 0.5 * torch.einsum('...i,...ij,...j->...', delta, self.precision, delta)

def gradient(self, x: torch.Tensor) -> torch.Tensor:
x = x.to(self.device)
return torch.einsum('...ij,...j->...i', self.precision, x - self.mean)

def to(self, device):
self.device = device
self.mean = self.mean.to(device)
self.precision = self.precision.to(device)
return self

def plot_samples(samples, energy_function, title):
x = torch.linspace(-3, 3, 100)
y = torch.linspace(-3, 3, 100)
X, Y = torch.meshgrid(x, y)
Z = energy_function(torch.stack([X, Y], dim=-1)).numpy()
# Create energy function for a 2D Gaussian
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
mean = torch.tensor([1.0, -1.0])
cov = torch.tensor([[1.0, 0.5], [0.5, 2.0]])
energy_fn = GaussianEnergy(mean, cov, device=device)

plt.figure(figsize=(10, 8))
plt.contourf(X.numpy(), Y.numpy(), Z, levels=20, cmap='viridis')
plt.colorbar(label='Energy')
plt.scatter(samples[:, 0], samples[:, 1], c='red', alpha=0.5)
plt.title(title)
plt.xlabel('x')
plt.ylabel('y')
# Initialize sampler
sampler = LangevinDynamics(
energy_function=energy_fn,
step_size=0.01,
noise_scale=0.1,
device=device # Make sure to pass the same device
)

# Generate samples
initial_state = torch.zeros(2, device=device)
samples = sampler.sample_chain(
initial_state=initial_state,
n_steps=100, # steps between samples
n_samples=1000 # number of samples to collect
)

# Plot results
samples = samples.cpu().numpy()
plt.figure(figsize=(10, 5))
plt.scatter(samples[:, 0], samples[:, 1], alpha=0.1)
plt.title('Samples from 2D Gaussian using Langevin Dynamics')
plt.xlabel('x₁')
plt.ylabel('x₂')
plt.show()


def main():
try:
# Set up the energy function and sampler
energy_function = QuadraticEnergy()
sampler = LangevinDynamics(energy_function, step_size=0.1, noise_scale=0.1)
# ===================== Example 2 =====================

def advanced_example():
"""
Advanced example showing:
1. Custom energy functions
2. Parallel sampling
3. Diagnostics and trajectory tracking
4. Different initialization strategies
5. Handling multimodal distributions
"""

# Define a double-well potential
class DoubleWellEnergy:
def __init__(self, barrier_height: float = 2.0):
self.barrier_height = barrier_height

def __call__(self, x: torch.Tensor) -> torch.Tensor:
return self.barrier_height * (x.pow(2) - 1).pow(2)

def gradient(self, x: torch.Tensor) -> torch.Tensor:
return 4 * self.barrier_height * x * (x.pow(2) - 1)

def to(self, device):
self.device = device
return self

# Generate samples
initial_state = torch.tensor([2.0, 2.0])
n_steps = 1000
n_samples = 500
# Create energy function and sampler
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
energy_fn = DoubleWellEnergy(barrier_height=2.0)
sampler = LangevinDynamics(
energy_function=energy_fn,
step_size=0.001,
noise_scale=0.1,
decay=0.1, # for stability
device=device
)

samples = sampler.sample_chain(initial_state, n_steps, n_samples)
# 1. Generate trajectory with diagnostics
initial_state = torch.tensor([0.0], device=device)
trajectory, diagnostics = sampler.sample(
initial_state=initial_state,
n_steps=1000,
return_trajectory=True,
return_diagnostics=True
)

# Plot the results
plot_samples(samples, energy_function, "Langevin Dynamics Sampling")
# 2. Parallel sampling from multiple initial points
initial_states = torch.linspace(-2, 2, 10).unsqueeze(1)
parallel_samples, parallel_diagnostics = sampler.sample_parallel(
initial_states=initial_states,
n_steps=1000,
return_diagnostics=True
)

# Visualize a single trajectory
trajectory = sampler.sample(initial_state, n_steps, return_trajectory=True)
plot_samples(trajectory, energy_function, "Single Langevin Dynamics Trajectory")
if isinstance(parallel_diagnostics, list) and all(isinstance(diag, dict) for diag in parallel_diagnostics):
# Extract diagnostics safely
energies = [diag['energies'] for diag in parallel_diagnostics if 'energies' in diag]
else:
energies = None
# Plot results
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))

# Demonstrate parallel sampling
n_chains = 10
initial_states = torch.randn(n_chains, 2) * 2
parallel_samples = sampler.sample_parallel(initial_states, n_steps)
plot_samples(parallel_samples, energy_function, "Parallel Langevin Dynamics Sampling")
# Plot trajectory
# parallel_samples = parallel_samples.cpu().numpy()
# parallel_diagnostics = [diag['energies'] for diag in parallel_diagnostics]
ax1.plot(trajectory.cpu().numpy())
ax1.set_title('Single Chain Trajectory')
ax1.set_xlabel('Step')
ax1.set_ylabel('Position')

except ValueError as e:
print(f"Error: {e}")
# Plot energy over time
ax2.plot(diagnostics['energies'].cpu().numpy())
ax2.set_title('Energy Evolution')
ax2.set_xlabel('Step')
ax2.set_ylabel('Energy')

# Plot parallel chain results
parallel_samples_cpu = parallel_samples.cpu().numpy()
for sample in parallel_samples_cpu:
ax3.axhline(sample.item(), alpha=0.3, color='blue')
ax3.set_title('Final States of Parallel Chains')
ax3.set_ylabel('Position')

if 'mean_energies' in parallel_diagnostics:
ax2.plot(parallel_diagnostics['mean_energies'].cpu().numpy(),
linestyle='--', label='Parallel Mean Energy')
ax2.legend()

plt.tight_layout()
plt.show()

# ===================== Examples =====================

def sampling_utilities_example():
"""
Example demonstrating various utility features:
1. Chain thinning
2. Device management
3. Custom diagnostics
4. Convergence checking
"""

# Define simple harmonic oscillator
class HarmonicEnergy:
def __call__(self, x: torch.Tensor) -> torch.Tensor:
return 0.5 * x.pow(2)

def gradient(self, x: torch.Tensor) -> torch.Tensor:
return x

# Initialize sampler with GPU support if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
sampler = LangevinDynamics(
energy_function=HarmonicEnergy(),
step_size=0.01,
noise_scale=0.1
).to(device)

# Generate samples with thinning
initial_state = torch.tensor([2.0], device=device)
samples, diagnostics = sampler.sample_chain(
initial_state=initial_state,
n_steps=50, # steps between samples
n_samples=1000, # number of samples
thin=10, # keep every 10th sample
return_diagnostics=True
)

# Custom analysis of results
def analyze_convergence(samples: torch.Tensor, diagnostics: list) -> Tuple[float, float]:
"""Example utility function to analyze convergence."""
mean = samples.mean().item()
std = samples.std().item()
return mean, std

mean, std = analyze_convergence(samples, diagnostics)
print(f"Sample Statistics - Mean: {mean:.3f}, Std: {std:.3f}")


if __name__ == "__main__":
main()
print("Running basic example...")
basic_example()

print("\nRunning advanced example...")
advanced_example()

print("\nRunning utilities example...")
sampling_utilities_example()
Loading

0 comments on commit ba93f8b

Please sign in to comment.