Skip to content

Commit ce1aa08

Browse files
committed
Release evaluation scripts.
1 parent a967492 commit ce1aa08

30 files changed

+1721
-22
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818

1919
## Release
20-
- [10/11] The training data and scripts of LLaVA-1.5 are released [here](https://github.com/haotian-liu/LLaVA#train), with evaluation scripts coming this week!
20+
- [10/11] The training data and scripts of LLaVA-1.5 are released [here](https://github.com/haotian-liu/LLaVA#train), and evaluation scripts are released [here](https://github.com/haotian-liu/LLaVA/blob/main/docs/Evaluation.md)!
2121
- [10/5] 🔥 LLaVA-1.5 is out! Achieving SoTA on 11 benchmarks, with just simple modifications to the original LLaVA, utilizes all public data, completes training in ~1 day on a single 8-A100 node, and surpasses methods like Qwen-VL-Chat that use billion-scale data. Check out the [technical report](https://arxiv.org/abs/2310.03744), and explore the [demo](https://llava.hliu.cc/)! Models are available in [Model Zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md).
2222
- [9/26] LLaVA is improved with reinforcement learning from human feedback (RLHF) to improve fact grounding and reduce hallucination. Check out the new SFT and RLHF checkpoints at project [[LLavA-RLHF]](https://llava-rlhf.github.io/)
2323
- [9/22] [LLaVA](https://arxiv.org/abs/2304.08485) is accpeted by NeurIPS 2023 as **oral presentation**, and [LLaVA-Med](https://arxiv.org/abs/2306.00890) is accpeted by NeurIPS 2023 Datasets and Benchmarks Track as **spotlight presentation**.

docs/Evaluation.md

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
# Evaluation
2+
3+
In LLaVA-1.5, we evaluate models on a diverse set of 12 benchmarks. To ensure the reproducibility, we evaluate the models with greedy decoding. We do not evaluate using beam search to make the inference process consistent with the chat demo of real-time outputs.
4+
5+
Currently, we mostly utilize the official toolkit or server for the evaluation.
6+
7+
## Evaluate on Custom Datasets
8+
9+
You can evaluate LLaVA on your custom datasets by converting your dataset to LLaVA's jsonl format, and evaluate using [`model_vqa.py`](https://github.com/haotian-liu/LLaVA/blob/main/llava/eval/model_vqa.py).
10+
11+
Below we provide a general guideline for evaluating datasets with some common formats.
12+
13+
1. Short-answer (e.g. VQAv2, MME).
14+
15+
```
16+
<question>
17+
Answer the question using a single word or phrase.
18+
```
19+
20+
2. Option-only for multiple-choice (e.g. MMBench, SEED-Bench).
21+
22+
```
23+
<question>
24+
A. <option_1>
25+
B. <option_2>
26+
C. <option_3>
27+
D. <option_4>
28+
Answer with the option's letter from the given choices directly.
29+
```
30+
31+
3. Natural QA (e.g. LLaVA-Bench, MM-Vet).
32+
33+
No postprocessing is needed.
34+
35+
## Scripts
36+
37+
Before preparing task-specific data, download [eval.zip](https://drive.google.com/file/d/1atZSBBrAX54yYpxtVVW33zFvcnaHeFPy/view?usp=sharing). It contains custom annotations, scripts, and the prediction files with LLaVA v1.5. Extract to `./playground/data/eval`. This also provides a general structure for all datasets.
38+
39+
### VQAv2
40+
41+
1. Download [`test2015`](http://images.cocodataset.org/zips/test2015.zip) and put it under `./playground/data/eval/vqav2`.
42+
2. Multi-GPU inference.
43+
```Shell
44+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 bash scripts/v1_5/eval/vqav2.sh
45+
```
46+
3. Submit the results to the evaluation server: `./playground/data/eval/vqav2/answers_upload`.
47+
48+
### GQA
49+
50+
1. Download the data following the official instructions [here](https://cs.stanford.edu/people/dorarad/gqa/download.html) and put under `./playground/data/eval/gqa/data`.
51+
2. Multi-GPU inference.
52+
```Shell
53+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 bash scripts/v1_5/eval/gqa.sh
54+
```
55+
56+
### VisWiz
57+
58+
1. Download [`test.json`](https://vizwiz.cs.colorado.edu/VizWiz_final/vqa_data/Annotations.zip) and extract [`test.zip`](https://vizwiz.cs.colorado.edu/VizWiz_final/images/test.zip) to `test`. Put them under `./playground/data/eval/vizwiz`.
59+
2. Single-GPU inference.
60+
```Shell
61+
CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/vizwiz.sh
62+
```
63+
3. Submit the results to the evaluation server: `./playground/data/eval/vizwiz/answers_upload`.
64+
65+
### ScienceQA
66+
67+
1. Under `./playground/data/eval/scienceqa`, download `images`, `pid_splits.json`, `problems.json` from the `data/scienceqa` folder of the ScienceQA [repo](https://github.com/lupantech/ScienceQA).
68+
2. Single-GPU inference and evaluate.
69+
```Shell
70+
CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/sqa.sh
71+
```
72+
73+
### TextVQA
74+
75+
1. Download [`TextVQA_0.5.1_val.json](https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_0.5.1_val.json) and [images](https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip) and extract to `./playground/data/eval/textvqa`.
76+
2. Single-GPU inference and evaluate.
77+
```Shell
78+
CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/textvqa.sh
79+
```
80+
81+
### POPE
82+
83+
1. Download `coco` from [POPE](https://github.com/AoiDragon/POPE/tree/e3e39262c85a6a83f26cf5094022a782cb0df58d/output/coco) and put under `./playground/data/eval/pope`.
84+
2. Single-GPU inference and evaluate.
85+
```Shell
86+
CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/pope.sh
87+
```
88+
89+
### MME
90+
91+
1. Download the data following the official instructions [here](https://github.com/BradyFU/Awesome-Multimodal-Large-Language-Models/tree/Evaluation).
92+
2. Downloaded images to `MME_Benchmark_release_version`.
93+
3. put the official `eval_tool` and `MME_Benchmark_release_version` under `./playground/data/eval/MME`.
94+
4. Single-GPU inference and evaluate.
95+
```Shell
96+
CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/mme.sh
97+
```
98+
99+
### MMBench
100+
101+
1. Download `mmbench_dev_20230712.tsv` from the official [website](https://github.com/open-compass/MMBench) and put under `./playground/data/eval/mmbench`.
102+
2. Single-GPU inference.
103+
```Shell
104+
CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/mmbench.sh
105+
```
106+
3. Submit the results to the evaluation server: `./playground/data/eval/mmbench/answers_upload/mmbench_dev_20230712`.
107+
108+
### MMBench-CN
109+
110+
1. Download `mmbench_dev_cn_20231003.tsv` from the official [website](https://github.com/open-compass/MMBench) and put under `./playground/data/eval/mmbench`.
111+
2. Single-GPU inference.
112+
```Shell
113+
CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/mmbench_cn.sh
114+
```
115+
3. Submit the results to the evaluation server: `./playground/data/eval/mmbench/answers_upload/mmbench_dev_cn_20231003`.
116+
117+
### SEED-Bench
118+
119+
1. Following the official [instructions](https://github.com/AILab-CVC/SEED-Bench/blob/main/DATASET.md) to download the images and the videos. Put images under `./playground/data/eval/seed_bench/SEED-Bench-image`.
120+
2. Extract the video frame in the middle from the downloaded videos, and put them under `./playground/data/eval/seed_bench/SEED-Bench-video-image`. We provide our script `extract_video_frames.py` modified from the official one.
121+
3. Multiple-GPU inference and evaluate.
122+
```Shell
123+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 bash scripts/v1_5/eval/seed.sh
124+
```
125+
4. Optionally, submit the results to the leaderboard: `./playground/data/eval/seed_bench/answers_upload` using the official jupyter notebook.
126+
127+
### LLaVA-Bench-in-the-Wild
128+
129+
1. Extract contents of [`llava-bench-in-the-wild`](https://huggingface.co/datasets/liuhaotian/llava-bench-in-the-wild) to `./playground/data/eval/llava-bench-in-the-wild`.
130+
2. Single-GPU inference and evaluate.
131+
```Shell
132+
CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/llavabench.sh
133+
```
134+
135+
### MM-Vet
136+
137+
1. Extract [`mm-vet.zip`](https://github.com/yuweihao/MM-Vet/releases/download/v1/mm-vet.zip) to `./playground/data/eval/mmvet`.
138+
2. Single-GPU inference.
139+
```Shell
140+
CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/mmvet.sh
141+
```
142+
3. Evaluate the predictions in `./playground/data/eval/mmvet/results` using the official jupyter notebook.

llava/eval/eval_mmbench.py

Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
import argparse
2+
import os
3+
import json
4+
import pandas as pd
5+
from tqdm import tqdm
6+
import openai
7+
from concurrent.futures import ThreadPoolExecutor, as_completed
8+
import math
9+
import time
10+
11+
12+
all_options = ['A', 'B', 'C', 'D']
13+
14+
15+
def split_list(lst, n):
16+
"""Split a list into n (roughly) equal-sized chunks"""
17+
chunk_size = math.ceil(len(lst) / n) # integer division
18+
return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
19+
20+
21+
def get_chunk(lst, n, k):
22+
chunks = split_list(lst, n)
23+
return chunks[k]
24+
25+
26+
def get_row(df, colname, value):
27+
assert (df[colname] == value).sum() == 1
28+
return df[df[colname] == value].iloc[0]
29+
30+
31+
def encode_query(question, options, answer):
32+
query = ""
33+
query += "Question: " + question + "\n"
34+
query += "Options: " + "\n".join([f"{option_char}. {option}" for option_char, option in zip(all_options[:len(options)], options)]) + "\n"
35+
query += "Answer: " + answer + "\n"
36+
return query
37+
38+
39+
def get_openai_api():
40+
api_type = os.environ.get('API_TYPE', 'azure')
41+
42+
if api_type == 'azure':
43+
api_key = os.environ.get('API_KEY', 'sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx')
44+
engine = os.environ.get('ENGINE', 'chatgpt-turbo')
45+
api_host = os.environ.get('API_BASE')
46+
return {
47+
'api_type': 'azure',
48+
'api_version': '2023-06-01-preview',
49+
'engine': engine,
50+
'api_key': api_key,
51+
'api_base': f'https://{api_host}.openai.azure.com',
52+
}
53+
else:
54+
api_key = os.environ.get('API_KEY', 'sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx')
55+
model = os.environ.get('MODEL', 'gpt-3.5-turbo-0301')
56+
57+
return {
58+
'model': model,
59+
'api_key': api_key,
60+
}
61+
62+
63+
def chatgpt_extract_answer(
64+
question, options, answer, max_tokens=64, temperature=0.2, top_p=0.9, frequency_penalty=0, presence_penalty=0,
65+
request_timeout=None, num_retry=1):
66+
api_kwargs = get_openai_api()
67+
68+
system_message = """You are an AI assistant to help me matching an answer with several options of a multiple choice question.
69+
You are provided with a question, several options, and an answer, and you need to find which option is most similar to the answer.
70+
If the meaning of all options are significantly different from the answer, output X.
71+
You should output a single uppercase character in A, B, C, D, if they are valid options, and X otherwise."""
72+
exemplers = [
73+
{
74+
"question": "What is the main object in image?",
75+
"options": ["teddy bear", "rabbit", "cat", "dog"],
76+
"answer": "a cute teddy bear",
77+
"output": "A",
78+
},
79+
{
80+
"question": "What is the main object in image?",
81+
"options": ["teddy bear", "rabbit", "cat", "dog"],
82+
"answer": "Spider",
83+
"output": "X",
84+
},
85+
]
86+
87+
messages = [
88+
{"role": "system", "content": system_message},
89+
]
90+
for exempler in exemplers:
91+
messages.append({"role": "user", "content": encode_query(exempler['question'], exempler['options'], exempler['answer'])})
92+
messages.append({"role": "assistant", "content": exempler['output']})
93+
messages.append({"role": "user", "content": encode_query(question, options, answer)})
94+
95+
response = None
96+
attempts = []
97+
for i in range(num_retry):
98+
try:
99+
response = openai.ChatCompletion.create(
100+
messages = messages,
101+
max_tokens = max_tokens,
102+
temperature = temperature,
103+
top_p = top_p,
104+
frequency_penalty = frequency_penalty,
105+
presence_penalty = presence_penalty,
106+
request_timeout = request_timeout,
107+
**api_kwargs
108+
)
109+
except Exception as e:
110+
if type(e) in [openai.error.RateLimitError, openai.error.APIError, openai.error.APIConnectionError, openai.error.Timeout]:
111+
pass
112+
elif type(e) in [openai.error.AuthenticationError, openai.error.InvalidRequestError]:
113+
print(e)
114+
return None
115+
else:
116+
print(type(e), e)
117+
attempts.append(e.__class__.__name__)
118+
time.sleep(1)
119+
else:
120+
time.sleep(1)
121+
break
122+
123+
if response is None:
124+
print(f'All {num_retry} attempts failed: {attempts}. Returning None.')
125+
return None
126+
127+
content = response['choices'][0]['message']['content']
128+
content = content.strip()
129+
return content
130+
131+
def is_none(value):
132+
if value is None:
133+
return True
134+
if type(value) is float and math.isnan(value):
135+
return True
136+
if type(value) is str and value.lower() == 'nan':
137+
return True
138+
if type(value) is str and value.lower() == 'none':
139+
return True
140+
return False
141+
142+
def get_options(row, options):
143+
parsed_options = []
144+
for option in options:
145+
option_value = row[option]
146+
if is_none(option_value):
147+
break
148+
parsed_options.append(option_value)
149+
return parsed_options
150+
151+
def auto_parse_answer(question, options, answer):
152+
if answer.strip('.').strip().upper() in all_options[:len(options)]:
153+
return answer.strip('.').strip().upper()
154+
expand_option_valid = [f'The answer is {option}.'.lower() in answer.lower() for option in all_options[:len(options)]]
155+
if any(expand_option_valid):
156+
return all_options[expand_option_valid.index(True)]
157+
158+
matched_ops = [all_options[_i] for _i, option in enumerate(options) if answer.lower() in option.lower()]
159+
if len(matched_ops) == 1:
160+
return matched_ops[0]
161+
return None
162+
163+
def eval_results(args):
164+
questions = pd.read_table(os.path.expanduser(args.question_file))
165+
answers = [json.loads(line) for line in open(os.path.expanduser(args.answers_file))]
166+
answers = {(row['question_id'], row.get('round_id', 0)): row for row in answers}
167+
results_file = os.path.expanduser(args.results_file)
168+
if os.path.exists(results_file):
169+
results = [json.loads(line) for line in open(results_file)]
170+
results = {(row['question_id'], row.get('round_id', 0)): row for row in results}
171+
else:
172+
results = {}
173+
results_writer = open(results_file, 'a')
174+
175+
def process_answer(idx, answer):
176+
if idx in results:
177+
return None
178+
question_id, round_id = idx
179+
question_data = get_row(questions, 'index', question_id)
180+
if 'options' in answer:
181+
options = answer['options']
182+
option_char = answer['option_char']
183+
else:
184+
assert round_id == 0, "round_id must be 0 when options are not provided"
185+
options = get_options(question_data, all_options)
186+
option_char = all_options[:len(options)]
187+
option_map = {all_options[i]: option_char[i] for i in range(len(options))}
188+
option_map['X'] = 'X'
189+
parsed_answer = auto_parse_answer(question_data['question'], options, answer['text'])
190+
if parsed_answer is None:
191+
parsed_answer = chatgpt_extract_answer(
192+
question_data['question'], options, answer['text'],
193+
request_timeout=args.request_timeout, num_retry=args.num_retry)
194+
if parsed_answer is None:
195+
return None
196+
if parsed_answer not in option_map:
197+
print(f'Invalid parsed answer: {parsed_answer}')
198+
return None
199+
answer['parsed_answer'] = option_map[parsed_answer]
200+
return answer
201+
202+
with ThreadPoolExecutor(max_workers=args.max_workers) as executor:
203+
# Submit all tasks to the executor
204+
futures = {executor.submit(process_answer, key, value): key for key, value in answers.items()}
205+
206+
# Process results as they become available
207+
for future in tqdm(as_completed(futures), total=len(answers)):
208+
answer = future.result()
209+
if answer is not None:
210+
results_writer.write(json.dumps(answer) + '\n')
211+
results_writer.flush()
212+
213+
results_writer.close()
214+
215+
216+
if __name__ == "__main__":
217+
parser = argparse.ArgumentParser()
218+
parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
219+
parser.add_argument("--answers-file", type=str, default="answer.jsonl")
220+
parser.add_argument("--results-file", type=str, default="results.jsonl")
221+
parser.add_argument("--max-workers", type=int, default=1)
222+
parser.add_argument("--num-retry", type=int, default=3)
223+
parser.add_argument("--request-timeout", type=int, default=None)
224+
args = parser.parse_args()
225+
226+
eval_results(args)

0 commit comments

Comments
 (0)