From 82f6562d40db94133000ab8659233c61078318be Mon Sep 17 00:00:00 2001 From: Deeptendu Santra Date: Mon, 9 Sep 2024 23:16:34 +0530 Subject: [PATCH] Wigner matrices --- Project.toml | 6 ++-- src/o3/o3.jl | 11 +++++-- src/o3/rotations.jl | 8 ++--- src/o3/s2grid.jl | 3 +- src/o3/spherical_harmonics.jl | 2 ++ src/o3/wigner.jl | 59 +++++++++++++++++++++++++++++++++++ test/o3/wigner.jl | 13 ++++++++ 7 files changed, 92 insertions(+), 10 deletions(-) create mode 100644 src/o3/wigner.jl create mode 100644 test/o3/wigner.jl diff --git a/Project.toml b/Project.toml index fce94f5..5f243f1 100644 --- a/Project.toml +++ b/Project.toml @@ -1,12 +1,14 @@ name = "e3nn" uuid = "1c50a8ea-cbe2-4d3e-83e0-d59f5e8851b3" -authors = ["Deeptendu Santra and contributors"] +authors = [ + "Deeptendu Santra and contributors", +] version = "0.1.0" [deps] CoordinateTransformations = "150eb455-5306-5404-9cee-2592286d6298" -FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" Quaternions = "94ee1d12-ae83-5a48-8b1c-48b8ff168ae0" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Rotations = "6038ab10-8711-5258-84ad-4b1120ba62dc" diff --git a/src/o3/o3.jl b/src/o3/o3.jl index bac79d0..c599f34 100644 --- a/src/o3/o3.jl +++ b/src/o3/o3.jl @@ -6,13 +6,20 @@ export remove_zero_multiplicities, num_irreps, ls, lmax # not implemented yet export D_from_angles, D_from_quaternion, D_from_axis_angle, D_from_matrix +include("irrepsarray.jl") +export IrrepsArray + +include("wigner.jl") +export so3_generators, su2_generators, wigner_D + include("rotations.jl") using .rot -export Quaternion -export RotMatrix, RotMatrix3, AngleAxis, RotYXY, QuatRotation export euler_angles, CartesianToSphericalAngles, SphercialAnglesToCartesian include("spherical_harmonics.jl") export spherical_harmonics, SphericalHarmonics +include("s2grid.jl") +using .S2Grid + end diff --git a/src/o3/rotations.jl b/src/o3/rotations.jl index 01b3a83..71b656c 100644 --- a/src/o3/rotations.jl +++ b/src/o3/rotations.jl @@ -1,5 +1,6 @@ module rot +using LinearAlgebra using Quaternions using Rotations using StaticArrays @@ -41,14 +42,13 @@ function (::Type{Q})(R::RotMatrix3) where {Q <: Quaternion} return QuatRotation(R).q |> Q end -function CartesianToSphericalAngles(x::AbstractVector{T}) where {T <: Real} - length(x) == 3 || error("Spherical transform takes a 3D coordinate") +function CartesianToSphericalAngles(x::SVector{3, T}) where {T <: Real} # done in e3nn to remove NaNs # need to check for Julia - normalize!(x, p = 2) + normalize!(x, 2) clamp!(x, -1, 1) - return SVector(acosx[2], atan(x[1], x[3])) + return SVector(acos[2], atan(x[1], x[3])) end function SphercialAnglesToCartesian(α::Real, β::Real) diff --git a/src/o3/s2grid.jl b/src/o3/s2grid.jl index 281f20a..28ccabf 100644 --- a/src/o3/s2grid.jl +++ b/src/o3/s2grid.jl @@ -1,7 +1,6 @@ module S2Grid using LinearAlgebra -using FFTW using StaticArrays import Base: *, +, -, / @@ -20,7 +19,7 @@ struct SphericalSignal{T <: AbstractArray} throw(ArgumentError("Grid values should have at least 2 axes. Got grid_values of shape $(size(grid_values)).")) end - if !(quadraturie in ["soft", "gausslegendre"]) + if !(quadrature in ["soft", "gausslegendre"]) throw(ArgumentError("Invalid quadrature for SphericalSignal: $quadrature")) end diff --git a/src/o3/spherical_harmonics.jl b/src/o3/spherical_harmonics.jl index 3b07915..c4a0bcb 100644 --- a/src/o3/spherical_harmonics.jl +++ b/src/o3/spherical_harmonics.jl @@ -1,3 +1,5 @@ +using LinearAlgebra + struct SphericalHarmonics normalize::Bool normalization::String diff --git a/src/o3/wigner.jl b/src/o3/wigner.jl new file mode 100644 index 0000000..843a6aa --- /dev/null +++ b/src/o3/wigner.jl @@ -0,0 +1,59 @@ +using LinearAlgebra +using StaticArrays + +function su2_generators(j::Int) + m = range(-j, j - 1, step = 1) + raising = diagm(-1 => -sqrt.(j * (j + 1) .- m .* (m .+ 1))) + + m = range(-j + 1, j, step = 1) + lowering = diagm(1 => sqrt.(j * (j + 1) .- m .* (m .- 1))) + + m = range(-j, j, step = 1) + return stack( + [0.5 * (raising + lowering), # x (usually) + Diagonal(1im * m), # z (usually) + -0.5im * (raising - lowering)], # -y (usually) + dims = 3) +end + +function change_basis_real_to_complex(l::Int) + # https://en.wikipedia.org/wiki/Spherical_harmonics#Real_form + q = zeros(ComplexF64, 2 * l + 1, 2 * l + 1) + for m in (-l):-1 + q[l + m + 1, l + abs(m) + 1] = 1 / sqrt(2) + q[l + m + 1, l - abs(m) + 1] = -im / sqrt(2) + end + q[l + 1, l + 1] = 1 + for m in 1:l + q[l + m + 1, l + abs(m) + 1] = (-1)^m / sqrt(2) + q[l + m + 1, l - abs(m) + 1] = im * (-1)^m / sqrt(2) + end + q = (-im)^l * q # Added factor of im^l to make the Clebsch-Gordan coefficients real + return q +end + +function so3_generators(l::Int) + X = su2_generators(l) + Q = change_basis_real_to_complex(l) + Q_c_T = conj(transpose(Q)) + + for i in 1:size(X, 3) + @views X[:, :, i] = Q_c_T * X[:, :, i] * Q + end + return real.(X) +end + +function wigner_D(l::Int, α::T, β::T, γ::T) where {T <: Real} + X = so3_generators(l) + return exp(α * X[:, :, 2]) * exp(β * X[:, :, 1]) * exp(γ * X[:, :, 2]) +end + +function wigner_D(l::Int, angles::Tuple{Vararg{T, 3}}) where {T <: Real} + return wigner_D(l, angles[1], angles[2], angles[3]) +end + +function wigner_D(l::Int, angles::SVector{3, T}) where {T <: Real} + return wigner_D(l, angles[1], angles[2], angles[3]) +end + +Broadcast.broadcast(::typeof(wigner_D), l, α, β, γ) = broadcast(wigner_D, Ref(l), α, β, γ) diff --git a/test/o3/wigner.jl b/test/o3/wigner.jl new file mode 100644 index 0000000..9d04ce1 --- /dev/null +++ b/test/o3/wigner.jl @@ -0,0 +1,13 @@ +using e3nn.o3 +using Test +using Rotations +using MLUtils: batch + +@testset "WignerD" begin + @testset "basic" begin + R = rand(RotYXY, 10) + angles = R .|> Rotations.params + D = wigner_D.(1, angles) + @test batch(R - D) .|> abs |> maximum < 1e-8 + end +end