|
| 1 | +#This will be in Flowfusion: |
| 2 | +element(state, seqindex) = selectdim(state, ndims(state), seqindex:seqindex) |
| 3 | +element(state, seqindex, batchindex) = element(selectdim(state, ndims(state), batchindex), seqindex) |
| 4 | +element(S::MaskedState, inds...) = element(S.S, inds...) |
| 5 | +element(S::ContinuousState, inds...) = ContinuousState(element(S.state, inds...)) |
| 6 | +element(S::ManifoldState, inds...) = ManifoldState(S.M, element(S.state, inds...)) |
| 7 | +element(S::DiscreteState, inds...) = DiscreteState(S.K, element(S.state, inds...)) |
| 8 | +element(S::Tuple{Vararg{Flowfusion.UState}}, inds...) = element.(S, inds...) |
| 9 | + |
| 10 | + |
| 11 | +#Create a "zero" state appropriate for the type. Tricky for manifolds, but we just want rotations working for now I think. |
| 12 | +zerostate(element::T, expandsize...) where T <: ContinuousState = T(similar(tensor(element), size(tensor(element))[1:end-1]..., expandsize...) .= 0) |
| 13 | +zerostate(element::DiscreteState{<:AbstractArray{<:Signed}}, expandsize...) = DiscreteState(element.K,similar(tensor(element), size(tensor(element))[1:end-1]..., expandsize...) .= element.K) |
| 14 | +zerostate(element::DiscreteState, expandsize...) = Flowfusion.onehot(DiscreteState(element.K,zeros(Int,expandsize...) .= element.K)) |
| 15 | +function zerostate(element::T, expandsize...) where T <: Union{ManifoldState{<:Rotations},ManifoldState{<:SpecialOrthogonal}} |
| 16 | + newtensor = similar(tensor(element), size(tensor(element))[1:end-1]..., expandsize...) .= 0 |
| 17 | + for i in 1:manifold_dimension(element.M) |
| 18 | + selectdim(selectdim(newtensor, 1,i),1,i) .= 1 |
| 19 | + end |
| 20 | + return ManifoldState(element.M, eachslice(newtensor, dims=ntuple(i -> 2+i, length(expandsize)))) |
| 21 | +end |
| 22 | +#Pls test this general version with other manifolds? Not sure this will handle the various underlying representations |
| 23 | +function zerostate(element::ManifoldState, expandsize...) |
| 24 | + newtensor = similar(tensor(element), size(tensor(element))[1:end-1]..., expandsize...) .= 0 |
| 25 | + return ManifoldState(element.M, eachslice(newtensor, dims=ntuple(i -> 2+i, length(expandsize)))) |
| 26 | +end |
| 27 | + |
| 28 | +#In general, these will be different lengths, so we use an array of arrays as input. |
| 29 | +#Doesn't work for onehot states yet. |
| 30 | +function regroup(elarray::AbstractArray{<:AbstractArray}) |
| 31 | + example_tuple = elarray[1][1] |
| 32 | + maxlen = maximum(length.(elarray)) |
| 33 | + b = length(elarray) |
| 34 | + newstates = [zerostate(example_tuple[i],maxlen,b) for i in 1:length(example_tuple)] |
| 35 | + for i in 1:b |
| 36 | + for j in 1:length(elarray[i]) |
| 37 | + for k in 1:length(example_tuple) |
| 38 | + element(tensor(newstates[k]),j,i) .= tensor(elarray[i][j][k]) |
| 39 | + end |
| 40 | + end |
| 41 | + end |
| 42 | + return Tuple(newstates) |
| 43 | +end |
0 commit comments