diff --git a/src/doob.jl b/src/doob.jl index 2045417..293e875 100644 --- a/src/doob.jl +++ b/src/doob.jl @@ -14,7 +14,9 @@ DoobMatchingFlow(P::DiscreteProcess, onescale::Bool) = DoobMatchingFlow(P, onesc onescale(P::DoobMatchingFlow,t) = P.onescale ? (1 .- t) : eltype(t)(1) mulexpand(t,x) = expand(t, ndims(x)) .* x -Flowfusion.bridge(p::DoobMatchingFlow, x0::DiscreteState{<:AbstractArray{<:Signed}}, x1::DiscreteState{<:AbstractArray{<:Signed}}, t) = bridge(p.P, x0, x1, t) +#We could consider making this preserve one-hotness: +bridge(p::DoobMatchingFlow, x0::DiscreteState, x1::DiscreteState, t) = bridge(p.P, x0, x1, t) +bridge(p::DoobMatchingFlow, x0::DiscreteState, x1::DiscreteState, t0, t) = bridge(p.P, x0, x1, t0, t) function fallback_doob(P::DiscreteProcess, t, Xt::DiscreteState, X1::DiscreteState; delta = eltype(t)(1e-5)) return (tensor(forward(Xt, P, delta) ⊙ backward(X1, P, (1 .- t) .- delta)) .- tensor(onehot(Xt))) ./ delta; @@ -46,7 +48,7 @@ function rate_constraint(Xt, X̂₁, f) return posQt .+ diagQt end -function velo_step(P, Xₜ::DiscreteState{<:AbstractArray{<:Signed}}, delta_t, log_velocity, scale) +function velo_step(P, Xₜ::DiscreteState, delta_t, log_velocity, scale) ohXₜ = onehot(Xₜ) velocity = rate_constraint(tensor(ohXₜ), log_velocity, P.transform) .* scale newXₜ = CategoricalLikelihood(eltype(delta_t).(tensor(ohXₜ) .+ (delta_t .* velocity))) @@ -54,12 +56,12 @@ function velo_step(P, Xₜ::DiscreteState{<:AbstractArray{<:Signed}}, delta_t, l return rand(newXₜ) end -step(P::DoobMatchingFlow, Xₜ::DiscreteState{<:AbstractArray{<:Signed}}, veloX̂₁::Flowfusion.Guide, s₁, s₂) = velo_step(P, Xₜ, s₂ .- s₁, veloX̂₁.H, expand(1 ./ onescale(P, s₁), ndims(veloX̂₁.H))) -step(P::DoobMatchingFlow, Xₜ::DiscreteState{<:AbstractArray{<:Signed}}, veloX̂₁, s₁, s₂) = velo_step(P, Xₜ, s₂ .- s₁, veloX̂₁, expand(1 ./ onescale(P, s₁), ndims(veloX̂₁))) +step(P::DoobMatchingFlow, Xₜ::DiscreteState, veloX̂₁::Flowfusion.Guide, s₁, s₂) = velo_step(P, Xₜ, s₂ .- s₁, veloX̂₁.H, expand(1 ./ onescale(P, s₁), ndims(veloX̂₁.H))) +step(P::DoobMatchingFlow, Xₜ::DiscreteState, veloX̂₁, s₁, s₂) = velo_step(P, Xₜ, s₂ .- s₁, veloX̂₁, expand(1 ./ onescale(P, s₁), ndims(veloX̂₁))) function cgm_dloss(P, Xt, X̂₁, doobX₁) Qt = P.transform(X̂₁) return sum((1 .- Xt) .* (Qt .- xlogy.(doobX₁, Qt)), dims = 1) #<- note, diagonals ignored; implicit zero sum end -floss(P::Flowfusion.fbu(DoobMatchingFlow), Xt::Flowfusion.msu(DiscreteState), X̂₁, X₁::Guide, c) = Flowfusion.scaledmaskedmean(cgm_dloss(P, tensor(Xt), tensor(X̂₁), X₁.H), c, Flowfusion.getlmask(X₁)) \ No newline at end of file +floss(P::Flowfusion.fbu(DoobMatchingFlow), Xt::Flowfusion.msu(DiscreteState), X̂₁, X₁::Guide, c) = Flowfusion.scaledmaskedmean(cgm_dloss(P, tensor(Xt), tensor(X̂₁), X₁.H), c, Flowfusion.getlmask(X₁))