diff --git a/Project.toml b/Project.toml index 6759854..0e298c0 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.1.2" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardBackward = "e879419d-bb0f-4252-adee-d266c51ac92d" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" Manifolds = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e" @@ -14,6 +15,7 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [compat] Adapt = "4.1.1" +FillArrays = "1.13.0" ForwardBackward = "0.1.0" LogExpFunctions = "0.3.29" Manifolds = "0.10.12" @@ -21,9 +23,3 @@ NNlib = "0.9.27" OneHotArrays = "0.2.6" StatsBase = "0.34.4" julia = "1.9" - -[extras] -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[targets] -test = ["Test"] diff --git a/src/Flowfusion.jl b/src/Flowfusion.jl index c9a7481..8ce1390 100644 --- a/src/Flowfusion.jl +++ b/src/Flowfusion.jl @@ -17,9 +17,6 @@ Later: =# - - - module Flowfusion using ForwardBackward, OneHotArrays, Adapt, Manifolds, NNlib, LogExpFunctions @@ -31,6 +28,8 @@ include("loss.jl") include("processes.jl") include("doob.jl") +include("batching.jl") + export #Processes not in ForwardBackward.jl InterpolatingDiscreteFlow, diff --git a/src/batching.jl b/src/batching.jl new file mode 100644 index 0000000..b7afbaa --- /dev/null +++ b/src/batching.jl @@ -0,0 +1,43 @@ +#This will be in Flowfusion: +element(state, seqindex) = selectdim(state, ndims(state), seqindex:seqindex) +element(state, seqindex, batchindex) = element(selectdim(state, ndims(state), batchindex), seqindex) +element(S::MaskedState, inds...) = element(S.S, inds...) +element(S::ContinuousState, inds...) = ContinuousState(element(S.state, inds...)) +element(S::ManifoldState, inds...) = ManifoldState(S.M, element(S.state, inds...)) +element(S::DiscreteState, inds...) = DiscreteState(S.K, element(S.state, inds...)) +element(S::Tuple{Vararg{Flowfusion.UState}}, inds...) = element.(S, inds...) + + +#Create a "zero" state appropriate for the type. Tricky for manifolds, but we just want rotations working for now I think. +zerostate(element::T, expandsize...) where T <: ContinuousState = T(similar(tensor(element), size(tensor(element))[1:end-1]..., expandsize...) .= 0) +zerostate(element::DiscreteState{<:AbstractArray{<:Signed}}, expandsize...) = DiscreteState(element.K,similar(tensor(element), size(tensor(element))[1:end-1]..., expandsize...) .= element.K) +zerostate(element::DiscreteState, expandsize...) = Flowfusion.onehot(DiscreteState(element.K,zeros(Int,expandsize...) .= element.K)) +function zerostate(element::T, expandsize...) where T <: Union{ManifoldState{<:Rotations},ManifoldState{<:SpecialOrthogonal}} + newtensor = similar(tensor(element), size(tensor(element))[1:end-1]..., expandsize...) .= 0 + for i in 1:manifold_dimension(element.M) + selectdim(selectdim(newtensor, 1,i),1,i) .= 1 + end + return ManifoldState(element.M, eachslice(newtensor, dims=ntuple(i -> 2+i, length(expandsize)))) +end +#Pls test this general version with other manifolds? Not sure this will handle the various underlying representations +function zerostate(element::ManifoldState, expandsize...) + newtensor = similar(tensor(element), size(tensor(element))[1:end-1]..., expandsize...) .= 0 + return ManifoldState(element.M, eachslice(newtensor, dims=ntuple(i -> 2+i, length(expandsize)))) +end + +#In general, these will be different lengths, so we use an array of arrays as input. +#Doesn't work for onehot states yet. +function regroup(elarray::AbstractArray{<:AbstractArray}) + example_tuple = elarray[1][1] + maxlen = maximum(length.(elarray)) + b = length(elarray) + newstates = [zerostate(example_tuple[i],maxlen,b) for i in 1:length(example_tuple)] + for i in 1:b + for j in 1:length(elarray[i]) + for k in 1:length(example_tuple) + element(tensor(newstates[k]),j,i) .= tensor(elarray[i][j][k]) + end + end + end + return Tuple(newstates) +end diff --git a/test/Project.toml b/test/Project.toml new file mode 100644 index 0000000..d9f3840 --- /dev/null +++ b/test/Project.toml @@ -0,0 +1,4 @@ +[deps] +ForwardBackward = "e879419d-bb0f-4252-adee-d266c51ac92d" +Manifolds = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/runtests.jl b/test/runtests.jl index 0658410..5a4cb67 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -44,7 +44,7 @@ using ForwardBackward @testset "Bridge, step" begin - siz = (5,6,7) + siz = (5,6) XC() = ContinuousState(randn(5, siz...)) XD() = DiscreteState(5, rand(1:5, siz...)) MT = Torus(2)