@@ -19,24 +19,37 @@ function PyBoltz.predict(input, ::Type{MolecularStructure}; options...)
19
19
options... )
20
20
prediction_paths = readdir (joinpath (out_dir, only (readdir (out_dir)), " predictions" ); join= true )
21
21
prediction_names = basename .(prediction_paths)
22
- structures = MolecularStructure[]
23
- perm, structure_names = if all (name -> startswith (name, PyBoltz. PYBOLTZ_INPUT_INDEX_PREFIX), prediction_names)
24
- indices = Int[]
25
- structure_names = String[]
22
+
23
+ local results
24
+ if all (name -> startswith (name, PyBoltz. PYBOLTZ_INPUT_INDEX_PREFIX), prediction_names)
25
+ # output vector needs to match input vector (with possible missing values)
26
+ @assert input isa AbstractVector{PyBoltz. Schema. MolecularInput}
27
+ results = Union{MolecularStructure,Missing}[fill (missing , length (input))... ]
26
28
for prediction_name in prediction_names
27
29
index, name = split (split (prediction_name, PyBoltz. PYBOLTZ_INPUT_INDEX_PREFIX, limit= 2 )[2 ], " _" , limit= 2 )
28
- push! (indices, parse (Int, index))
29
- push! (structure_names, name)
30
+ idx = parse (Int, index)
31
+ prediction_path = joinpath (joinpath (out_dir, only (readdir (out_dir)), " predictions" ), prediction_name)
32
+ cif_path = joinpath (prediction_path, basename (prediction_path)* " _model_0.cif" )
33
+ try
34
+ results[idx] = read_boltz_cif (cif_path, name)
35
+ catch e
36
+ @warn e
37
+ results[idx] = missing
38
+ end
30
39
end
31
- sortperm (indices), structure_names
32
40
else
33
- collect (1 : length (prediction_names)), prediction_names
34
- end
35
- for (structure_name, prediction_path) in zip (structure_names, prediction_paths)
36
- cif_path = joinpath (prediction_path, basename (prediction_path)* " _model_0.cif" )
37
- push! (structures, read_boltz_cif (cif_path, structure_name))
41
+ results = Union{MolecularStructure,Missing}[]
42
+ for prediction_path in prediction_paths
43
+ cif_path = joinpath (prediction_path, basename (prediction_path)* " _model_0.cif" )
44
+ try
45
+ push! (results, read_boltz_cif (cif_path, basename (prediction_path)))
46
+ catch e
47
+ @warn e
48
+ push! (results, missing )
49
+ end
50
+ end
38
51
end
39
- return structures[perm]
52
+ return results
40
53
end
41
54
end
42
55
0 commit comments