1
+ # Note: Haven't figured out exactly what, in the literature, this is. Not very tested!
2
+
3
+ struct DoobMatchingFlow{Proc, B, F} <: Process
4
+ P:: Proc
5
+ onescale:: B # Controls whether the "step" is unit scale or "time remaining" scale. Need to think carefully about schedules in all this...
6
+ transform:: F # Transforms the output of the model to the rate space. Must act on the whole tensor.
7
+ # Note: losses can be compared for different transforms, but not for different onescale.
8
+ end
9
+
10
+ DoobMatchingFlow (P:: DiscreteProcess ) = DoobMatchingFlow (P, true , NNlib. softplus) # x -> exp.(clamp.(x, -100, 11)) also works, but is scary
11
+ DoobMatchingFlow (P:: DiscreteProcess , transform:: Function ) = DoobMatchingFlow (P, true , transform)
12
+ DoobMatchingFlow (P:: DiscreteProcess , onescale:: Bool ) = DoobMatchingFlow (P, onescale, NNlib. softplus)
13
+
14
+ onescale (P:: DoobMatchingFlow ,t) = P. onescale ? (1 .- t) : eltype (t)(1 )
15
+ mulexpand (t,x) = expand (t, ndims (x)) .* x
16
+
17
+ Flowfusion. bridge (p:: DoobMatchingFlow , x0:: DiscreteState{<:AbstractArray{<:Signed}} , x1:: DiscreteState{<:AbstractArray{<:Signed}} , t) = bridge (p. P, x0, x1, t)
18
+
19
+ function fallback_doob (P:: DiscreteProcess , t, Xt:: DiscreteState , X1:: DiscreteState ; delta = eltype (t)(1e-5 ))
20
+ return (tensor (forward (Xt, P, delta) ⊙ backward (X1, P, (1 .- t) .- delta)) .- tensor (onehot (Xt))) ./ delta;
21
+ end
22
+
23
+ doob_guide (P:: DiscreteProcess , t, Xt:: DiscreteState , X1:: DiscreteState ) = fallback_doob (P, t, Xt, X1)
24
+
25
+ function closed_form_doob (P:: DiscreteProcess , t, Xt:: DiscreteState , X1:: DiscreteState )
26
+ tenXt = tensor (onehot (Xt))
27
+ bk = tensor (backward (X1, P, 1 .- t))
28
+ fv = forward_positive_velocities (onehot (Xt), P)
29
+ positive_doob = (fv .* bk) ./ sum (bk .* tenXt, dims = 1 )
30
+ return positive_doob .- tenXt .* sum (positive_doob, dims = 1 )
31
+ end
32
+
33
+ forward_positive_velocities (Xt:: DiscreteState , P:: PiQ )= (P. r .* (P. π ./ sum (P. π))) .* (1 .- tensor (onehot (Xt)))
34
+ doob_guide (P:: PiQ , t, Xt:: DiscreteState , X1:: DiscreteState ) = closed_form_doob (P, t, Xt, X1)
35
+ forward_positive_velocities (Xt:: DiscreteState , P:: UniformUnmasking{T} ) where T = (P. μ .* T ((1 ./ (Xt. K- 1 )))) .* (1 .- tensor (onehot (Xt)))
36
+ doob_guide (P:: UniformUnmasking , t, Xt:: DiscreteState , X1:: DiscreteState ) = closed_form_doob (P, t, Xt, X1)
37
+ forward_positive_velocities (Xt:: DiscreteState , P:: UniformDiscrete{T} ) where T = (P. μ * T (1 / (Xt. K* (1 - 1 / Xt. K)))) .* (1 .- tensor (onehot (Xt)))
38
+ doob_guide (P:: UniformDiscrete , t, Xt:: DiscreteState , X1:: DiscreteState ) = closed_form_doob (P, t, Xt, X1)
39
+
40
+ Guide (P:: DoobMatchingFlow , t, Xt:: DiscreteState , X1:: DiscreteState ) = Flowfusion. Guide (mulexpand (onescale (P, t), doob_guide (P. P, t, Xt, X1)))
41
+ Guide (P:: DoobMatchingFlow , t, mXt:: Union{MaskedState{<:DiscreteState}, DiscreteState} , mX1:: MaskedState{<:DiscreteState} ) = Guide (mulexpand (onescale (P, t), doob_guide (P. P, t, mXt, mX1)), mX1. cmask, mX1. lmask)
42
+
43
+ function rate_constraint (Xt, X̂₁, f)
44
+ posQt = f (X̂₁) .* (1 .- Xt)
45
+ diagQt = - sum (posQt, dims = 1 ) .* Xt
46
+ return posQt .+ diagQt
47
+ end
48
+
49
+ function velo_step (P, Xₜ:: DiscreteState{<:AbstractArray{<:Signed}} , delta_t, log_velocity, scale)
50
+ ohXₜ = onehot (Xₜ)
51
+ velocity = rate_constraint (tensor (ohXₜ), log_velocity, P. transform) .* scale
52
+ newXₜ = CategoricalLikelihood (eltype (delta_t).(tensor (ohXₜ) .+ (delta_t .* velocity)))
53
+ clamp! (tensor (newXₜ), 0 , Inf ) # Because one velo will be < 0 and a large step might push Xₜ < 0
54
+ return rand (newXₜ)
55
+ end
56
+
57
+ step (P:: DoobMatchingFlow , Xₜ:: DiscreteState{<:AbstractArray{<:Signed}} , veloX̂₁:: Flowfusion.Guide , s₁, s₂) = velo_step (P, Xₜ, s₂ .- s₁, veloX̂₁. H, expand (1 ./ onescale (P, s₁), ndims (veloX̂₁. H)))
58
+ step (P:: DoobMatchingFlow , Xₜ:: DiscreteState{<:AbstractArray{<:Signed}} , veloX̂₁, s₁, s₂) = velo_step (P, Xₜ, s₂ .- s₁, veloX̂₁, expand (1 ./ onescale (P, s₁), ndims (veloX̂₁)))
59
+
60
+ function cgm_dloss (P, Xt, X̂₁, doobX₁)
61
+ Qt = P. transform (X̂₁)
62
+ return sum ((1 .- Xt) .* (Qt .- xlogy .(doobX₁, Qt)), dims = 1 ) # <- note, diagonals ignored; implicit zero sum
63
+ end
64
+
65
+ floss (P:: Flowfusion.fbu (DoobMatchingFlow), Xt:: Flowfusion.msu (DiscreteState), X̂₁, X₁:: Guide , c) = Flowfusion. scaledmaskedmean (cgm_dloss (P, tensor (Xt), tensor (X̂₁), X₁. H), c, Flowfusion. getlmask (X₁))
0 commit comments