{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from tinyllava.eval.run_tiny_llava import *\n", "\n", "## Change the model path here\n", "## You need to change the type of the model\n", "## 0.55B, 0.89B: llama\n", "## 2.4B: gemma\n", "## 3.1B: phi\n", "model_path = \"\"\n", "conv_mode = \"llama\" # or llama, gemma, phi\n", "\n", "## You need to change the \"max_new_tokens\" if the model can't deal with long tokens.\n", "## possible values: 1024, 1152, 1536, 2048, 3072\n", "args = type('Args', (), {\n", " \"model_path\": model_path,\n", " \"model_base\": None,\n", " \"conv_mode\": conv_mode,\n", " \"sep\": \",\",\n", " \"temperature\": 0,\n", " \"top_p\": None,\n", " \"num_beams\": 1,\n", " \"max_new_tokens\": 2048\n", "})()\n", "\n", "# Model\n", "disable_torch_init()\n", "\n", "if args.model_path is not None:\n", " model, tokenizer, image_processor, context_len = load_pretrained_model(args.model_path)\n", "else:\n", " assert args.model is not None, 'model_path or model must be provided'\n", " model = args.model\n", " if hasattr(model.config, \"max_sequence_length\"):\n", " context_len = model.config.max_sequence_length\n", " else:\n", " context_len = 2048\n", " tokenizer = model.tokenizer\n", " image_processor = model.vision_tower._image_processor\n", "\n", "\n", "text_processor = TextPreprocess(tokenizer, args.conv_mode)\n", "data_args = model.config\n", "image_processor = ImagePreprocess(image_processor, data_args)\n", "model.cuda()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "\n", "def ensure_dir(path):\n", " \"\"\"\n", " create path by first checking its existence,\n", " :param paths: path\n", " :return:\n", " \"\"\"\n", " if not os.path.exists(path):\n", " os.makedirs(path)\n", "\n", "\n", "import signal\n", "\n", "class TimeoutException(Exception):\n", " pass\n", "\n", "def handler(signum, frame):\n", " raise TimeoutException()\n", "\n", "# Set timeout (unit: s)\n", "timeout = 300\n", "\n", "def timeout_decorator(func):\n", " def wrapper(*args, **kwargs):\n", " signal.signal(signal.SIGALRM, handler)\n", " signal.alarm(timeout)\n", " try:\n", " result = func(*args, **kwargs)\n", " except TimeoutException:\n", " print(\"Function timed out!\")\n", " raise TimeoutException\n", " result = None\n", " finally:\n", " signal.alarm(0)\n", " return result\n", " return wrapper" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@timeout_decorator\n", "def process_image(qs, path):\n", " qs = DEFAULT_IMAGE_TOKEN + \"\\n\" + qs\n", "\n", "\n", " msg = Message()\n", " msg.add_message(qs)\n", "\n", " result = text_processor(msg.messages, mode='eval')\n", " input_ids = result['input_ids']\n", " prompt = result['prompt']\n", " input_ids = input_ids.unsqueeze(0).cuda()\n", " \n", "\n", " image_files = [path]\n", " images = load_images(image_files)[0]\n", " images_tensor = image_processor(images)\n", " images_tensor = images_tensor.unsqueeze(0).half().cuda()\n", "\n", " stop_str = text_processor.template.separator.apply()[1]\n", " keywords = [stop_str]\n", " stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)\n", "\n", " with torch.inference_mode():\n", " output_ids = model.generate(\n", " input_ids,\n", " images=images_tensor,\n", " do_sample=True if args.temperature > 0 else False,\n", " temperature=args.temperature,\n", " top_p=args.top_p,\n", " num_beams=args.num_beams,\n", " pad_token_id=tokenizer.pad_token_id,\n", " max_new_tokens=args.max_new_tokens,\n", " use_cache=True,\n", " stopping_criteria=[stopping_criteria],\n", " )\n", "\n", " outputs = tokenizer.batch_decode(\n", " output_ids, skip_special_tokens=True\n", " )[0]\n", " outputs = outputs.strip()\n", " if outputs.endswith(stop_str):\n", " outputs = outputs[: -len(stop_str)]\n", " outputs = outputs.strip()\n", " return outputs\n", "\n", "import os\n", "import glob\n", "import traceback\n", "errors = []\n", "\n", "ts = [\"Default\", \"Transparent\", \"Orthographic\"]\n", "# Change the src_b to the path of \"Input Pictures and Reference Codes\" in \n", "src_b = \"\"\n", "out_b = \"\"\n", "for index in range(len(ts)):\n", " tt = ts[index]\n", " src = src_b + tt\n", " out = out_b + tt\n", " ensure_dir(out)\n", " out_paths = sorted(glob.glob(os.path.join(src, \"*.{}\".format(\"jpg\"))))\n", " if tt == \"Orthographic\" :\n", " qs = \"This image is 4 views of a 3D model from certain angles. Please try to use Python-style APIs to render this model.\"\n", " else: \n", " qs = \"This image is a view of a 3D model from a certain angle. Please try to use Python-style APIs to render this model.\"\n", "\n", " for i in range(len(out_paths)):\n", " path = out_paths[i]\n", " print(f\"{tt}: {i + 1}/{len(out_paths)}\", end='\\r')\n", " name = path.split(\"/\")[-1].split(\".\")[0]\n", " save_path = os.path.join(out, f'{name}.py')\n", " if os.path.isfile(save_path): continue\n", " try:\n", " outputs = process_image(qs, path)\n", " with open(save_path, 'w', encoding='utf-8') as file:\n", " file.write(outputs)\n", " file.close()\n", " except:\n", " errors.append(f\"{tt}: {name}\")\n", " print(f\"gen error: {name}\")\n", " traceback.print_exc()\n", " print()\n", "\n", "print(errors)\n" ] } ], "metadata": { "kernelspec": { "display_name": "tinyllava_factory", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 2 }