Skip to content

Commit

Permalink
Plot data; look at distros. Generally took AGES to run.
Browse files Browse the repository at this point in the history
This is a 2018 style continuous normalising flow, so you are somehow
taking the Jacobian of the solution.

Flow matching is the technique which has made these methods
competitative with diffusion models.

I am slightly concerned whether normalising flow can deal with negative
densities: I think it might have a sign problem where you are forced to
consider everything being a positive definite probability density.
  • Loading branch information
jarvist committed Oct 16, 2024
1 parent 6a180ec commit 1adc80d
Showing 1 changed file with 86 additions and 1 deletion.
87 changes: 86 additions & 1 deletion notebooks/2D_backflow_DiffEqFlux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ using ComponentArrays, DiffEqFlux, OrdinaryDiffEq, Optimization, Distributions,
OptimizationOptimisers, OptimizationOptimJL


# ╔═╡ f0b7f232-9f15-45fb-8fd9-cceebbdd651b
using Gnuplot

# ╔═╡ 9927b14e-490a-11ef-0110-6d89dd7c4844
# 2D backflow wavefunction node visualiser
# Following PRB 78 035104 (2008) - Fermionic quantum criticality
Expand Down Expand Up @@ -241,20 +244,68 @@ begin
new_data = rand(ffjord_dist, 100)
end

# ╔═╡ 6a9c8007-2ec5-440f-8cb2-e22b1ae2d2d7
begin
@gp 1:100 new_data "w lp"
@gp :- 1:100 -train_data "w lp"

end

# ╔═╡ 0b4ed2a0-6842-479c-9527-d32b9e578fe5
begin
@gp hist(vec(rand(ffjord_dist, 1000)))
@gp :- -train_data |> vec |> hist
end

# ╔═╡ 1a287044-5cb2-48f6-a932-1652d2acd4ab


# ╔═╡ ff794a70-02c1-4f19-af89-25cef3d7eb54
maximum(new_data)

# ╔═╡ d3fd2fbd-341a-4ec9-aae1-cd6e7f281868
train_data

# ╔═╡ 04a32f43-bb9f-4ed8-8072-e73486f50409
# OK, let's have a look at this model
model

# ╔═╡ c97dda78-5dc9-4b51-a3c7-87b58d395b1d
model

# ╔═╡ 4ac327cf-0b52-4060-bece-00719460437d


# ╔═╡ b3835164-1913-4d7f-916c-b8b32db67dbe


# ╔═╡ 00000000-0000-0000-0000-000000000001
PLUTO_PROJECT_TOML_CONTENTS = """
[deps]
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
DiffEqFlux = "aae7a2af-3d4f-5e19-a356-7da93b79d9d0"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Gnuplot = "dc211083-a33a-5b79-959f-2ff34033469d"
Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
[compat]
ComponentArrays = "~0.15.17"
DiffEqFlux = "~4.0.0"
Distances = "~0.10.11"
Distributions = "~0.25.112"
Gnuplot = "~1.6.5"
Images = "~0.24.1"
Optimization = "~3.26.3"
OptimizationOptimJL = "~0.3.2"
OptimizationOptimisers = "~0.2.1"
OrdinaryDiffEq = "~6.89.0"
"""

# ╔═╡ 00000000-0000-0000-0000-000000000002
Expand All @@ -263,7 +314,7 @@ PLUTO_MANIFEST_TOML_CONTENTS = """
julia_version = "1.10.0"
manifest_format = "2.0"
project_hash = "40350df2796768eacfbe3f19b299cd20f2c627de"
project_hash = "49bbc44afb44f41aab42d39ae0b162b084778c9e"
[[deps.ADTypes]]
git-tree-sha1 = "eea5d80188827b35333801ef97a40c2ed653b081"
Expand Down Expand Up @@ -485,6 +536,12 @@ git-tree-sha1 = "05ba0d07cd4fd8b7a39541e31a7b0254704ea581"
uuid = "fb6a15b2-703c-40df-9091-08a04967cfa9"
version = "0.1.13"
[[deps.ColorSchemes]]
deps = ["ColorTypes", "ColorVectorSpace", "Colors", "FixedPointNumbers", "PrecompileTools", "Random"]
git-tree-sha1 = "b5278586822443594ff615963b0c09755771b3e0"
uuid = "35d6a980-a343-548e-a6ea-1d62b119f2f4"
version = "3.26.0"
[[deps.ColorTypes]]
deps = ["FixedPointNumbers", "Random"]
git-tree-sha1 = "b10d0b65641d57b8b4d5e234446582de5047050d"
Expand Down Expand Up @@ -1002,6 +1059,12 @@ git-tree-sha1 = "43ba3d3c82c18d88471cfd2924931658838c9d8f"
uuid = "61579ee1-b43e-5ca0-a5da-69d92c66a64b"
version = "9.55.0+4"
[[deps.Gnuplot]]
deps = ["ColorSchemes", "ColorTypes", "Colors", "DataStructures", "PrecompileTools", "REPL", "ReplMaker", "StatsBase", "StructC14N", "Test"]
git-tree-sha1 = "72b7242dccedbe153dadbf1e1412f9bff3d81bad"
uuid = "dc211083-a33a-5b79-959f-2ff34033469d"
version = "1.6.5"
[[deps.Graphics]]
deps = ["Colors", "LinearAlgebra", "NaNMath"]
git-tree-sha1 = "d61890399bc535850c4bf08e4e0d3a7ad0f21cbd"
Expand Down Expand Up @@ -2214,6 +2277,12 @@ git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b"
uuid = "189a3867-3050-52da-a836-e630ba90ab69"
version = "1.2.2"
[[deps.ReplMaker]]
deps = ["REPL", "Unicode"]
git-tree-sha1 = "f8bb680b97ee232c4c6591e213adc9c1e4ba0349"
uuid = "b873ce64-0db9-51f5-a568-4457d8e49576"
version = "0.2.7"
[[deps.Requires]]
deps = ["UUIDs"]
git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7"
Expand Down Expand Up @@ -2514,6 +2583,12 @@ weakdeps = ["Adapt", "GPUArraysCore", "SparseArrays", "StaticArrays"]
StructArraysSparseArraysExt = "SparseArrays"
StructArraysStaticArraysExt = "StaticArrays"
[[deps.StructC14N]]
deps = ["DataStructures", "Test"]
git-tree-sha1 = "a3d153488e0fe30715835e66585532c0bcf460e9"
uuid = "d2514e9c-36c4-5b8e-97e2-51e7675c221c"
version = "0.3.1"
[[deps.StructIO]]
git-tree-sha1 = "c581be48ae1cbf83e899b14c07a807e1787512cc"
uuid = "53d494c1-5632-5724-8f4c-31dff12d585f"
Expand Down Expand Up @@ -2769,5 +2844,15 @@ version = "17.4.0+2"
# ╠═1e72bfb6-7f1c-4d71-81f9-efd3d03fa90d
# ╠═68b11a40-e806-4ce1-a21b-28ed66612672
# ╠═66ece731-d926-4f95-a2bf-c54ac73137bd
# ╠═f0b7f232-9f15-45fb-8fd9-cceebbdd651b
# ╠═6a9c8007-2ec5-440f-8cb2-e22b1ae2d2d7
# ╠═0b4ed2a0-6842-479c-9527-d32b9e578fe5
# ╠═1a287044-5cb2-48f6-a932-1652d2acd4ab
# ╠═ff794a70-02c1-4f19-af89-25cef3d7eb54
# ╠═d3fd2fbd-341a-4ec9-aae1-cd6e7f281868
# ╠═04a32f43-bb9f-4ed8-8072-e73486f50409
# ╠═c97dda78-5dc9-4b51-a3c7-87b58d395b1d
# ╠═4ac327cf-0b52-4060-bece-00719460437d
# ╠═b3835164-1913-4d7f-916c-b8b32db67dbe
# ╟─00000000-0000-0000-0000-000000000001
# ╟─00000000-0000-0000-0000-000000000002

0 comments on commit 1adc80d

Please sign in to comment.