From 839a126aedc5a8e19518f6f94c1c681ef8653f34 Mon Sep 17 00:00:00 2001 From: Rajat Mishra <77861069+mst-rajatmishra@users.noreply.github.com> Date: Fri, 8 Nov 2024 16:51:11 +0530 Subject: [PATCH] Update main.py --- main.py | 202 +++++++++++++++++++++++++++++--------------------------- 1 file changed, 104 insertions(+), 98 deletions(-) diff --git a/main.py b/main.py index 01744fe..8b48463 100644 --- a/main.py +++ b/main.py @@ -24,123 +24,129 @@ import torch from PIL import Image import argparse - +from datetime import datetime +from tqdm import tqdm from infer import Text2Image, Removebg, Image2Views, Views2Mesh, GifRenderer +# ---- Define Functions ---- def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--use_lite", default=False, action="store_true" - ) - parser.add_argument( - "--mv23d_cfg_path", default="./svrm/configs/svrm.yaml", type=str - ) - parser.add_argument( - "--mv23d_ckt_path", default="weights/svrm/svrm.safetensors", type=str - ) - parser.add_argument( - "--text2image_path", default="weights/hunyuanDiT", type=str - ) - parser.add_argument( - "--save_folder", default="./outputs/test/", type=str - ) - parser.add_argument( - "--text_prompt", default="", type=str, - ) - parser.add_argument( - "--image_prompt", default="", type=str - ) - parser.add_argument( - "--device", default="cuda:0", type=str - ) - parser.add_argument( - "--t2i_seed", default=0, type=int - ) - parser.add_argument( - "--t2i_steps", default=25, type=int - ) - parser.add_argument( - "--gen_seed", default=0, type=int - ) - parser.add_argument( - "--gen_steps", default=50, type=int - ) - parser.add_argument( - "--max_faces_num", default=80000, type=int, - help="max num of face, suggest 80000 for effect, 10000 for speed" - ) - parser.add_argument( - "--save_memory", default=False, action="store_true" - ) - parser.add_argument( - "--do_texture_mapping", default=False, action="store_true" - ) - parser.add_argument( - "--do_render", default=False, action="store_true" - ) + parser = argparse.ArgumentParser(description="Pipeline for generating 3D models from text or images.") + + # General arguments + parser.add_argument("--use_lite", default=False, action="store_true", help="Use the lite version of models (saves memory).") + parser.add_argument("--save_folder", default="./outputs/test/", type=str, help="Folder to save output files.") + parser.add_argument("--device", default="cuda:0", type=str, help="Device for running the model (e.g., cuda:0 or cpu).") + + # Model paths and configuration + parser.add_argument("--mv23d_cfg_path", default="./svrm/configs/svrm.yaml", type=str, help="Path to the SVRM config file.") + parser.add_argument("--mv23d_ckt_path", default="weights/svrm/svrm.safetensors", type=str, help="Path to the SVRM checkpoint file.") + parser.add_argument("--text2image_path", default="weights/hunyuanDiT", type=str, help="Path to the text-to-image pre-trained model.") + + # Inputs + parser.add_argument("--text_prompt", default="", type=str, help="Text prompt for image generation.") + parser.add_argument("--image_prompt", default="", type=str, help="Image prompt for generating views.") + + # Randomness and steps control + parser.add_argument("--t2i_seed", default=0, type=int, help="Seed for text-to-image generation.") + parser.add_argument("--t2i_steps", default=25, type=int, help="Steps for generating the image from text.") + parser.add_argument("--gen_seed", default=0, type=int, help="Seed for generating 3D mesh from views.") + parser.add_argument("--gen_steps", default=50, type=int, help="Steps for generating 3D mesh.") + + # Mesh generation settings + parser.add_argument("--max_faces_num", default=80000, type=int, help="Max number of faces for the generated 3D mesh.") + parser.add_argument("--do_texture_mapping", default=False, action="store_true", help="Apply texture mapping to the 3D mesh.") + parser.add_argument("--do_render", default=False, action="store_true", help="Render a rotating gif of the 3D model.") + + # Memory and output settings + parser.add_argument("--save_memory", default=False, action="store_true", help="Save memory by optimizing model inference.") + parser.add_argument("--save_intermediate", default=False, action="store_true", help="Save intermediate steps for debugging.") + parser.add_argument("--verbose", default=False, action="store_true", help="Enable verbose output during processing.") + + # Custom output filename + parser.add_argument("--output_name", type=str, default="output", help="Base name for output files (e.g., output_mesh.obj, output.gif).") + return parser.parse_args() +def check_paths(args): + """Validate file paths before starting the pipeline.""" + assert args.text_prompt or args.image_prompt, "You must provide either a text prompt or an image prompt." + assert not (args.text_prompt and args.image_prompt), "You cannot provide both text and image prompts." + + # Check model and config paths + if not os.path.exists(args.mv23d_cfg_path): + raise FileNotFoundError(f"Configuration file not found: {args.mv23d_cfg_path}") + if not os.path.exists(args.mv23d_ckt_path): + raise FileNotFoundError(f"Checkpoint file not found: {args.mv23d_ckt_path}") + if not os.path.exists(args.text2image_path): + raise FileNotFoundError(f"Text-to-image model not found: {args.text2image_path}") -if __name__ == "__main__": + os.makedirs(args.save_folder, exist_ok=True) + + # Ensure save folder exists + save_subfolders = ["images", "models", "renders"] + for subfolder in save_subfolders: + os.makedirs(os.path.join(args.save_folder, subfolder), exist_ok=True) + +def save_image(image, filename, folder): + """Save images to the specified folder.""" + image.save(os.path.join(folder, filename)) + +def main(): args = get_args() - - assert not (args.text_prompt and args.image_prompt), "Text and image can only be given to one" - assert args.text_prompt or args.image_prompt, "Text and image can only be given to one" - # init model + # Check and validate paths + check_paths(args) + + # Initialize models rembg_model = Removebg() image_to_views_model = Image2Views(device=args.device, use_lite=args.use_lite) views_to_mesh_model = Views2Mesh(args.mv23d_cfg_path, args.mv23d_ckt_path, args.device, use_lite=args.use_lite) + if args.text_prompt: - text_to_image_model = Text2Image( - pretrain = args.text2image_path, - device = args.device, - save_memory = args.save_memory - ) + text_to_image_model = Text2Image(pretrain=args.text2image_path, device=args.device, save_memory=args.save_memory) if args.do_render: gif_renderer = GifRenderer(device=args.device) - # ---- ----- ---- ---- ---- ---- - - os.makedirs(args.save_folder, exist_ok=True) - - # stage 1, text to image + # ---- Stage 1: Text-to-Image Generation ---- if args.text_prompt: - res_rgb_pil = text_to_image_model( - args.text_prompt, - seed=args.t2i_seed, - steps=args.t2i_steps - ) - res_rgb_pil.save(os.path.join(args.save_folder, "img.jpg")) + if args.verbose: + print("Generating image from text prompt...") + res_rgb_pil = text_to_image_model(args.text_prompt, seed=args.t2i_seed, steps=args.t2i_steps) + save_image(res_rgb_pil, f"{args.output_name}_img.jpg", os.path.join(args.save_folder, "images")) + elif args.image_prompt: + if args.verbose: + print("Loading provided image...") res_rgb_pil = Image.open(args.image_prompt) - # stage 2, remove back ground + # ---- Stage 2: Background Removal ---- + if args.verbose: + print("Removing background from image...") res_rgba_pil = rembg_model(res_rgb_pil) - res_rgb_pil.save(os.path.join(args.save_folder, "img_nobg.png")) - - # stage 3, image to views - (views_grid_pil, cond_img), view_pil_list = image_to_views_model( - res_rgba_pil, - seed = args.gen_seed, - steps = args.gen_steps - ) - views_grid_pil.save(os.path.join(args.save_folder, "views.jpg")) - - # stage 4, views to mesh - views_to_mesh_model( - views_grid_pil, - cond_img, - seed = args.gen_seed, - target_face_count = args.max_faces_num, - save_folder = args.save_folder, - do_texture_mapping = args.do_texture_mapping - ) - - # stage 5, render gif + save_image(res_rgba_pil, f"{args.output_name}_img_nobg.png", os.path.join(args.save_folder, "images")) + + # ---- Stage 3: Image to Views Generation ---- + if args.verbose: + print("Generating views from image...") + views_grid_pil, cond_img = image_to_views_model(res_rgba_pil, seed=args.gen_seed, steps=args.gen_steps) + save_image(views_grid_pil, f"{args.output_name}_views.jpg", os.path.join(args.save_folder, "images")) + + # ---- Stage 4: Views to Mesh ---- + if args.verbose: + print("Generating 3D mesh from views...") + views_to_mesh_model(views_grid_pil, cond_img, seed=args.gen_seed, target_face_count=args.max_faces_num, + save_folder=os.path.join(args.save_folder, "models"), do_texture_mapping=args.do_texture_mapping) + + # ---- Stage 5: Render GIF ---- if args.do_render: - gif_renderer( - os.path.join(args.save_folder, 'mesh.obj'), - gif_dst_path = os.path.join(args.save_folder, 'output.gif'), - ) + if args.verbose: + print("Rendering gif of 3D model...") + gif_renderer(os.path.join(args.save_folder, "models", f"{args.output_name}_mesh.obj"), + gif_dst_path=os.path.join(args.save_folder, "renders", f"{args.output_name}.gif")) + + if args.verbose: + print(f"Process complete. Output saved in {args.save_folder}.") + +if __name__ == "__main__": + main()