diff --git a/run_model.ipynb b/run_model.ipynb new file mode 100644 index 0000000..3b60388 --- /dev/null +++ b/run_model.ipynb @@ -0,0 +1,213 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from tinyllava.eval.run_tiny_llava import *\n", + "\n", + "## Change the model path here\n", + "## You need to change the type of the model\n", + "## 0.55B, 0.89B: llama\n", + "## 2.4B: gemma\n", + "## 3.1B: phi\n", + "model_path = \"\"\n", + "conv_mode = \"llama\" # or llama, gemma, phi\n", + "\n", + "## You need to change the \"max_new_tokens\" if the model can't deal with long tokens.\n", + "## possible values: 1024, 1152, 1536, 2048, 3072\n", + "args = type('Args', (), {\n", + " \"model_path\": model_path,\n", + " \"model_base\": None,\n", + " \"conv_mode\": conv_mode,\n", + " \"sep\": \",\",\n", + " \"temperature\": 0,\n", + " \"top_p\": None,\n", + " \"num_beams\": 1,\n", + " \"max_new_tokens\": 2048\n", + "})()\n", + "\n", + "# Model\n", + "disable_torch_init()\n", + "\n", + "if args.model_path is not None:\n", + " model, tokenizer, image_processor, context_len = load_pretrained_model(args.model_path)\n", + "else:\n", + " assert args.model is not None, 'model_path or model must be provided'\n", + " model = args.model\n", + " if hasattr(model.config, \"max_sequence_length\"):\n", + " context_len = model.config.max_sequence_length\n", + " else:\n", + " context_len = 2048\n", + " tokenizer = model.tokenizer\n", + " image_processor = model.vision_tower._image_processor\n", + "\n", + "\n", + "text_processor = TextPreprocess(tokenizer, args.conv_mode)\n", + "data_args = model.config\n", + "image_processor = ImagePreprocess(image_processor, data_args)\n", + "model.cuda()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "def ensure_dir(path):\n", + " \"\"\"\n", + " create path by first checking its existence,\n", + " :param paths: path\n", + " :return:\n", + " \"\"\"\n", + " if not os.path.exists(path):\n", + " os.makedirs(path)\n", + "\n", + "\n", + "import signal\n", + "\n", + "class TimeoutException(Exception):\n", + " pass\n", + "\n", + "def handler(signum, frame):\n", + " raise TimeoutException()\n", + "\n", + "# Set timeout (unit: s)\n", + "timeout = 300\n", + "\n", + "def timeout_decorator(func):\n", + " def wrapper(*args, **kwargs):\n", + " signal.signal(signal.SIGALRM, handler)\n", + " signal.alarm(timeout)\n", + " try:\n", + " result = func(*args, **kwargs)\n", + " except TimeoutException:\n", + " print(\"Function timed out!\")\n", + " raise TimeoutException\n", + " result = None\n", + " finally:\n", + " signal.alarm(0)\n", + " return result\n", + " return wrapper" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "@timeout_decorator\n", + "def process_image(qs, path):\n", + " qs = DEFAULT_IMAGE_TOKEN + \"\\n\" + qs\n", + "\n", + "\n", + " msg = Message()\n", + " msg.add_message(qs)\n", + "\n", + " result = text_processor(msg.messages, mode='eval')\n", + " input_ids = result['input_ids']\n", + " prompt = result['prompt']\n", + " input_ids = input_ids.unsqueeze(0).cuda()\n", + " \n", + "\n", + " image_files = [path]\n", + " images = load_images(image_files)[0]\n", + " images_tensor = image_processor(images)\n", + " images_tensor = images_tensor.unsqueeze(0).half().cuda()\n", + "\n", + " stop_str = text_processor.template.separator.apply()[1]\n", + " keywords = [stop_str]\n", + " stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)\n", + "\n", + " with torch.inference_mode():\n", + " output_ids = model.generate(\n", + " input_ids,\n", + " images=images_tensor,\n", + " do_sample=True if args.temperature > 0 else False,\n", + " temperature=args.temperature,\n", + " top_p=args.top_p,\n", + " num_beams=args.num_beams,\n", + " pad_token_id=tokenizer.pad_token_id,\n", + " max_new_tokens=args.max_new_tokens,\n", + " use_cache=True,\n", + " stopping_criteria=[stopping_criteria],\n", + " )\n", + "\n", + " outputs = tokenizer.batch_decode(\n", + " output_ids, skip_special_tokens=True\n", + " )[0]\n", + " outputs = outputs.strip()\n", + " if outputs.endswith(stop_str):\n", + " outputs = outputs[: -len(stop_str)]\n", + " outputs = outputs.strip()\n", + " return outputs\n", + "\n", + "import os\n", + "import glob\n", + "import traceback\n", + "errors = []\n", + "\n", + "ts = [\"Default\", \"Transparent\", \"Orthographic\"]\n", + "# Change the src_b to the path of \"Input Pictures and Reference Codes\" in \n", + "src_b = \"\"\n", + "out_b = \"\"\n", + "for index in range(len(ts)):\n", + " tt = ts[index]\n", + " src = src_b + tt\n", + " out = out_b + tt\n", + " ensure_dir(out)\n", + " out_paths = sorted(glob.glob(os.path.join(src, \"*.{}\".format(\"jpg\"))))\n", + " if tt == \"Orthographic\" :\n", + " qs = \"This image is 4 views of a 3D model from certain angles. Please try to use Python-style APIs to render this model.\"\n", + " else: \n", + " qs = \"This image is a view of a 3D model from a certain angle. Please try to use Python-style APIs to render this model.\"\n", + "\n", + " for i in range(len(out_paths)):\n", + " path = out_paths[i]\n", + " print(f\"{tt}: {i + 1}/{len(out_paths)}\", end='\\r')\n", + " name = path.split(\"/\")[-1].split(\".\")[0]\n", + " save_path = os.path.join(out, f'{name}.py')\n", + " if os.path.isfile(save_path): continue\n", + " try:\n", + " outputs = process_image(qs, path)\n", + " with open(save_path, 'w', encoding='utf-8') as file:\n", + " file.write(outputs)\n", + " file.close()\n", + " except:\n", + " errors.append(f\"{tt}: {name}\")\n", + " print(f\"gen error: {name}\")\n", + " traceback.print_exc()\n", + " print()\n", + "\n", + "print(errors)\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "tinyllava_factory", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} \ No newline at end of file