Files
OpenECAD_Project/run_model.py
2024-11-06 20:35:42 +08:00

191 lines
5.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from tinyllava.eval.run_tiny_llava import *
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str, required=True, help="Path to the model.")
parser.add_argument('--src', type=str, required=True, help="Path of Input Pictures and Reference Codes.")
parser.add_argument('--out', type=str, required=True, help="Any output path you like.")
# 解析命令行参数
args = parser.parse_args()
# 将命令行输入的值传递给相应的变量
model_path = args.model_path
src_base = args.src
out_base = args.out
input_types = ["Default"]#["Default", "Transparent", "Orthographic"]
conv_mode = "gemma" # or llama, gemma, phi
## You need to change the "max_new_tokens" if the model can't deal with long tokens.
## possible values: 1024, 1152, 1536, 2048, 3072
args = type('Args', (), {
"model_path": model_path,
"model_base": None,
"conv_mode": conv_mode,
"sep": ",",
"temperature": 0,
"top_p": None,
"num_beams": 1,
"max_new_tokens": 2048
})()
# Model
disable_torch_init()
if args.model_path is not None:
model, tokenizer, image_processor, context_len = load_pretrained_model(args.model_path)
else:
assert args.model is not None, 'model_path or model must be provided'
model = args.model
if hasattr(model.config, "max_sequence_length"):
context_len = model.config.max_sequence_length
else:
context_len = 2048
tokenizer = model.tokenizer
image_processor = model.vision_tower._image_processor
text_processor = TextPreprocess(tokenizer, args.conv_mode)
data_args = model.config
image_processor = ImagePreprocess(image_processor, data_args)
model.cuda()
import os
def ensure_dir(path):
"""
create path by first checking its existence,
:param paths: path
:return:
"""
if not os.path.exists(path):
os.makedirs(path)
import signal
class TimeoutException(Exception):
pass
def handler(signum, frame):
raise TimeoutException()
# Set timeout (unit: s)
timeout = 300
def timeout_decorator(func):
def wrapper(*args, **kwargs):
signal.signal(signal.SIGALRM, handler)
signal.alarm(timeout)
try:
result = func(*args, **kwargs)
except TimeoutException:
print("Function timed out!")
raise TimeoutException
result = None
finally:
signal.alarm(0)
return result
return wrapper
@timeout_decorator
def process_image(qs, path):
qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
msg = Message()
msg.add_message(qs)
result = text_processor(msg.messages, mode='eval')
input_ids = result['input_ids']
prompt = result['prompt']
input_ids = input_ids.unsqueeze(0).cuda()
image_files = [path]
images = load_images(image_files)[0]
images_tensor = image_processor(images)
images_tensor = images_tensor.unsqueeze(0).half().cuda()
stop_str = text_processor.template.separator.apply()[1]
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=images_tensor,
do_sample=True if args.temperature > 0 else False,
temperature=args.temperature,
top_p=args.top_p,
num_beams=args.num_beams,
pad_token_id=tokenizer.pad_token_id,
max_new_tokens=args.max_new_tokens,
use_cache=True,
stopping_criteria=[stopping_criteria],
)
outputs = tokenizer.batch_decode(
output_ids, skip_special_tokens=True
)[0]
outputs = outputs.strip()
if outputs.endswith(stop_str):
outputs = outputs[: -len(stop_str)]
outputs = outputs.strip()
return outputs
import re
def extract_python_code(input_str):
# 匹配以```python开头```结束的内容
match = re.search(r'```python(.*?)```', input_str, re.DOTALL)
if match:
# 如果找到匹配的内容,返回```python和```之间的内容
return match.group(1)
else:
# 如果没有```python```包裹的内容,返回```python后面的所有内容
# 找到```python的位置并返回从该位置到字符串末尾的所有内容
match = re.search(r'```python(.*)', input_str, re.DOTALL)
if match:
return match.group(1)
else:
return input_str
import os
import glob
import traceback
errors = []
for index in range(len(input_types)):
cur_type = input_types[index]
src = os.path.join(src_base, cur_type)
out = os.path.join(out_base, cur_type)
ensure_dir(out)
out_paths = sorted(glob.glob(os.path.join(src, "*.{}".format("jpg"))))
if cur_type == "Orthographic" :
qs = "This image is 4 views of a 3D model from certain angles. Please try to use Python-style APIs to render this model."
else:
qs = "This image is a view of a 3D model from a certain angle. Please try to use Python-style APIs to render this model."
for i in range(len(out_paths)):
path = out_paths[i]
print(f"{cur_type}: {i + 1}/{len(out_paths)}", end='\r')
name = path.split("/")[-1].split(".")[0]
save_path = os.path.join(out, f'{name}.py')
if os.path.isfile(save_path): continue
try:
outputs = process_image(qs, path)
outputs = extract_python_code(outputs)
with open(save_path, 'w', encoding='utf-8') as file:
file.write(outputs)
file.close()
except:
errors.append(f"{cur_type}: {name}")
print(f"gen error: {name}")
traceback.print_exc()
print()
print("Can't Generate these inputs:")
print(errors)