diff --git a/README.md b/README.md index f55e315..9adefa6 100644 --- a/README.md +++ b/README.md @@ -49,7 +49,7 @@ Data processing is streamlined for instant conversions that are fully **renderin - [x] Release image-to-3D inference code - [x] Release pretrained checkpoints (4B) - [x] Hugging Face Spaces demo -- [ ] Release shape-conditioned texture generation inference code (Current schdule: before 12/24/2025) +- [x] Release shape-conditioned texture generation inference code - [ ] Release training code (Current schdule: before 12/31/2025) @@ -184,7 +184,7 @@ Then, you can access the demo at the address shown in the terminal. ### 2. PBR Texture Generation -Will be released soon. Please stay tuned! +Please refer to the [example_texturing.py](example_texturing.py) for an example of how to generate PBR textures for a given 3D shape. Also, you can use the [app_texturing.py](app_texturing.py) to run a web demo for PBR texture generation. ## 🧩 Related Packages diff --git a/app_texturing.py b/app_texturing.py new file mode 100644 index 0000000..50b2873 --- /dev/null +++ b/app_texturing.py @@ -0,0 +1,151 @@ +import gradio as gr + +import os +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +from datetime import datetime +import shutil +from typing import * +import torch +import numpy as np +import trimesh +from PIL import Image +from trellis2.pipelines import Trellis2TexturingPipeline + + +MAX_SEED = np.iinfo(np.int32).max +TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp') + + +def start_session(req: gr.Request): + user_dir = os.path.join(TMP_DIR, str(req.session_hash)) + os.makedirs(user_dir, exist_ok=True) + + +def end_session(req: gr.Request): + user_dir = os.path.join(TMP_DIR, str(req.session_hash)) + shutil.rmtree(user_dir) + + +def preprocess_image(image: Image.Image) -> Image.Image: + """ + Preprocess the input image. + + Args: + image (Image.Image): The input image. + + Returns: + Image.Image: The preprocessed image. + """ + processed_image = pipeline.preprocess_image(image) + return processed_image + + +def get_seed(randomize_seed: bool, seed: int) -> int: + """ + Get the random seed. + """ + return np.random.randint(0, MAX_SEED) if randomize_seed else seed + + +def shapeimage_to_tex( + mesh_file: str, + image: Image.Image, + seed: int, + resolution: str, + texture_size: int, + tex_slat_guidance_strength: float, + tex_slat_guidance_rescale: float, + tex_slat_sampling_steps: int, + tex_slat_rescale_t: float, + req: gr.Request, + progress=gr.Progress(track_tqdm=True), +) -> str: + mesh = trimesh.load(mesh_file) + if isinstance(mesh, trimesh.Scene): + mesh = mesh.to_mesh() + output = pipeline.run( + mesh, + image, + seed=seed, + preprocess_image=False, + tex_slat_sampler_params={ + "steps": tex_slat_sampling_steps, + "guidance_strength": tex_slat_guidance_strength, + "guidance_rescale": tex_slat_guidance_rescale, + "rescale_t": tex_slat_rescale_t, + }, + resolution=int(resolution), + texture_size=texture_size, + ) + now = datetime.now() + timestamp = now.strftime("%Y-%m-%dT%H%M%S") + f".{now.microsecond // 1000:03d}" + user_dir = os.path.join(TMP_DIR, str(req.session_hash)) + os.makedirs(user_dir, exist_ok=True) + glb_path = os.path.join(user_dir, f'sample_{timestamp}.glb') + output.export(glb_path, extension_webp=True) + torch.cuda.empty_cache() + return glb_path, glb_path + + +with gr.Blocks(delete_cache=(600, 600)) as demo: + gr.Markdown(""" + ## Texturing a mesh with [TRELLIS.2](https://microsoft.github.io/TRELLIS.2) + * Upload a mesh and corresponding reference image (preferably with an alpha-masked foreground object) and click Generate to create a textured 3D asset. + """) + + with gr.Row(): + with gr.Column(scale=1, min_width=360): + mesh_file = gr.File(label="Upload Mesh", file_types=[".ply", ".obj", ".glb", ".gltf"], file_count="single") + image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=400) + + resolution = gr.Radio(["512", "1024", "1536"], label="Resolution", value="1024") + seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1) + randomize_seed = gr.Checkbox(label="Randomize Seed", value=True) + texture_size = gr.Slider(1024, 4096, label="Texture Size", value=2048, step=1024) + + generate_btn = gr.Button("Generate") + + with gr.Accordion(label="Advanced Settings", open=False): + with gr.Row(): + tex_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=1.0, step=0.1) + tex_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.0, step=0.01) + tex_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1) + tex_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1) + + with gr.Column(scale=10): + glb_output = gr.Model3D(label="Extracted GLB", height=724, show_label=True, display_mode="solid", clear_color=(0.25, 0.25, 0.25, 1.0)) + download_btn = gr.DownloadButton(label="Download GLB") + + + # Handlers + demo.load(start_session) + demo.unload(end_session) + + image_prompt.upload( + preprocess_image, + inputs=[image_prompt], + outputs=[image_prompt], + ) + + generate_btn.click( + get_seed, + inputs=[randomize_seed, seed], + outputs=[seed], + ).then( + shapeimage_to_tex, + inputs=[ + mesh_file, image_prompt, seed, resolution, texture_size, + tex_slat_guidance_strength, tex_slat_guidance_rescale, tex_slat_sampling_steps, tex_slat_rescale_t, + ], + outputs=[glb_output, download_btn], + ) + + +# Launch the Gradio app +if __name__ == "__main__": + os.makedirs(TMP_DIR, exist_ok=True) + + pipeline = Trellis2TexturingPipeline.from_pretrained('microsoft/TRELLIS.2-4B', config_file="texturing_pipeline.json") + pipeline.cuda() + + demo.launch() diff --git a/assets/example_texturing/image.webp b/assets/example_texturing/image.webp new file mode 100644 index 0000000..a69f81d Binary files /dev/null and b/assets/example_texturing/image.webp differ diff --git a/assets/example_texturing/readme b/assets/example_texturing/readme new file mode 100644 index 0000000..7f6c799 --- /dev/null +++ b/assets/example_texturing/readme @@ -0,0 +1,11 @@ +## Asset Information + +* Title: The Forgotten Knight +* Author: dark_igorek +* Source: https://sketchfab.com/3d-models/the-forgotten-knight-d14eb14d83bd4e7ba7cbe443d76a10fd +* License: Creative Commons Attribution (CC BY) + +## Usage + +The asset is used for research purposes only. +Please credit the original author and include the Sketchfab link when using or redistributing this model. \ No newline at end of file diff --git a/assets/example_texturing/the_forgotten_knight.ply b/assets/example_texturing/the_forgotten_knight.ply new file mode 100644 index 0000000..b243a8f Binary files /dev/null and b/assets/example_texturing/the_forgotten_knight.ply differ diff --git a/example_texturing.py b/example_texturing.py new file mode 100644 index 0000000..5cb5ffc --- /dev/null +++ b/example_texturing.py @@ -0,0 +1,17 @@ +import os +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" # Can save GPU memory +import trimesh +from PIL import Image +from trellis2.pipelines import Trellis2TexturingPipeline + +# 1. Load Pipeline +pipeline = Trellis2TexturingPipeline.from_pretrained("microsoft/TRELLIS.2-4B", config_file="texturing_pipeline.json") +pipeline.cuda() + +# 2. Load Mesh, image & Run +mesh = trimesh.load("assets/example_texturing/the_forgotten_knight.ply") +image = Image.open("assets/example_texturing/image.webp") +output = pipeline.run(mesh, image) + +# 3. Render Mesh +output.export("textured.glb", extension_webp=True) \ No newline at end of file diff --git a/trellis2/pipelines/__init__.py b/trellis2/pipelines/__init__.py index 53d8917..9bb6835 100644 --- a/trellis2/pipelines/__init__.py +++ b/trellis2/pipelines/__init__.py @@ -2,8 +2,7 @@ import importlib __attributes = { "Trellis2ImageTo3DPipeline": "trellis2_image_to_3d", - "Trellis2ImageTo3DCascadePipeline": "trellis2_image_to_3d_cascade", - "Trellis2ImageToTexturePipeline": "trellis2_image_to_tex", + "Trellis2TexturingPipeline": "trellis2_texturing", } __submodules = ['samplers', 'rembg'] @@ -49,7 +48,5 @@ def from_pretrained(path: str): # For PyLance if __name__ == '__main__': from . import samplers, rembg - from .trellis_image_to_3d import TrellisImageTo3DPipeline from .trellis2_image_to_3d import Trellis2ImageTo3DPipeline - from .trellis2_image_to_3d_cascade import Trellis2ImageTo3DCascadePipeline - from .trellis2_image_to_tex import Trellis2ImageToTexturePipeline + from .trellis2_texturing import Trellis2TexturingPipeline diff --git a/trellis2/pipelines/base.py b/trellis2/pipelines/base.py index d897825..331e1ed 100644 --- a/trellis2/pipelines/base.py +++ b/trellis2/pipelines/base.py @@ -18,32 +18,34 @@ class Pipeline: for model in self.models.values(): model.eval() - @staticmethod - def from_pretrained(path: str) -> "Pipeline": + @classmethod + def from_pretrained(cls, path: str, config_file: str = "pipeline.json") -> "Pipeline": """ Load a pretrained model. """ import os import json - is_local = os.path.exists(f"{path}/pipeline.json") + is_local = os.path.exists(f"{path}/{config_file}") if is_local: - config_file = f"{path}/pipeline.json" + config_file = f"{path}/{config_file}" else: from huggingface_hub import hf_hub_download - config_file = hf_hub_download(path, "pipeline.json") + config_file = hf_hub_download(path, config_file) with open(config_file, 'r') as f: args = json.load(f)['args'] _models = {} for k, v in args['models'].items(): + if hasattr(cls, 'model_names_to_load') and k not in cls.model_names_to_load: + continue try: _models[k] = models.from_pretrained(f"{path}/{v}") except Exception as e: _models[k] = models.from_pretrained(v) - new_pipeline = Pipeline(_models) + new_pipeline = cls(_models) new_pipeline._pretrained_args = args return new_pipeline diff --git a/trellis2/pipelines/trellis2_image_to_3d.py b/trellis2/pipelines/trellis2_image_to_3d.py index 8d7afd5..a7b84b9 100644 --- a/trellis2/pipelines/trellis2_image_to_3d.py +++ b/trellis2/pipelines/trellis2_image_to_3d.py @@ -28,6 +28,17 @@ class Trellis2ImageTo3DPipeline(Pipeline): rembg_model (Callable): The model for removing background. low_vram (bool): Whether to use low-VRAM mode. """ + model_names_to_load = [ + 'sparse_structure_flow_model', + 'sparse_structure_decoder', + 'shape_slat_flow_model_512', + 'shape_slat_flow_model_1024', + 'shape_slat_decoder', + 'tex_slat_flow_model_512', + 'tex_slat_flow_model_1024', + 'tex_slat_decoder', + ] + def __init__( self, models: dict[str, nn.Module] = None, @@ -67,45 +78,43 @@ class Trellis2ImageTo3DPipeline(Pipeline): } self._device = 'cpu' - @staticmethod - def from_pretrained(path: str) -> "Trellis2ImageTo3DPipeline": + @classmethod + def from_pretrained(cls, path: str, config_file: str = "pipeline.json") -> "Trellis2ImageTo3DPipeline": """ Load a pretrained model. Args: path (str): The path to the model. Can be either local path or a Hugging Face repository. """ - pipeline = super(Trellis2ImageTo3DPipeline, Trellis2ImageTo3DPipeline).from_pretrained(path) - new_pipeline = Trellis2ImageTo3DPipeline() - new_pipeline.__dict__ = pipeline.__dict__ + pipeline = super().from_pretrained(path, config_file) args = pipeline._pretrained_args - new_pipeline.sparse_structure_sampler = getattr(samplers, args['sparse_structure_sampler']['name'])(**args['sparse_structure_sampler']['args']) - new_pipeline.sparse_structure_sampler_params = args['sparse_structure_sampler']['params'] + pipeline.sparse_structure_sampler = getattr(samplers, args['sparse_structure_sampler']['name'])(**args['sparse_structure_sampler']['args']) + pipeline.sparse_structure_sampler_params = args['sparse_structure_sampler']['params'] - new_pipeline.shape_slat_sampler = getattr(samplers, args['shape_slat_sampler']['name'])(**args['shape_slat_sampler']['args']) - new_pipeline.shape_slat_sampler_params = args['shape_slat_sampler']['params'] + pipeline.shape_slat_sampler = getattr(samplers, args['shape_slat_sampler']['name'])(**args['shape_slat_sampler']['args']) + pipeline.shape_slat_sampler_params = args['shape_slat_sampler']['params'] - new_pipeline.tex_slat_sampler = getattr(samplers, args['tex_slat_sampler']['name'])(**args['tex_slat_sampler']['args']) - new_pipeline.tex_slat_sampler_params = args['tex_slat_sampler']['params'] + pipeline.tex_slat_sampler = getattr(samplers, args['tex_slat_sampler']['name'])(**args['tex_slat_sampler']['args']) + pipeline.tex_slat_sampler_params = args['tex_slat_sampler']['params'] - new_pipeline.shape_slat_normalization = args['shape_slat_normalization'] - new_pipeline.tex_slat_normalization = args['tex_slat_normalization'] + pipeline.shape_slat_normalization = args['shape_slat_normalization'] + pipeline.tex_slat_normalization = args['tex_slat_normalization'] - new_pipeline.image_cond_model = getattr(image_feature_extractor, args['image_cond_model']['name'])(**args['image_cond_model']['args']) - new_pipeline.rembg_model = getattr(rembg, args['rembg_model']['name'])(**args['rembg_model']['args']) + pipeline.image_cond_model = getattr(image_feature_extractor, args['image_cond_model']['name'])(**args['image_cond_model']['args']) + pipeline.rembg_model = getattr(rembg, args['rembg_model']['name'])(**args['rembg_model']['args']) - new_pipeline.low_vram = args.get('low_vram', True) - new_pipeline.default_pipeline_type = args.get('default_pipeline_type', '1024_cascade') - new_pipeline.pbr_attr_layout = { + pipeline.low_vram = args.get('low_vram', True) + pipeline.default_pipeline_type = args.get('default_pipeline_type', '1024_cascade') + pipeline.pbr_attr_layout = { 'base_color': slice(0, 3), 'metallic': slice(3, 4), 'roughness': slice(4, 5), 'alpha': slice(5, 6), } - new_pipeline._device = 'cpu' + pipeline._device = 'cpu' - return new_pipeline + return pipeline def to(self, device: torch.device) -> None: self._device = device @@ -364,7 +373,6 @@ class Trellis2ImageTo3DPipeline(Pipeline): Args: slat (SparseTensor): The structured latent. - formats (List[str]): The formats to decode the structured latent to. Returns: List[Mesh]: The decoded meshes. @@ -433,10 +441,9 @@ class Trellis2ImageTo3DPipeline(Pipeline): Args: slat (SparseTensor): The structured latent. - formats (List[str]): The formats to decode the structured latent to. Returns: - List[SparseTensor]: The decoded texture voxels + SparseTensor: The decoded texture voxels """ if self.low_vram: self.models['tex_slat_decoder'].to(self.device) diff --git a/trellis2/pipelines/trellis2_texturing.py b/trellis2/pipelines/trellis2_texturing.py new file mode 100755 index 0000000..c184b5e --- /dev/null +++ b/trellis2/pipelines/trellis2_texturing.py @@ -0,0 +1,408 @@ +from typing import * +import torch +import torch.nn as nn +import numpy as np +from PIL import Image +import trimesh +from .base import Pipeline +from . import samplers, rembg +from ..modules.sparse import SparseTensor +from ..modules import image_feature_extractor +import o_voxel +import cumesh +import nvdiffrast.torch as dr +import cv2 +import flex_gemm + + +class Trellis2TexturingPipeline(Pipeline): + """ + Pipeline for inferring Trellis2 image-to-3D models. + + Args: + models (dict[str, nn.Module]): The models to use in the pipeline. + tex_slat_sampler (samplers.Sampler): The sampler for the texture latent. + tex_slat_sampler_params (dict): The parameters for the texture latent sampler. + shape_slat_normalization (dict): The normalization parameters for the structured latent. + tex_slat_normalization (dict): The normalization parameters for the texture latent. + image_cond_model (Callable): The image conditioning model. + rembg_model (Callable): The model for removing background. + low_vram (bool): Whether to use low-VRAM mode. + """ + model_names_to_load = [ + 'shape_slat_encoder', + 'tex_slat_decoder', + 'tex_slat_flow_model_512', + 'tex_slat_flow_model_1024' + ] + + def __init__( + self, + models: dict[str, nn.Module] = None, + tex_slat_sampler: samplers.Sampler = None, + tex_slat_sampler_params: dict = None, + shape_slat_normalization: dict = None, + tex_slat_normalization: dict = None, + image_cond_model: Callable = None, + rembg_model: Callable = None, + low_vram: bool = True, + ): + if models is None: + return + super().__init__(models) + self.tex_slat_sampler = tex_slat_sampler + self.tex_slat_sampler_params = tex_slat_sampler_params + self.shape_slat_normalization = shape_slat_normalization + self.tex_slat_normalization = tex_slat_normalization + self.image_cond_model = image_cond_model + self.rembg_model = rembg_model + self.low_vram = low_vram + self.pbr_attr_layout = { + 'base_color': slice(0, 3), + 'metallic': slice(3, 4), + 'roughness': slice(4, 5), + 'alpha': slice(5, 6), + } + self._device = 'cpu' + + @classmethod + def from_pretrained(cls, path: str, config_file: str = "pipeline.json") -> "Trellis2TexturingPipeline": + """ + Load a pretrained model. + + Args: + path (str): The path to the model. Can be either local path or a Hugging Face repository. + """ + pipeline = super().from_pretrained(path, config_file) + args = pipeline._pretrained_args + + pipeline.tex_slat_sampler = getattr(samplers, args['tex_slat_sampler']['name'])(**args['tex_slat_sampler']['args']) + pipeline.tex_slat_sampler_params = args['tex_slat_sampler']['params'] + + pipeline.shape_slat_normalization = args['shape_slat_normalization'] + pipeline.tex_slat_normalization = args['tex_slat_normalization'] + + pipeline.image_cond_model = getattr(image_feature_extractor, args['image_cond_model']['name'])(**args['image_cond_model']['args']) + pipeline.rembg_model = getattr(rembg, args['rembg_model']['name'])(**args['rembg_model']['args']) + + pipeline.low_vram = args.get('low_vram', True) + pipeline.pbr_attr_layout = { + 'base_color': slice(0, 3), + 'metallic': slice(3, 4), + 'roughness': slice(4, 5), + 'alpha': slice(5, 6), + } + pipeline._device = 'cpu' + return pipeline + + def to(self, device: torch.device) -> None: + self._device = device + if not self.low_vram: + super().to(device) + self.image_cond_model.to(device) + if self.rembg_model is not None: + self.rembg_model.to(device) + + def preprocess_mesh(self, mesh: trimesh.Trimesh) -> trimesh.Trimesh: + """ + Preprocess the input mesh. + """ + vertices = mesh.vertices + vertices_min = vertices.min(axis=0) + vertices_max = vertices.max(axis=0) + center = (vertices_min + vertices_max) / 2 + scale = 0.99999 / (vertices_max - vertices_min).max() + vertices = (vertices - center) * scale + tmp = vertices[:, 1].copy() + vertices[:, 1] = -vertices[:, 2] + vertices[:, 2] = tmp + assert np.all(vertices >= -0.5) and np.all(vertices <= 0.5), 'vertices out of range' + return trimesh.Trimesh(vertices=vertices, faces=mesh.faces, process=False) + + def preprocess_image(self, input: Image.Image) -> Image.Image: + """ + Preprocess the input image. + """ + # if has alpha channel, use it directly; otherwise, remove background + has_alpha = False + if input.mode == 'RGBA': + alpha = np.array(input)[:, :, 3] + if not np.all(alpha == 255): + has_alpha = True + max_size = max(input.size) + scale = min(1, 1024 / max_size) + if scale < 1: + input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS) + if has_alpha: + output = input + else: + input = input.convert('RGB') + if self.low_vram: + self.rembg_model.to(self.device) + output = self.rembg_model(input) + if self.low_vram: + self.rembg_model.cpu() + output_np = np.array(output) + alpha = output_np[:, :, 3] + bbox = np.argwhere(alpha > 0.8 * 255) + bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0]) + center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2 + size = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) + size = int(size * 1) + bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2 + output = output.crop(bbox) # type: ignore + output = np.array(output).astype(np.float32) / 255 + output = output[:, :, :3] * output[:, :, 3:4] + output = Image.fromarray((output * 255).astype(np.uint8)) + return output + + def get_cond(self, image: Union[torch.Tensor, list[Image.Image]], resolution: int, include_neg_cond: bool = True) -> dict: + """ + Get the conditioning information for the model. + + Args: + image (Union[torch.Tensor, list[Image.Image]]): The image prompts. + + Returns: + dict: The conditioning information + """ + self.image_cond_model.image_size = resolution + if self.low_vram: + self.image_cond_model.to(self.device) + cond = self.image_cond_model(image) + if self.low_vram: + self.image_cond_model.cpu() + if not include_neg_cond: + return {'cond': cond} + neg_cond = torch.zeros_like(cond) + return { + 'cond': cond, + 'neg_cond': neg_cond, + } + + def encode_shape_slat( + self, + mesh: trimesh.Trimesh, + resolution: int = 1024, + ) -> SparseTensor: + """ + Encode the meshes to structured latent. + + Args: + mesh (trimesh.Trimesh): The mesh to encode. + resolution (int): The resolution of mesh + + Returns: + SparseTensor: The encoded structured latent. + """ + vertices = torch.from_numpy(mesh.vertices).float() + faces = torch.from_numpy(mesh.faces).long() + + voxel_indices, dual_vertices, intersected = o_voxel.convert.mesh_to_flexible_dual_grid( + vertices.cpu(), faces.cpu(), + grid_size=resolution, + aabb=[[-0.5,-0.5,-0.5],[0.5,0.5,0.5]], + face_weight=1.0, + boundary_weight=0.2, + regularization_weight=1e-2, + timing=True, + ) + + vertices = SparseTensor( + feats=dual_vertices * resolution - voxel_indices, + coords=torch.cat([torch.zeros_like(voxel_indices[:, 0:1]), voxel_indices], dim=-1) + ).to(self.device) + intersected = vertices.replace(intersected).to(self.device) + + if self.low_vram: + self.models['shape_slat_encoder'].to(self.device) + shape_slat = self.models['shape_slat_encoder'](vertices, intersected) + if self.low_vram: + self.models['shape_slat_encoder'].cpu() + return shape_slat + + def sample_tex_slat( + self, + cond: dict, + flow_model, + shape_slat: SparseTensor, + sampler_params: dict = {}, + ) -> SparseTensor: + """ + Sample structured latent with the given conditioning. + + Args: + cond (dict): The conditioning information. + shape_slat (SparseTensor): The structured latent for shape + sampler_params (dict): Additional parameters for the sampler. + """ + # Sample structured latent + std = torch.tensor(self.shape_slat_normalization['std'])[None].to(shape_slat.device) + mean = torch.tensor(self.shape_slat_normalization['mean'])[None].to(shape_slat.device) + shape_slat = (shape_slat - mean) / std + + in_channels = flow_model.in_channels if isinstance(flow_model, nn.Module) else flow_model[0].in_channels + noise = shape_slat.replace(feats=torch.randn(shape_slat.coords.shape[0], in_channels - shape_slat.feats.shape[1]).to(self.device)) + sampler_params = {**self.tex_slat_sampler_params, **sampler_params} + if self.low_vram: + flow_model.to(self.device) + slat = self.tex_slat_sampler.sample( + flow_model, + noise, + concat_cond=shape_slat, + **cond, + **sampler_params, + verbose=True, + tqdm_desc="Sampling texture SLat", + ).samples + if self.low_vram: + flow_model.cpu() + + std = torch.tensor(self.tex_slat_normalization['std'])[None].to(slat.device) + mean = torch.tensor(self.tex_slat_normalization['mean'])[None].to(slat.device) + slat = slat * std + mean + + return slat + + def decode_tex_slat( + self, + slat: SparseTensor, + ) -> SparseTensor: + """ + Decode the structured latent. + + Args: + slat (SparseTensor): The structured latent. + + Returns: + SparseTensor: The decoded texture voxels + """ + if self.low_vram: + self.models['tex_slat_decoder'].to(self.device) + ret = self.models['tex_slat_decoder'](slat) * 0.5 + 0.5 + if self.low_vram: + self.models['tex_slat_decoder'].cpu() + return ret + + def postprocess_mesh( + self, + mesh: trimesh.Trimesh, + pbr_voxel: SparseTensor, + resolution: int = 1024, + texture_size: int = 1024, + ) -> trimesh.Trimesh: + vertices = mesh.vertices + faces = mesh.faces + normals = mesh.vertex_normals + vertices_torch = torch.from_numpy(vertices).float().cuda() + faces_torch = torch.from_numpy(faces).int().cuda() + if hasattr(mesh, 'visual') and hasattr(mesh.visual, 'uv') and mesh.visual.uv is not None: + uvs = mesh.visual.uv.copy() + uvs[:, 1] = 1 - uvs[:, 1] + uvs_torch = torch.from_numpy(uvs).float().cuda() + else: + _cumesh = cumesh.CuMesh() + _cumesh.init(vertices_torch, faces_torch) + vertices_torch, faces_torch, uvs_torch, vmap = _cumesh.uv_unwrap(return_vmaps=True) + vertices_torch = vertices_torch.cuda() + faces_torch = faces_torch.cuda() + uvs_torch = uvs_torch.cuda() + vertices = vertices_torch.cpu().numpy() + faces = faces_torch.cpu().numpy() + uvs = uvs_torch.cpu().numpy() + normals = normals[vmap.cpu().numpy()] + + # rasterize + ctx = dr.RasterizeCudaContext() + uvs_torch = torch.cat([uvs_torch * 2 - 1, torch.zeros_like(uvs_torch[:, :1]), torch.ones_like(uvs_torch[:, :1])], dim=-1).unsqueeze(0) + rast, _ = dr.rasterize( + ctx, uvs_torch, faces_torch, + resolution=[texture_size, texture_size], + ) + mask = rast[0, ..., 3] > 0 + pos = dr.interpolate(vertices_torch.unsqueeze(0), rast, faces_torch)[0][0] + + attrs = torch.zeros(texture_size, texture_size, pbr_voxel.shape[1], device=self.device) + attrs[mask] = flex_gemm.ops.grid_sample.grid_sample_3d( + pbr_voxel.feats, + pbr_voxel.coords, + shape=torch.Size([*pbr_voxel.shape, *pbr_voxel.spatial_shape]), + grid=((pos[mask] + 0.5) * resolution).reshape(1, -1, 3), + mode='trilinear', + ) + + # construct mesh + mask = mask.cpu().numpy() + base_color = np.clip(attrs[..., self.pbr_attr_layout['base_color']].cpu().numpy() * 255, 0, 255).astype(np.uint8) + metallic = np.clip(attrs[..., self.pbr_attr_layout['metallic']].cpu().numpy() * 255, 0, 255).astype(np.uint8) + roughness = np.clip(attrs[..., self.pbr_attr_layout['roughness']].cpu().numpy() * 255, 0, 255).astype(np.uint8) + alpha = np.clip(attrs[..., self.pbr_attr_layout['alpha']].cpu().numpy() * 255, 0, 255).astype(np.uint8) + + # extend + mask = (~mask).astype(np.uint8) + base_color = cv2.inpaint(base_color, mask, 3, cv2.INPAINT_TELEA) + metallic = cv2.inpaint(metallic, mask, 1, cv2.INPAINT_TELEA)[..., None] + roughness = cv2.inpaint(roughness, mask, 1, cv2.INPAINT_TELEA)[..., None] + alpha = cv2.inpaint(alpha, mask, 1, cv2.INPAINT_TELEA)[..., None] + + material = trimesh.visual.material.PBRMaterial( + baseColorTexture=Image.fromarray(np.concatenate([base_color, alpha], axis=-1)), + baseColorFactor=np.array([255, 255, 255, 255], dtype=np.uint8), + metallicRoughnessTexture=Image.fromarray(np.concatenate([np.zeros_like(metallic), roughness, metallic], axis=-1)), + metallicFactor=1.0, + roughnessFactor=1.0, + alphaMode='OPAQUE', + doubleSided=True, + ) + + # Swap Y and Z axes, invert Y (common conversion for GLB compatibility) + vertices[:, 1], vertices[:, 2] = vertices[:, 2], -vertices[:, 1] + normals[:, 1], normals[:, 2] = normals[:, 2], -normals[:, 1] + uvs[:, 1] = 1 - uvs[:, 1] # Flip UV V-coordinate + + textured_mesh = trimesh.Trimesh( + vertices=vertices, + faces=faces, + vertex_normals=normals, + process=False, + visual=trimesh.visual.TextureVisuals(uv=uvs, material=material) + ) + + return textured_mesh + + + @torch.no_grad() + def run( + self, + mesh: trimesh.Trimesh, + image: Image.Image, + seed: int = 42, + tex_slat_sampler_params: dict = {}, + preprocess_image: bool = True, + resolution: int = 1024, + texture_size: int = 2048, + ) -> trimesh.Trimesh: + """ + Run the pipeline. + + Args: + mesh (trimesh.Trimesh): The mesh to texture. + image (Image.Image): The image prompt. + seed (int): The random seed. + tex_slat_sampler_params (dict): Additional parameters for the texture latent sampler. + preprocess_image (bool): Whether to preprocess the image. + """ + if preprocess_image: + image = self.preprocess_image(image) + mesh = self.preprocess_mesh(mesh) + torch.manual_seed(seed) + cond = self.get_cond([image], 512) if resolution == 512 else self.get_cond([image], 1024) + shape_slat = self.encode_shape_slat(mesh, resolution) + tex_model = self.models['tex_slat_flow_model_512'] if resolution == 512 else self.models['tex_slat_flow_model_1024'] + tex_slat = self.sample_tex_slat( + cond, tex_model, + shape_slat, tex_slat_sampler_params + ) + pbr_voxel = self.decode_tex_slat(tex_slat) + out_mesh = self.postprocess_mesh(mesh, pbr_voxel, resolution, texture_size) + return out_mesh