Skip to content

Commit bee1999

Browse files
committed
Add Replicate demo and API
1 parent fda0665 commit bee1999

File tree

4 files changed

+185
-1
lines changed

4 files changed

+185
-1
lines changed

.dockerignore

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# The .dockerignore file excludes files from the container build process.
2+
#
3+
# https://docs.docker.com/engine/reference/builder/#dockerignore-file
4+
5+
# Exclude Git files
6+
.git
7+
.github
8+
.gitignore
9+
10+
# Exclude Python cache files
11+
__pycache__
12+
.mypy_cache
13+
.pytest_cache
14+
.ruff_cache
15+
16+
# Exclude Python virtual environment
17+
/venv
18+
19+
# Exclude some weights
20+
/openai
21+
/liuhaotian

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
*Visual instruction tuning towards large language and vision models with GPT-4 level capabilities.*
44

5-
[[Project Page](https://llava-vl.github.io/)] [[Demo](https://llava.hliu.cc/)] [[Data](https://github.com/haotian-liu/LLaVA/blob/main/docs/Data.md)] [[Model Zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)]
5+
[[Project Page](https://llava-vl.github.io/)] [[Demo](https://llava.hliu.cc/)] [[Data](https://github.com/haotian-liu/LLaVA/blob/main/docs/Data.md)] [[Model Zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)] [![Replicate](https://replicate.com/yorickvp/llava-13b/badge)](https://replicate.com/yorickvp/llava-13b)
66

77
**Improved Baselines with Visual Instruction Tuning** [[Paper](https://arxiv.org/abs/2310.03744)] <br>
88
[Haotian Liu](https://hliu.cc), [Chunyuan Li](https://chunyuan.li/), [Yuheng Li](https://yuheng-li.github.io/), [Yong Jae Lee](https://pages.cs.wisc.edu/~yongjaelee/)

cog.yaml

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Configuration for Cog ⚙️
2+
# Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md
3+
4+
build:
5+
gpu: true
6+
7+
python_version: "3.11"
8+
9+
python_packages:
10+
- "torch==2.0.1"
11+
- "accelerate==0.21.0"
12+
- "bitsandbytes==0.41.0"
13+
- "deepspeed==0.9.5"
14+
- "einops-exts==0.0.4"
15+
- "einops==0.6.1"
16+
- "gradio==3.35.2"
17+
- "gradio_client==0.2.9"
18+
- "httpx==0.24.0"
19+
- "markdown2==2.4.10"
20+
- "numpy==1.26.0"
21+
- "peft==0.4.0"
22+
- "scikit-learn==1.2.2"
23+
- "sentencepiece==0.1.99"
24+
- "shortuuid==1.0.11"
25+
- "timm==0.6.13"
26+
- "tokenizers==0.13.3"
27+
- "torch==2.0.1"
28+
- "torchvision==0.15.2"
29+
- "transformers==4.31.0"
30+
- "wandb==0.15.12"
31+
- "wavedrom==2.0.3.post3"
32+
- "Pygments==2.16.1"
33+
run:
34+
- curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.0.3/pget" && chmod +x /usr/local/bin/pget
35+
36+
# predict.py defines how predictions are run on your model
37+
predict: "predict.py:Predictor"

predict.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
import torch
2+
3+
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
4+
from llava.conversation import conv_templates, SeparatorStyle
5+
from llava.model.builder import load_pretrained_model
6+
from llava.utils import disable_torch_init
7+
from llava.mm_utils import tokenizer_image_token, KeywordsStoppingCriteria
8+
9+
from PIL import Image
10+
11+
import requests
12+
from io import BytesIO
13+
14+
from cog import BasePredictor, Input, Path
15+
import time, subprocess
16+
17+
import os
18+
os.environ["HUGGINGFACE_HUB_CACHE"] = os.getcwd() + "/weights"
19+
20+
# url for the weights mirror
21+
REPLICATE_WEIGHTS_URL = "https://weights.replicate.delivery/default"
22+
# files to download from the weights mirrors
23+
weights = [
24+
{
25+
"dest": "liuhaotian/llava-v1.5-13b",
26+
# git commit hash from huggingface
27+
"src": "llava-v1.5-13b/006818fc465ebda4c003c0998674d9141d8d95f8",
28+
"files": [
29+
"config.json",
30+
"generation_config.json",
31+
"pytorch_model-00001-of-00003.bin",
32+
"pytorch_model-00002-of-00003.bin",
33+
"pytorch_model-00003-of-00003.bin",
34+
"pytorch_model.bin.index.json",
35+
"special_tokens_map.json",
36+
"tokenizer.model",
37+
"tokenizer_config.json",
38+
]
39+
},
40+
{
41+
"dest": "openai/clip-vit-large-patch14-336",
42+
"src": "clip-vit-large-patch14-336/ce19dc912ca5cd21c8a653c79e251e808ccabcd1",
43+
"files": [
44+
"config.json",
45+
"preprocessor_config.json",
46+
"pytorch_model.bin"
47+
],
48+
}
49+
]
50+
51+
def download_weights(baseurl, basedest, files):
52+
start = time.time()
53+
print("downloading to: ", basedest)
54+
os.makedirs(basedest, exist_ok=True)
55+
for f in files:
56+
dest = os.path.join(basedest, f)
57+
url = os.path.join(REPLICATE_WEIGHTS_URL, baseurl, f)
58+
if not os.path.exists(dest):
59+
print("downloading url: ", url)
60+
subprocess.check_call(["pget", url, dest], close_fds=False)
61+
print("downloading took: ", time.time() - start)
62+
63+
class Predictor(BasePredictor):
64+
def setup(self) -> None:
65+
"""Load the model into memory to make running multiple predictions efficient"""
66+
for weight in weights:
67+
download_weights(weight["src"], weight["dest"], weight["files"])
68+
disable_torch_init()
69+
70+
self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model("liuhaotian/llava-v1.5-13b", model_name="llava-v1.5-13b", model_base=None, load_8bit=False, load_4bit=False)
71+
72+
def predict(
73+
self,
74+
image: Path = Input(description="Input image"),
75+
prompt: str = Input(description="Prompt to use for text generation"),
76+
temperature: float = Input(description="Temperature for text generation", default=0.2, ge=0.0),
77+
max_tokens: int = Input(description="Maximum number of tokens to generate", default=1024, ge=0),
78+
) -> str:
79+
"""Run a single prediction on the model"""
80+
81+
conv_mode = "llava_v1"
82+
conv = conv_templates[conv_mode].copy()
83+
84+
image_data = load_image(str(image))
85+
image_tensor = self.image_processor.preprocess(image_data, return_tensors='pt')['pixel_values'].half().cuda()
86+
87+
# loop start
88+
89+
# just one turn, always prepend image token
90+
inp = DEFAULT_IMAGE_TOKEN + '\n' + prompt
91+
conv.append_message(conv.roles[0], inp)
92+
93+
conv.append_message(conv.roles[1], None)
94+
prompt = conv.get_prompt()
95+
96+
input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
97+
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
98+
keywords = [stop_str]
99+
stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids)
100+
101+
with torch.inference_mode():
102+
output_ids = self.model.generate(
103+
input_ids,
104+
images=image_tensor,
105+
do_sample=True,
106+
temperature=temperature,
107+
max_new_tokens=max_tokens,
108+
use_cache=True,
109+
stopping_criteria=[stopping_criteria])
110+
111+
outputs = self.tokenizer.decode(output_ids[0, input_ids.shape[1]:], skip_special_tokens=True).strip()
112+
conv.messages[-1][-1] = outputs
113+
114+
if outputs.endswith(stop_str):
115+
outputs = outputs[:-len(stop_str)].strip()
116+
return outputs
117+
118+
119+
def load_image(image_file):
120+
if image_file.startswith('http') or image_file.startswith('https'):
121+
response = requests.get(image_file)
122+
image = Image.open(BytesIO(response.content)).convert('RGB')
123+
else:
124+
image = Image.open(image_file).convert('RGB')
125+
return image
126+

0 commit comments

Comments
 (0)