Skip to content

Commit 9c81eab

Browse files
authored
Merge pull request #7 from MurrellGroup/elements_and_batching
Adding some States->Elements and back
2 parents 76112c0 + 450d08b commit 9c81eab

File tree

5 files changed

+52
-10
lines changed

5 files changed

+52
-10
lines changed

Project.toml

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "0.1.2"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
8+
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
89
ForwardBackward = "e879419d-bb0f-4252-adee-d266c51ac92d"
910
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
1011
Manifolds = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e"
@@ -14,16 +15,11 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1415

1516
[compat]
1617
Adapt = "4.1.1"
18+
FillArrays = "1.13.0"
1719
ForwardBackward = "0.1.0"
1820
LogExpFunctions = "0.3.29"
1921
Manifolds = "0.10.12"
2022
NNlib = "0.9.27"
2123
OneHotArrays = "0.2.6"
2224
StatsBase = "0.34.4"
2325
julia = "1.9"
24-
25-
[extras]
26-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
27-
28-
[targets]
29-
test = ["Test"]

src/Flowfusion.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,6 @@ Later:
1717
=#
1818

1919

20-
21-
22-
2320
module Flowfusion
2421

2522
using ForwardBackward, OneHotArrays, Adapt, Manifolds, NNlib, LogExpFunctions
@@ -31,6 +28,8 @@ include("loss.jl")
3128
include("processes.jl")
3229
include("doob.jl")
3330

31+
include("batching.jl")
32+
3433
export
3534
#Processes not in ForwardBackward.jl
3635
InterpolatingDiscreteFlow,

src/batching.jl

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#This will be in Flowfusion:
2+
element(state, seqindex) = selectdim(state, ndims(state), seqindex:seqindex)
3+
element(state, seqindex, batchindex) = element(selectdim(state, ndims(state), batchindex), seqindex)
4+
element(S::MaskedState, inds...) = element(S.S, inds...)
5+
element(S::ContinuousState, inds...) = ContinuousState(element(S.state, inds...))
6+
element(S::ManifoldState, inds...) = ManifoldState(S.M, element(S.state, inds...))
7+
element(S::DiscreteState, inds...) = DiscreteState(S.K, element(S.state, inds...))
8+
element(S::Tuple{Vararg{Flowfusion.UState}}, inds...) = element.(S, inds...)
9+
10+
11+
#Create a "zero" state appropriate for the type. Tricky for manifolds, but we just want rotations working for now I think.
12+
zerostate(element::T, expandsize...) where T <: ContinuousState = T(similar(tensor(element), size(tensor(element))[1:end-1]..., expandsize...) .= 0)
13+
zerostate(element::DiscreteState{<:AbstractArray{<:Signed}}, expandsize...) = DiscreteState(element.K,similar(tensor(element), size(tensor(element))[1:end-1]..., expandsize...) .= element.K)
14+
zerostate(element::DiscreteState, expandsize...) = Flowfusion.onehot(DiscreteState(element.K,zeros(Int,expandsize...) .= element.K))
15+
function zerostate(element::T, expandsize...) where T <: Union{ManifoldState{<:Rotations},ManifoldState{<:SpecialOrthogonal}}
16+
newtensor = similar(tensor(element), size(tensor(element))[1:end-1]..., expandsize...) .= 0
17+
for i in 1:manifold_dimension(element.M)
18+
selectdim(selectdim(newtensor, 1,i),1,i) .= 1
19+
end
20+
return ManifoldState(element.M, eachslice(newtensor, dims=ntuple(i -> 2+i, length(expandsize))))
21+
end
22+
#Pls test this general version with other manifolds? Not sure this will handle the various underlying representations
23+
function zerostate(element::ManifoldState, expandsize...)
24+
newtensor = similar(tensor(element), size(tensor(element))[1:end-1]..., expandsize...) .= 0
25+
return ManifoldState(element.M, eachslice(newtensor, dims=ntuple(i -> 2+i, length(expandsize))))
26+
end
27+
28+
#In general, these will be different lengths, so we use an array of arrays as input.
29+
#Doesn't work for onehot states yet.
30+
function regroup(elarray::AbstractArray{<:AbstractArray})
31+
example_tuple = elarray[1][1]
32+
maxlen = maximum(length.(elarray))
33+
b = length(elarray)
34+
newstates = [zerostate(example_tuple[i],maxlen,b) for i in 1:length(example_tuple)]
35+
for i in 1:b
36+
for j in 1:length(elarray[i])
37+
for k in 1:length(example_tuple)
38+
element(tensor(newstates[k]),j,i) .= tensor(elarray[i][j][k])
39+
end
40+
end
41+
end
42+
return Tuple(newstates)
43+
end

test/Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
[deps]
2+
ForwardBackward = "e879419d-bb0f-4252-adee-d266c51ac92d"
3+
Manifolds = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e"
4+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ using ForwardBackward
4444

4545
@testset "Bridge, step" begin
4646

47-
siz = (5,6,7)
47+
siz = (5,6)
4848
XC() = ContinuousState(randn(5, siz...))
4949
XD() = DiscreteState(5, rand(1:5, siz...))
5050
MT = Torus(2)

0 commit comments

Comments
 (0)