Skip to content

Commit 698a83f

Browse files
committed
fix greedy refinement - try to refine for total memory
1 parent 2720242 commit 698a83f

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,12 @@ def greedy_solution_refinement_procedure(mp_solution: Dict[BaseNode, int],
5050
if target_resource_utilization.bops_restricted():
5151
Logger.info(f'Target resource utilization constraint BOPs - Skipping MP greedy solution refinement')
5252
return mp_solution
53-
5453
assert search_manager.using_virtual_graph is False
5554

55+
tru = target_resource_utilization
56+
activation_restricted = tru.activation_restricted() or tru.total_mem_restricted() or tru.bops_restricted()
57+
weights_restricted = tru.weight_restricted() or tru.total_mem_restricted() or tru.bops_restricted()
58+
5659
new_solution = mp_solution.copy()
5760
changed = True
5861

@@ -62,7 +65,7 @@ def greedy_solution_refinement_procedure(mp_solution: Dict[BaseNode, int],
6265
nodes_next_candidate = {}
6366

6467
for node in search_manager.mp_topo_configurable_nodes:
65-
if new_solution[node] == 0:
68+
if new_solution[node] == node.find_max_candidate_index():
6669
# layer has max config in the given solution, nothing to optimize
6770
continue
6871

@@ -71,9 +74,8 @@ def greedy_solution_refinement_procedure(mp_solution: Dict[BaseNode, int],
7174
# only weights kernel attribute is quantized with weights mixed precision
7275
valid_candidates = _get_valid_candidates_indices(node_candidates,
7376
new_solution[node],
74-
target_resource_utilization.activation_restricted(),
75-
target_resource_utilization.weight_restricted()
76-
)
77+
activation_restricted,
78+
weights_restricted)
7779

7880
# Create a list of ru for the valid candidates.
7981
updated_ru = []

0 commit comments

Comments
 (0)