Skip to content

Add DoobMatchingFlow #8

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
*.jl.mem
/docs/Manifest.toml
/docs/build/
Manifest*.toml
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ version = "0.1.2"
[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ForwardBackward = "e879419d-bb0f-4252-adee-d266c51ac92d"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
Manifolds = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
Expand All @@ -14,6 +15,7 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
[compat]
Adapt = "4.1.1"
ForwardBackward = "0.1.0"
LogExpFunctions = "0.3.29"
Manifolds = "0.10.12"
NNlib = "0.9.27"
OneHotArrays = "0.2.6"
Expand Down
93 changes: 93 additions & 0 deletions examples/doob.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
#Note: Haven't figured out exactly what, in the literature, this is
using Pkg
Pkg.activate(".")
using ForwardBackward, Flowfusion, NNlib, Flux, RandomFeatureMaps, Optimisers, Plots

struct DModel{A}
layers::A
end

Flux.@layer DModel

function DModel(; embeddim = 64, l = 2, K = 32, layers = 5)
embed_time = Chain(RandomFourierFeatures(1 => embeddim, 2.0f0), Dense(embeddim => embeddim, leakyrelu))
embed_char = Dense(K => embeddim, bias = false)
mix = Dense(l*embeddim => embeddim, leakyrelu)
ffs = [Dense(embeddim => embeddim, leakyrelu) for _ in 1:layers]
decode = Dense(embeddim => l*K)
layers = (; embed_time, embed_char, mix, ffs, decode)
DModel(layers)
end

function (f::DModel)(t, Xt)
l = f.layers
tXt = tensor(Xt)
len = size(tXt)[end]
tv = zero(similar(Float32.(tXt), 1, len)) .+ expand(t, 2)
x = l.mix(reshape(l.embed_char(tXt), :, len)) .+ l.embed_time(tv)
for ff in l.ffs
x = x .+ ff(x)
end
reshape(l.decode(x), :, 2, len)
end

T = Float32
n_samples = 1000

sampleX1(n_samples) = Flowfusion.random_discrete_cat(n_samples)
sampleX0(n_samples) = rand(25:32, 2, n_samples)
P = DoobMatchingFlow(UniformDiscrete(1f0)) #The rate of the inner process controls how noisy the paths are

#If you use a UniformUnmasking process, you must start in the last token for Doob h to be defined.
#Generally, an X0 without token overlap with the training data might give better results!
#sampleX0(n_samples) = [33 for _ in zeros(2, n_samples)]
#P = DoobMatchingFlow(UniformUnmasking(1f0)) #The rate of the inner process controls how noisy the paths are

model = DModel(embeddim = 128, l = 2, K = 33, layers = 2)

orig_eta = eta = 0.001
opt_state = Flux.setup(AdamW(eta = eta, lambda = 0.0001), model)

iters = 3500
for i in 1:iters
#Set up a batch of training pairs, and t
X1 = DiscreteState(33, sampleX1(n_samples))
X0 = DiscreteState(33, sampleX0(n_samples))
t = rand(T, 1, n_samples)
#Construct the bridge:
Xt = bridge(P, X0, X1, t)
Xt = onehot(Xt) #<-Need this for the doob loss
denseXt = dense(Xt) #<-Zygote doesn't like the onehot input, so we make it dense.
G = Guide(P, t, Xt, onehot(X1)) #This sets up the "training target rate" via a Doob h-transform
#Gradient
l,g = Flux.withgradient(model) do m
floss(P, Xt, m(t,denseXt), G, scalefloss(P,t,1))
end
#Update
Flux.update!(opt_state, model, g[1])
if i % 10 == 0
if i > iters - 1000
global eta = max(eta - orig_eta/100, 1e-9)
Optimisers.adjust!(opt_state, eta)
end
println("i: $i; Loss: $l; eta: $eta")
end
end


n_inference_samples = 10000
X0 = DiscreteState(33, sampleX0(n_inference_samples))
paths = Tracker()
@time samp = gen(P, X0, (t,Xt) -> model(t,onehot(Xt)), 0f0:0.005f0:1f0, tracker = paths)

pl = scatter(X0.state[1,:],X0.state[2,:], msw = 0, color = "blue", alpha = 0.4, label = "Initial", size = (400,400), legend = :topleft, xlim = (1,33), ylim = (1,33))
scatter!(samp.state[1,:],samp.state[2,:], msw = 0, color = "green", alpha = 0.04, label = :none)
scatter!([-10],[-10], msw = 0, color = "green", alpha = 0.3, label = "Sampled")
tvec = stack_tracker(paths, :t)
xttraj = stack_tracker(paths, :xt)
for i in 1:200:n_inference_samples
plot!(xttraj[1,i,:], xttraj[2,i,:], color = "red", label = :none, alpha = 0.15)
end
plot!([-10],[-10], color = "red", label = "Trajectory", alpha = 0.4)
pl
savefig("discrete_doob.svg")
4 changes: 3 additions & 1 deletion src/Flowfusion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,20 @@ Later:

module Flowfusion

using ForwardBackward, OneHotArrays, Adapt, Manifolds, NNlib
using ForwardBackward, OneHotArrays, Adapt, Manifolds, NNlib, LogExpFunctions

include("types.jl")
include("mask.jl")
include("bridge.jl")
include("loss.jl")
include("processes.jl")
include("doob.jl")

export
#Processes not in ForwardBackward.jl
InterpolatingDiscreteFlow,
NoisyInterpolatingDiscreteFlow,
DoobMatchingFlow,
MaskedState,
Guide,
tangent_guide,
Expand Down
8 changes: 6 additions & 2 deletions src/bridge.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,13 @@
#resolveprediction exists to stop bridge from needing multiple definitions.
#Tuple broadcast:
resolveprediction(dest::Tuple, src::Tuple) = map(resolveprediction, dest, src)

#Default if X̂₁ is a plain tensor:
resolveprediction(X̂₁, Xₜ::DiscreteState{<:AbstractArray{<:Signed}}) = copytensor!(stochastic(Xₜ), X̂₁) #Returns a Likelihood
resolveprediction(X̂₁, Xₜ::DiscreteState{<:Union{OneHotArray, OneHotMatrix}}) = copytensor!(stochastic(unhot(Xₜ)), X̂₁) #Probably inefficient
#I think these were serving processes with a faulty assumption, so I'm swapping them out to make Doob flows easier.
#resolveprediction(X̂₁, Xₜ::DiscreteState{<:AbstractArray{<:Signed}}) = copytensor!(stochastic(Xₜ), X̂₁) #Returns a Likelihood
#resolveprediction(X̂₁, Xₜ::DiscreteState{<:Union{OneHotArray, OneHotMatrix}}) = copytensor!(stochastic(unhot(Xₜ)), X̂₁) #Probably inefficient
resolveprediction(X̂₁, Xₜ::DiscreteState{<:AbstractArray{<:Signed}}) = X̂₁ #<-Need to test if this breaking anything else
resolveprediction(X̂₁, Xₜ::DiscreteState{<:Union{OneHotArray, OneHotMatrix}}) = X̂₁ #<-Need to test if this breaking anything else

Check warning on line 59 in src/bridge.jl

View check run for this annotation

Codecov / codecov/patch

src/bridge.jl#L58-L59

Added lines #L58 - L59 were not covered by tests

resolveprediction(X̂₁, Xₜ::State) = copytensor!(copy(Xₜ), X̂₁) #Returns a State - Handles Continuous and Manifold cases
#Passthrough if the user returns a State or Likelihood
Expand Down
65 changes: 65 additions & 0 deletions src/doob.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#Note: Haven't figured out exactly what, in the literature, this is. Not very tested!

struct DoobMatchingFlow{Proc, B, F} <: Process
P::Proc
onescale::B #Controls whether the "step" is unit scale or "time remaining" scale. Need to think carefully about schedules in all this...
transform::F #Transforms the output of the model to the rate space. Must act on the whole tensor.
#Note: losses can be compared for different transforms, but not for different onescale.
end

DoobMatchingFlow(P::DiscreteProcess) = DoobMatchingFlow(P, true, NNlib.softplus) #x -> exp.(clamp.(x, -100, 11)) also works, but is scary
DoobMatchingFlow(P::DiscreteProcess, transform::Function) = DoobMatchingFlow(P, true, transform)
DoobMatchingFlow(P::DiscreteProcess, onescale::Bool) = DoobMatchingFlow(P, onescale, NNlib.softplus)

Check warning on line 12 in src/doob.jl

View check run for this annotation

Codecov / codecov/patch

src/doob.jl#L10-L12

Added lines #L10 - L12 were not covered by tests

onescale(P::DoobMatchingFlow,t) = P.onescale ? (1 .- t) : eltype(t)(1)
mulexpand(t,x) = expand(t, ndims(x)) .* x

Check warning on line 15 in src/doob.jl

View check run for this annotation

Codecov / codecov/patch

src/doob.jl#L14-L15

Added lines #L14 - L15 were not covered by tests

Flowfusion.bridge(p::DoobMatchingFlow, x0::DiscreteState{<:AbstractArray{<:Signed}}, x1::DiscreteState{<:AbstractArray{<:Signed}}, t) = bridge(p.P, x0, x1, t)

Check warning on line 17 in src/doob.jl

View check run for this annotation

Codecov / codecov/patch

src/doob.jl#L17

Added line #L17 was not covered by tests

function fallback_doob(P::DiscreteProcess, t, Xt::DiscreteState, X1::DiscreteState; delta = eltype(t)(1e-5))
return (tensor(forward(Xt, P, delta) ⊙ backward(X1, P, (1 .- t) .- delta)) .- tensor(onehot(Xt))) ./ delta;

Check warning on line 20 in src/doob.jl

View check run for this annotation

Codecov / codecov/patch

src/doob.jl#L19-L20

Added lines #L19 - L20 were not covered by tests
end

doob_guide(P::DiscreteProcess, t, Xt::DiscreteState, X1::DiscreteState) = fallback_doob(P, t, Xt, X1)

Check warning on line 23 in src/doob.jl

View check run for this annotation

Codecov / codecov/patch

src/doob.jl#L23

Added line #L23 was not covered by tests

function closed_form_doob(P::DiscreteProcess, t, Xt::DiscreteState, X1::DiscreteState)
tenXt = tensor(onehot(Xt))
bk = tensor(backward(X1, P, 1 .- t))
fv = forward_positive_velocities(onehot(Xt), P)
positive_doob = (fv .* bk) ./ sum(bk .* tenXt, dims = 1)
return positive_doob .- tenXt .* sum(positive_doob, dims = 1)

Check warning on line 30 in src/doob.jl

View check run for this annotation

Codecov / codecov/patch

src/doob.jl#L25-L30

Added lines #L25 - L30 were not covered by tests
end

forward_positive_velocities(Xt::DiscreteState, P::PiQ)= (P.r .* (P.π ./ sum(P.π))) .* (1 .- tensor(onehot(Xt)))
doob_guide(P::PiQ, t, Xt::DiscreteState, X1::DiscreteState) = closed_form_doob(P, t, Xt, X1)
forward_positive_velocities(Xt::DiscreteState, P::UniformUnmasking{T}) where T = (P.μ .* T((1 ./ (Xt.K-1)))) .* (1 .- tensor(onehot(Xt)))
doob_guide(P::UniformUnmasking, t, Xt::DiscreteState, X1::DiscreteState) = closed_form_doob(P, t, Xt, X1)
forward_positive_velocities(Xt::DiscreteState, P::UniformDiscrete{T}) where T = (P.μ * T(1/(Xt.K*(1-1/Xt.K)))) .* (1 .- tensor(onehot(Xt)))
doob_guide(P::UniformDiscrete, t, Xt::DiscreteState, X1::DiscreteState) = closed_form_doob(P, t, Xt, X1)

Check warning on line 38 in src/doob.jl

View check run for this annotation

Codecov / codecov/patch

src/doob.jl#L33-L38

Added lines #L33 - L38 were not covered by tests

Guide(P::DoobMatchingFlow, t, Xt::DiscreteState, X1::DiscreteState) = Flowfusion.Guide(mulexpand(onescale(P, t), doob_guide(P.P, t, Xt, X1)))
Guide(P::DoobMatchingFlow, t, mXt::Union{MaskedState{<:DiscreteState}, DiscreteState}, mX1::MaskedState{<:DiscreteState}) = Guide(mulexpand(onescale(P, t), doob_guide(P.P, t, mXt, mX1)), mX1.cmask, mX1.lmask)

Check warning on line 41 in src/doob.jl

View check run for this annotation

Codecov / codecov/patch

src/doob.jl#L40-L41

Added lines #L40 - L41 were not covered by tests

function rate_constraint(Xt, X̂₁, f)
posQt = f(X̂₁) .* (1 .- Xt)
diagQt = -sum(posQt, dims = 1) .* Xt
return posQt .+ diagQt

Check warning on line 46 in src/doob.jl

View check run for this annotation

Codecov / codecov/patch

src/doob.jl#L43-L46

Added lines #L43 - L46 were not covered by tests
end

function velo_step(P, Xₜ::DiscreteState{<:AbstractArray{<:Signed}}, delta_t, log_velocity, scale)
ohXₜ = onehot(Xₜ)
velocity = rate_constraint(tensor(ohXₜ), log_velocity, P.transform) .* scale
newXₜ = CategoricalLikelihood(eltype(delta_t).(tensor(ohXₜ) .+ (delta_t .* velocity)))
clamp!(tensor(newXₜ), 0, Inf) #Because one velo will be < 0 and a large step might push Xₜ < 0
return rand(newXₜ)

Check warning on line 54 in src/doob.jl

View check run for this annotation

Codecov / codecov/patch

src/doob.jl#L49-L54

Added lines #L49 - L54 were not covered by tests
end

step(P::DoobMatchingFlow, Xₜ::DiscreteState{<:AbstractArray{<:Signed}}, veloX̂₁::Flowfusion.Guide, s₁, s₂) = velo_step(P, Xₜ, s₂ .- s₁, veloX̂₁.H, expand(1 ./ onescale(P, s₁), ndims(veloX̂₁.H)))
step(P::DoobMatchingFlow, Xₜ::DiscreteState{<:AbstractArray{<:Signed}}, veloX̂₁, s₁, s₂) = velo_step(P, Xₜ, s₂ .- s₁, veloX̂₁, expand(1 ./ onescale(P, s₁), ndims(veloX̂₁)))

Check warning on line 58 in src/doob.jl

View check run for this annotation

Codecov / codecov/patch

src/doob.jl#L57-L58

Added lines #L57 - L58 were not covered by tests

function cgm_dloss(P, Xt, X̂₁, doobX₁)
Qt = P.transform(X̂₁)
return sum((1 .- Xt) .* (Qt .- xlogy.(doobX₁, Qt)), dims = 1) #<- note, diagonals ignored; implicit zero sum

Check warning on line 62 in src/doob.jl

View check run for this annotation

Codecov / codecov/patch

src/doob.jl#L60-L62

Added lines #L60 - L62 were not covered by tests
end

floss(P::Flowfusion.fbu(DoobMatchingFlow), Xt::Flowfusion.msu(DiscreteState), X̂₁, X₁::Guide, c) = Flowfusion.scaledmaskedmean(cgm_dloss(P, tensor(Xt), tensor(X̂₁), X₁.H), c, Flowfusion.getlmask(X₁))

Check warning on line 65 in src/doob.jl

View check run for this annotation

Codecov / codecov/patch

src/doob.jl#L65

Added line #L65 was not covered by tests