Skip to content

Commit 76112c0

Browse files
authored
Merge pull request #8 from MurrellGroup/murrellb-patch-1
Add DoobMatchingFlow
2 parents b4525f7 + d300d36 commit 76112c0

File tree

6 files changed

+170
-3
lines changed

6 files changed

+170
-3
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
*.jl.mem
44
/docs/Manifest.toml
55
/docs/build/
6+
Manifest*.toml

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ version = "0.1.2"
66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
88
ForwardBackward = "e879419d-bb0f-4252-adee-d266c51ac92d"
9+
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
910
Manifolds = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e"
1011
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1112
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
@@ -14,6 +15,7 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1415
[compat]
1516
Adapt = "4.1.1"
1617
ForwardBackward = "0.1.0"
18+
LogExpFunctions = "0.3.29"
1719
Manifolds = "0.10.12"
1820
NNlib = "0.9.27"
1921
OneHotArrays = "0.2.6"

examples/doob.jl

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
#Note: Haven't figured out exactly what, in the literature, this is
2+
using Pkg
3+
Pkg.activate(".")
4+
using ForwardBackward, Flowfusion, NNlib, Flux, RandomFeatureMaps, Optimisers, Plots
5+
6+
struct DModel{A}
7+
layers::A
8+
end
9+
10+
Flux.@layer DModel
11+
12+
function DModel(; embeddim = 64, l = 2, K = 32, layers = 5)
13+
embed_time = Chain(RandomFourierFeatures(1 => embeddim, 2.0f0), Dense(embeddim => embeddim, leakyrelu))
14+
embed_char = Dense(K => embeddim, bias = false)
15+
mix = Dense(l*embeddim => embeddim, leakyrelu)
16+
ffs = [Dense(embeddim => embeddim, leakyrelu) for _ in 1:layers]
17+
decode = Dense(embeddim => l*K)
18+
layers = (; embed_time, embed_char, mix, ffs, decode)
19+
DModel(layers)
20+
end
21+
22+
function (f::DModel)(t, Xt)
23+
l = f.layers
24+
tXt = tensor(Xt)
25+
len = size(tXt)[end]
26+
tv = zero(similar(Float32.(tXt), 1, len)) .+ expand(t, 2)
27+
x = l.mix(reshape(l.embed_char(tXt), :, len)) .+ l.embed_time(tv)
28+
for ff in l.ffs
29+
x = x .+ ff(x)
30+
end
31+
reshape(l.decode(x), :, 2, len)
32+
end
33+
34+
T = Float32
35+
n_samples = 1000
36+
37+
sampleX1(n_samples) = Flowfusion.random_discrete_cat(n_samples)
38+
sampleX0(n_samples) = rand(25:32, 2, n_samples)
39+
P = DoobMatchingFlow(UniformDiscrete(1f0)) #The rate of the inner process controls how noisy the paths are
40+
41+
#If you use a UniformUnmasking process, you must start in the last token for Doob h to be defined.
42+
#Generally, an X0 without token overlap with the training data might give better results!
43+
#sampleX0(n_samples) = [33 for _ in zeros(2, n_samples)]
44+
#P = DoobMatchingFlow(UniformUnmasking(1f0)) #The rate of the inner process controls how noisy the paths are
45+
46+
model = DModel(embeddim = 128, l = 2, K = 33, layers = 2)
47+
48+
orig_eta = eta = 0.001
49+
opt_state = Flux.setup(AdamW(eta = eta, lambda = 0.0001), model)
50+
51+
iters = 3500
52+
for i in 1:iters
53+
#Set up a batch of training pairs, and t
54+
X1 = DiscreteState(33, sampleX1(n_samples))
55+
X0 = DiscreteState(33, sampleX0(n_samples))
56+
t = rand(T, 1, n_samples)
57+
#Construct the bridge:
58+
Xt = bridge(P, X0, X1, t)
59+
Xt = onehot(Xt) #<-Need this for the doob loss
60+
denseXt = dense(Xt) #<-Zygote doesn't like the onehot input, so we make it dense.
61+
G = Guide(P, t, Xt, onehot(X1)) #This sets up the "training target rate" via a Doob h-transform
62+
#Gradient
63+
l,g = Flux.withgradient(model) do m
64+
floss(P, Xt, m(t,denseXt), G, scalefloss(P,t,1))
65+
end
66+
#Update
67+
Flux.update!(opt_state, model, g[1])
68+
if i % 10 == 0
69+
if i > iters - 1000
70+
global eta = max(eta - orig_eta/100, 1e-9)
71+
Optimisers.adjust!(opt_state, eta)
72+
end
73+
println("i: $i; Loss: $l; eta: $eta")
74+
end
75+
end
76+
77+
78+
n_inference_samples = 10000
79+
X0 = DiscreteState(33, sampleX0(n_inference_samples))
80+
paths = Tracker()
81+
@time samp = gen(P, X0, (t,Xt) -> model(t,onehot(Xt)), 0f0:0.005f0:1f0, tracker = paths)
82+
83+
pl = scatter(X0.state[1,:],X0.state[2,:], msw = 0, color = "blue", alpha = 0.4, label = "Initial", size = (400,400), legend = :topleft, xlim = (1,33), ylim = (1,33))
84+
scatter!(samp.state[1,:],samp.state[2,:], msw = 0, color = "green", alpha = 0.04, label = :none)
85+
scatter!([-10],[-10], msw = 0, color = "green", alpha = 0.3, label = "Sampled")
86+
tvec = stack_tracker(paths, :t)
87+
xttraj = stack_tracker(paths, :xt)
88+
for i in 1:200:n_inference_samples
89+
plot!(xttraj[1,i,:], xttraj[2,i,:], color = "red", label = :none, alpha = 0.15)
90+
end
91+
plot!([-10],[-10], color = "red", label = "Trajectory", alpha = 0.4)
92+
pl
93+
savefig("discrete_doob.svg")

src/Flowfusion.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,20 @@ Later:
2222

2323
module Flowfusion
2424

25-
using ForwardBackward, OneHotArrays, Adapt, Manifolds, NNlib
25+
using ForwardBackward, OneHotArrays, Adapt, Manifolds, NNlib, LogExpFunctions
2626

2727
include("types.jl")
2828
include("mask.jl")
2929
include("bridge.jl")
3030
include("loss.jl")
3131
include("processes.jl")
32+
include("doob.jl")
3233

3334
export
3435
#Processes not in ForwardBackward.jl
3536
InterpolatingDiscreteFlow,
3637
NoisyInterpolatingDiscreteFlow,
38+
DoobMatchingFlow,
3739
MaskedState,
3840
Guide,
3941
tangent_guide,

src/bridge.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,13 @@ end
5050
#resolveprediction exists to stop bridge from needing multiple definitions.
5151
#Tuple broadcast:
5252
resolveprediction(dest::Tuple, src::Tuple) = map(resolveprediction, dest, src)
53+
5354
#Default if X̂₁ is a plain tensor:
54-
resolveprediction(X̂₁, Xₜ::DiscreteState{<:AbstractArray{<:Signed}}) = copytensor!(stochastic(Xₜ), X̂₁) #Returns a Likelihood
55-
resolveprediction(X̂₁, Xₜ::DiscreteState{<:Union{OneHotArray, OneHotMatrix}}) = copytensor!(stochastic(unhot(Xₜ)), X̂₁) #Probably inefficient
55+
#I think these were serving processes with a faulty assumption, so I'm swapping them out to make Doob flows easier.
56+
#resolveprediction(X̂₁, Xₜ::DiscreteState{<:AbstractArray{<:Signed}}) = copytensor!(stochastic(Xₜ), X̂₁) #Returns a Likelihood
57+
#resolveprediction(X̂₁, Xₜ::DiscreteState{<:Union{OneHotArray, OneHotMatrix}}) = copytensor!(stochastic(unhot(Xₜ)), X̂₁) #Probably inefficient
58+
resolveprediction(X̂₁, Xₜ::DiscreteState{<:AbstractArray{<:Signed}}) = X̂₁ #<-Need to test if this breaking anything else
59+
resolveprediction(X̂₁, Xₜ::DiscreteState{<:Union{OneHotArray, OneHotMatrix}}) = X̂₁ #<-Need to test if this breaking anything else
5660

5761
resolveprediction(X̂₁, Xₜ::State) = copytensor!(copy(Xₜ), X̂₁) #Returns a State - Handles Continuous and Manifold cases
5862
#Passthrough if the user returns a State or Likelihood

src/doob.jl

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
#Note: Haven't figured out exactly what, in the literature, this is. Not very tested!
2+
3+
struct DoobMatchingFlow{Proc, B, F} <: Process
4+
P::Proc
5+
onescale::B #Controls whether the "step" is unit scale or "time remaining" scale. Need to think carefully about schedules in all this...
6+
transform::F #Transforms the output of the model to the rate space. Must act on the whole tensor.
7+
#Note: losses can be compared for different transforms, but not for different onescale.
8+
end
9+
10+
DoobMatchingFlow(P::DiscreteProcess) = DoobMatchingFlow(P, true, NNlib.softplus) #x -> exp.(clamp.(x, -100, 11)) also works, but is scary
11+
DoobMatchingFlow(P::DiscreteProcess, transform::Function) = DoobMatchingFlow(P, true, transform)
12+
DoobMatchingFlow(P::DiscreteProcess, onescale::Bool) = DoobMatchingFlow(P, onescale, NNlib.softplus)
13+
14+
onescale(P::DoobMatchingFlow,t) = P.onescale ? (1 .- t) : eltype(t)(1)
15+
mulexpand(t,x) = expand(t, ndims(x)) .* x
16+
17+
Flowfusion.bridge(p::DoobMatchingFlow, x0::DiscreteState{<:AbstractArray{<:Signed}}, x1::DiscreteState{<:AbstractArray{<:Signed}}, t) = bridge(p.P, x0, x1, t)
18+
19+
function fallback_doob(P::DiscreteProcess, t, Xt::DiscreteState, X1::DiscreteState; delta = eltype(t)(1e-5))
20+
return (tensor(forward(Xt, P, delta) backward(X1, P, (1 .- t) .- delta)) .- tensor(onehot(Xt))) ./ delta;
21+
end
22+
23+
doob_guide(P::DiscreteProcess, t, Xt::DiscreteState, X1::DiscreteState) = fallback_doob(P, t, Xt, X1)
24+
25+
function closed_form_doob(P::DiscreteProcess, t, Xt::DiscreteState, X1::DiscreteState)
26+
tenXt = tensor(onehot(Xt))
27+
bk = tensor(backward(X1, P, 1 .- t))
28+
fv = forward_positive_velocities(onehot(Xt), P)
29+
positive_doob = (fv .* bk) ./ sum(bk .* tenXt, dims = 1)
30+
return positive_doob .- tenXt .* sum(positive_doob, dims = 1)
31+
end
32+
33+
forward_positive_velocities(Xt::DiscreteState, P::PiQ)= (P.r .* (P.π ./ sum(P.π))) .* (1 .- tensor(onehot(Xt)))
34+
doob_guide(P::PiQ, t, Xt::DiscreteState, X1::DiscreteState) = closed_form_doob(P, t, Xt, X1)
35+
forward_positive_velocities(Xt::DiscreteState, P::UniformUnmasking{T}) where T = (P.μ .* T((1 ./ (Xt.K-1)))) .* (1 .- tensor(onehot(Xt)))
36+
doob_guide(P::UniformUnmasking, t, Xt::DiscreteState, X1::DiscreteState) = closed_form_doob(P, t, Xt, X1)
37+
forward_positive_velocities(Xt::DiscreteState, P::UniformDiscrete{T}) where T = (P.μ * T(1/(Xt.K*(1-1/Xt.K)))) .* (1 .- tensor(onehot(Xt)))
38+
doob_guide(P::UniformDiscrete, t, Xt::DiscreteState, X1::DiscreteState) = closed_form_doob(P, t, Xt, X1)
39+
40+
Guide(P::DoobMatchingFlow, t, Xt::DiscreteState, X1::DiscreteState) = Flowfusion.Guide(mulexpand(onescale(P, t), doob_guide(P.P, t, Xt, X1)))
41+
Guide(P::DoobMatchingFlow, t, mXt::Union{MaskedState{<:DiscreteState}, DiscreteState}, mX1::MaskedState{<:DiscreteState}) = Guide(mulexpand(onescale(P, t), doob_guide(P.P, t, mXt, mX1)), mX1.cmask, mX1.lmask)
42+
43+
function rate_constraint(Xt, X̂₁, f)
44+
posQt = f(X̂₁) .* (1 .- Xt)
45+
diagQt = -sum(posQt, dims = 1) .* Xt
46+
return posQt .+ diagQt
47+
end
48+
49+
function velo_step(P, Xₜ::DiscreteState{<:AbstractArray{<:Signed}}, delta_t, log_velocity, scale)
50+
ohXₜ = onehot(Xₜ)
51+
velocity = rate_constraint(tensor(ohXₜ), log_velocity, P.transform) .* scale
52+
newXₜ = CategoricalLikelihood(eltype(delta_t).(tensor(ohXₜ) .+ (delta_t .* velocity)))
53+
clamp!(tensor(newXₜ), 0, Inf) #Because one velo will be < 0 and a large step might push Xₜ < 0
54+
return rand(newXₜ)
55+
end
56+
57+
step(P::DoobMatchingFlow, Xₜ::DiscreteState{<:AbstractArray{<:Signed}}, veloX̂₁::Flowfusion.Guide, s₁, s₂) = velo_step(P, Xₜ, s₂ .- s₁, veloX̂₁.H, expand(1 ./ onescale(P, s₁), ndims(veloX̂₁.H)))
58+
step(P::DoobMatchingFlow, Xₜ::DiscreteState{<:AbstractArray{<:Signed}}, veloX̂₁, s₁, s₂) = velo_step(P, Xₜ, s₂ .- s₁, veloX̂₁, expand(1 ./ onescale(P, s₁), ndims(veloX̂₁)))
59+
60+
function cgm_dloss(P, Xt, X̂₁, doobX₁)
61+
Qt = P.transform(X̂₁)
62+
return sum((1 .- Xt) .* (Qt .- xlogy.(doobX₁, Qt)), dims = 1) #<- note, diagonals ignored; implicit zero sum
63+
end
64+
65+
floss(P::Flowfusion.fbu(DoobMatchingFlow), Xt::Flowfusion.msu(DiscreteState), X̂₁, X₁::Guide, c) = Flowfusion.scaledmaskedmean(cgm_dloss(P, tensor(Xt), tensor(X̂₁), X₁.H), c, Flowfusion.getlmask(X₁))

0 commit comments

Comments
 (0)