|
| 1 | +import torch |
| 2 | + |
| 3 | +from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN |
| 4 | +from llava.conversation import conv_templates, SeparatorStyle |
| 5 | +from llava.model.builder import load_pretrained_model |
| 6 | +from llava.utils import disable_torch_init |
| 7 | +from llava.mm_utils import tokenizer_image_token, KeywordsStoppingCriteria |
| 8 | + |
| 9 | +from PIL import Image |
| 10 | + |
| 11 | +import requests |
| 12 | +from io import BytesIO |
| 13 | + |
| 14 | +from cog import BasePredictor, Input, Path |
| 15 | +import time, subprocess |
| 16 | + |
| 17 | +import os |
| 18 | +os.environ["HUGGINGFACE_HUB_CACHE"] = os.getcwd() + "/weights" |
| 19 | + |
| 20 | +# url for the weights mirror |
| 21 | +REPLICATE_WEIGHTS_URL = "https://weights.replicate.delivery/default" |
| 22 | +# files to download from the weights mirrors |
| 23 | +weights = [ |
| 24 | + { |
| 25 | + "dest": "liuhaotian/llava-v1.5-13b", |
| 26 | + # git commit hash from huggingface |
| 27 | + "src": "llava-v1.5-13b/006818fc465ebda4c003c0998674d9141d8d95f8", |
| 28 | + "files": [ |
| 29 | + "config.json", |
| 30 | + "generation_config.json", |
| 31 | + "pytorch_model-00001-of-00003.bin", |
| 32 | + "pytorch_model-00002-of-00003.bin", |
| 33 | + "pytorch_model-00003-of-00003.bin", |
| 34 | + "pytorch_model.bin.index.json", |
| 35 | + "special_tokens_map.json", |
| 36 | + "tokenizer.model", |
| 37 | + "tokenizer_config.json", |
| 38 | + ] |
| 39 | + }, |
| 40 | + { |
| 41 | + "dest": "openai/clip-vit-large-patch14-336", |
| 42 | + "src": "clip-vit-large-patch14-336/ce19dc912ca5cd21c8a653c79e251e808ccabcd1", |
| 43 | + "files": [ |
| 44 | + "config.json", |
| 45 | + "preprocessor_config.json", |
| 46 | + "pytorch_model.bin" |
| 47 | + ], |
| 48 | + } |
| 49 | +] |
| 50 | + |
| 51 | +def download_weights(baseurl, basedest, files): |
| 52 | + start = time.time() |
| 53 | + print("downloading to: ", basedest) |
| 54 | + os.makedirs(basedest, exist_ok=True) |
| 55 | + for f in files: |
| 56 | + dest = os.path.join(basedest, f) |
| 57 | + url = os.path.join(REPLICATE_WEIGHTS_URL, baseurl, f) |
| 58 | + if not os.path.exists(dest): |
| 59 | + print("downloading url: ", url) |
| 60 | + subprocess.check_call(["pget", url, dest], close_fds=False) |
| 61 | + print("downloading took: ", time.time() - start) |
| 62 | + |
| 63 | +class Predictor(BasePredictor): |
| 64 | + def setup(self) -> None: |
| 65 | + """Load the model into memory to make running multiple predictions efficient""" |
| 66 | + for weight in weights: |
| 67 | + download_weights(weight["src"], weight["dest"], weight["files"]) |
| 68 | + disable_torch_init() |
| 69 | + |
| 70 | + self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model("liuhaotian/llava-v1.5-13b", model_name="llava-v1.5-13b", model_base=None, load_8bit=False, load_4bit=False) |
| 71 | + |
| 72 | + def predict( |
| 73 | + self, |
| 74 | + image: Path = Input(description="Input image"), |
| 75 | + prompt: str = Input(description="Prompt to use for text generation"), |
| 76 | + temperature: float = Input(description="Temperature for text generation", default=0.2, ge=0.0), |
| 77 | + max_tokens: int = Input(description="Maximum number of tokens to generate", default=1024, ge=0), |
| 78 | + ) -> str: |
| 79 | + """Run a single prediction on the model""" |
| 80 | + |
| 81 | + conv_mode = "llava_v1" |
| 82 | + conv = conv_templates[conv_mode].copy() |
| 83 | + |
| 84 | + image_data = load_image(str(image)) |
| 85 | + image_tensor = self.image_processor.preprocess(image_data, return_tensors='pt')['pixel_values'].half().cuda() |
| 86 | + |
| 87 | + # loop start |
| 88 | + |
| 89 | + # just one turn, always prepend image token |
| 90 | + inp = DEFAULT_IMAGE_TOKEN + '\n' + prompt |
| 91 | + conv.append_message(conv.roles[0], inp) |
| 92 | + |
| 93 | + conv.append_message(conv.roles[1], None) |
| 94 | + prompt = conv.get_prompt() |
| 95 | + |
| 96 | + input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() |
| 97 | + stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 |
| 98 | + keywords = [stop_str] |
| 99 | + stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids) |
| 100 | + |
| 101 | + with torch.inference_mode(): |
| 102 | + output_ids = self.model.generate( |
| 103 | + input_ids, |
| 104 | + images=image_tensor, |
| 105 | + do_sample=True, |
| 106 | + temperature=temperature, |
| 107 | + max_new_tokens=max_tokens, |
| 108 | + use_cache=True, |
| 109 | + stopping_criteria=[stopping_criteria]) |
| 110 | + |
| 111 | + outputs = self.tokenizer.decode(output_ids[0, input_ids.shape[1]:], skip_special_tokens=True).strip() |
| 112 | + conv.messages[-1][-1] = outputs |
| 113 | + |
| 114 | + if outputs.endswith(stop_str): |
| 115 | + outputs = outputs[:-len(stop_str)].strip() |
| 116 | + return outputs |
| 117 | + |
| 118 | + |
| 119 | +def load_image(image_file): |
| 120 | + if image_file.startswith('http') or image_file.startswith('https'): |
| 121 | + response = requests.get(image_file) |
| 122 | + image = Image.open(BytesIO(response.content)).convert('RGB') |
| 123 | + else: |
| 124 | + image = Image.open(image_file).convert('RGB') |
| 125 | + return image |
| 126 | + |
0 commit comments