mirror of
https://github.com/Yuki-Kokomi/OpenECAD_Project.git
synced 2026-02-03 16:23:24 -05:00
add Jupyter Notebook to help running the models
This commit is contained in:
213
run_model.ipynb
Normal file
213
run_model.ipynb
Normal file
@@ -0,0 +1,213 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from tinyllava.eval.run_tiny_llava import *\n",
|
||||
"\n",
|
||||
"## Change the model path here\n",
|
||||
"## You need to change the type of the model\n",
|
||||
"## 0.55B, 0.89B: llama\n",
|
||||
"## 2.4B: gemma\n",
|
||||
"## 3.1B: phi\n",
|
||||
"model_path = \"<MODEL_PATH>\"\n",
|
||||
"conv_mode = \"llama\" # or llama, gemma, phi\n",
|
||||
"\n",
|
||||
"## You need to change the \"max_new_tokens\" if the model can't deal with long tokens.\n",
|
||||
"## possible values: 1024, 1152, 1536, 2048, 3072\n",
|
||||
"args = type('Args', (), {\n",
|
||||
" \"model_path\": model_path,\n",
|
||||
" \"model_base\": None,\n",
|
||||
" \"conv_mode\": conv_mode,\n",
|
||||
" \"sep\": \",\",\n",
|
||||
" \"temperature\": 0,\n",
|
||||
" \"top_p\": None,\n",
|
||||
" \"num_beams\": 1,\n",
|
||||
" \"max_new_tokens\": 2048\n",
|
||||
"})()\n",
|
||||
"\n",
|
||||
"# Model\n",
|
||||
"disable_torch_init()\n",
|
||||
"\n",
|
||||
"if args.model_path is not None:\n",
|
||||
" model, tokenizer, image_processor, context_len = load_pretrained_model(args.model_path)\n",
|
||||
"else:\n",
|
||||
" assert args.model is not None, 'model_path or model must be provided'\n",
|
||||
" model = args.model\n",
|
||||
" if hasattr(model.config, \"max_sequence_length\"):\n",
|
||||
" context_len = model.config.max_sequence_length\n",
|
||||
" else:\n",
|
||||
" context_len = 2048\n",
|
||||
" tokenizer = model.tokenizer\n",
|
||||
" image_processor = model.vision_tower._image_processor\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"text_processor = TextPreprocess(tokenizer, args.conv_mode)\n",
|
||||
"data_args = model.config\n",
|
||||
"image_processor = ImagePreprocess(image_processor, data_args)\n",
|
||||
"model.cuda()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"def ensure_dir(path):\n",
|
||||
" \"\"\"\n",
|
||||
" create path by first checking its existence,\n",
|
||||
" :param paths: path\n",
|
||||
" :return:\n",
|
||||
" \"\"\"\n",
|
||||
" if not os.path.exists(path):\n",
|
||||
" os.makedirs(path)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"import signal\n",
|
||||
"\n",
|
||||
"class TimeoutException(Exception):\n",
|
||||
" pass\n",
|
||||
"\n",
|
||||
"def handler(signum, frame):\n",
|
||||
" raise TimeoutException()\n",
|
||||
"\n",
|
||||
"# Set timeout (unit: s)\n",
|
||||
"timeout = 300\n",
|
||||
"\n",
|
||||
"def timeout_decorator(func):\n",
|
||||
" def wrapper(*args, **kwargs):\n",
|
||||
" signal.signal(signal.SIGALRM, handler)\n",
|
||||
" signal.alarm(timeout)\n",
|
||||
" try:\n",
|
||||
" result = func(*args, **kwargs)\n",
|
||||
" except TimeoutException:\n",
|
||||
" print(\"Function timed out!\")\n",
|
||||
" raise TimeoutException\n",
|
||||
" result = None\n",
|
||||
" finally:\n",
|
||||
" signal.alarm(0)\n",
|
||||
" return result\n",
|
||||
" return wrapper"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@timeout_decorator\n",
|
||||
"def process_image(qs, path):\n",
|
||||
" qs = DEFAULT_IMAGE_TOKEN + \"\\n\" + qs\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" msg = Message()\n",
|
||||
" msg.add_message(qs)\n",
|
||||
"\n",
|
||||
" result = text_processor(msg.messages, mode='eval')\n",
|
||||
" input_ids = result['input_ids']\n",
|
||||
" prompt = result['prompt']\n",
|
||||
" input_ids = input_ids.unsqueeze(0).cuda()\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" image_files = [path]\n",
|
||||
" images = load_images(image_files)[0]\n",
|
||||
" images_tensor = image_processor(images)\n",
|
||||
" images_tensor = images_tensor.unsqueeze(0).half().cuda()\n",
|
||||
"\n",
|
||||
" stop_str = text_processor.template.separator.apply()[1]\n",
|
||||
" keywords = [stop_str]\n",
|
||||
" stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)\n",
|
||||
"\n",
|
||||
" with torch.inference_mode():\n",
|
||||
" output_ids = model.generate(\n",
|
||||
" input_ids,\n",
|
||||
" images=images_tensor,\n",
|
||||
" do_sample=True if args.temperature > 0 else False,\n",
|
||||
" temperature=args.temperature,\n",
|
||||
" top_p=args.top_p,\n",
|
||||
" num_beams=args.num_beams,\n",
|
||||
" pad_token_id=tokenizer.pad_token_id,\n",
|
||||
" max_new_tokens=args.max_new_tokens,\n",
|
||||
" use_cache=True,\n",
|
||||
" stopping_criteria=[stopping_criteria],\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" outputs = tokenizer.batch_decode(\n",
|
||||
" output_ids, skip_special_tokens=True\n",
|
||||
" )[0]\n",
|
||||
" outputs = outputs.strip()\n",
|
||||
" if outputs.endswith(stop_str):\n",
|
||||
" outputs = outputs[: -len(stop_str)]\n",
|
||||
" outputs = outputs.strip()\n",
|
||||
" return outputs\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"import glob\n",
|
||||
"import traceback\n",
|
||||
"errors = []\n",
|
||||
"\n",
|
||||
"ts = [\"Default\", \"Transparent\", \"Orthographic\"]\n",
|
||||
"# Change the src_b to the path of \"Input Pictures and Reference Codes\" in \n",
|
||||
"src_b = \"<PATH of Input Pictures and Reference Codes>\"\n",
|
||||
"out_b = \"<ANY OUTPUT PATH YOU LIKE>\"\n",
|
||||
"for index in range(len(ts)):\n",
|
||||
" tt = ts[index]\n",
|
||||
" src = src_b + tt\n",
|
||||
" out = out_b + tt\n",
|
||||
" ensure_dir(out)\n",
|
||||
" out_paths = sorted(glob.glob(os.path.join(src, \"*.{}\".format(\"jpg\"))))\n",
|
||||
" if tt == \"Orthographic\" :\n",
|
||||
" qs = \"This image is 4 views of a 3D model from certain angles. Please try to use Python-style APIs to render this model.\"\n",
|
||||
" else: \n",
|
||||
" qs = \"This image is a view of a 3D model from a certain angle. Please try to use Python-style APIs to render this model.\"\n",
|
||||
"\n",
|
||||
" for i in range(len(out_paths)):\n",
|
||||
" path = out_paths[i]\n",
|
||||
" print(f\"{tt}: {i + 1}/{len(out_paths)}\", end='\\r')\n",
|
||||
" name = path.split(\"/\")[-1].split(\".\")[0]\n",
|
||||
" save_path = os.path.join(out, f'{name}.py')\n",
|
||||
" if os.path.isfile(save_path): continue\n",
|
||||
" try:\n",
|
||||
" outputs = process_image(qs, path)\n",
|
||||
" with open(save_path, 'w', encoding='utf-8') as file:\n",
|
||||
" file.write(outputs)\n",
|
||||
" file.close()\n",
|
||||
" except:\n",
|
||||
" errors.append(f\"{tt}: {name}\")\n",
|
||||
" print(f\"gen error: {name}\")\n",
|
||||
" traceback.print_exc()\n",
|
||||
" print()\n",
|
||||
"\n",
|
||||
"print(errors)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "tinyllava_factory",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.14"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
Reference in New Issue
Block a user