mirror of
https://github.com/microsoft/TRELLIS.2
synced 2026-04-25 17:15:37 +02:00
update texturing pipeline
This commit is contained in:
@@ -49,7 +49,7 @@ Data processing is streamlined for instant conversions that are fully **renderin
|
|||||||
- [x] Release image-to-3D inference code
|
- [x] Release image-to-3D inference code
|
||||||
- [x] Release pretrained checkpoints (4B)
|
- [x] Release pretrained checkpoints (4B)
|
||||||
- [x] Hugging Face Spaces demo
|
- [x] Hugging Face Spaces demo
|
||||||
- [ ] Release shape-conditioned texture generation inference code (Current schdule: before 12/24/2025)
|
- [x] Release shape-conditioned texture generation inference code
|
||||||
- [ ] Release training code (Current schdule: before 12/31/2025)
|
- [ ] Release training code (Current schdule: before 12/31/2025)
|
||||||
|
|
||||||
|
|
||||||
@@ -184,7 +184,7 @@ Then, you can access the demo at the address shown in the terminal.
|
|||||||
|
|
||||||
### 2. PBR Texture Generation
|
### 2. PBR Texture Generation
|
||||||
|
|
||||||
Will be released soon. Please stay tuned!
|
Please refer to the [example_texturing.py](example_texturing.py) for an example of how to generate PBR textures for a given 3D shape. Also, you can use the [app_texturing.py](app_texturing.py) to run a web demo for PBR texture generation.
|
||||||
|
|
||||||
## 🧩 Related Packages
|
## 🧩 Related Packages
|
||||||
|
|
||||||
|
|||||||
151
app_texturing.py
Normal file
151
app_texturing.py
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
import os
|
||||||
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||||
|
from datetime import datetime
|
||||||
|
import shutil
|
||||||
|
from typing import *
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import trimesh
|
||||||
|
from PIL import Image
|
||||||
|
from trellis2.pipelines import Trellis2TexturingPipeline
|
||||||
|
|
||||||
|
|
||||||
|
MAX_SEED = np.iinfo(np.int32).max
|
||||||
|
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
|
||||||
|
|
||||||
|
|
||||||
|
def start_session(req: gr.Request):
|
||||||
|
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
||||||
|
os.makedirs(user_dir, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
def end_session(req: gr.Request):
|
||||||
|
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
||||||
|
shutil.rmtree(user_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_image(image: Image.Image) -> Image.Image:
|
||||||
|
"""
|
||||||
|
Preprocess the input image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (Image.Image): The input image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Image.Image: The preprocessed image.
|
||||||
|
"""
|
||||||
|
processed_image = pipeline.preprocess_image(image)
|
||||||
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
|
def get_seed(randomize_seed: bool, seed: int) -> int:
|
||||||
|
"""
|
||||||
|
Get the random seed.
|
||||||
|
"""
|
||||||
|
return np.random.randint(0, MAX_SEED) if randomize_seed else seed
|
||||||
|
|
||||||
|
|
||||||
|
def shapeimage_to_tex(
|
||||||
|
mesh_file: str,
|
||||||
|
image: Image.Image,
|
||||||
|
seed: int,
|
||||||
|
resolution: str,
|
||||||
|
texture_size: int,
|
||||||
|
tex_slat_guidance_strength: float,
|
||||||
|
tex_slat_guidance_rescale: float,
|
||||||
|
tex_slat_sampling_steps: int,
|
||||||
|
tex_slat_rescale_t: float,
|
||||||
|
req: gr.Request,
|
||||||
|
progress=gr.Progress(track_tqdm=True),
|
||||||
|
) -> str:
|
||||||
|
mesh = trimesh.load(mesh_file)
|
||||||
|
if isinstance(mesh, trimesh.Scene):
|
||||||
|
mesh = mesh.to_mesh()
|
||||||
|
output = pipeline.run(
|
||||||
|
mesh,
|
||||||
|
image,
|
||||||
|
seed=seed,
|
||||||
|
preprocess_image=False,
|
||||||
|
tex_slat_sampler_params={
|
||||||
|
"steps": tex_slat_sampling_steps,
|
||||||
|
"guidance_strength": tex_slat_guidance_strength,
|
||||||
|
"guidance_rescale": tex_slat_guidance_rescale,
|
||||||
|
"rescale_t": tex_slat_rescale_t,
|
||||||
|
},
|
||||||
|
resolution=int(resolution),
|
||||||
|
texture_size=texture_size,
|
||||||
|
)
|
||||||
|
now = datetime.now()
|
||||||
|
timestamp = now.strftime("%Y-%m-%dT%H%M%S") + f".{now.microsecond // 1000:03d}"
|
||||||
|
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
||||||
|
os.makedirs(user_dir, exist_ok=True)
|
||||||
|
glb_path = os.path.join(user_dir, f'sample_{timestamp}.glb')
|
||||||
|
output.export(glb_path, extension_webp=True)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
return glb_path, glb_path
|
||||||
|
|
||||||
|
|
||||||
|
with gr.Blocks(delete_cache=(600, 600)) as demo:
|
||||||
|
gr.Markdown("""
|
||||||
|
## Texturing a mesh with [TRELLIS.2](https://microsoft.github.io/TRELLIS.2)
|
||||||
|
* Upload a mesh and corresponding reference image (preferably with an alpha-masked foreground object) and click Generate to create a textured 3D asset.
|
||||||
|
""")
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column(scale=1, min_width=360):
|
||||||
|
mesh_file = gr.File(label="Upload Mesh", file_types=[".ply", ".obj", ".glb", ".gltf"], file_count="single")
|
||||||
|
image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=400)
|
||||||
|
|
||||||
|
resolution = gr.Radio(["512", "1024", "1536"], label="Resolution", value="1024")
|
||||||
|
seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
|
||||||
|
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
|
||||||
|
texture_size = gr.Slider(1024, 4096, label="Texture Size", value=2048, step=1024)
|
||||||
|
|
||||||
|
generate_btn = gr.Button("Generate")
|
||||||
|
|
||||||
|
with gr.Accordion(label="Advanced Settings", open=False):
|
||||||
|
with gr.Row():
|
||||||
|
tex_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=1.0, step=0.1)
|
||||||
|
tex_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.0, step=0.01)
|
||||||
|
tex_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
|
||||||
|
tex_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1)
|
||||||
|
|
||||||
|
with gr.Column(scale=10):
|
||||||
|
glb_output = gr.Model3D(label="Extracted GLB", height=724, show_label=True, display_mode="solid", clear_color=(0.25, 0.25, 0.25, 1.0))
|
||||||
|
download_btn = gr.DownloadButton(label="Download GLB")
|
||||||
|
|
||||||
|
|
||||||
|
# Handlers
|
||||||
|
demo.load(start_session)
|
||||||
|
demo.unload(end_session)
|
||||||
|
|
||||||
|
image_prompt.upload(
|
||||||
|
preprocess_image,
|
||||||
|
inputs=[image_prompt],
|
||||||
|
outputs=[image_prompt],
|
||||||
|
)
|
||||||
|
|
||||||
|
generate_btn.click(
|
||||||
|
get_seed,
|
||||||
|
inputs=[randomize_seed, seed],
|
||||||
|
outputs=[seed],
|
||||||
|
).then(
|
||||||
|
shapeimage_to_tex,
|
||||||
|
inputs=[
|
||||||
|
mesh_file, image_prompt, seed, resolution, texture_size,
|
||||||
|
tex_slat_guidance_strength, tex_slat_guidance_rescale, tex_slat_sampling_steps, tex_slat_rescale_t,
|
||||||
|
],
|
||||||
|
outputs=[glb_output, download_btn],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Launch the Gradio app
|
||||||
|
if __name__ == "__main__":
|
||||||
|
os.makedirs(TMP_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
pipeline = Trellis2TexturingPipeline.from_pretrained('microsoft/TRELLIS.2-4B', config_file="texturing_pipeline.json")
|
||||||
|
pipeline.cuda()
|
||||||
|
|
||||||
|
demo.launch()
|
||||||
BIN
assets/example_texturing/image.webp
Normal file
BIN
assets/example_texturing/image.webp
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 62 KiB |
11
assets/example_texturing/readme
Normal file
11
assets/example_texturing/readme
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
## Asset Information
|
||||||
|
|
||||||
|
* Title: The Forgotten Knight
|
||||||
|
* Author: dark_igorek
|
||||||
|
* Source: https://sketchfab.com/3d-models/the-forgotten-knight-d14eb14d83bd4e7ba7cbe443d76a10fd
|
||||||
|
* License: Creative Commons Attribution (CC BY)
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
The asset is used for research purposes only.
|
||||||
|
Please credit the original author and include the Sketchfab link when using or redistributing this model.
|
||||||
BIN
assets/example_texturing/the_forgotten_knight.ply
Normal file
BIN
assets/example_texturing/the_forgotten_knight.ply
Normal file
Binary file not shown.
17
example_texturing.py
Normal file
17
example_texturing.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
import os
|
||||||
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" # Can save GPU memory
|
||||||
|
import trimesh
|
||||||
|
from PIL import Image
|
||||||
|
from trellis2.pipelines import Trellis2TexturingPipeline
|
||||||
|
|
||||||
|
# 1. Load Pipeline
|
||||||
|
pipeline = Trellis2TexturingPipeline.from_pretrained("microsoft/TRELLIS.2-4B", config_file="texturing_pipeline.json")
|
||||||
|
pipeline.cuda()
|
||||||
|
|
||||||
|
# 2. Load Mesh, image & Run
|
||||||
|
mesh = trimesh.load("assets/example_texturing/the_forgotten_knight.ply")
|
||||||
|
image = Image.open("assets/example_texturing/image.webp")
|
||||||
|
output = pipeline.run(mesh, image)
|
||||||
|
|
||||||
|
# 3. Render Mesh
|
||||||
|
output.export("textured.glb", extension_webp=True)
|
||||||
@@ -2,8 +2,7 @@ import importlib
|
|||||||
|
|
||||||
__attributes = {
|
__attributes = {
|
||||||
"Trellis2ImageTo3DPipeline": "trellis2_image_to_3d",
|
"Trellis2ImageTo3DPipeline": "trellis2_image_to_3d",
|
||||||
"Trellis2ImageTo3DCascadePipeline": "trellis2_image_to_3d_cascade",
|
"Trellis2TexturingPipeline": "trellis2_texturing",
|
||||||
"Trellis2ImageToTexturePipeline": "trellis2_image_to_tex",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
__submodules = ['samplers', 'rembg']
|
__submodules = ['samplers', 'rembg']
|
||||||
@@ -49,7 +48,5 @@ def from_pretrained(path: str):
|
|||||||
# For PyLance
|
# For PyLance
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
from . import samplers, rembg
|
from . import samplers, rembg
|
||||||
from .trellis_image_to_3d import TrellisImageTo3DPipeline
|
|
||||||
from .trellis2_image_to_3d import Trellis2ImageTo3DPipeline
|
from .trellis2_image_to_3d import Trellis2ImageTo3DPipeline
|
||||||
from .trellis2_image_to_3d_cascade import Trellis2ImageTo3DCascadePipeline
|
from .trellis2_texturing import Trellis2TexturingPipeline
|
||||||
from .trellis2_image_to_tex import Trellis2ImageToTexturePipeline
|
|
||||||
|
|||||||
@@ -18,32 +18,34 @@ class Pipeline:
|
|||||||
for model in self.models.values():
|
for model in self.models.values():
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def from_pretrained(path: str) -> "Pipeline":
|
def from_pretrained(cls, path: str, config_file: str = "pipeline.json") -> "Pipeline":
|
||||||
"""
|
"""
|
||||||
Load a pretrained model.
|
Load a pretrained model.
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
is_local = os.path.exists(f"{path}/pipeline.json")
|
is_local = os.path.exists(f"{path}/{config_file}")
|
||||||
|
|
||||||
if is_local:
|
if is_local:
|
||||||
config_file = f"{path}/pipeline.json"
|
config_file = f"{path}/{config_file}"
|
||||||
else:
|
else:
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
config_file = hf_hub_download(path, "pipeline.json")
|
config_file = hf_hub_download(path, config_file)
|
||||||
|
|
||||||
with open(config_file, 'r') as f:
|
with open(config_file, 'r') as f:
|
||||||
args = json.load(f)['args']
|
args = json.load(f)['args']
|
||||||
|
|
||||||
_models = {}
|
_models = {}
|
||||||
for k, v in args['models'].items():
|
for k, v in args['models'].items():
|
||||||
|
if hasattr(cls, 'model_names_to_load') and k not in cls.model_names_to_load:
|
||||||
|
continue
|
||||||
try:
|
try:
|
||||||
_models[k] = models.from_pretrained(f"{path}/{v}")
|
_models[k] = models.from_pretrained(f"{path}/{v}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
_models[k] = models.from_pretrained(v)
|
_models[k] = models.from_pretrained(v)
|
||||||
|
|
||||||
new_pipeline = Pipeline(_models)
|
new_pipeline = cls(_models)
|
||||||
new_pipeline._pretrained_args = args
|
new_pipeline._pretrained_args = args
|
||||||
return new_pipeline
|
return new_pipeline
|
||||||
|
|
||||||
|
|||||||
@@ -28,6 +28,17 @@ class Trellis2ImageTo3DPipeline(Pipeline):
|
|||||||
rembg_model (Callable): The model for removing background.
|
rembg_model (Callable): The model for removing background.
|
||||||
low_vram (bool): Whether to use low-VRAM mode.
|
low_vram (bool): Whether to use low-VRAM mode.
|
||||||
"""
|
"""
|
||||||
|
model_names_to_load = [
|
||||||
|
'sparse_structure_flow_model',
|
||||||
|
'sparse_structure_decoder',
|
||||||
|
'shape_slat_flow_model_512',
|
||||||
|
'shape_slat_flow_model_1024',
|
||||||
|
'shape_slat_decoder',
|
||||||
|
'tex_slat_flow_model_512',
|
||||||
|
'tex_slat_flow_model_1024',
|
||||||
|
'tex_slat_decoder',
|
||||||
|
]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
models: dict[str, nn.Module] = None,
|
models: dict[str, nn.Module] = None,
|
||||||
@@ -67,45 +78,43 @@ class Trellis2ImageTo3DPipeline(Pipeline):
|
|||||||
}
|
}
|
||||||
self._device = 'cpu'
|
self._device = 'cpu'
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def from_pretrained(path: str) -> "Trellis2ImageTo3DPipeline":
|
def from_pretrained(cls, path: str, config_file: str = "pipeline.json") -> "Trellis2ImageTo3DPipeline":
|
||||||
"""
|
"""
|
||||||
Load a pretrained model.
|
Load a pretrained model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
path (str): The path to the model. Can be either local path or a Hugging Face repository.
|
path (str): The path to the model. Can be either local path or a Hugging Face repository.
|
||||||
"""
|
"""
|
||||||
pipeline = super(Trellis2ImageTo3DPipeline, Trellis2ImageTo3DPipeline).from_pretrained(path)
|
pipeline = super().from_pretrained(path, config_file)
|
||||||
new_pipeline = Trellis2ImageTo3DPipeline()
|
|
||||||
new_pipeline.__dict__ = pipeline.__dict__
|
|
||||||
args = pipeline._pretrained_args
|
args = pipeline._pretrained_args
|
||||||
|
|
||||||
new_pipeline.sparse_structure_sampler = getattr(samplers, args['sparse_structure_sampler']['name'])(**args['sparse_structure_sampler']['args'])
|
pipeline.sparse_structure_sampler = getattr(samplers, args['sparse_structure_sampler']['name'])(**args['sparse_structure_sampler']['args'])
|
||||||
new_pipeline.sparse_structure_sampler_params = args['sparse_structure_sampler']['params']
|
pipeline.sparse_structure_sampler_params = args['sparse_structure_sampler']['params']
|
||||||
|
|
||||||
new_pipeline.shape_slat_sampler = getattr(samplers, args['shape_slat_sampler']['name'])(**args['shape_slat_sampler']['args'])
|
pipeline.shape_slat_sampler = getattr(samplers, args['shape_slat_sampler']['name'])(**args['shape_slat_sampler']['args'])
|
||||||
new_pipeline.shape_slat_sampler_params = args['shape_slat_sampler']['params']
|
pipeline.shape_slat_sampler_params = args['shape_slat_sampler']['params']
|
||||||
|
|
||||||
new_pipeline.tex_slat_sampler = getattr(samplers, args['tex_slat_sampler']['name'])(**args['tex_slat_sampler']['args'])
|
pipeline.tex_slat_sampler = getattr(samplers, args['tex_slat_sampler']['name'])(**args['tex_slat_sampler']['args'])
|
||||||
new_pipeline.tex_slat_sampler_params = args['tex_slat_sampler']['params']
|
pipeline.tex_slat_sampler_params = args['tex_slat_sampler']['params']
|
||||||
|
|
||||||
new_pipeline.shape_slat_normalization = args['shape_slat_normalization']
|
pipeline.shape_slat_normalization = args['shape_slat_normalization']
|
||||||
new_pipeline.tex_slat_normalization = args['tex_slat_normalization']
|
pipeline.tex_slat_normalization = args['tex_slat_normalization']
|
||||||
|
|
||||||
new_pipeline.image_cond_model = getattr(image_feature_extractor, args['image_cond_model']['name'])(**args['image_cond_model']['args'])
|
pipeline.image_cond_model = getattr(image_feature_extractor, args['image_cond_model']['name'])(**args['image_cond_model']['args'])
|
||||||
new_pipeline.rembg_model = getattr(rembg, args['rembg_model']['name'])(**args['rembg_model']['args'])
|
pipeline.rembg_model = getattr(rembg, args['rembg_model']['name'])(**args['rembg_model']['args'])
|
||||||
|
|
||||||
new_pipeline.low_vram = args.get('low_vram', True)
|
pipeline.low_vram = args.get('low_vram', True)
|
||||||
new_pipeline.default_pipeline_type = args.get('default_pipeline_type', '1024_cascade')
|
pipeline.default_pipeline_type = args.get('default_pipeline_type', '1024_cascade')
|
||||||
new_pipeline.pbr_attr_layout = {
|
pipeline.pbr_attr_layout = {
|
||||||
'base_color': slice(0, 3),
|
'base_color': slice(0, 3),
|
||||||
'metallic': slice(3, 4),
|
'metallic': slice(3, 4),
|
||||||
'roughness': slice(4, 5),
|
'roughness': slice(4, 5),
|
||||||
'alpha': slice(5, 6),
|
'alpha': slice(5, 6),
|
||||||
}
|
}
|
||||||
new_pipeline._device = 'cpu'
|
pipeline._device = 'cpu'
|
||||||
|
|
||||||
return new_pipeline
|
return pipeline
|
||||||
|
|
||||||
def to(self, device: torch.device) -> None:
|
def to(self, device: torch.device) -> None:
|
||||||
self._device = device
|
self._device = device
|
||||||
@@ -364,7 +373,6 @@ class Trellis2ImageTo3DPipeline(Pipeline):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
slat (SparseTensor): The structured latent.
|
slat (SparseTensor): The structured latent.
|
||||||
formats (List[str]): The formats to decode the structured latent to.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[Mesh]: The decoded meshes.
|
List[Mesh]: The decoded meshes.
|
||||||
@@ -433,10 +441,9 @@ class Trellis2ImageTo3DPipeline(Pipeline):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
slat (SparseTensor): The structured latent.
|
slat (SparseTensor): The structured latent.
|
||||||
formats (List[str]): The formats to decode the structured latent to.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[SparseTensor]: The decoded texture voxels
|
SparseTensor: The decoded texture voxels
|
||||||
"""
|
"""
|
||||||
if self.low_vram:
|
if self.low_vram:
|
||||||
self.models['tex_slat_decoder'].to(self.device)
|
self.models['tex_slat_decoder'].to(self.device)
|
||||||
|
|||||||
408
trellis2/pipelines/trellis2_texturing.py
Executable file
408
trellis2/pipelines/trellis2_texturing.py
Executable file
@@ -0,0 +1,408 @@
|
|||||||
|
from typing import *
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
import trimesh
|
||||||
|
from .base import Pipeline
|
||||||
|
from . import samplers, rembg
|
||||||
|
from ..modules.sparse import SparseTensor
|
||||||
|
from ..modules import image_feature_extractor
|
||||||
|
import o_voxel
|
||||||
|
import cumesh
|
||||||
|
import nvdiffrast.torch as dr
|
||||||
|
import cv2
|
||||||
|
import flex_gemm
|
||||||
|
|
||||||
|
|
||||||
|
class Trellis2TexturingPipeline(Pipeline):
|
||||||
|
"""
|
||||||
|
Pipeline for inferring Trellis2 image-to-3D models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
models (dict[str, nn.Module]): The models to use in the pipeline.
|
||||||
|
tex_slat_sampler (samplers.Sampler): The sampler for the texture latent.
|
||||||
|
tex_slat_sampler_params (dict): The parameters for the texture latent sampler.
|
||||||
|
shape_slat_normalization (dict): The normalization parameters for the structured latent.
|
||||||
|
tex_slat_normalization (dict): The normalization parameters for the texture latent.
|
||||||
|
image_cond_model (Callable): The image conditioning model.
|
||||||
|
rembg_model (Callable): The model for removing background.
|
||||||
|
low_vram (bool): Whether to use low-VRAM mode.
|
||||||
|
"""
|
||||||
|
model_names_to_load = [
|
||||||
|
'shape_slat_encoder',
|
||||||
|
'tex_slat_decoder',
|
||||||
|
'tex_slat_flow_model_512',
|
||||||
|
'tex_slat_flow_model_1024'
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
models: dict[str, nn.Module] = None,
|
||||||
|
tex_slat_sampler: samplers.Sampler = None,
|
||||||
|
tex_slat_sampler_params: dict = None,
|
||||||
|
shape_slat_normalization: dict = None,
|
||||||
|
tex_slat_normalization: dict = None,
|
||||||
|
image_cond_model: Callable = None,
|
||||||
|
rembg_model: Callable = None,
|
||||||
|
low_vram: bool = True,
|
||||||
|
):
|
||||||
|
if models is None:
|
||||||
|
return
|
||||||
|
super().__init__(models)
|
||||||
|
self.tex_slat_sampler = tex_slat_sampler
|
||||||
|
self.tex_slat_sampler_params = tex_slat_sampler_params
|
||||||
|
self.shape_slat_normalization = shape_slat_normalization
|
||||||
|
self.tex_slat_normalization = tex_slat_normalization
|
||||||
|
self.image_cond_model = image_cond_model
|
||||||
|
self.rembg_model = rembg_model
|
||||||
|
self.low_vram = low_vram
|
||||||
|
self.pbr_attr_layout = {
|
||||||
|
'base_color': slice(0, 3),
|
||||||
|
'metallic': slice(3, 4),
|
||||||
|
'roughness': slice(4, 5),
|
||||||
|
'alpha': slice(5, 6),
|
||||||
|
}
|
||||||
|
self._device = 'cpu'
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, path: str, config_file: str = "pipeline.json") -> "Trellis2TexturingPipeline":
|
||||||
|
"""
|
||||||
|
Load a pretrained model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): The path to the model. Can be either local path or a Hugging Face repository.
|
||||||
|
"""
|
||||||
|
pipeline = super().from_pretrained(path, config_file)
|
||||||
|
args = pipeline._pretrained_args
|
||||||
|
|
||||||
|
pipeline.tex_slat_sampler = getattr(samplers, args['tex_slat_sampler']['name'])(**args['tex_slat_sampler']['args'])
|
||||||
|
pipeline.tex_slat_sampler_params = args['tex_slat_sampler']['params']
|
||||||
|
|
||||||
|
pipeline.shape_slat_normalization = args['shape_slat_normalization']
|
||||||
|
pipeline.tex_slat_normalization = args['tex_slat_normalization']
|
||||||
|
|
||||||
|
pipeline.image_cond_model = getattr(image_feature_extractor, args['image_cond_model']['name'])(**args['image_cond_model']['args'])
|
||||||
|
pipeline.rembg_model = getattr(rembg, args['rembg_model']['name'])(**args['rembg_model']['args'])
|
||||||
|
|
||||||
|
pipeline.low_vram = args.get('low_vram', True)
|
||||||
|
pipeline.pbr_attr_layout = {
|
||||||
|
'base_color': slice(0, 3),
|
||||||
|
'metallic': slice(3, 4),
|
||||||
|
'roughness': slice(4, 5),
|
||||||
|
'alpha': slice(5, 6),
|
||||||
|
}
|
||||||
|
pipeline._device = 'cpu'
|
||||||
|
return pipeline
|
||||||
|
|
||||||
|
def to(self, device: torch.device) -> None:
|
||||||
|
self._device = device
|
||||||
|
if not self.low_vram:
|
||||||
|
super().to(device)
|
||||||
|
self.image_cond_model.to(device)
|
||||||
|
if self.rembg_model is not None:
|
||||||
|
self.rembg_model.to(device)
|
||||||
|
|
||||||
|
def preprocess_mesh(self, mesh: trimesh.Trimesh) -> trimesh.Trimesh:
|
||||||
|
"""
|
||||||
|
Preprocess the input mesh.
|
||||||
|
"""
|
||||||
|
vertices = mesh.vertices
|
||||||
|
vertices_min = vertices.min(axis=0)
|
||||||
|
vertices_max = vertices.max(axis=0)
|
||||||
|
center = (vertices_min + vertices_max) / 2
|
||||||
|
scale = 0.99999 / (vertices_max - vertices_min).max()
|
||||||
|
vertices = (vertices - center) * scale
|
||||||
|
tmp = vertices[:, 1].copy()
|
||||||
|
vertices[:, 1] = -vertices[:, 2]
|
||||||
|
vertices[:, 2] = tmp
|
||||||
|
assert np.all(vertices >= -0.5) and np.all(vertices <= 0.5), 'vertices out of range'
|
||||||
|
return trimesh.Trimesh(vertices=vertices, faces=mesh.faces, process=False)
|
||||||
|
|
||||||
|
def preprocess_image(self, input: Image.Image) -> Image.Image:
|
||||||
|
"""
|
||||||
|
Preprocess the input image.
|
||||||
|
"""
|
||||||
|
# if has alpha channel, use it directly; otherwise, remove background
|
||||||
|
has_alpha = False
|
||||||
|
if input.mode == 'RGBA':
|
||||||
|
alpha = np.array(input)[:, :, 3]
|
||||||
|
if not np.all(alpha == 255):
|
||||||
|
has_alpha = True
|
||||||
|
max_size = max(input.size)
|
||||||
|
scale = min(1, 1024 / max_size)
|
||||||
|
if scale < 1:
|
||||||
|
input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
|
||||||
|
if has_alpha:
|
||||||
|
output = input
|
||||||
|
else:
|
||||||
|
input = input.convert('RGB')
|
||||||
|
if self.low_vram:
|
||||||
|
self.rembg_model.to(self.device)
|
||||||
|
output = self.rembg_model(input)
|
||||||
|
if self.low_vram:
|
||||||
|
self.rembg_model.cpu()
|
||||||
|
output_np = np.array(output)
|
||||||
|
alpha = output_np[:, :, 3]
|
||||||
|
bbox = np.argwhere(alpha > 0.8 * 255)
|
||||||
|
bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0])
|
||||||
|
center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
|
||||||
|
size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
|
||||||
|
size = int(size * 1)
|
||||||
|
bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
|
||||||
|
output = output.crop(bbox) # type: ignore
|
||||||
|
output = np.array(output).astype(np.float32) / 255
|
||||||
|
output = output[:, :, :3] * output[:, :, 3:4]
|
||||||
|
output = Image.fromarray((output * 255).astype(np.uint8))
|
||||||
|
return output
|
||||||
|
|
||||||
|
def get_cond(self, image: Union[torch.Tensor, list[Image.Image]], resolution: int, include_neg_cond: bool = True) -> dict:
|
||||||
|
"""
|
||||||
|
Get the conditioning information for the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (Union[torch.Tensor, list[Image.Image]]): The image prompts.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: The conditioning information
|
||||||
|
"""
|
||||||
|
self.image_cond_model.image_size = resolution
|
||||||
|
if self.low_vram:
|
||||||
|
self.image_cond_model.to(self.device)
|
||||||
|
cond = self.image_cond_model(image)
|
||||||
|
if self.low_vram:
|
||||||
|
self.image_cond_model.cpu()
|
||||||
|
if not include_neg_cond:
|
||||||
|
return {'cond': cond}
|
||||||
|
neg_cond = torch.zeros_like(cond)
|
||||||
|
return {
|
||||||
|
'cond': cond,
|
||||||
|
'neg_cond': neg_cond,
|
||||||
|
}
|
||||||
|
|
||||||
|
def encode_shape_slat(
|
||||||
|
self,
|
||||||
|
mesh: trimesh.Trimesh,
|
||||||
|
resolution: int = 1024,
|
||||||
|
) -> SparseTensor:
|
||||||
|
"""
|
||||||
|
Encode the meshes to structured latent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mesh (trimesh.Trimesh): The mesh to encode.
|
||||||
|
resolution (int): The resolution of mesh
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SparseTensor: The encoded structured latent.
|
||||||
|
"""
|
||||||
|
vertices = torch.from_numpy(mesh.vertices).float()
|
||||||
|
faces = torch.from_numpy(mesh.faces).long()
|
||||||
|
|
||||||
|
voxel_indices, dual_vertices, intersected = o_voxel.convert.mesh_to_flexible_dual_grid(
|
||||||
|
vertices.cpu(), faces.cpu(),
|
||||||
|
grid_size=resolution,
|
||||||
|
aabb=[[-0.5,-0.5,-0.5],[0.5,0.5,0.5]],
|
||||||
|
face_weight=1.0,
|
||||||
|
boundary_weight=0.2,
|
||||||
|
regularization_weight=1e-2,
|
||||||
|
timing=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
vertices = SparseTensor(
|
||||||
|
feats=dual_vertices * resolution - voxel_indices,
|
||||||
|
coords=torch.cat([torch.zeros_like(voxel_indices[:, 0:1]), voxel_indices], dim=-1)
|
||||||
|
).to(self.device)
|
||||||
|
intersected = vertices.replace(intersected).to(self.device)
|
||||||
|
|
||||||
|
if self.low_vram:
|
||||||
|
self.models['shape_slat_encoder'].to(self.device)
|
||||||
|
shape_slat = self.models['shape_slat_encoder'](vertices, intersected)
|
||||||
|
if self.low_vram:
|
||||||
|
self.models['shape_slat_encoder'].cpu()
|
||||||
|
return shape_slat
|
||||||
|
|
||||||
|
def sample_tex_slat(
|
||||||
|
self,
|
||||||
|
cond: dict,
|
||||||
|
flow_model,
|
||||||
|
shape_slat: SparseTensor,
|
||||||
|
sampler_params: dict = {},
|
||||||
|
) -> SparseTensor:
|
||||||
|
"""
|
||||||
|
Sample structured latent with the given conditioning.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cond (dict): The conditioning information.
|
||||||
|
shape_slat (SparseTensor): The structured latent for shape
|
||||||
|
sampler_params (dict): Additional parameters for the sampler.
|
||||||
|
"""
|
||||||
|
# Sample structured latent
|
||||||
|
std = torch.tensor(self.shape_slat_normalization['std'])[None].to(shape_slat.device)
|
||||||
|
mean = torch.tensor(self.shape_slat_normalization['mean'])[None].to(shape_slat.device)
|
||||||
|
shape_slat = (shape_slat - mean) / std
|
||||||
|
|
||||||
|
in_channels = flow_model.in_channels if isinstance(flow_model, nn.Module) else flow_model[0].in_channels
|
||||||
|
noise = shape_slat.replace(feats=torch.randn(shape_slat.coords.shape[0], in_channels - shape_slat.feats.shape[1]).to(self.device))
|
||||||
|
sampler_params = {**self.tex_slat_sampler_params, **sampler_params}
|
||||||
|
if self.low_vram:
|
||||||
|
flow_model.to(self.device)
|
||||||
|
slat = self.tex_slat_sampler.sample(
|
||||||
|
flow_model,
|
||||||
|
noise,
|
||||||
|
concat_cond=shape_slat,
|
||||||
|
**cond,
|
||||||
|
**sampler_params,
|
||||||
|
verbose=True,
|
||||||
|
tqdm_desc="Sampling texture SLat",
|
||||||
|
).samples
|
||||||
|
if self.low_vram:
|
||||||
|
flow_model.cpu()
|
||||||
|
|
||||||
|
std = torch.tensor(self.tex_slat_normalization['std'])[None].to(slat.device)
|
||||||
|
mean = torch.tensor(self.tex_slat_normalization['mean'])[None].to(slat.device)
|
||||||
|
slat = slat * std + mean
|
||||||
|
|
||||||
|
return slat
|
||||||
|
|
||||||
|
def decode_tex_slat(
|
||||||
|
self,
|
||||||
|
slat: SparseTensor,
|
||||||
|
) -> SparseTensor:
|
||||||
|
"""
|
||||||
|
Decode the structured latent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
slat (SparseTensor): The structured latent.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SparseTensor: The decoded texture voxels
|
||||||
|
"""
|
||||||
|
if self.low_vram:
|
||||||
|
self.models['tex_slat_decoder'].to(self.device)
|
||||||
|
ret = self.models['tex_slat_decoder'](slat) * 0.5 + 0.5
|
||||||
|
if self.low_vram:
|
||||||
|
self.models['tex_slat_decoder'].cpu()
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def postprocess_mesh(
|
||||||
|
self,
|
||||||
|
mesh: trimesh.Trimesh,
|
||||||
|
pbr_voxel: SparseTensor,
|
||||||
|
resolution: int = 1024,
|
||||||
|
texture_size: int = 1024,
|
||||||
|
) -> trimesh.Trimesh:
|
||||||
|
vertices = mesh.vertices
|
||||||
|
faces = mesh.faces
|
||||||
|
normals = mesh.vertex_normals
|
||||||
|
vertices_torch = torch.from_numpy(vertices).float().cuda()
|
||||||
|
faces_torch = torch.from_numpy(faces).int().cuda()
|
||||||
|
if hasattr(mesh, 'visual') and hasattr(mesh.visual, 'uv') and mesh.visual.uv is not None:
|
||||||
|
uvs = mesh.visual.uv.copy()
|
||||||
|
uvs[:, 1] = 1 - uvs[:, 1]
|
||||||
|
uvs_torch = torch.from_numpy(uvs).float().cuda()
|
||||||
|
else:
|
||||||
|
_cumesh = cumesh.CuMesh()
|
||||||
|
_cumesh.init(vertices_torch, faces_torch)
|
||||||
|
vertices_torch, faces_torch, uvs_torch, vmap = _cumesh.uv_unwrap(return_vmaps=True)
|
||||||
|
vertices_torch = vertices_torch.cuda()
|
||||||
|
faces_torch = faces_torch.cuda()
|
||||||
|
uvs_torch = uvs_torch.cuda()
|
||||||
|
vertices = vertices_torch.cpu().numpy()
|
||||||
|
faces = faces_torch.cpu().numpy()
|
||||||
|
uvs = uvs_torch.cpu().numpy()
|
||||||
|
normals = normals[vmap.cpu().numpy()]
|
||||||
|
|
||||||
|
# rasterize
|
||||||
|
ctx = dr.RasterizeCudaContext()
|
||||||
|
uvs_torch = torch.cat([uvs_torch * 2 - 1, torch.zeros_like(uvs_torch[:, :1]), torch.ones_like(uvs_torch[:, :1])], dim=-1).unsqueeze(0)
|
||||||
|
rast, _ = dr.rasterize(
|
||||||
|
ctx, uvs_torch, faces_torch,
|
||||||
|
resolution=[texture_size, texture_size],
|
||||||
|
)
|
||||||
|
mask = rast[0, ..., 3] > 0
|
||||||
|
pos = dr.interpolate(vertices_torch.unsqueeze(0), rast, faces_torch)[0][0]
|
||||||
|
|
||||||
|
attrs = torch.zeros(texture_size, texture_size, pbr_voxel.shape[1], device=self.device)
|
||||||
|
attrs[mask] = flex_gemm.ops.grid_sample.grid_sample_3d(
|
||||||
|
pbr_voxel.feats,
|
||||||
|
pbr_voxel.coords,
|
||||||
|
shape=torch.Size([*pbr_voxel.shape, *pbr_voxel.spatial_shape]),
|
||||||
|
grid=((pos[mask] + 0.5) * resolution).reshape(1, -1, 3),
|
||||||
|
mode='trilinear',
|
||||||
|
)
|
||||||
|
|
||||||
|
# construct mesh
|
||||||
|
mask = mask.cpu().numpy()
|
||||||
|
base_color = np.clip(attrs[..., self.pbr_attr_layout['base_color']].cpu().numpy() * 255, 0, 255).astype(np.uint8)
|
||||||
|
metallic = np.clip(attrs[..., self.pbr_attr_layout['metallic']].cpu().numpy() * 255, 0, 255).astype(np.uint8)
|
||||||
|
roughness = np.clip(attrs[..., self.pbr_attr_layout['roughness']].cpu().numpy() * 255, 0, 255).astype(np.uint8)
|
||||||
|
alpha = np.clip(attrs[..., self.pbr_attr_layout['alpha']].cpu().numpy() * 255, 0, 255).astype(np.uint8)
|
||||||
|
|
||||||
|
# extend
|
||||||
|
mask = (~mask).astype(np.uint8)
|
||||||
|
base_color = cv2.inpaint(base_color, mask, 3, cv2.INPAINT_TELEA)
|
||||||
|
metallic = cv2.inpaint(metallic, mask, 1, cv2.INPAINT_TELEA)[..., None]
|
||||||
|
roughness = cv2.inpaint(roughness, mask, 1, cv2.INPAINT_TELEA)[..., None]
|
||||||
|
alpha = cv2.inpaint(alpha, mask, 1, cv2.INPAINT_TELEA)[..., None]
|
||||||
|
|
||||||
|
material = trimesh.visual.material.PBRMaterial(
|
||||||
|
baseColorTexture=Image.fromarray(np.concatenate([base_color, alpha], axis=-1)),
|
||||||
|
baseColorFactor=np.array([255, 255, 255, 255], dtype=np.uint8),
|
||||||
|
metallicRoughnessTexture=Image.fromarray(np.concatenate([np.zeros_like(metallic), roughness, metallic], axis=-1)),
|
||||||
|
metallicFactor=1.0,
|
||||||
|
roughnessFactor=1.0,
|
||||||
|
alphaMode='OPAQUE',
|
||||||
|
doubleSided=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Swap Y and Z axes, invert Y (common conversion for GLB compatibility)
|
||||||
|
vertices[:, 1], vertices[:, 2] = vertices[:, 2], -vertices[:, 1]
|
||||||
|
normals[:, 1], normals[:, 2] = normals[:, 2], -normals[:, 1]
|
||||||
|
uvs[:, 1] = 1 - uvs[:, 1] # Flip UV V-coordinate
|
||||||
|
|
||||||
|
textured_mesh = trimesh.Trimesh(
|
||||||
|
vertices=vertices,
|
||||||
|
faces=faces,
|
||||||
|
vertex_normals=normals,
|
||||||
|
process=False,
|
||||||
|
visual=trimesh.visual.TextureVisuals(uv=uvs, material=material)
|
||||||
|
)
|
||||||
|
|
||||||
|
return textured_mesh
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
mesh: trimesh.Trimesh,
|
||||||
|
image: Image.Image,
|
||||||
|
seed: int = 42,
|
||||||
|
tex_slat_sampler_params: dict = {},
|
||||||
|
preprocess_image: bool = True,
|
||||||
|
resolution: int = 1024,
|
||||||
|
texture_size: int = 2048,
|
||||||
|
) -> trimesh.Trimesh:
|
||||||
|
"""
|
||||||
|
Run the pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mesh (trimesh.Trimesh): The mesh to texture.
|
||||||
|
image (Image.Image): The image prompt.
|
||||||
|
seed (int): The random seed.
|
||||||
|
tex_slat_sampler_params (dict): Additional parameters for the texture latent sampler.
|
||||||
|
preprocess_image (bool): Whether to preprocess the image.
|
||||||
|
"""
|
||||||
|
if preprocess_image:
|
||||||
|
image = self.preprocess_image(image)
|
||||||
|
mesh = self.preprocess_mesh(mesh)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
cond = self.get_cond([image], 512) if resolution == 512 else self.get_cond([image], 1024)
|
||||||
|
shape_slat = self.encode_shape_slat(mesh, resolution)
|
||||||
|
tex_model = self.models['tex_slat_flow_model_512'] if resolution == 512 else self.models['tex_slat_flow_model_1024']
|
||||||
|
tex_slat = self.sample_tex_slat(
|
||||||
|
cond, tex_model,
|
||||||
|
shape_slat, tex_slat_sampler_params
|
||||||
|
)
|
||||||
|
pbr_voxel = self.decode_tex_slat(tex_slat)
|
||||||
|
out_mesh = self.postprocess_mesh(mesh, pbr_voxel, resolution, texture_size)
|
||||||
|
return out_mesh
|
||||||
Reference in New Issue
Block a user