Skip to content

Commit f1a7b36

Browse files
committed
Add script to extract projector
1 parent 0eeaf44 commit f1a7b36

File tree

1 file changed

+47
-0
lines changed

1 file changed

+47
-0
lines changed

scripts/extract_mm_projector.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
"""
2+
This is just a utility that I use to extract the projector for quantized models.
3+
It is NOT necessary at all to train, or run inference/serve demos.
4+
Use this script ONLY if you fully understand its implications.
5+
"""
6+
7+
8+
import os
9+
import argparse
10+
import torch
11+
import json
12+
from collections import defaultdict
13+
14+
15+
def parse_args():
16+
parser = argparse.ArgumentParser(description='Extract MMProjector weights')
17+
parser.add_argument('--model-path', type=str, help='model folder')
18+
parser.add_argument('--output', type=str, help='output file')
19+
args = parser.parse_args()
20+
return args
21+
22+
23+
if __name__ == '__main__':
24+
args = parse_args()
25+
26+
keys_to_match = ['mm_projector']
27+
ckpt_to_key = defaultdict(list)
28+
try:
29+
model_indices = json.load(open(os.path.join(args.model_path, 'pytorch_model.bin.index.json')))
30+
for k, v in model_indices['weight_map'].items():
31+
if any(key_match in k for key_match in keys_to_match):
32+
ckpt_to_key[v].append(k)
33+
except FileNotFoundError:
34+
# Smaller models or model checkpoints saved by DeepSpeed.
35+
v = 'pytorch_model.bin'
36+
for k in torch.load(os.path.join(args.model_path, v), map_location='cpu').keys():
37+
if any(key_match in k for key_match in keys_to_match):
38+
ckpt_to_key[v].append(k)
39+
40+
loaded_weights = {}
41+
42+
for ckpt_name, weight_keys in ckpt_to_key.items():
43+
ckpt = torch.load(os.path.join(args.model_path, ckpt_name), map_location='cpu')
44+
for k in weight_keys:
45+
loaded_weights[k] = ckpt[k]
46+
47+
torch.save(loaded_weights, args.output)

0 commit comments

Comments
 (0)