Skip to content

Adding some States->Elements and back #7

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "0.1.2"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardBackward = "e879419d-bb0f-4252-adee-d266c51ac92d"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
Manifolds = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e"
Expand All @@ -14,16 +15,11 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[compat]
Adapt = "4.1.1"
FillArrays = "1.13.0"
ForwardBackward = "0.1.0"
LogExpFunctions = "0.3.29"
Manifolds = "0.10.12"
NNlib = "0.9.27"
OneHotArrays = "0.2.6"
StatsBase = "0.34.4"
julia = "1.9"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]
5 changes: 2 additions & 3 deletions src/Flowfusion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@ Later:
=#





module Flowfusion

using ForwardBackward, OneHotArrays, Adapt, Manifolds, NNlib, LogExpFunctions
Expand All @@ -31,6 +28,8 @@ include("loss.jl")
include("processes.jl")
include("doob.jl")

include("batching.jl")

export
#Processes not in ForwardBackward.jl
InterpolatingDiscreteFlow,
Expand Down
43 changes: 43 additions & 0 deletions src/batching.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#This will be in Flowfusion:
element(state, seqindex) = selectdim(state, ndims(state), seqindex:seqindex)
element(state, seqindex, batchindex) = element(selectdim(state, ndims(state), batchindex), seqindex)
element(S::MaskedState, inds...) = element(S.S, inds...)
element(S::ContinuousState, inds...) = ContinuousState(element(S.state, inds...))
element(S::ManifoldState, inds...) = ManifoldState(S.M, element(S.state, inds...))
element(S::DiscreteState, inds...) = DiscreteState(S.K, element(S.state, inds...))
element(S::Tuple{Vararg{Flowfusion.UState}}, inds...) = element.(S, inds...)

Check warning on line 8 in src/batching.jl

View check run for this annotation

Codecov / codecov/patch

src/batching.jl#L2-L8

Added lines #L2 - L8 were not covered by tests


#Create a "zero" state appropriate for the type. Tricky for manifolds, but we just want rotations working for now I think.
zerostate(element::T, expandsize...) where T <: ContinuousState = T(similar(tensor(element), size(tensor(element))[1:end-1]..., expandsize...) .= 0)
zerostate(element::DiscreteState{<:AbstractArray{<:Signed}}, expandsize...) = DiscreteState(element.K,similar(tensor(element), size(tensor(element))[1:end-1]..., expandsize...) .= element.K)
zerostate(element::DiscreteState, expandsize...) = Flowfusion.onehot(DiscreteState(element.K,zeros(Int,expandsize...) .= element.K))
function zerostate(element::T, expandsize...) where T <: Union{ManifoldState{<:Rotations},ManifoldState{<:SpecialOrthogonal}}
newtensor = similar(tensor(element), size(tensor(element))[1:end-1]..., expandsize...) .= 0
for i in 1:manifold_dimension(element.M)
selectdim(selectdim(newtensor, 1,i),1,i) .= 1
end
return ManifoldState(element.M, eachslice(newtensor, dims=ntuple(i -> 2+i, length(expandsize))))

Check warning on line 20 in src/batching.jl

View check run for this annotation

Codecov / codecov/patch

src/batching.jl#L12-L20

Added lines #L12 - L20 were not covered by tests
end
#Pls test this general version with other manifolds? Not sure this will handle the various underlying representations
function zerostate(element::ManifoldState, expandsize...)
newtensor = similar(tensor(element), size(tensor(element))[1:end-1]..., expandsize...) .= 0
return ManifoldState(element.M, eachslice(newtensor, dims=ntuple(i -> 2+i, length(expandsize))))

Check warning on line 25 in src/batching.jl

View check run for this annotation

Codecov / codecov/patch

src/batching.jl#L23-L25

Added lines #L23 - L25 were not covered by tests
end

#In general, these will be different lengths, so we use an array of arrays as input.
#Doesn't work for onehot states yet.
function regroup(elarray::AbstractArray{<:AbstractArray})
example_tuple = elarray[1][1]
maxlen = maximum(length.(elarray))
b = length(elarray)
newstates = [zerostate(example_tuple[i],maxlen,b) for i in 1:length(example_tuple)]
for i in 1:b
for j in 1:length(elarray[i])
for k in 1:length(example_tuple)
element(tensor(newstates[k]),j,i) .= tensor(elarray[i][j][k])
end
end
end
return Tuple(newstates)

Check warning on line 42 in src/batching.jl

View check run for this annotation

Codecov / codecov/patch

src/batching.jl#L30-L42

Added lines #L30 - L42 were not covered by tests
end
4 changes: 4 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[deps]
ForwardBackward = "e879419d-bb0f-4252-adee-d266c51ac92d"
Manifolds = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ using ForwardBackward

@testset "Bridge, step" begin

siz = (5,6,7)
siz = (5,6)
XC() = ContinuousState(randn(5, siz...))
XD() = DiscreteState(5, rand(1:5, siz...))
MT = Torus(2)
Expand Down