Skip to content

Commit

Permalink
Wigner matrices (#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dsantra92 authored Sep 9, 2024
1 parent 527d1a7 commit 5d9feff
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 10 deletions.
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
name = "e3nn"
uuid = "1c50a8ea-cbe2-4d3e-83e0-d59f5e8851b3"
authors = ["Deeptendu Santra <deeptendu.santra@protonmail.com> and contributors"]
authors = [
"Deeptendu Santra <deeptendu.santra@protonmail.com> and contributors",
]
version = "0.1.0"

[deps]
CoordinateTransformations = "150eb455-5306-5404-9cee-2592286d6298"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Quaternions = "94ee1d12-ae83-5a48-8b1c-48b8ff168ae0"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Rotations = "6038ab10-8711-5258-84ad-4b1120ba62dc"
Expand Down
11 changes: 9 additions & 2 deletions src/o3/o3.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,20 @@ export remove_zero_multiplicities, num_irreps, ls, lmax
# not implemented yet
export D_from_angles, D_from_quaternion, D_from_axis_angle, D_from_matrix

include("irrepsarray.jl")
export IrrepsArray

include("wigner.jl")
export so3_generators, su2_generators, wigner_D

include("rotations.jl")
using .rot
export Quaternion
export RotMatrix, RotMatrix3, AngleAxis, RotYXY, QuatRotation
export euler_angles, CartesianToSphericalAngles, SphercialAnglesToCartesian

include("spherical_harmonics.jl")
export spherical_harmonics, SphericalHarmonics

include("s2grid.jl")
using .S2Grid

end
8 changes: 4 additions & 4 deletions src/o3/rotations.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module rot

using LinearAlgebra
using Quaternions
using Rotations
using StaticArrays
Expand Down Expand Up @@ -41,14 +42,13 @@ function (::Type{Q})(R::RotMatrix3) where {Q <: Quaternion}
return QuatRotation(R).q |> Q
end

function CartesianToSphericalAngles(x::AbstractVector{T}) where {T <: Real}
length(x) == 3 || error("Spherical transform takes a 3D coordinate")
function CartesianToSphericalAngles(x::SVector{3, T}) where {T <: Real}

# done in e3nn to remove NaNs
# need to check for Julia
normalize!(x, p = 2)
normalize!(x, 2)
clamp!(x, -1, 1)
return SVector(acosx[2], atan(x[1], x[3]))
return SVector(acos[2], atan(x[1], x[3]))
end

function SphercialAnglesToCartesian::Real, β::Real)
Expand Down
3 changes: 1 addition & 2 deletions src/o3/s2grid.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
module S2Grid

using LinearAlgebra
using FFTW
using StaticArrays

import Base: *, +, -, /
Expand All @@ -20,7 +19,7 @@ struct SphericalSignal{T <: AbstractArray}
throw(ArgumentError("Grid values should have at least 2 axes. Got grid_values of shape $(size(grid_values))."))
end

if !(quadraturie in ["soft", "gausslegendre"])
if !(quadrature in ["soft", "gausslegendre"])
throw(ArgumentError("Invalid quadrature for SphericalSignal: $quadrature"))
end

Expand Down
2 changes: 2 additions & 0 deletions src/o3/spherical_harmonics.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using LinearAlgebra

struct SphericalHarmonics
normalize::Bool
normalization::String
Expand Down
59 changes: 59 additions & 0 deletions src/o3/wigner.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
using LinearAlgebra
using StaticArrays

function su2_generators(j::Int)
m = range(-j, j - 1, step = 1)
raising = diagm(-1 => -sqrt.(j * (j + 1) .- m .* (m .+ 1)))

m = range(-j + 1, j, step = 1)
lowering = diagm(1 => sqrt.(j * (j + 1) .- m .* (m .- 1)))

m = range(-j, j, step = 1)
return stack(
[0.5 * (raising + lowering), # x (usually)
Diagonal(1im * m), # z (usually)
-0.5im * (raising - lowering)], # -y (usually)
dims = 3)
end

function change_basis_real_to_complex(l::Int)
# https://en.wikipedia.org/wiki/Spherical_harmonics#Real_form
q = zeros(ComplexF64, 2 * l + 1, 2 * l + 1)
for m in (-l):-1
q[l + m + 1, l + abs(m) + 1] = 1 / sqrt(2)
q[l + m + 1, l - abs(m) + 1] = -im / sqrt(2)
end
q[l + 1, l + 1] = 1
for m in 1:l
q[l + m + 1, l + abs(m) + 1] = (-1)^m / sqrt(2)
q[l + m + 1, l - abs(m) + 1] = im * (-1)^m / sqrt(2)
end
q = (-im)^l * q # Added factor of im^l to make the Clebsch-Gordan coefficients real
return q
end

function so3_generators(l::Int)
X = su2_generators(l)
Q = change_basis_real_to_complex(l)
Q_c_T = conj(transpose(Q))

for i in 1:size(X, 3)
@views X[:, :, i] = Q_c_T * X[:, :, i] * Q
end
return real.(X)
end

function wigner_D(l::Int, α::T, β::T, γ::T) where {T <: Real}
X = so3_generators(l)
return exp* X[:, :, 2]) * exp* X[:, :, 1]) * exp* X[:, :, 2])
end

function wigner_D(l::Int, angles::Tuple{Vararg{T, 3}}) where {T <: Real}
return wigner_D(l, angles[1], angles[2], angles[3])
end

function wigner_D(l::Int, angles::SVector{3, T}) where {T <: Real}
return wigner_D(l, angles[1], angles[2], angles[3])
end

Broadcast.broadcast(::typeof(wigner_D), l, α, β, γ) = broadcast(wigner_D, Ref(l), α, β, γ)
13 changes: 13 additions & 0 deletions test/o3/wigner.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
using e3nn.o3
using Test
using Rotations
using MLUtils: batch

@testset "WignerD" begin
@testset "basic" begin
R = rand(RotYXY, 10)
angles = R .|> Rotations.params
D = wigner_D.(1, angles)
@test batch(R - D) .|> abs |> maximum < 1e-8
end
end

0 comments on commit 5d9feff

Please sign in to comment.