From 0b09f51a72f2a6fbc090ff8401613b6c29d714de Mon Sep 17 00:00:00 2001 From: Yuki-Kokomi Date: Wed, 6 Nov 2024 17:35:42 +0800 Subject: [PATCH] add run_model.py --- run_model.py | 190 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 190 insertions(+) create mode 100644 run_model.py diff --git a/run_model.py b/run_model.py new file mode 100644 index 0000000..0f723e4 --- /dev/null +++ b/run_model.py @@ -0,0 +1,190 @@ +from tinyllava.eval.run_tiny_llava import * +import argparse + + +parser.add_argument('--model_path', type=str, required=True, help="Path to the model.") +parser.add_argument('--src', type=str, required=True, help="Path of Input Pictures and Reference Codes.") +parser.add_argument('--out', type=str, required=True, help="Any output path you like.") + +# 解析命令行参数 +args = parser.parse_args() + +# 将命令行输入的值传递给相应的变量 +model_path = args.model_path +src_base = args.src +out_base = args.out + +input_types = ["Default"]#["Default", "Transparent", "Orthographic"] +conv_mode = "gemma" # or llama, gemma, phi + +## You need to change the "max_new_tokens" if the model can't deal with long tokens. +## possible values: 1024, 1152, 1536, 2048, 3072 +args = type('Args', (), { + "model_path": model_path, + "model_base": None, + "conv_mode": conv_mode, + "sep": ",", + "temperature": 0, + "top_p": None, + "num_beams": 1, + "max_new_tokens": 2048 +})() + +# Model +disable_torch_init() + +if args.model_path is not None: + model, tokenizer, image_processor, context_len = load_pretrained_model(args.model_path) +else: + assert args.model is not None, 'model_path or model must be provided' + model = args.model + if hasattr(model.config, "max_sequence_length"): + context_len = model.config.max_sequence_length + else: + context_len = 2048 + tokenizer = model.tokenizer + image_processor = model.vision_tower._image_processor + + +text_processor = TextPreprocess(tokenizer, args.conv_mode) +data_args = model.config +image_processor = ImagePreprocess(image_processor, data_args) +model.cuda() + +import os + +def ensure_dir(path): + """ + create path by first checking its existence, + :param paths: path + :return: + """ + if not os.path.exists(path): + os.makedirs(path) + + +import signal + +class TimeoutException(Exception): + pass + +def handler(signum, frame): + raise TimeoutException() + +# Set timeout (unit: s) +timeout = 300 + +def timeout_decorator(func): + def wrapper(*args, **kwargs): + signal.signal(signal.SIGALRM, handler) + signal.alarm(timeout) + try: + result = func(*args, **kwargs) + except TimeoutException: + print("Function timed out!") + raise TimeoutException + result = None + finally: + signal.alarm(0) + return result + return wrapper + +@timeout_decorator +def process_image(qs, path): + qs = DEFAULT_IMAGE_TOKEN + "\n" + qs + + + msg = Message() + msg.add_message(qs) + + result = text_processor(msg.messages, mode='eval') + input_ids = result['input_ids'] + prompt = result['prompt'] + input_ids = input_ids.unsqueeze(0).cuda() + + + image_files = [path] + images = load_images(image_files)[0] + images_tensor = image_processor(images) + images_tensor = images_tensor.unsqueeze(0).half().cuda() + + stop_str = text_processor.template.separator.apply()[1] + keywords = [stop_str] + stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) + + with torch.inference_mode(): + output_ids = model.generate( + input_ids, + images=images_tensor, + do_sample=True if args.temperature > 0 else False, + temperature=args.temperature, + top_p=args.top_p, + num_beams=args.num_beams, + pad_token_id=tokenizer.pad_token_id, + max_new_tokens=args.max_new_tokens, + use_cache=True, + stopping_criteria=[stopping_criteria], + ) + + outputs = tokenizer.batch_decode( + output_ids, skip_special_tokens=True + )[0] + outputs = outputs.strip() + if outputs.endswith(stop_str): + outputs = outputs[: -len(stop_str)] + outputs = outputs.strip() + return outputs + +import re + +def extract_python_code(input_str): + # 匹配以```python开头,```结束的内容 + match = re.search(r'```python(.*?)```', input_str, re.DOTALL) + + if match: + # 如果找到匹配的内容,返回```python和```之间的内容 + return match.group(1) + else: + # 如果没有```python```包裹的内容,返回```python后面的所有内容 + # 找到```python的位置并返回从该位置到字符串末尾的所有内容 + match = re.search(r'```python(.*)', input_str, re.DOTALL) + if match: + return match.group(1) + else: + return input_str + +import os +import glob +import traceback +errors = [] +for index in range(len(input_types)): + cur_type = input_types[index] + src = src_base + cur_type + out = out_base + cur_type + ensure_dir(out) + out_paths = sorted(glob.glob(os.path.join(src, "*.{}".format("jpg")))) + if cur_type == "Orthographic" : + qs = "This image is 4 views of a 3D model from certain angles. Please try to use Python-style APIs to render this model." + else: + qs = "This image is a view of a 3D model from a certain angle. Please try to use Python-style APIs to render this model." + + for i in range(len(out_paths)): + path = out_paths[i] + print(f"{cur_type}: {i + 1}/{len(out_paths)}", end='\r') + name = path.split("/")[-1].split(".")[0] + save_path = os.path.join(out, f'{name}.py') + if os.path.isfile(save_path): continue + try: + outputs = process_image(qs, path) + outputs = extract_python_code(outputs) + with open(save_path, 'w', encoding='utf-8') as file: + file.write(outputs) + file.close() + except: + errors.append(f"{cur_type}: {name}") + print(f"gen error: {name}") + traceback.print_exc() + print() + +print("Can't Generate these inputs:") +print(errors) \ No newline at end of file