Skip to content

Commit df52cb0

Browse files
committed
Small fixes for transformers test runner
1 parent cf1912d commit df52cb0

File tree

2 files changed

+18
-11
lines changed

2 files changed

+18
-11
lines changed

olmocr/bench/runners/run_transformers.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,12 @@ def run_transformers(
5151
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
5252

5353
if _cached_model is None:
54-
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_name, torch_dtype=torch.bfloat16).eval()
54+
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
55+
model_name,
56+
torch_dtype=torch.bfloat16,
57+
device_map="auto",
58+
attn_implementation="flash_attention_2"
59+
).eval()
5560
processor = AutoProcessor.from_pretrained(model_name)
5661

5762
model = model.to(device)

scripts/run_transformers_benchmark.sh

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@ set -e
55

66

77
# Check for uncommitted changes
8-
if ! git diff-index --quiet HEAD --; then
9-
echo "Error: There are uncommitted changes in the repository."
10-
echo "Please commit or stash your changes before running the benchmark."
11-
echo ""
12-
echo "Uncommitted changes:"
13-
git status --short
14-
exit 1
15-
fi
8+
# if ! git diff-index --quiet HEAD --; then
9+
# echo "Error: There are uncommitted changes in the repository."
10+
# echo "Please commit or stash your changes before running the benchmark."
11+
# echo ""
12+
# echo "Uncommitted changes:"
13+
# git status --short
14+
# exit 1
15+
# fi
1616

1717
# Use conda environment Python if available, otherwise use system Python
1818
if [ -n "$CONDA_PREFIX" ]; then
@@ -90,7 +90,9 @@ if has_aws_creds:
9090
commands.extend([
9191
"git clone https://huggingface.co/datasets/allenai/olmOCR-bench",
9292
"cd olmOCR-bench && git lfs pull && cd ..",
93-
"python -m olmocr.bench.convert transformers:target_longest_image_dim=1288:prompt_template=yaml:response_template=yaml: --dir ./olmOCR-bench/bench_data",
93+
"pip install accelerate",
94+
"pip install flash-attn==2.8.0.post2 --no-build-isolation",
95+
"python -m olmocr.bench.convert transformers:target_longest_image_dim=1288:prompt_template=yaml:response_template=yaml --dir ./olmOCR-bench/bench_data",
9496
"python -m olmocr.bench.benchmark --dir ./olmOCR-bench/bench_data"
9597
])
9698
@@ -107,7 +109,7 @@ task_spec_args = {
107109
preemptible=True,
108110
),
109111
"resources": TaskResources(gpu_count=1),
110-
"constraints": Constraints(cluster=["ai2/ceres-cirrascale", "ai2/jupiter-cirrascale-2"]),
112+
"constraints": Constraints(cluster=["ai2/titan-cirrascale"]),
111113
"result": ResultSpec(path="/noop-results"),
112114
}
113115

0 commit comments

Comments
 (0)