Skip to content

Commit 95fefdd

Browse files
committed
docs: qwen-vl-2b latexocr
1 parent 29d325b commit 95fefdd

File tree

20 files changed

+1132
-0
lines changed

20 files changed

+1132
-0
lines changed

models/Qwen2-VL/06-Qwen2-VL-2B-Instruct Lora 微调案例 - LaTexOCR.md

Lines changed: 712 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import pandas as pd
2+
import json
3+
4+
csv_path = './latex_ocr_train.csv'
5+
train_json_path = './latex_ocr_train.json'
6+
val_json_path = './latex_ocr_val.json'
7+
df = pd.read_csv(csv_path)
8+
# Create conversation format
9+
conversations = []
10+
11+
# Add image conversations
12+
for i in range(len(df)):
13+
conversations.append({
14+
"id": f"identity_{i+1}",
15+
"conversations": [
16+
{
17+
"role": "user",
18+
"value": f"{df.iloc[i]['image_path']}"
19+
},
20+
{
21+
"role": "assistant",
22+
"value": str(df.iloc[i]['text'])
23+
}
24+
]
25+
})
26+
27+
# print(conversations)
28+
# Save to JSON
29+
# Split into train and validation sets
30+
train_conversations = conversations[:-4]
31+
val_conversations = conversations[-4:]
32+
33+
# Save train set
34+
with open(train_json_path, 'w', encoding='utf-8') as f:
35+
json.dump(train_conversations, f, ensure_ascii=False, indent=2)
36+
37+
# Save validation set
38+
with open(val_json_path, 'w', encoding='utf-8') as f:
39+
json.dump(val_conversations, f, ensure_ascii=False, indent=2)
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# 导入所需的库
2+
from modelscope.msdatasets import MsDataset
3+
import os
4+
import pandas as pd
5+
6+
MAX_DATA_NUMBER = 1000
7+
dataset_id = 'AI-ModelScope/LaTeX_OCR'
8+
subset_name = 'default'
9+
split = 'train'
10+
11+
dataset_dir = 'LaTeX_OCR'
12+
csv_path = './latex_ocr_train.csv'
13+
14+
15+
# 检查目录是否已存在
16+
if not os.path.exists(dataset_dir):
17+
# 从modelscope下载COCO 2014图像描述数据集
18+
ds = MsDataset.load(dataset_id, subset_name=subset_name, split=split)
19+
print(len(ds))
20+
# 设置处理的图片数量上限
21+
total = min(MAX_DATA_NUMBER, len(ds))
22+
23+
# 创建保存图片的目录
24+
os.makedirs(dataset_dir, exist_ok=True)
25+
26+
# 初始化存储图片路径和描述的列表
27+
image_paths = []
28+
texts = []
29+
30+
for i in range(total):
31+
# 获取每个样本的信息
32+
item = ds[i]
33+
text = item['text']
34+
image = item['image']
35+
36+
# 保存图片并记录路径
37+
image_path = os.path.abspath(f'{dataset_dir}/{i}.jpg')
38+
image.save(image_path)
39+
40+
# 将路径和描述添加到列表中
41+
image_paths.append(image_path)
42+
texts.append(text)
43+
44+
# 每处理50张图片打印一次进度
45+
if (i + 1) % 50 == 0:
46+
print(f'Processing {i+1}/{total} images ({(i+1)/total*100:.1f}%)')
47+
48+
# 将图片路径和描述保存为CSV文件
49+
df = pd.DataFrame({
50+
'image_path': image_paths,
51+
'text': texts,
52+
})
53+
54+
# 将数据保存为CSV文件
55+
df.to_csv(csv_path, index=False)
56+
57+
print(f'数据处理完成,共处理了{total}张图片')
58+
59+
else:
60+
print(f'{dataset_dir}目录已存在,跳过数据处理步骤')
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
2+
from qwen_vl_utils import process_vision_info
3+
from peft import PeftModel, LoraConfig, TaskType
4+
5+
prompt = "你是一个LaText OCR助手,目标是读取用户输入的照片,转换成LaTex公式。"
6+
local_model_path = "./Qwen/Qwen2-VL-2B-Instruct"
7+
lora_model_path = "./output/Qwen2-VL-2B-LatexOCR/checkpoint-124"
8+
test_image_path = "./LaTeX_OCR/997.jpg"
9+
10+
config = LoraConfig(
11+
task_type=TaskType.CAUSAL_LM,
12+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
13+
inference_mode=True,
14+
r=64, # Lora 秩
15+
lora_alpha=16, # Lora alaph,具体作用参见 Lora 原理
16+
lora_dropout=0.05, # Dropout 比例
17+
bias="none",
18+
)
19+
20+
# default: Load the model on the available device(s)
21+
model = Qwen2VLForConditionalGeneration.from_pretrained(
22+
local_model_path, torch_dtype="auto", device_map="auto"
23+
)
24+
25+
model = PeftModel.from_pretrained(model, model_id=f"{lora_model_path}", config=config)
26+
processor = AutoProcessor.from_pretrained(local_model_path)
27+
28+
messages = [
29+
{
30+
"role": "user",
31+
"content": [
32+
{
33+
"type": "image",
34+
"image": test_image_path,
35+
"resized_height": 100,
36+
"resized_width": 500,
37+
},
38+
{"type": "text", "text": f"{prompt}"},
39+
],
40+
}
41+
]
42+
43+
# Preparation for inference
44+
text = processor.apply_chat_template(
45+
messages, tokenize=False, add_generation_prompt=True
46+
)
47+
image_inputs, video_inputs = process_vision_info(messages)
48+
inputs = processor(
49+
text=[text],
50+
images=image_inputs,
51+
videos=video_inputs,
52+
padding=True,
53+
return_tensors="pt",
54+
)
55+
inputs = inputs.to("cuda")
56+
57+
# Inference: Generation of the output
58+
generated_ids = model.generate(**inputs, max_new_tokens=8192)
59+
generated_ids_trimmed = [
60+
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
61+
]
62+
output_text = processor.batch_decode(
63+
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
64+
)
65+
66+
print(output_text[0])
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
torch==2.3.0
2+
torchvision==0.18.0
3+
swanlab==0.3.27
4+
transformers==4.46.2
5+
accelerate==1.1.1
6+
pandas==2.2.2
7+
modelscope==1.15.0
8+
qwen-vl-utils==0.0.8
9+
datasets==2.18.0
10+
peft==0.13.2

0 commit comments

Comments
 (0)