@@ -50,9 +50,12 @@ def greedy_solution_refinement_procedure(mp_solution: Dict[BaseNode, int],
50
50
if target_resource_utilization .bops_restricted ():
51
51
Logger .info (f'Target resource utilization constraint BOPs - Skipping MP greedy solution refinement' )
52
52
return mp_solution
53
-
54
53
assert search_manager .using_virtual_graph is False
55
54
55
+ tru = target_resource_utilization
56
+ activation_restricted = tru .activation_restricted () or tru .total_mem_restricted () or tru .bops_restricted ()
57
+ weights_restricted = tru .weight_restricted () or tru .total_mem_restricted () or tru .bops_restricted ()
58
+
56
59
new_solution = mp_solution .copy ()
57
60
changed = True
58
61
@@ -62,7 +65,7 @@ def greedy_solution_refinement_procedure(mp_solution: Dict[BaseNode, int],
62
65
nodes_next_candidate = {}
63
66
64
67
for node in search_manager .mp_topo_configurable_nodes :
65
- if new_solution [node ] == 0 :
68
+ if new_solution [node ] == node . find_max_candidate_index () :
66
69
# layer has max config in the given solution, nothing to optimize
67
70
continue
68
71
@@ -71,9 +74,8 @@ def greedy_solution_refinement_procedure(mp_solution: Dict[BaseNode, int],
71
74
# only weights kernel attribute is quantized with weights mixed precision
72
75
valid_candidates = _get_valid_candidates_indices (node_candidates ,
73
76
new_solution [node ],
74
- target_resource_utilization .activation_restricted (),
75
- target_resource_utilization .weight_restricted ()
76
- )
77
+ activation_restricted ,
78
+ weights_restricted )
77
79
78
80
# Create a list of ru for the valid candidates.
79
81
updated_ru = []
0 commit comments