add Jupyter Notebook to help running the models

This commit is contained in:
Yuki-Kokomi
2024-10-23 21:35:56 +08:00
committed by GitHub
parent a3fb458265
commit 2cd6d16fac

213
run_model.ipynb Normal file
View File

@@ -0,0 +1,213 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from tinyllava.eval.run_tiny_llava import *\n",
"\n",
"## Change the model path here\n",
"## You need to change the type of the model\n",
"## 0.55B, 0.89B: llama\n",
"## 2.4B: gemma\n",
"## 3.1B: phi\n",
"model_path = \"<MODEL_PATH>\"\n",
"conv_mode = \"llama\" # or llama, gemma, phi\n",
"\n",
"## You need to change the \"max_new_tokens\" if the model can't deal with long tokens.\n",
"## possible values: 1024, 1152, 1536, 2048, 3072\n",
"args = type('Args', (), {\n",
" \"model_path\": model_path,\n",
" \"model_base\": None,\n",
" \"conv_mode\": conv_mode,\n",
" \"sep\": \",\",\n",
" \"temperature\": 0,\n",
" \"top_p\": None,\n",
" \"num_beams\": 1,\n",
" \"max_new_tokens\": 2048\n",
"})()\n",
"\n",
"# Model\n",
"disable_torch_init()\n",
"\n",
"if args.model_path is not None:\n",
" model, tokenizer, image_processor, context_len = load_pretrained_model(args.model_path)\n",
"else:\n",
" assert args.model is not None, 'model_path or model must be provided'\n",
" model = args.model\n",
" if hasattr(model.config, \"max_sequence_length\"):\n",
" context_len = model.config.max_sequence_length\n",
" else:\n",
" context_len = 2048\n",
" tokenizer = model.tokenizer\n",
" image_processor = model.vision_tower._image_processor\n",
"\n",
"\n",
"text_processor = TextPreprocess(tokenizer, args.conv_mode)\n",
"data_args = model.config\n",
"image_processor = ImagePreprocess(image_processor, data_args)\n",
"model.cuda()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"def ensure_dir(path):\n",
" \"\"\"\n",
" create path by first checking its existence,\n",
" :param paths: path\n",
" :return:\n",
" \"\"\"\n",
" if not os.path.exists(path):\n",
" os.makedirs(path)\n",
"\n",
"\n",
"import signal\n",
"\n",
"class TimeoutException(Exception):\n",
" pass\n",
"\n",
"def handler(signum, frame):\n",
" raise TimeoutException()\n",
"\n",
"# Set timeout (unit: s)\n",
"timeout = 300\n",
"\n",
"def timeout_decorator(func):\n",
" def wrapper(*args, **kwargs):\n",
" signal.signal(signal.SIGALRM, handler)\n",
" signal.alarm(timeout)\n",
" try:\n",
" result = func(*args, **kwargs)\n",
" except TimeoutException:\n",
" print(\"Function timed out!\")\n",
" raise TimeoutException\n",
" result = None\n",
" finally:\n",
" signal.alarm(0)\n",
" return result\n",
" return wrapper"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@timeout_decorator\n",
"def process_image(qs, path):\n",
" qs = DEFAULT_IMAGE_TOKEN + \"\\n\" + qs\n",
"\n",
"\n",
" msg = Message()\n",
" msg.add_message(qs)\n",
"\n",
" result = text_processor(msg.messages, mode='eval')\n",
" input_ids = result['input_ids']\n",
" prompt = result['prompt']\n",
" input_ids = input_ids.unsqueeze(0).cuda()\n",
" \n",
"\n",
" image_files = [path]\n",
" images = load_images(image_files)[0]\n",
" images_tensor = image_processor(images)\n",
" images_tensor = images_tensor.unsqueeze(0).half().cuda()\n",
"\n",
" stop_str = text_processor.template.separator.apply()[1]\n",
" keywords = [stop_str]\n",
" stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)\n",
"\n",
" with torch.inference_mode():\n",
" output_ids = model.generate(\n",
" input_ids,\n",
" images=images_tensor,\n",
" do_sample=True if args.temperature > 0 else False,\n",
" temperature=args.temperature,\n",
" top_p=args.top_p,\n",
" num_beams=args.num_beams,\n",
" pad_token_id=tokenizer.pad_token_id,\n",
" max_new_tokens=args.max_new_tokens,\n",
" use_cache=True,\n",
" stopping_criteria=[stopping_criteria],\n",
" )\n",
"\n",
" outputs = tokenizer.batch_decode(\n",
" output_ids, skip_special_tokens=True\n",
" )[0]\n",
" outputs = outputs.strip()\n",
" if outputs.endswith(stop_str):\n",
" outputs = outputs[: -len(stop_str)]\n",
" outputs = outputs.strip()\n",
" return outputs\n",
"\n",
"import os\n",
"import glob\n",
"import traceback\n",
"errors = []\n",
"\n",
"ts = [\"Default\", \"Transparent\", \"Orthographic\"]\n",
"# Change the src_b to the path of \"Input Pictures and Reference Codes\" in \n",
"src_b = \"<PATH of Input Pictures and Reference Codes>\"\n",
"out_b = \"<ANY OUTPUT PATH YOU LIKE>\"\n",
"for index in range(len(ts)):\n",
" tt = ts[index]\n",
" src = src_b + tt\n",
" out = out_b + tt\n",
" ensure_dir(out)\n",
" out_paths = sorted(glob.glob(os.path.join(src, \"*.{}\".format(\"jpg\"))))\n",
" if tt == \"Orthographic\" :\n",
" qs = \"This image is 4 views of a 3D model from certain angles. Please try to use Python-style APIs to render this model.\"\n",
" else: \n",
" qs = \"This image is a view of a 3D model from a certain angle. Please try to use Python-style APIs to render this model.\"\n",
"\n",
" for i in range(len(out_paths)):\n",
" path = out_paths[i]\n",
" print(f\"{tt}: {i + 1}/{len(out_paths)}\", end='\\r')\n",
" name = path.split(\"/\")[-1].split(\".\")[0]\n",
" save_path = os.path.join(out, f'{name}.py')\n",
" if os.path.isfile(save_path): continue\n",
" try:\n",
" outputs = process_image(qs, path)\n",
" with open(save_path, 'w', encoding='utf-8') as file:\n",
" file.write(outputs)\n",
" file.close()\n",
" except:\n",
" errors.append(f\"{tt}: {name}\")\n",
" print(f\"gen error: {name}\")\n",
" traceback.print_exc()\n",
" print()\n",
"\n",
"print(errors)\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "tinyllava_factory",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 2
}