Skip to content

Commit 945096f

Browse files
committed
Pre-training supports 8*V100 (32G) gpus
1 parent c12c256 commit 945096f

File tree

5 files changed

+192
-1
lines changed

5 files changed

+192
-1
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,11 @@ python llava/train/train_mem.py \
210210
--report_to wandb
211211
```
212212
</details>
213+
<details>
214+
<summary>Pretrain: LLaVA-7B, 8x V100 (32G). Time: ~20 hours.</summary>
213215

216+
We provide training script with DeepSpeed [here](https://github.com/haotian-liu/LLaVA/blob/main/scripts/pretrain_xformers.sh).
217+
</details>
214218

215219
### Visual Instruction Tuning
216220

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
"""
2+
Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments
3+
"""
4+
5+
import logging
6+
import math
7+
from typing import Optional, Tuple
8+
9+
import torch
10+
import transformers.models.llama.modeling_llama
11+
from torch import nn
12+
13+
try:
14+
import xformers.ops
15+
except ImportError:
16+
logging.error("xformers not found! Please install it before trying to use it.")
17+
18+
19+
def replace_llama_attn_with_xformers_attn():
20+
transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
21+
22+
23+
def xformers_forward(
24+
self,
25+
hidden_states: torch.Tensor,
26+
attention_mask: Optional[torch.Tensor] = None,
27+
position_ids: Optional[torch.LongTensor] = None,
28+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
29+
output_attentions: bool = False,
30+
use_cache: bool = False,
31+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
32+
# pylint: disable=duplicate-code
33+
bsz, q_len, _ = hidden_states.size()
34+
35+
query_states = (
36+
self.q_proj(hidden_states)
37+
.view(bsz, q_len, self.num_heads, self.head_dim)
38+
.transpose(1, 2)
39+
)
40+
key_states = (
41+
self.k_proj(hidden_states)
42+
.view(bsz, q_len, self.num_heads, self.head_dim)
43+
.transpose(1, 2)
44+
)
45+
value_states = (
46+
self.v_proj(hidden_states)
47+
.view(bsz, q_len, self.num_heads, self.head_dim)
48+
.transpose(1, 2)
49+
)
50+
51+
kv_seq_len = key_states.shape[-2]
52+
if past_key_value is not None:
53+
kv_seq_len += past_key_value[0].shape[-2]
54+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
55+
(
56+
query_states,
57+
key_states,
58+
) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(
59+
query_states, key_states, cos, sin, position_ids
60+
)
61+
# [bsz, nh, t, hd]
62+
63+
if past_key_value is not None:
64+
# reuse k, v, self_attention
65+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
66+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
67+
68+
past_key_value = (key_states, value_states) if use_cache else None
69+
70+
# We only apply xformers optimizations if we don't need to output the whole attention matrix
71+
if not output_attentions:
72+
query_states = query_states.transpose(1, 2)
73+
key_states = key_states.transpose(1, 2)
74+
value_states = value_states.transpose(1, 2)
75+
76+
# This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
77+
# We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
78+
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
79+
# input and output should be of form (bsz, q_len, num_heads, head_dim)
80+
attn_output = xformers.ops.memory_efficient_attention(
81+
query_states, key_states, value_states, attn_bias=None
82+
)
83+
else:
84+
# input and output should be of form (bsz, q_len, num_heads, head_dim)
85+
attn_output = xformers.ops.memory_efficient_attention(
86+
query_states,
87+
key_states,
88+
value_states,
89+
attn_bias=xformers.ops.LowerTriangularMask(),
90+
)
91+
attn_weights = None
92+
else:
93+
attn_weights = torch.matmul(
94+
query_states, key_states.transpose(2, 3)
95+
) / math.sqrt(self.head_dim)
96+
97+
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
98+
raise ValueError(
99+
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
100+
f" {attn_weights.size()}"
101+
)
102+
103+
if attention_mask is not None:
104+
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
105+
raise ValueError(
106+
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
107+
)
108+
attn_weights = attn_weights + attention_mask
109+
attn_weights = torch.max(
110+
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
111+
)
112+
113+
# upcast attention to fp32
114+
attn_weights = nn.functional.softmax(
115+
attn_weights, dim=-1, dtype=torch.float32
116+
).to(query_states.dtype)
117+
attn_output = torch.matmul(attn_weights, value_states)
118+
119+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
120+
raise ValueError(
121+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
122+
f" {attn_output.size()}"
123+
)
124+
125+
attn_output = attn_output.transpose(1, 2)
126+
127+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
128+
attn_output = self.o_proj(attn_output)
129+
return attn_output, attn_weights, past_key_value

llava/train/train_xformers.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Make it more memory efficient by monkey patching the LLaMA model with xformers attention.
2+
3+
# Need to call this before importing transformers.
4+
from llava.train.llama_xformers_attn_monkey_patch import (
5+
replace_llama_attn_with_xformers_attn,
6+
)
7+
8+
replace_llama_attn_with_xformers_attn()
9+
10+
from llava.train.train import train
11+
12+
if __name__ == "__main__":
13+
train()

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ dependencies = [
2525
"scikit-learn==1.2.2",
2626
"sentencepiece==0.1.99",
2727
"einops==0.6.1", "einops-exts==0.0.4", "timm==0.6.13",
28-
"gradio_client==0.2.9"
28+
"gradio_client==0.2.9",
29+
"xformers"
2930
]
3031

3132
[project.urls]

scripts/pretrain_xformers.sh

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#!/bin/bash
2+
3+
# Uncomment and set the following variables correspondingly to run this script:
4+
5+
# MODEL_VERSION=vicuna-v1-3-7b
6+
# MODEL_VERSION=llama-2-7b-chat
7+
8+
########### DO NOT CHANGE ###########
9+
########### USE THIS FOR BOTH ###########
10+
PROMPT_VERSION=plain
11+
########### DO NOT CHANGE ###########
12+
13+
deepspeed llava/train/train_xformers.py \
14+
--deepspeed ./scripts/zero2.json \
15+
--model_name_or_path ./checkpoints/$MODEL_VERSION \
16+
--version $PROMPT_VERSION \
17+
--data_path /path/to/pretrain_data.json \
18+
--image_folder /path/to/images \
19+
--vision_tower openai/clip-vit-large-patch14 \
20+
--tune_mm_mlp_adapter True \
21+
--mm_vision_select_layer -2 \
22+
--mm_use_im_start_end False \
23+
--mm_use_im_patch_token False \
24+
--bf16 False \
25+
--output_dir ./checkpoints/llava-$MODEL_VERSION-pretrain \
26+
--num_train_epochs 1 \
27+
--per_device_train_batch_size 4 \
28+
--per_device_eval_batch_size 4 \
29+
--gradient_accumulation_steps 4 \
30+
--evaluation_strategy "no" \
31+
--save_strategy "steps" \
32+
--save_steps 24000 \
33+
--save_total_limit 1 \
34+
--learning_rate 2e-3 \
35+
--weight_decay 0. \
36+
--warmup_ratio 0.03 \
37+
--lr_scheduler_type "cosine" \
38+
--logging_steps 1 \
39+
--tf32 False \
40+
--model_max_length 2048 \
41+
--gradient_checkpointing True \
42+
--dataloader_num_workers 4 \
43+
--lazy_preprocess True \
44+
--report_to wandb

0 commit comments

Comments
 (0)