diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..d11da70 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "o-voxel/third_party/eigen"] + path = o-voxel/third_party/eigen + url = https://gitlab.com/libeigen/eigen.git diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..7965606 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ + MIT License + + Copyright (c) Microsoft Corporation. + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..e97aee1 --- /dev/null +++ b/README.md @@ -0,0 +1,222 @@ +![](assets/teaser.webp) + +# Native and Compact Structured Latents for 3D Generation + +Paper +Hugging Face +Project Page +License + +https://github.com/user-attachments/assets/5ee056e4-73a9-4fd8-bf60-59cae90d3dfc + +*(Compressed version due to GitHub size limits. See the full-quality video on our project page!)* + +**TRELLIS.2** is a state-of-the-art large 3D generative model (4B parameters) designed for high-fidelity **image-to-3D** generation. It leverages a novel "field-free" sparse voxel structure termed **O-Voxel** to reconstruct and generate arbitrary 3D assets with complex topologies, sharp features, and full PBR materials. + + +## ✨ Features + +### 1. High Quality, Resolution & Efficiency +Our 4B-parameter model generates high-resolution fully textured assets with exceptional fidelity and efficiency using vanilla DiTs. It utilizes a Sparse 3D VAE with 16× spatial downsampling to encode assets into a compact latent space. + +| Resolution | Total Time* | Breakdown (Shape + Mat) | +| :--- | :--- | :--- | +| **512³** | **~3s** | 2s + 1s | +| **1024³** | **~17s** | 10s + 7s | +| **1536³** | **~60s** | 35s + 25s | + +*Tested on NVIDIA H100 GPU. + +### 2. Arbitrary Topology Handling +The **O-Voxel** representation breaks the limits of iso-surface fields. It robustly handles complex structures without lossy conversion: +* ✅ **Open Surfaces** (e.g., clothing, leaves) +* ✅ **Non-manifold Geometry** +* ✅ **Internal Enclosed Structures** + +### 3. Rich Texture Modeling +Beyond basic colors, TRELLIS.2 models arbitrary surface attributes including **Base Color, Roughness, Metallic, and Opacity**, enabling photorealistic rendering and transparency support. + +### 4. Minimalist Processing +Data processing is streamlined for instant conversions that are fully **rendering-free** and **optimization-free**. +* **< 10s** (Single CPU): Textured Mesh → O-Voxel +* **< 100ms** (CUDA): O-Voxel → Textured Mesh + + +## 🗺️ Roadmap + +- [x] Paper release +- [x] Release image-to-3D inference code +- [x] Release pretrained checkpoints (4B) +- [x] Hugging Face Spaces demo +- [ ] Release shape-conditioned texture generation inference code (Current schdule: before 12/24/2025) +- [ ] Release training code (Current schdule: before 12/31/2025) + + +## 🛠️ Installation + +### Prerequisites +- **System**: The code is currently tested only on **Linux**. +- **Hardware**: An NVIDIA GPU with at least 24GB of memory is necessary. The code has been verified on NVIDIA A100 and H100 GPUs. +- **Software**: + - The [CUDA Toolkit](https://developer.nvidia.com/cuda-toolkit-archive) is needed to compile certain packages. Recommended version is 12.4. + - [Conda](https://docs.anaconda.com/miniconda/install/#quick-command-line-install) is recommended for managing dependencies. + - Python version 3.8 or higher is required. + +### Installation Steps +1. Clone the repo: + ```sh + git clone -b main https://github.com/microsoft/TRELLIS.2.git --recursive + cd TRELLIS.2 + ``` + +2. Install the dependencies: + + **Before running the following command there are somethings to note:** + - By adding `--new-env`, a new conda environment named `trellis2` will be created. If you want to use an existing conda environment, please remove this flag. + - By default the `trellis2` environment will use pytorch 2.6.0 with CUDA 12.4. If you want to use a different version of CUDA, you can remove the `--new-env` flag and manually install the required dependencies. Refer to [PyTorch](https://pytorch.org/get-started/previous-versions/) for the installation command. + - If you have multiple CUDA Toolkit versions installed, `CUDA_HOME` should be set to the correct version before running the command. For example, if you have CUDA Toolkit 12.4 and 13.0 installed, you can run `export CUDA_HOME=/usr/local/cuda-12.4` before running the command. + - By default, the code uses the `flash-attn` backend for attention. For GPUs do not support `flash-attn` (e.g., NVIDIA V100), you can install `xformers` manually and set the `ATTN_BACKEND` environment variable to `xformers` before running the code. See the [Minimal Example](#minimal-example) for more details. + - The installation may take a while due to the large number of dependencies. Please be patient. If you encounter any issues, you can try to install the dependencies one by one, specifying one flag at a time. + - If you encounter any issues during the installation, feel free to open an issue or contact us. + + Create a new conda environment named `trellis2` and install the dependencies: + ```sh + . ./setup.sh --new-env --basic --flash-attn --nvdiffrast --nvdiffrec --cumesh --o-voxel --flexgemm + ``` + The detailed usage of `setup.sh` can be found by running `. ./setup.sh --help`. + ```sh + Usage: setup.sh [OPTIONS] + Options: + -h, --help Display this help message + --new-env Create a new conda environment + --basic Install basic dependencies + --flash-attn Install flash-attention + --cumesh Install cumesh + --o-voxel Install o-voxel + --flexgemm Install flexgemm + --nvdiffrast Install nvdiffrast + --nvdiffrec Install nvdiffrec + ``` + +## 📦 Pretrained Weights + +The pretrained model **TRELLIS.2-4B** is available on Hugging Face. Please refer to the model card there for more details. + +| Model | Parameters | Resolution | Link | +| :--- | :--- | :--- | :--- | +| **TRELLIS.2-4B** | 4 Billion | 512³ - 1536³ | [Hugging Face](https://huggingface.co/microsoft/TRELLIS.2-4B) | + + +## 🚀 Usage + +### 1. Image to 3D Generation + +#### Minimal Example + +Here is an [example](example.py) of how to use the pretrained models for 3D asset generation. + +```python +import os +os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" # Can save GPU memory +import cv2 +import imageio +from PIL import Image +import torch +from trellis2.pipelines import Trellis2ImageTo3DPipeline +from trellis2.utils import render_utils +from trellis2.renderers import EnvMap +import o_voxel + +# 1. Setup Environment Map +envmap = EnvMap(torch.tensor( + cv2.cvtColor(cv2.imread('assets/hdri/forest.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), + dtype=torch.float32, device='cuda' +)) + +# 2. Load Pipeline +pipeline = Trellis2ImageTo3DPipeline.from_pretrained("microsoft/TRELLIS.2-4B") +pipeline.cuda() + +# 3. Load Image & Run +image = Image.open("assets/example_image/T.png") +mesh = pipeline.run(image)[0] +mesh.simplify(16777216) # nvdiffrast limit + +# 4. Render Video +video = render_utils.make_pbr_vis_frames(render_utils.render_video(mesh, envmap=envmap)) +imageio.mimsave("sample.mp4", video, fps=15) + +# 5. Export to GLB +glb = o_voxel.postprocess.to_glb( + vertices = mesh.vertices, + faces = mesh.faces, + attr_volume = mesh.attrs, + coords = mesh.coords, + attr_layout = mesh.layout, + voxel_size = mesh.voxel_size, + aabb = [[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]], + decimation_target = 1000000, + texture_size = 4096, + remesh = True, + remesh_band = 1, + remesh_project = 0, + verbose = True +) +glb.export("sample.glb", extension_webp=True) +``` + +Upon execution, the script generates the following files: + - `sample.mp4`: A video visualizing the generated 3D asset with PBR materials and environmental lighting. + - `sample.glb`: The extracted PBR-ready 3D asset in GLB format. + +**Note:** The `.glb` file is exported in `OPAQUE` mode by default. Although the alpha channel is preserved within the texture map, it is not active initially. To enable transparency, import the asset into your 3D software and manually connect the texture's alpha channel to the material's opacity or alpha input. + +#### Web Demo + +[app.py](app.py) provides a simple web demo for image to 3D asset generation. you can run the demo with the following command: +```sh +python app.py +``` + +Then, you can access the demo at the address shown in the terminal. + +### 2. PBR Texture Generation + +Will be released soon. Please stay tuned! + +## 🧩 Related Packages + +TRELLIS.2 is built upon several specialized high-performance packages developed by our team: + +* **[O-Voxel](o-voxel):** + Core library handling the logic for converting between textured meshes and the O-Voxel representation, ensuring instant bidirectional transformation. +* **[FlexGEMM](https://github.com/JeffreyXiang/FlexGEMM):** + Efficient sparse convolution implementation based on Triton, enabling rapid processing of sparse voxel structures. +* **[CuMesh](https://github.com/JeffreyXiang/CuMesh):** + CUDA-accelerated mesh utilities used for high-speed post-processing, remeshing, decimation, and UV-unwrapping. + + +## ⚖️ License + +This model and code are released under the **[MIT License](LICENSE)**. + +Please note that certain dependencies operate under separate license terms: + +- [**nvdiffrast**](https://github.com/NVlabs/nvdiffrast): Utilized for rendering generated 3D assets. This package is governed by its own [License](https://github.com/NVlabs/nvdiffrast/blob/main/LICENSE.txt). + +- [**nvdiffrec**](https://github.com/NVlabs/nvdiffrec): Implements the split-sum renderer for PBR materials. This package is governed by its own [License](https://github.com/NVlabs/nvdiffrec/blob/main/LICENSE.txt). + +## 📚 Citation + +If you find this model useful for your research, please cite our work: + +```bibtex +@article{ + xiang2025trellis2, + title={Native and Compact Structured Latents for 3D Generation}, + author={Xiang, Jianfeng and Chen, Xiaoxue and Xu, Sicheng and Wang, Ruicheng and Lv, Zelong and Deng, Yu and Zhu, Hongyuan and Dong, Yue and Zhao, Hao and Yuan, Nicholas Jing and Yang, Jiaolong}, + journal={Tech report}, + year={2025} +} +``` diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 0000000..e751608 --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,14 @@ + + +## Security + +Microsoft takes the security of our software products and services seriously, which +includes all source code repositories in our GitHub organizations. + +**Please do not report security vulnerabilities through public GitHub issues.** + +For security reporting information, locations, contact information, and policies, +please review the latest guidance for Microsoft repositories at +[https://aka.ms/SECURITY.md](https://aka.ms/SECURITY.md). + + \ No newline at end of file diff --git a/app.py b/app.py new file mode 100644 index 0000000..b6dea25 --- /dev/null +++ b/app.py @@ -0,0 +1,645 @@ +import gradio as gr + +import os +os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +from datetime import datetime +import shutil +import cv2 +from typing import * +import torch +import numpy as np +from PIL import Image +import base64 +import io +from trellis2.modules.sparse import SparseTensor +from trellis2.pipelines import Trellis2ImageTo3DPipeline +from trellis2.renderers import EnvMap +from trellis2.utils import render_utils +import o_voxel + + +MAX_SEED = np.iinfo(np.int32).max +TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp') +MODES = [ + {"name": "Normal", "icon": "assets/app/normal.png", "render_key": "normal"}, + {"name": "Clay render", "icon": "assets/app/clay.png", "render_key": "clay"}, + {"name": "Base color", "icon": "assets/app/basecolor.png", "render_key": "base_color"}, + {"name": "HDRI forest", "icon": "assets/app/hdri_forest.png", "render_key": "shaded_forest"}, + {"name": "HDRI sunset", "icon": "assets/app/hdri_sunset.png", "render_key": "shaded_sunset"}, + {"name": "HDRI courtyard", "icon": "assets/app/hdri_courtyard.png", "render_key": "shaded_courtyard"}, +] +STEPS = 8 +DEFAULT_MODE = 3 +DEFAULT_STEP = 3 + + +css = """ +/* Overwrite Gradio Default Style */ +.stepper-wrapper { + padding: 0; +} + +.stepper-container { + padding: 0; + align-items: center; +} + +.step-button { + flex-direction: row; +} + +.step-connector { + transform: none; +} + +.step-number { + width: 16px; + height: 16px; +} + +.step-label { + position: relative; + bottom: 0; +} + +.wrap.center.full { + inset: 0; + height: 100%; +} + +.wrap.center.full.translucent { + background: var(--block-background-fill); +} + +.meta-text-center { + display: block !important; + position: absolute !important; + top: unset !important; + bottom: 0 !important; + right: 0 !important; + transform: unset !important; +} + +/* Previewer */ +.previewer-container { + position: relative; + font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif; + width: 100%; + height: 722px; + margin: 0 auto; + padding: 20px; + display: flex; + flex-direction: column; + align-items: center; + justify-content: center; +} + +.previewer-container .tips-icon { + position: absolute; + right: 10px; + top: 10px; + z-index: 10; + border-radius: 10px; + color: #fff; + background-color: var(--color-accent); + padding: 3px 6px; + user-select: none; +} + +.previewer-container .tips-text { + position: absolute; + right: 10px; + top: 50px; + color: #fff; + background-color: var(--color-accent); + border-radius: 10px; + padding: 6px; + text-align: left; + max-width: 300px; + z-index: 10; + transition: all 0.3s; + opacity: 0%; + user-select: none; +} + +.previewer-container .tips-text p { + font-size: 14px; + line-height: 1.2; +} + +.tips-icon:hover + .tips-text { + display: block; + opacity: 100%; +} + +/* Row 1: Display Modes */ +.previewer-container .mode-row { + width: 100%; + display: flex; + gap: 8px; + justify-content: center; + margin-bottom: 20px; + flex-wrap: wrap; +} +.previewer-container .mode-btn { + width: 24px; + height: 24px; + border-radius: 50%; + cursor: pointer; + opacity: 0.5; + transition: all 0.2s; + border: 2px solid #ddd; + object-fit: cover; +} +.previewer-container .mode-btn:hover { opacity: 0.9; transform: scale(1.1); } +.previewer-container .mode-btn.active { + opacity: 1; + border-color: var(--color-accent); + transform: scale(1.1); +} + +/* Row 2: Display Image */ +.previewer-container .display-row { + margin-bottom: 20px; + min-height: 400px; + width: 100%; + flex-grow: 1; + display: flex; + justify-content: center; + align-items: center; +} +.previewer-container .previewer-main-image { + max-width: 100%; + max-height: 100%; + flex-grow: 1; + object-fit: contain; + display: none; +} +.previewer-container .previewer-main-image.visible { + display: block; +} + +/* Row 3: Custom HTML Slider */ +.previewer-container .slider-row { + width: 100%; + display: flex; + flex-direction: column; + align-items: center; + gap: 10px; + padding: 0 10px; +} + +.previewer-container input[type=range] { + -webkit-appearance: none; + width: 100%; + max-width: 400px; + background: transparent; +} +.previewer-container input[type=range]::-webkit-slider-runnable-track { + width: 100%; + height: 8px; + cursor: pointer; + background: #ddd; + border-radius: 5px; +} +.previewer-container input[type=range]::-webkit-slider-thumb { + height: 20px; + width: 20px; + border-radius: 50%; + background: var(--color-accent); + cursor: pointer; + -webkit-appearance: none; + margin-top: -6px; + box-shadow: 0 2px 5px rgba(0,0,0,0.2); + transition: transform 0.1s; +} +.previewer-container input[type=range]::-webkit-slider-thumb:hover { + transform: scale(1.2); +} + +/* Overwrite Previewer Block Style */ +.gradio-container .padded:has(.previewer-container) { + padding: 0 !important; +} + +.gradio-container:has(.previewer-container) [data-testid="block-label"] { + position: absolute; + top: 0; + left: 0; +} +""" + + +head = """ + +""" + + +empty_html = f""" +
+ +
+""" + + +def image_to_base64(image): + buffered = io.BytesIO() + image = image.convert("RGB") + image.save(buffered, format="jpeg", quality=85) + img_str = base64.b64encode(buffered.getvalue()).decode() + return f"data:image/jpeg;base64,{img_str}" + + +def start_session(req: gr.Request): + user_dir = os.path.join(TMP_DIR, str(req.session_hash)) + os.makedirs(user_dir, exist_ok=True) + + +def end_session(req: gr.Request): + user_dir = os.path.join(TMP_DIR, str(req.session_hash)) + shutil.rmtree(user_dir) + + +def preprocess_image(image: Image.Image) -> Image.Image: + """ + Preprocess the input image. + + Args: + image (Image.Image): The input image. + + Returns: + Image.Image: The preprocessed image. + """ + processed_image = pipeline.preprocess_image(image) + return processed_image + + +def pack_state(latents: Tuple[SparseTensor, SparseTensor, int]) -> dict: + shape_slat, tex_slat, res = latents + return { + 'shape_slat_feats': shape_slat.feats.cpu().numpy(), + 'tex_slat_feats': tex_slat.feats.cpu().numpy(), + 'coords': shape_slat.coords.cpu().numpy(), + 'res': res, + } + + +def unpack_state(state: dict) -> Tuple[SparseTensor, SparseTensor, int]: + shape_slat = SparseTensor( + feats=torch.from_numpy(state['shape_slat_feats']).cuda(), + coords=torch.from_numpy(state['coords']).cuda(), + ) + tex_slat = shape_slat.replace(torch.from_numpy(state['tex_slat_feats']).cuda()) + return shape_slat, tex_slat, state['res'] + + +def get_seed(randomize_seed: bool, seed: int) -> int: + """ + Get the random seed. + """ + return np.random.randint(0, MAX_SEED) if randomize_seed else seed + + +def image_to_3d( + image: Image.Image, + seed: int, + resolution: str, + ss_guidance_strength: float, + ss_guidance_rescale: float, + ss_sampling_steps: int, + ss_rescale_t: float, + shape_slat_guidance_strength: float, + shape_slat_guidance_rescale: float, + shape_slat_sampling_steps: int, + shape_slat_rescale_t: float, + tex_slat_guidance_strength: float, + tex_slat_guidance_rescale: float, + tex_slat_sampling_steps: int, + tex_slat_rescale_t: float, + req: gr.Request, + progress=gr.Progress(track_tqdm=True), +) -> str: + # --- Sampling --- + outputs, latents = pipeline.run( + image, + seed=seed, + preprocess_image=False, + sparse_structure_sampler_params={ + "steps": ss_sampling_steps, + "guidance_strength": ss_guidance_strength, + "guidance_rescale": ss_guidance_rescale, + "rescale_t": ss_rescale_t, + }, + shape_slat_sampler_params={ + "steps": shape_slat_sampling_steps, + "guidance_strength": shape_slat_guidance_strength, + "guidance_rescale": shape_slat_guidance_rescale, + "rescale_t": shape_slat_rescale_t, + }, + tex_slat_sampler_params={ + "steps": tex_slat_sampling_steps, + "guidance_strength": tex_slat_guidance_strength, + "guidance_rescale": tex_slat_guidance_rescale, + "rescale_t": tex_slat_rescale_t, + }, + pipeline_type={ + "512": "512", + "1024": "1024_cascade", + "1536": "1536_cascade", + }[resolution], + return_latent=True, + ) + mesh = outputs[0] + mesh.simplify(16777216) # nvdiffrast limit + images = render_utils.render_snapshot(mesh, resolution=1024, r=2, fov=36, nviews=STEPS, envmap=envmap) + state = pack_state(latents) + torch.cuda.empty_cache() + + # --- HTML Construction --- + # The Stack of 48 Images + images_html = "" + for m_idx, mode in enumerate(MODES): + for s_idx in range(STEPS): + # ID Naming Convention: view-m{mode}-s{step} + unique_id = f"view-m{m_idx}-s{s_idx}" + + # Logic: Only Mode 0, Step 0 is visible initially + is_visible = (m_idx == DEFAULT_MODE and s_idx == DEFAULT_STEP) + vis_class = "visible" if is_visible else "" + + # Image Source + img_base64 = image_to_base64(Image.fromarray(images[mode['render_key']][s_idx])) + + # Render the Tag + images_html += f""" + + """ + + # Button Row HTML + btns_html = "" + for idx, mode in enumerate(MODES): + active_class = "active" if idx == DEFAULT_MODE else "" + # Note: onclick calls the JS function defined in Head + btns_html += f""" + + """ + + # Assemble the full component + full_html = f""" +
+
+
💡Tips
+
+

Render Mode - Click on the circular buttons to switch between different render modes.

+

View Angle - Drag the slider to change the view angle.

+
+
+ + +
+ {images_html} +
+ + +
+ {btns_html} +
+ + +
+ +
+
+ """ + + return state, full_html + + +def extract_glb( + state: dict, + decimation_target: int, + texture_size: int, + req: gr.Request, + progress=gr.Progress(track_tqdm=True), +) -> Tuple[str, str]: + """ + Extract a GLB file from the 3D model. + + Args: + state (dict): The state of the generated 3D model. + decimation_target (int): The target face count for decimation. + texture_size (int): The texture resolution. + + Returns: + str: The path to the extracted GLB file. + """ + user_dir = os.path.join(TMP_DIR, str(req.session_hash)) + shape_slat, tex_slat, res = unpack_state(state) + mesh = pipeline.decode_latent(shape_slat, tex_slat, res)[0] + glb = o_voxel.postprocess.to_glb( + vertices=mesh.vertices, + faces=mesh.faces, + attr_volume=mesh.attrs, + coords=mesh.coords, + attr_layout=pipeline.pbr_attr_layout, + grid_size=res, + aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]], + decimation_target=decimation_target, + texture_size=texture_size, + remesh=True, + remesh_band=1, + remesh_project=0, + use_tqdm=True, + ) + now = datetime.now() + timestamp = now.strftime("%Y-%m-%dT%H%M%S") + f".{now.microsecond // 1000:03d}" + os.makedirs(user_dir, exist_ok=True) + glb_path = os.path.join(user_dir, f'sample_{timestamp}.glb') + glb.export(glb_path, extension_webp=True) + torch.cuda.empty_cache() + return glb_path, glb_path + + +with gr.Blocks(delete_cache=(600, 600)) as demo: + gr.Markdown(""" + ## Image to 3D Asset with [TRELLIS.2](https://microsoft.github.io/trellis.2) + * Upload an image (preferably with an alpha-masked foreground object) and click Generate to create a 3D asset. + * Click Extract GLB to export and download the generated GLB file if you're satisfied with the result. Otherwise, try another time. + """) + + with gr.Row(): + with gr.Column(scale=1, min_width=360): + image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=400) + + resolution = gr.Radio(["512", "1024", "1536"], label="Resolution", value="1024") + seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1) + randomize_seed = gr.Checkbox(label="Randomize Seed", value=True) + decimation_target = gr.Slider(100000, 1000000, label="Decimation Target", value=500000, step=10000) + texture_size = gr.Slider(1024, 4096, label="Texture Size", value=2048, step=1024) + + generate_btn = gr.Button("Generate") + + with gr.Accordion(label="Advanced Settings", open=False): + gr.Markdown("Stage 1: Sparse Structure Generation") + with gr.Row(): + ss_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=7.5, step=0.1) + ss_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.7, step=0.01) + ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1) + ss_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=5.0, step=0.1) + gr.Markdown("Stage 2: Shape Generation") + with gr.Row(): + shape_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=7.5, step=0.1) + shape_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.5, step=0.01) + shape_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1) + shape_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1) + gr.Markdown("Stage 3: Material Generation") + with gr.Row(): + tex_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=1.0, step=0.1) + tex_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.0, step=0.01) + tex_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1) + tex_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1) + + with gr.Column(scale=10): + with gr.Walkthrough(selected=0) as walkthrough: + with gr.Step("Preview", id=0): + preview_output = gr.HTML(empty_html, label="3D Asset Preview", show_label=True, container=True) + extract_btn = gr.Button("Extract GLB") + with gr.Step("Extract", id=1): + glb_output = gr.Model3D(label="Extracted GLB", height=724, show_label=True, display_mode="solid", clear_color=(0.25, 0.25, 0.25, 1.0)) + download_btn = gr.DownloadButton(label="Download GLB") + + with gr.Column(scale=1, min_width=172): + examples = gr.Examples( + examples=[ + f'assets/example_image/{image}' + for image in os.listdir("assets/example_image") + ], + inputs=[image_prompt], + fn=preprocess_image, + outputs=[image_prompt], + run_on_click=True, + examples_per_page=18, + ) + + output_buf = gr.State() + + + # Handlers + demo.load(start_session) + demo.unload(end_session) + + image_prompt.upload( + preprocess_image, + inputs=[image_prompt], + outputs=[image_prompt], + ) + + generate_btn.click( + get_seed, + inputs=[randomize_seed, seed], + outputs=[seed], + ).then( + lambda: gr.Walkthrough(selected=0), outputs=walkthrough + ).then( + image_to_3d, + inputs=[ + image_prompt, seed, resolution, + ss_guidance_strength, ss_guidance_rescale, ss_sampling_steps, ss_rescale_t, + shape_slat_guidance_strength, shape_slat_guidance_rescale, shape_slat_sampling_steps, shape_slat_rescale_t, + tex_slat_guidance_strength, tex_slat_guidance_rescale, tex_slat_sampling_steps, tex_slat_rescale_t, + ], + outputs=[output_buf, preview_output], + ) + + extract_btn.click( + lambda: gr.Walkthrough(selected=1), outputs=walkthrough + ).then( + extract_glb, + inputs=[output_buf, decimation_target, texture_size], + outputs=[glb_output, download_btn], + ) + + +# Launch the Gradio app +if __name__ == "__main__": + os.makedirs(TMP_DIR, exist_ok=True) + + # Construct ui components + btn_img_base64_strs = {} + for i in range(len(MODES)): + icon = Image.open(MODES[i]['icon']) + MODES[i]['icon_base64'] = image_to_base64(icon) + + pipeline = Trellis2ImageTo3DPipeline.from_pretrained('microsoft/TRELLIS.2-4B') + pipeline.cuda() + + envmap = { + 'forest': EnvMap(torch.tensor( + cv2.cvtColor(cv2.imread('assets/hdri/forest.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), + dtype=torch.float32, device='cuda' + )), + 'sunset': EnvMap(torch.tensor( + cv2.cvtColor(cv2.imread('assets/hdri/sunset.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), + dtype=torch.float32, device='cuda' + )), + 'courtyard': EnvMap(torch.tensor( + cv2.cvtColor(cv2.imread('assets/hdri/courtyard.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), + dtype=torch.float32, device='cuda' + )), + } + + demo.launch(css=css, head=head) diff --git a/assets/app/basecolor.png b/assets/app/basecolor.png new file mode 100644 index 0000000..e7dbeaf Binary files /dev/null and b/assets/app/basecolor.png differ diff --git a/assets/app/clay.png b/assets/app/clay.png new file mode 100644 index 0000000..e02866a Binary files /dev/null and b/assets/app/clay.png differ diff --git a/assets/app/hdri_city.png b/assets/app/hdri_city.png new file mode 100644 index 0000000..43e1c2a Binary files /dev/null and b/assets/app/hdri_city.png differ diff --git a/assets/app/hdri_courtyard.png b/assets/app/hdri_courtyard.png new file mode 100644 index 0000000..4261ad6 Binary files /dev/null and b/assets/app/hdri_courtyard.png differ diff --git a/assets/app/hdri_forest.png b/assets/app/hdri_forest.png new file mode 100644 index 0000000..7617fe1 Binary files /dev/null and b/assets/app/hdri_forest.png differ diff --git a/assets/app/hdri_interior.png b/assets/app/hdri_interior.png new file mode 100644 index 0000000..e00c1d6 Binary files /dev/null and b/assets/app/hdri_interior.png differ diff --git a/assets/app/hdri_night.png b/assets/app/hdri_night.png new file mode 100644 index 0000000..f0423d2 Binary files /dev/null and b/assets/app/hdri_night.png differ diff --git a/assets/app/hdri_studio.png b/assets/app/hdri_studio.png new file mode 100644 index 0000000..0f5a4e8 Binary files /dev/null and b/assets/app/hdri_studio.png differ diff --git a/assets/app/hdri_sunrise.png b/assets/app/hdri_sunrise.png new file mode 100644 index 0000000..9cee3bb Binary files /dev/null and b/assets/app/hdri_sunrise.png differ diff --git a/assets/app/hdri_sunset.png b/assets/app/hdri_sunset.png new file mode 100644 index 0000000..bd67070 Binary files /dev/null and b/assets/app/hdri_sunset.png differ diff --git a/assets/app/normal.png b/assets/app/normal.png new file mode 100644 index 0000000..352e92b Binary files /dev/null and b/assets/app/normal.png differ diff --git a/assets/example_image/0a34fae7ba57cb8870df5325b9c30ea474def1b0913c19c596655b85a79fdee4.webp b/assets/example_image/0a34fae7ba57cb8870df5325b9c30ea474def1b0913c19c596655b85a79fdee4.webp new file mode 100644 index 0000000..4522b59 Binary files /dev/null and b/assets/example_image/0a34fae7ba57cb8870df5325b9c30ea474def1b0913c19c596655b85a79fdee4.webp differ diff --git a/assets/example_image/0e4984a9b3765ce80e9853443f9319ecedf90885c74b56cccfebc09402740f8a.webp b/assets/example_image/0e4984a9b3765ce80e9853443f9319ecedf90885c74b56cccfebc09402740f8a.webp new file mode 100644 index 0000000..712c0e0 Binary files /dev/null and b/assets/example_image/0e4984a9b3765ce80e9853443f9319ecedf90885c74b56cccfebc09402740f8a.webp differ diff --git a/assets/example_image/0f168a4b1b6e96c72e9627c97a212c27a4572250ff58e25703b9d0c2bc74191a.webp b/assets/example_image/0f168a4b1b6e96c72e9627c97a212c27a4572250ff58e25703b9d0c2bc74191a.webp new file mode 100644 index 0000000..d565e59 Binary files /dev/null and b/assets/example_image/0f168a4b1b6e96c72e9627c97a212c27a4572250ff58e25703b9d0c2bc74191a.webp differ diff --git a/assets/example_image/130c2b18f1651a70f8aa15b2c99f8dba29bb943044d92871f9223bd3e989e8b1.webp b/assets/example_image/130c2b18f1651a70f8aa15b2c99f8dba29bb943044d92871f9223bd3e989e8b1.webp new file mode 100644 index 0000000..07fa20c Binary files /dev/null and b/assets/example_image/130c2b18f1651a70f8aa15b2c99f8dba29bb943044d92871f9223bd3e989e8b1.webp differ diff --git a/assets/example_image/154c88671d9e8785bd909e9283bc87fb2709ac7ce13890832603ea7533981a46.webp b/assets/example_image/154c88671d9e8785bd909e9283bc87fb2709ac7ce13890832603ea7533981a46.webp new file mode 100644 index 0000000..3c3ac5d Binary files /dev/null and b/assets/example_image/154c88671d9e8785bd909e9283bc87fb2709ac7ce13890832603ea7533981a46.webp differ diff --git a/assets/example_image/1c359e94f2d699055c78487c90626cf5f1d7460c8fc04e60a286507e5286a28d.webp b/assets/example_image/1c359e94f2d699055c78487c90626cf5f1d7460c8fc04e60a286507e5286a28d.webp new file mode 100644 index 0000000..e596271 Binary files /dev/null and b/assets/example_image/1c359e94f2d699055c78487c90626cf5f1d7460c8fc04e60a286507e5286a28d.webp differ diff --git a/assets/example_image/22a868bac8e62511fccd2bc82ed31ae77ed31ae2a8a149be7150957f11b30c9b.webp b/assets/example_image/22a868bac8e62511fccd2bc82ed31ae77ed31ae2a8a149be7150957f11b30c9b.webp new file mode 100644 index 0000000..679c0e7 Binary files /dev/null and b/assets/example_image/22a868bac8e62511fccd2bc82ed31ae77ed31ae2a8a149be7150957f11b30c9b.webp differ diff --git a/assets/example_image/25d412fe36aab9f33913bc9f5e2fb1ff6458bdb286bf14397162c672c95d3697.webp b/assets/example_image/25d412fe36aab9f33913bc9f5e2fb1ff6458bdb286bf14397162c672c95d3697.webp new file mode 100644 index 0000000..79631f5 Binary files /dev/null and b/assets/example_image/25d412fe36aab9f33913bc9f5e2fb1ff6458bdb286bf14397162c672c95d3697.webp differ diff --git a/assets/example_image/26717a7dad644a5cf7554e8e6d06cf82d3dd9bbae31620b36cc7eb38b8de7ac9.webp b/assets/example_image/26717a7dad644a5cf7554e8e6d06cf82d3dd9bbae31620b36cc7eb38b8de7ac9.webp new file mode 100644 index 0000000..b833f8e Binary files /dev/null and b/assets/example_image/26717a7dad644a5cf7554e8e6d06cf82d3dd9bbae31620b36cc7eb38b8de7ac9.webp differ diff --git a/assets/example_image/290af2dd390c95db88a35b8062fdd2ac1a9c28edc6533bc6a26ab2c83c523c61.webp b/assets/example_image/290af2dd390c95db88a35b8062fdd2ac1a9c28edc6533bc6a26ab2c83c523c61.webp new file mode 100644 index 0000000..c9c24a4 Binary files /dev/null and b/assets/example_image/290af2dd390c95db88a35b8062fdd2ac1a9c28edc6533bc6a26ab2c83c523c61.webp differ diff --git a/assets/example_image/2bb0932314bae71eec94d0d01a20d3f761ade9664e013b9a9a43c00a2f44163a.webp b/assets/example_image/2bb0932314bae71eec94d0d01a20d3f761ade9664e013b9a9a43c00a2f44163a.webp new file mode 100644 index 0000000..58d29d8 Binary files /dev/null and b/assets/example_image/2bb0932314bae71eec94d0d01a20d3f761ade9664e013b9a9a43c00a2f44163a.webp differ diff --git a/assets/example_image/3723615e3766742ae35b09517152a58c36d62b707bc60d7f76f8a6c922add2c0.webp b/assets/example_image/3723615e3766742ae35b09517152a58c36d62b707bc60d7f76f8a6c922add2c0.webp new file mode 100644 index 0000000..a05f7c6 Binary files /dev/null and b/assets/example_image/3723615e3766742ae35b09517152a58c36d62b707bc60d7f76f8a6c922add2c0.webp differ diff --git a/assets/example_image/3903b87907a6b4947006e6fc7c0c64f40cd98932a02bf0ecf7d6dfae776f3a38.webp b/assets/example_image/3903b87907a6b4947006e6fc7c0c64f40cd98932a02bf0ecf7d6dfae776f3a38.webp new file mode 100644 index 0000000..7411bc6 Binary files /dev/null and b/assets/example_image/3903b87907a6b4947006e6fc7c0c64f40cd98932a02bf0ecf7d6dfae776f3a38.webp differ diff --git a/assets/example_image/39488b45bb4820ff0f31bb07cb8d0a19ebd991adbcb22a10fc89ee41c59219ee.webp b/assets/example_image/39488b45bb4820ff0f31bb07cb8d0a19ebd991adbcb22a10fc89ee41c59219ee.webp new file mode 100644 index 0000000..5f00cc2 Binary files /dev/null and b/assets/example_image/39488b45bb4820ff0f31bb07cb8d0a19ebd991adbcb22a10fc89ee41c59219ee.webp differ diff --git a/assets/example_image/454e7d8a30486c0635369936e7bec5677b78ae5f436d0e46af0d533738be859f.webp b/assets/example_image/454e7d8a30486c0635369936e7bec5677b78ae5f436d0e46af0d533738be859f.webp new file mode 100644 index 0000000..84c7d46 Binary files /dev/null and b/assets/example_image/454e7d8a30486c0635369936e7bec5677b78ae5f436d0e46af0d533738be859f.webp differ diff --git a/assets/example_image/4bc7abe209c8673dd3766ee4fad14d40acbed02d118e7629f645c60fd77313f1.webp b/assets/example_image/4bc7abe209c8673dd3766ee4fad14d40acbed02d118e7629f645c60fd77313f1.webp new file mode 100644 index 0000000..8f5902c Binary files /dev/null and b/assets/example_image/4bc7abe209c8673dd3766ee4fad14d40acbed02d118e7629f645c60fd77313f1.webp differ diff --git a/assets/example_image/4dae7ef0224e9305533c4801ce8144d5b3a89d883ca5d35bdb0aebb860ff705f.webp b/assets/example_image/4dae7ef0224e9305533c4801ce8144d5b3a89d883ca5d35bdb0aebb860ff705f.webp new file mode 100644 index 0000000..016b45b Binary files /dev/null and b/assets/example_image/4dae7ef0224e9305533c4801ce8144d5b3a89d883ca5d35bdb0aebb860ff705f.webp differ diff --git a/assets/example_image/50b70c5f88a5961d2c786158655d2fce5c3b214b2717956500a66a4e5b5fbe37.webp b/assets/example_image/50b70c5f88a5961d2c786158655d2fce5c3b214b2717956500a66a4e5b5fbe37.webp new file mode 100644 index 0000000..2dd5403 Binary files /dev/null and b/assets/example_image/50b70c5f88a5961d2c786158655d2fce5c3b214b2717956500a66a4e5b5fbe37.webp differ diff --git a/assets/example_image/51b1b31d40476b123db70a51ae0b5f8b8d0db695b616bc2ec4e6324eb178fc14.webp b/assets/example_image/51b1b31d40476b123db70a51ae0b5f8b8d0db695b616bc2ec4e6324eb178fc14.webp new file mode 100644 index 0000000..9886218 Binary files /dev/null and b/assets/example_image/51b1b31d40476b123db70a51ae0b5f8b8d0db695b616bc2ec4e6324eb178fc14.webp differ diff --git a/assets/example_image/52284bf45134c59a94be150a5b18b9cc3619ada4b30ded8d8d0288383b8c016f.webp b/assets/example_image/52284bf45134c59a94be150a5b18b9cc3619ada4b30ded8d8d0288383b8c016f.webp new file mode 100644 index 0000000..fd577dc Binary files /dev/null and b/assets/example_image/52284bf45134c59a94be150a5b18b9cc3619ada4b30ded8d8d0288383b8c016f.webp differ diff --git a/assets/example_image/5a020584b95cf3db3b6420e9b09fb93e7c0f4046e61076e5b4c65c63dc1f5837.webp b/assets/example_image/5a020584b95cf3db3b6420e9b09fb93e7c0f4046e61076e5b4c65c63dc1f5837.webp new file mode 100644 index 0000000..1b66d25 Binary files /dev/null and b/assets/example_image/5a020584b95cf3db3b6420e9b09fb93e7c0f4046e61076e5b4c65c63dc1f5837.webp differ diff --git a/assets/example_image/5a6c81d3b2afca4323e4b8b379e2cf06d18371a57fc8c5dc24b57e60e3216690.webp b/assets/example_image/5a6c81d3b2afca4323e4b8b379e2cf06d18371a57fc8c5dc24b57e60e3216690.webp new file mode 100644 index 0000000..4c4cb39 Binary files /dev/null and b/assets/example_image/5a6c81d3b2afca4323e4b8b379e2cf06d18371a57fc8c5dc24b57e60e3216690.webp differ diff --git a/assets/example_image/5c80e5e03a3b60b6f03eaf555ba1dafc0e4230c472d7e8c8e2c5ca0a0dfcef10.webp b/assets/example_image/5c80e5e03a3b60b6f03eaf555ba1dafc0e4230c472d7e8c8e2c5ca0a0dfcef10.webp new file mode 100644 index 0000000..79b9221 Binary files /dev/null and b/assets/example_image/5c80e5e03a3b60b6f03eaf555ba1dafc0e4230c472d7e8c8e2c5ca0a0dfcef10.webp differ diff --git a/assets/example_image/61fea9d08e0bd9a067c9f696621dc89165afb5aab318d0701bc025d7863dabf0.webp b/assets/example_image/61fea9d08e0bd9a067c9f696621dc89165afb5aab318d0701bc025d7863dabf0.webp new file mode 100644 index 0000000..57b1643 Binary files /dev/null and b/assets/example_image/61fea9d08e0bd9a067c9f696621dc89165afb5aab318d0701bc025d7863dabf0.webp differ diff --git a/assets/example_image/65433d02fc56dae164719ec29cb9646c0383aa1d0e24f0bb592899f08428d68e.webp b/assets/example_image/65433d02fc56dae164719ec29cb9646c0383aa1d0e24f0bb592899f08428d68e.webp new file mode 100644 index 0000000..5204c86 Binary files /dev/null and b/assets/example_image/65433d02fc56dae164719ec29cb9646c0383aa1d0e24f0bb592899f08428d68e.webp differ diff --git a/assets/example_image/6b6d89d46d7f53e6409dbe695a9ef8f97c5257e641da35015a78579e903acdad.webp b/assets/example_image/6b6d89d46d7f53e6409dbe695a9ef8f97c5257e641da35015a78579e903acdad.webp new file mode 100644 index 0000000..8f787a2 Binary files /dev/null and b/assets/example_image/6b6d89d46d7f53e6409dbe695a9ef8f97c5257e641da35015a78579e903acdad.webp differ diff --git a/assets/example_image/74fe541e8c8eac8d0b5d8ba144307f6c07ed832cd19bf1d431c74292002028cd.webp b/assets/example_image/74fe541e8c8eac8d0b5d8ba144307f6c07ed832cd19bf1d431c74292002028cd.webp new file mode 100644 index 0000000..54a78c8 Binary files /dev/null and b/assets/example_image/74fe541e8c8eac8d0b5d8ba144307f6c07ed832cd19bf1d431c74292002028cd.webp differ diff --git a/assets/example_image/799ab13a23fe319a6876b8bf48007d0374d514f5e7aa31210e9b2cecfbace082.webp b/assets/example_image/799ab13a23fe319a6876b8bf48007d0374d514f5e7aa31210e9b2cecfbace082.webp new file mode 100644 index 0000000..f25880a Binary files /dev/null and b/assets/example_image/799ab13a23fe319a6876b8bf48007d0374d514f5e7aa31210e9b2cecfbace082.webp differ diff --git a/assets/example_image/7b540da337f576ffce2adc36c7459b9bbbfd845ab2160a6abbe986f1f906f6cd.webp b/assets/example_image/7b540da337f576ffce2adc36c7459b9bbbfd845ab2160a6abbe986f1f906f6cd.webp new file mode 100644 index 0000000..b7a521f Binary files /dev/null and b/assets/example_image/7b540da337f576ffce2adc36c7459b9bbbfd845ab2160a6abbe986f1f906f6cd.webp differ diff --git a/assets/example_image/7baa867b4790b8596ee120f9b171b727fd9428c41980577a518505507c99d8a0.webp b/assets/example_image/7baa867b4790b8596ee120f9b171b727fd9428c41980577a518505507c99d8a0.webp new file mode 100644 index 0000000..baba6ce Binary files /dev/null and b/assets/example_image/7baa867b4790b8596ee120f9b171b727fd9428c41980577a518505507c99d8a0.webp differ diff --git a/assets/example_image/7bd0521d20ee4805d1462a0ffb7d9aacc15180c2b741c9ac42a0d82ad3d340cb.webp b/assets/example_image/7bd0521d20ee4805d1462a0ffb7d9aacc15180c2b741c9ac42a0d82ad3d340cb.webp new file mode 100644 index 0000000..0de4b3c Binary files /dev/null and b/assets/example_image/7bd0521d20ee4805d1462a0ffb7d9aacc15180c2b741c9ac42a0d82ad3d340cb.webp differ diff --git a/assets/example_image/7d585a8475db078593486367d98b5efa9368a60a3528c555b96026a1a674aa54.webp b/assets/example_image/7d585a8475db078593486367d98b5efa9368a60a3528c555b96026a1a674aa54.webp new file mode 100644 index 0000000..883d922 Binary files /dev/null and b/assets/example_image/7d585a8475db078593486367d98b5efa9368a60a3528c555b96026a1a674aa54.webp differ diff --git a/assets/example_image/7d6f4da4eafcc60243daf6ed210853df394a8bad7e701cadf551e21abcc77869.webp b/assets/example_image/7d6f4da4eafcc60243daf6ed210853df394a8bad7e701cadf551e21abcc77869.webp new file mode 100644 index 0000000..2843ff2 Binary files /dev/null and b/assets/example_image/7d6f4da4eafcc60243daf6ed210853df394a8bad7e701cadf551e21abcc77869.webp differ diff --git a/assets/example_image/7d7659d5943e85a73a4ffe33c6dd48f5d79601e9bf11b103516f419ce9fbf713.webp b/assets/example_image/7d7659d5943e85a73a4ffe33c6dd48f5d79601e9bf11b103516f419ce9fbf713.webp new file mode 100644 index 0000000..7a90b8a Binary files /dev/null and b/assets/example_image/7d7659d5943e85a73a4ffe33c6dd48f5d79601e9bf11b103516f419ce9fbf713.webp differ diff --git a/assets/example_image/80ad7988fc2ce62fc655b21a8950865566ec3f5a8b4398f2502db6414a3e6834.webp b/assets/example_image/80ad7988fc2ce62fc655b21a8950865566ec3f5a8b4398f2502db6414a3e6834.webp new file mode 100644 index 0000000..f268f16 Binary files /dev/null and b/assets/example_image/80ad7988fc2ce62fc655b21a8950865566ec3f5a8b4398f2502db6414a3e6834.webp differ diff --git a/assets/example_image/8aa698c59aab48d4ce69a558d9159107890e3d64e522af404d9635ad0be21f88.webp b/assets/example_image/8aa698c59aab48d4ce69a558d9159107890e3d64e522af404d9635ad0be21f88.webp new file mode 100644 index 0000000..0e5c91a Binary files /dev/null and b/assets/example_image/8aa698c59aab48d4ce69a558d9159107890e3d64e522af404d9635ad0be21f88.webp differ diff --git a/assets/example_image/8ce83f6a28910e755902de10918672e77dd23476f43f0f1521c48667de6cea84.webp b/assets/example_image/8ce83f6a28910e755902de10918672e77dd23476f43f0f1521c48667de6cea84.webp new file mode 100644 index 0000000..f0aa5be Binary files /dev/null and b/assets/example_image/8ce83f6a28910e755902de10918672e77dd23476f43f0f1521c48667de6cea84.webp differ diff --git a/assets/example_image/8e12cf0977c0476396e7112f04b73d4d73569421173fcb553213d45030bddec3.webp b/assets/example_image/8e12cf0977c0476396e7112f04b73d4d73569421173fcb553213d45030bddec3.webp new file mode 100644 index 0000000..be60e0f Binary files /dev/null and b/assets/example_image/8e12cf0977c0476396e7112f04b73d4d73569421173fcb553213d45030bddec3.webp differ diff --git a/assets/example_image/901d8de4c2011a8502a0decd0adec0fc7418f26165cd52ced64fd44f720353ef.webp b/assets/example_image/901d8de4c2011a8502a0decd0adec0fc7418f26165cd52ced64fd44f720353ef.webp new file mode 100644 index 0000000..bcee0ee Binary files /dev/null and b/assets/example_image/901d8de4c2011a8502a0decd0adec0fc7418f26165cd52ced64fd44f720353ef.webp differ diff --git a/assets/example_image/95db3c13622788ec311ae4dfa24dd88732c66ca5e340a0bf3465d2a528204037.webp b/assets/example_image/95db3c13622788ec311ae4dfa24dd88732c66ca5e340a0bf3465d2a528204037.webp new file mode 100644 index 0000000..2b782b4 Binary files /dev/null and b/assets/example_image/95db3c13622788ec311ae4dfa24dd88732c66ca5e340a0bf3465d2a528204037.webp differ diff --git a/assets/example_image/9c306c7bd0e857285f536fb500c0828e5fad4e23c3ceeab92c888c568fa19101.webp b/assets/example_image/9c306c7bd0e857285f536fb500c0828e5fad4e23c3ceeab92c888c568fa19101.webp new file mode 100644 index 0000000..9b82e77 Binary files /dev/null and b/assets/example_image/9c306c7bd0e857285f536fb500c0828e5fad4e23c3ceeab92c888c568fa19101.webp differ diff --git a/assets/example_image/T.png b/assets/example_image/T.png new file mode 100755 index 0000000..187c772 Binary files /dev/null and b/assets/example_image/T.png differ diff --git a/assets/example_image/a13d176cd7a7d457b42d1b32223bcff1a45dafbbb42c6a272b97d65ac2f2eb52.webp b/assets/example_image/a13d176cd7a7d457b42d1b32223bcff1a45dafbbb42c6a272b97d65ac2f2eb52.webp new file mode 100644 index 0000000..112d8a2 Binary files /dev/null and b/assets/example_image/a13d176cd7a7d457b42d1b32223bcff1a45dafbbb42c6a272b97d65ac2f2eb52.webp differ diff --git a/assets/example_image/a306e2ee5cbc3da45e7db48d75a0cade0bb7eee263a74bc6820c617afaba1302.webp b/assets/example_image/a306e2ee5cbc3da45e7db48d75a0cade0bb7eee263a74bc6820c617afaba1302.webp new file mode 100644 index 0000000..bf6e730 Binary files /dev/null and b/assets/example_image/a306e2ee5cbc3da45e7db48d75a0cade0bb7eee263a74bc6820c617afaba1302.webp differ diff --git a/assets/example_image/a3d0c28c7d9c6f23adb941c4def2523572c903a94469abcaa7dd1398d28af8f1.webp b/assets/example_image/a3d0c28c7d9c6f23adb941c4def2523572c903a94469abcaa7dd1398d28af8f1.webp new file mode 100644 index 0000000..677c79b Binary files /dev/null and b/assets/example_image/a3d0c28c7d9c6f23adb941c4def2523572c903a94469abcaa7dd1398d28af8f1.webp differ diff --git a/assets/example_image/a63d2595e10229067b19cb167fe2bdc152dabfd8b62ae45fc1655a4cf66509bc.webp b/assets/example_image/a63d2595e10229067b19cb167fe2bdc152dabfd8b62ae45fc1655a4cf66509bc.webp new file mode 100644 index 0000000..5123114 Binary files /dev/null and b/assets/example_image/a63d2595e10229067b19cb167fe2bdc152dabfd8b62ae45fc1655a4cf66509bc.webp differ diff --git a/assets/example_image/ab3bb3e183991253ae66c06d44dc6105f3c113a1a1f819ab57a93c6f60b0d32b.webp b/assets/example_image/ab3bb3e183991253ae66c06d44dc6105f3c113a1a1f819ab57a93c6f60b0d32b.webp new file mode 100644 index 0000000..6e8a099 Binary files /dev/null and b/assets/example_image/ab3bb3e183991253ae66c06d44dc6105f3c113a1a1f819ab57a93c6f60b0d32b.webp differ diff --git a/assets/example_image/b205f4483c47bd1fec8e229163361e4fdff9f77923c5e968343b8f1dd76b61dc.webp b/assets/example_image/b205f4483c47bd1fec8e229163361e4fdff9f77923c5e968343b8f1dd76b61dc.webp new file mode 100644 index 0000000..fe4f257 Binary files /dev/null and b/assets/example_image/b205f4483c47bd1fec8e229163361e4fdff9f77923c5e968343b8f1dd76b61dc.webp differ diff --git a/assets/example_image/b358d0eb96a68ac4ba1f2fb6d44ea2225f95fdfbf9cf4e0da08650c3704f1d23.webp b/assets/example_image/b358d0eb96a68ac4ba1f2fb6d44ea2225f95fdfbf9cf4e0da08650c3704f1d23.webp new file mode 100644 index 0000000..bbff41b Binary files /dev/null and b/assets/example_image/b358d0eb96a68ac4ba1f2fb6d44ea2225f95fdfbf9cf4e0da08650c3704f1d23.webp differ diff --git a/assets/example_image/bb3190891dd8341c9d6d3d4faa6525c6ecdac19945526904928f6bcd2f3f45f1.webp b/assets/example_image/bb3190891dd8341c9d6d3d4faa6525c6ecdac19945526904928f6bcd2f3f45f1.webp new file mode 100644 index 0000000..dcb1ada Binary files /dev/null and b/assets/example_image/bb3190891dd8341c9d6d3d4faa6525c6ecdac19945526904928f6bcd2f3f45f1.webp differ diff --git a/assets/example_image/be7deb26f4fdd2080d4288668af4c39e526564282c579559ff8a4126ca4ed6c1.webp b/assets/example_image/be7deb26f4fdd2080d4288668af4c39e526564282c579559ff8a4126ca4ed6c1.webp new file mode 100644 index 0000000..3fcff55 Binary files /dev/null and b/assets/example_image/be7deb26f4fdd2080d4288668af4c39e526564282c579559ff8a4126ca4ed6c1.webp differ diff --git a/assets/example_image/c2125d086c2529638841f38918ae1defbf33e6796d827253885b4c51e601034f.webp b/assets/example_image/c2125d086c2529638841f38918ae1defbf33e6796d827253885b4c51e601034f.webp new file mode 100644 index 0000000..ada14ed Binary files /dev/null and b/assets/example_image/c2125d086c2529638841f38918ae1defbf33e6796d827253885b4c51e601034f.webp differ diff --git a/assets/example_image/c3d714bc125f06ce1187799d5ca10736b4064a24c141e627089aad2bdedf7aa5.webp b/assets/example_image/c3d714bc125f06ce1187799d5ca10736b4064a24c141e627089aad2bdedf7aa5.webp new file mode 100644 index 0000000..213fe68 Binary files /dev/null and b/assets/example_image/c3d714bc125f06ce1187799d5ca10736b4064a24c141e627089aad2bdedf7aa5.webp differ diff --git a/assets/example_image/c9340e744541f310bf89838f652602961d3e5950b31cd349bcbfc7e59e15cd2e.webp b/assets/example_image/c9340e744541f310bf89838f652602961d3e5950b31cd349bcbfc7e59e15cd2e.webp new file mode 100644 index 0000000..6c33a0b Binary files /dev/null and b/assets/example_image/c9340e744541f310bf89838f652602961d3e5950b31cd349bcbfc7e59e15cd2e.webp differ diff --git a/assets/example_image/cd3c309f17eee5ad6afe4e001765893ade20b653f611365c93d158286b4cee96.webp b/assets/example_image/cd3c309f17eee5ad6afe4e001765893ade20b653f611365c93d158286b4cee96.webp new file mode 100644 index 0000000..6532f7c Binary files /dev/null and b/assets/example_image/cd3c309f17eee5ad6afe4e001765893ade20b653f611365c93d158286b4cee96.webp differ diff --git a/assets/example_image/cdf996a6cc218918eeb90209891ce306a230e6d9cca2a3d9bbb37c6d7b6bd318.webp b/assets/example_image/cdf996a6cc218918eeb90209891ce306a230e6d9cca2a3d9bbb37c6d7b6bd318.webp new file mode 100644 index 0000000..a5d7046 Binary files /dev/null and b/assets/example_image/cdf996a6cc218918eeb90209891ce306a230e6d9cca2a3d9bbb37c6d7b6bd318.webp differ diff --git a/assets/example_image/d39c2bd426456bd686de33f924524d18eb47343a5f080826aa3cb8e77de5147b.webp b/assets/example_image/d39c2bd426456bd686de33f924524d18eb47343a5f080826aa3cb8e77de5147b.webp new file mode 100644 index 0000000..dc233f0 Binary files /dev/null and b/assets/example_image/d39c2bd426456bd686de33f924524d18eb47343a5f080826aa3cb8e77de5147b.webp differ diff --git a/assets/example_image/d64c94dffdadf82d46004d11412b5a3b2a17f1b4ddb428477a7ba38652adf973.webp b/assets/example_image/d64c94dffdadf82d46004d11412b5a3b2a17f1b4ddb428477a7ba38652adf973.webp new file mode 100644 index 0000000..e78c058 Binary files /dev/null and b/assets/example_image/d64c94dffdadf82d46004d11412b5a3b2a17f1b4ddb428477a7ba38652adf973.webp differ diff --git a/assets/example_image/dd4c51c13a996b9eec9c954a45cd5cd457059bf9f030aadde48d88225a9f3321.webp b/assets/example_image/dd4c51c13a996b9eec9c954a45cd5cd457059bf9f030aadde48d88225a9f3321.webp new file mode 100644 index 0000000..65b852d Binary files /dev/null and b/assets/example_image/dd4c51c13a996b9eec9c954a45cd5cd457059bf9f030aadde48d88225a9f3321.webp differ diff --git a/assets/example_image/e10465728ebea1e055524f97ac5d47cebf82a672f07a05409aa07d826c9d9f37.webp b/assets/example_image/e10465728ebea1e055524f97ac5d47cebf82a672f07a05409aa07d826c9d9f37.webp new file mode 100644 index 0000000..a773e5c Binary files /dev/null and b/assets/example_image/e10465728ebea1e055524f97ac5d47cebf82a672f07a05409aa07d826c9d9f37.webp differ diff --git a/assets/example_image/e134444178eae855cfdefb9e5259d076df5e34f780ee44d4ad604483ff69cc74.webp b/assets/example_image/e134444178eae855cfdefb9e5259d076df5e34f780ee44d4ad604483ff69cc74.webp new file mode 100644 index 0000000..6430da9 Binary files /dev/null and b/assets/example_image/e134444178eae855cfdefb9e5259d076df5e34f780ee44d4ad604483ff69cc74.webp differ diff --git a/assets/example_image/e3c57169ce3d5ce10b3c10acef20b81ca774b54a17aabe74e8aca320c7b07b55.webp b/assets/example_image/e3c57169ce3d5ce10b3c10acef20b81ca774b54a17aabe74e8aca320c7b07b55.webp new file mode 100644 index 0000000..9492191 Binary files /dev/null and b/assets/example_image/e3c57169ce3d5ce10b3c10acef20b81ca774b54a17aabe74e8aca320c7b07b55.webp differ diff --git a/assets/example_image/e4d6b2f3a18c3e0f5146a5b40cda6c95d7f69372b2e741c023e5ec9661deda2b.webp b/assets/example_image/e4d6b2f3a18c3e0f5146a5b40cda6c95d7f69372b2e741c023e5ec9661deda2b.webp new file mode 100644 index 0000000..689df60 Binary files /dev/null and b/assets/example_image/e4d6b2f3a18c3e0f5146a5b40cda6c95d7f69372b2e741c023e5ec9661deda2b.webp differ diff --git a/assets/example_image/e513fcd6c897b249fc4bff54268b4d0bbab6403503ecf3846d92feb892536e5e.webp b/assets/example_image/e513fcd6c897b249fc4bff54268b4d0bbab6403503ecf3846d92feb892536e5e.webp new file mode 100644 index 0000000..e64acde Binary files /dev/null and b/assets/example_image/e513fcd6c897b249fc4bff54268b4d0bbab6403503ecf3846d92feb892536e5e.webp differ diff --git a/assets/example_image/ebd09565cf0b6593aced573dffdfff34915aa359c60ec5dd0b30cd91a7f153c8.webp b/assets/example_image/ebd09565cf0b6593aced573dffdfff34915aa359c60ec5dd0b30cd91a7f153c8.webp new file mode 100644 index 0000000..c9fb3fb Binary files /dev/null and b/assets/example_image/ebd09565cf0b6593aced573dffdfff34915aa359c60ec5dd0b30cd91a7f153c8.webp differ diff --git a/assets/example_image/ee8ecf658fde9c58830c021b2e30d0d5e7e492ef52febe7192a6c74fbf1b0472.webp b/assets/example_image/ee8ecf658fde9c58830c021b2e30d0d5e7e492ef52febe7192a6c74fbf1b0472.webp new file mode 100644 index 0000000..2e161b2 Binary files /dev/null and b/assets/example_image/ee8ecf658fde9c58830c021b2e30d0d5e7e492ef52febe7192a6c74fbf1b0472.webp differ diff --git a/assets/example_image/f351569ddc61116da4a7b929bccdab144d011f56b9603e6e72abea05236160f4.webp b/assets/example_image/f351569ddc61116da4a7b929bccdab144d011f56b9603e6e72abea05236160f4.webp new file mode 100644 index 0000000..bb2ea88 Binary files /dev/null and b/assets/example_image/f351569ddc61116da4a7b929bccdab144d011f56b9603e6e72abea05236160f4.webp differ diff --git a/assets/example_image/f5332118a0cda9cd13fe13d4be2b00437e702d1f9af51ebb6b75219a572a6ce9.webp b/assets/example_image/f5332118a0cda9cd13fe13d4be2b00437e702d1f9af51ebb6b75219a572a6ce9.webp new file mode 100644 index 0000000..73e6a55 Binary files /dev/null and b/assets/example_image/f5332118a0cda9cd13fe13d4be2b00437e702d1f9af51ebb6b75219a572a6ce9.webp differ diff --git a/assets/example_image/f8920788b704531f7a7e875afd7c5c423d62e0a987e9495c63893c2cb4d2b5dc.webp b/assets/example_image/f8920788b704531f7a7e875afd7c5c423d62e0a987e9495c63893c2cb4d2b5dc.webp new file mode 100644 index 0000000..4f6cbd0 Binary files /dev/null and b/assets/example_image/f8920788b704531f7a7e875afd7c5c423d62e0a987e9495c63893c2cb4d2b5dc.webp differ diff --git a/assets/example_image/f8a7eafe26a4f3ebd26a9e7d0289e4a40b5a93e9234e94ec3e1071c352acc65a.webp b/assets/example_image/f8a7eafe26a4f3ebd26a9e7d0289e4a40b5a93e9234e94ec3e1071c352acc65a.webp new file mode 100644 index 0000000..2e144a5 Binary files /dev/null and b/assets/example_image/f8a7eafe26a4f3ebd26a9e7d0289e4a40b5a93e9234e94ec3e1071c352acc65a.webp differ diff --git a/assets/example_image/f94e2b76494ce2cf1874611273e5fb3d76b395793bb5647492fa85c2ce0a248b.webp b/assets/example_image/f94e2b76494ce2cf1874611273e5fb3d76b395793bb5647492fa85c2ce0a248b.webp new file mode 100644 index 0000000..9c24d0e Binary files /dev/null and b/assets/example_image/f94e2b76494ce2cf1874611273e5fb3d76b395793bb5647492fa85c2ce0a248b.webp differ diff --git a/assets/example_image/fdf979f5227f24b554fca28aa71c351beb7b1be2be236b50bbe07f59e9b8a50c.webp b/assets/example_image/fdf979f5227f24b554fca28aa71c351beb7b1be2be236b50bbe07f59e9b8a50c.webp new file mode 100644 index 0000000..f90db41 Binary files /dev/null and b/assets/example_image/fdf979f5227f24b554fca28aa71c351beb7b1be2be236b50bbe07f59e9b8a50c.webp differ diff --git a/assets/hdri/city.exr b/assets/hdri/city.exr new file mode 100644 index 0000000..d922066 Binary files /dev/null and b/assets/hdri/city.exr differ diff --git a/assets/hdri/courtyard.exr b/assets/hdri/courtyard.exr new file mode 100644 index 0000000..b70a0e7 Binary files /dev/null and b/assets/hdri/courtyard.exr differ diff --git a/assets/hdri/forest.exr b/assets/hdri/forest.exr new file mode 100644 index 0000000..846b87d Binary files /dev/null and b/assets/hdri/forest.exr differ diff --git a/assets/hdri/interior.exr b/assets/hdri/interior.exr new file mode 100644 index 0000000..92d40ca Binary files /dev/null and b/assets/hdri/interior.exr differ diff --git a/assets/hdri/license.txt b/assets/hdri/license.txt new file mode 100644 index 0000000..eba5cb8 --- /dev/null +++ b/assets/hdri/license.txt @@ -0,0 +1,15 @@ +All HDRIs are licensed as CC0. + +These were created by Greg Zaal (Poly Haven https://polyhaven.com). +Originals used for each HDRI: +- City: https://polyhaven.com/a/portland_landing_pad +- Courtyard: https://polyhaven.com/a/courtyard +- Forest: https://polyhaven.com/a/ninomaru_teien +- Interior: https://polyhaven.com/a/hotel_room +- Night: Probably https://polyhaven.com/a/moonless_golf +- Studio: Probably https://polyhaven.com/a/studio_small_01 +- Sunrise: https://polyhaven.com/a/spruit_sunrise +- Sunset: https://polyhaven.com/a/venice_sunset + +1K resolution of each was taken, and compressed with oiiotool: +oiiotool input.exr --ch R,G,B -d float --compression dwab:300 --clamp:min=0.0:max=32000.0 -o output.exr diff --git a/assets/hdri/night.exr b/assets/hdri/night.exr new file mode 100644 index 0000000..a207d99 Binary files /dev/null and b/assets/hdri/night.exr differ diff --git a/assets/hdri/studio.exr b/assets/hdri/studio.exr new file mode 100644 index 0000000..baf478d Binary files /dev/null and b/assets/hdri/studio.exr differ diff --git a/assets/hdri/sunrise.exr b/assets/hdri/sunrise.exr new file mode 100644 index 0000000..985a9a5 Binary files /dev/null and b/assets/hdri/sunrise.exr differ diff --git a/assets/hdri/sunset.exr b/assets/hdri/sunset.exr new file mode 100644 index 0000000..e86206e Binary files /dev/null and b/assets/hdri/sunset.exr differ diff --git a/assets/teaser.webp b/assets/teaser.webp new file mode 100644 index 0000000..40a0431 Binary files /dev/null and b/assets/teaser.webp differ diff --git a/example.py b/example.py new file mode 100644 index 0000000..475a3b1 --- /dev/null +++ b/example.py @@ -0,0 +1,48 @@ +import os +os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" # Can save GPU memory +import cv2 +import imageio +from PIL import Image +import torch +from trellis2.pipelines import Trellis2ImageTo3DPipeline +from trellis2.utils import render_utils +from trellis2.renderers import EnvMap +import o_voxel + +# 1. Setup Environment Map +envmap = EnvMap(torch.tensor( + cv2.cvtColor(cv2.imread('assets/hdri/forest.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), + dtype=torch.float32, device='cuda' +)) + +# 2. Load Pipeline +pipeline = Trellis2ImageTo3DPipeline.from_pretrained("microsoft/TRELLIS.2-4B") +pipeline.cuda() + +# 3. Load Image & Run +image = Image.open("assets/example_image/T.png") +mesh = pipeline.run(image)[0] +mesh.simplify(16777216) # nvdiffrast limit + +# 4. Render Video +video = render_utils.make_pbr_vis_frames(render_utils.render_video(mesh, envmap=envmap)) +imageio.mimsave("sample.mp4", video, fps=15) + +# 5. Export to GLB +glb = o_voxel.postprocess.to_glb( + vertices = mesh.vertices, + faces = mesh.faces, + attr_volume = mesh.attrs, + coords = mesh.coords, + attr_layout = mesh.layout, + voxel_size = mesh.voxel_size, + aabb = [[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]], + decimation_target = 1000000, + texture_size = 4096, + remesh = True, + remesh_band = 1, + remesh_project = 0, + verbose = True +) +glb.export("sample.glb", extension_webp=True) \ No newline at end of file diff --git a/o-voxel/README.md b/o-voxel/README.md new file mode 100644 index 0000000..0dba0bf --- /dev/null +++ b/o-voxel/README.md @@ -0,0 +1,174 @@ +# O-Voxel: A Native 3D Representation + +**O-Voxel** is a sparse, voxel-based native 3D representation designed for high-quality 3D generation and reconstruction. Unlike traditional methods that rely on fields (e.g., Occupancy fields, SDFs), O-Voxel utilizes a **Flexible Dual Grid** formulation to robustly represent surfaces with arbitrary topology (including non-manifold and open surfaces) and **volumetric surface properties** such as Physically-Based Rendering (PBR) material attributes. + +This library provides an efficient implementation for the instant bidirectional conversion between Meshes and O-Voxels, along with tools for sparse voxel compression, serialization, and rendering. + +![Overview](assets/overview.webp) + +## Key Features + +- **🧱 Flexible Dual Grid**: A geometry representation that solves a enhanced QEF (Quadratic Error Function) to accurately capture sharp features and open boundaries without requiring watertight meshes. +- **🎨 Volumetric PBR Attributes**: Native support for physically-based rendering properties (Base Color, Metallic, Roughness, Opacity) aligned with the sparse voxel grid. +- **⚡ Instant Bidirectional Conversion**: Rapid `Mesh <-> O-Voxel` conversion without expensive SDF evaluation, flood-filling, or iterative optimization. +- **💾 Efficient Compression**: Supports custom `.vxz` format for compact storage of sparse voxel structures using Z-order/Hilbert curve encoding. +- **🛠️ Production Ready**: Tools to export converted assets directly to `.glb` with UV unwrapping and texture baking. + +## Installation + +```bash +git clone -b main https://github.com/microsoft/TRELLIS.2.git --recursive +pip install TRELLIS.2/o_voxel --no-build-isolation +``` + +## Quick Start + +> See also the [examples](examples) directory for more detailed usage. + +### 1. Convert Mesh to O-Voxel [[link]](examples/mesh2ovox.py) +Convert a standard 3D mesh (with textures) into the O-Voxel representation. + +```python +asset = trimesh.load("path/to/mesh.glb") + +# 1. Geometry Voxelization (Flexible Dual Grid) +# Returns: occupied indices, dual vertices (QEF solution), and edge intersected +mesh = asset.to_mesh() +vertices = torch.from_numpy(mesh.vertices).float() +faces = torch.from_numpy(mesh.faces).long() +voxel_indices, dual_vertices, intersected = o_voxel.convert.mesh_to_flexible_dual_grid( + vertices, faces, + grid_size=RES, # Resolution + aabb=[[-0.5,-0.5,-0.5],[0.5,0.5,0.5]], # Axis-aligned bounding box + face_weight=1.0, # Face term weight in QEF + boundary_weight=0.2, # Boundary term weight in QEF + regularization_weight=1e-2, # Regularization term weight in QEF + timing=True +) +## sort to ensure align between geometry and material voxelization +vid = o_voxel.serialize.encode_seq(voxel_indices) +mapping = torch.argsort(vid) +voxel_indices = voxel_indices[mapping] +dual_vertices = dual_vertices[mapping] +intersected = intersected[mapping] + +# 2. Material Voxelization (Volumetric Attributes) +# Returns: dict containing 'base_color', 'metallic', 'roughness', etc. +voxel_indices_mat, attributes = o_voxel.convert.textured_mesh_to_volumetric_attr( + asset, + grid_size=RES, + aabb=[[-0.5,-0.5,-0.5],[0.5,0.5,0.5]], + timing=True +) +## sort to ensure align between geometry and material voxelization +vid_mat = o_voxel.serialize.encode_seq(voxel_indices_mat) +mapping_mat = torch.argsort(vid_mat) +attributes = {k: v[mapping_mat] for k, v in attributes.items()} + +# Save to compressed .vxz format +## packing +dual_vertices = dual_vertices * RES - voxel_indices +dual_vertices = (torch.clamp(dual_vertices, 0, 1) * 255).type(torch.uint8) +intersected = (intersected[:, 0:1] + 2 * intersected[:, 1:2] + 4 * intersected[:, 2:3]).type(torch.uint8) +attributes['dual_vertices'] = dual_vertices +attributes['intersected'] = intersected +o_voxel.io.write("ovoxel_helmet.vxz", voxel_indices, attributes) +``` + +### 2. Recover Mesh from O-Voxel [[link]](examples/ovox2mesh.py) +Reconstruct the surface mesh from the sparse voxel data. + +```python +# Load data +coords, data = o_voxel.io.read("path/to/ovoxel.vxz") +dual_vertices = data['dual_vertices'] +intersected = data['intersected'] +base_color = data['base_color'] +## ... other attributes omitted for brevity + +# Depack +dual_vertices = dual_vertices / 255 +intersected = torch.cat([ + intersected % 2, + intersected // 2 % 2, + intersected // 4 % 2, +], dim=-1).bool() + +# Extract Mesh +# O-Voxel connects dual vertices to form quads, optionally splitting them +# based on geometric features. +rec_verts, rec_faces = o_voxel.convert.flexible_dual_grid_to_mesh( + coords.cuda(), + dual_vertices.cuda(), + intersected.cuda(), + split_weight=None, # Auto-split based on min angle if None + grid_size=RES, + aabb=[[-0.5,-0.5,-0.5],[0.5,0.5,0.5]], +) +``` + +### 3. Export to GLB [[link]](examples/ovox2glb.py) +For visualization in standard 3D viewers, you can clean, UV-unwrap, and bake the volumetric attributes into textures. + +```python +# Assuming you have the reconstructed verts/faces and volume attributes +mesh = o_voxel.postprocess.to_glb( + vertices=rec_verts, + faces=rec_faces, + attr_volume=attr_tensor, # Concatenated attributes + coords=coords, + attr_layout={'base_color': slice(0,3), 'metallic': slice(3,4), ...}, + grid_size=RES, + aabb=[[-0.5,-0.5,-0.5],[0.5,0.5,0.5]], + decimation_target=100000, + texture_size=2048, + verbose=True, +) +mesh.export("rec_helmet.glb") +``` + +### 4. Voxel Rendering [[link]](examples/render_ovox.py) +Render the voxel representation directly. + +```python +# Load data +coords, data = o_voxel.io.read("ovoxel_helmet.vxz") +position = (coords / RES - 0.5).cuda() +base_color = (data['base_color'] / 255).cuda() + +# Render +renderer = o_voxel.rasterize.VoxelRenderer( + rendering_options={"resolution": 512, "ssaa": 2} +) +output = renderer.render( + position=position, # Voxel centers + attrs=base_color, # Color/Opacity etc. + voxel_size=1.0/RES, + extrinsics=extr, + intrinsics=intr +) +# output.attr contains the rendered image (C, H, W) +``` + +## API Overview + +### `o_voxel.convert` +Core algorithms for the conversion between meshes and O-Voxels. +* `mesh_to_flexible_dual_grid`: Determines the active sparse voxels and solves the QEF to determine dual vertex positions within voxels based on mesh-voxel grid intersections. +* `flexible_dual_grid_to_mesh`: Reconnects dual vertices to form a surface. +* `textured_mesh_to_volumetric_attr`: Samples texture maps into voxel space. + +### `o_voxel.io` +Handles sparse voxel file I/O operations. +* **Formats**: `.npz` (NumPy), `.ply` (Point Cloud), `.vxz` (Custom compressed, recommended). +* **Functions**: `read()`, `write()`. + +### `o_voxel.serialize` +Utilities for spatial hashing and ordering. +* `encode_seq` / `decode_seq`: Converts 3D coordinates to/from Morton codes (Z-order) or Hilbert curves for efficient storage and processing. + +### `o_voxel.rasterize` +* `VoxelRenderer`: A lightweight renderer for sparse voxel visualization during training. + +### `o_voxel.postprocess` +* `to_glb`: A comprehensive pipeline for mesh cleaning, remeshing, UV unwrapping, and texture baking. diff --git a/o-voxel/assets/overview.webp b/o-voxel/assets/overview.webp new file mode 100644 index 0000000..0e8fb75 Binary files /dev/null and b/o-voxel/assets/overview.webp differ diff --git a/o-voxel/examples/mesh2ovox.py b/o-voxel/examples/mesh2ovox.py new file mode 100644 index 0000000..54d0c9b --- /dev/null +++ b/o-voxel/examples/mesh2ovox.py @@ -0,0 +1,57 @@ +import torch +import o_voxel +import utils + +RES = 512 + +asset = utils.get_helmet() + +# 0. Normalize asset to unit cube +aabb = asset.bounding_box.bounds +center = (aabb[0] + aabb[1]) / 2 +scale = 0.99999 / (aabb[1] - aabb[0]).max() # To avoid numerical issues +asset.apply_translation(-center) +asset.apply_scale(scale) + +# 1. Geometry Voxelization (Flexible Dual Grid) +# Returns: occupied indices, dual vertices (QEF solution), and edge intersected +mesh = asset.to_mesh() +vertices = torch.from_numpy(mesh.vertices).float() +faces = torch.from_numpy(mesh.faces).long() +voxel_indices, dual_vertices, intersected = o_voxel.convert.mesh_to_flexible_dual_grid( + vertices, faces, + grid_size=RES, # Resolution + aabb=[[-0.5,-0.5,-0.5],[0.5,0.5,0.5]], # Axis-aligned bounding box + face_weight=1.0, # Face term weight in QEF + boundary_weight=0.2, # Boundary term weight in QEF + regularization_weight=1e-2, # Regularization term weight in QEF + timing=True +) +## sort to ensure align between geometry and material voxelization +vid = o_voxel.serialize.encode_seq(voxel_indices) +mapping = torch.argsort(vid) +voxel_indices = voxel_indices[mapping] +dual_vertices = dual_vertices[mapping] +intersected = intersected[mapping] + +# 2. Material Voxelization (Volumetric Attributes) +# Returns: dict containing 'base_color', 'metallic', 'roughness', etc. +voxel_indices_mat, attributes = o_voxel.convert.textured_mesh_to_volumetric_attr( + asset, + grid_size=RES, + aabb=[[-0.5,-0.5,-0.5],[0.5,0.5,0.5]], + timing=True +) +## sort to ensure align between geometry and material voxelization +vid_mat = o_voxel.serialize.encode_seq(voxel_indices_mat) +mapping_mat = torch.argsort(vid_mat) +attributes = {k: v[mapping_mat] for k, v in attributes.items()} + +# Save to compressed .vxz format +## packing +dual_vertices = dual_vertices * RES - voxel_indices +dual_vertices = (torch.clamp(dual_vertices, 0, 1) * 255).type(torch.uint8) +intersected = (intersected[:, 0:1] + 2 * intersected[:, 1:2] + 4 * intersected[:, 2:3]).type(torch.uint8) +attributes['dual_vertices'] = dual_vertices +attributes['intersected'] = intersected +o_voxel.io.write("ovoxel_helmet.vxz", voxel_indices, attributes) \ No newline at end of file diff --git a/o-voxel/examples/ovox2glb.py b/o-voxel/examples/ovox2glb.py new file mode 100644 index 0000000..2306310 --- /dev/null +++ b/o-voxel/examples/ovox2glb.py @@ -0,0 +1,52 @@ +import torch +import o_voxel + +RES = 512 + +# Load data +coords, data = o_voxel.io.read("ovoxel_helmet.vxz") +dual_vertices = data['dual_vertices'] +intersected = data['intersected'] +base_color = data['base_color'] +metallic = data['metallic'] +roughness = data['roughness'] +alpha = data['alpha'] + +# Depack +dual_vertices = dual_vertices / 255 +intersected = torch.cat([ + intersected % 2, + intersected // 2 % 2, + intersected // 4 % 2, +], dim=-1).bool() + +# Extract Mesh +# O-Voxel connects dual vertices to form quads, optionally splitting them +# based on geometric features. +rec_verts, rec_faces = o_voxel.convert.flexible_dual_grid_to_mesh( + coords.cuda(), + dual_vertices.cuda(), + intersected.cuda(), + split_weight=None, # Auto-split based on min angle if None + grid_size=RES, + aabb=[[-0.5,-0.5,-0.5],[0.5,0.5,0.5]], +) + +# Post-process +attr_volume = torch.cat([base_color.cuda(), metallic.cuda(), roughness.cuda(), alpha.cuda()], dim=-1) / 255 +attr_layout = {'base_color': slice(0,3), 'metallic': slice(3,4), 'roughness': slice(4,5), 'alpha': slice(5,6)} +mesh = o_voxel.postprocess.to_glb( + vertices=rec_verts, + faces=rec_faces, + attr_volume=attr_volume, + coords=coords.cuda(), + attr_layout=attr_layout, + grid_size=RES, + aabb=[[-0.5,-0.5,-0.5],[0.5,0.5,0.5]], + decimation_target=100000, + texture_size=2048, + verbose=True, +) + +# Save as glb +mesh.export("rec_helmet.glb") diff --git a/o-voxel/examples/ovox2mesh.py b/o-voxel/examples/ovox2mesh.py new file mode 100644 index 0000000..eb5a6be --- /dev/null +++ b/o-voxel/examples/ovox2mesh.py @@ -0,0 +1,45 @@ +import torch +import o_voxel +import trimesh +import trimesh.visual + +RES = 512 + +# Load data +coords, data = o_voxel.io.read("ovoxel_helmet.vxz") +dual_vertices = data['dual_vertices'] +intersected = data['intersected'] +base_color = data['base_color'] +metallic = data['metallic'] +roughness = data['roughness'] +alpha = data['alpha'] + +# Depack +dual_vertices = dual_vertices / 255 +intersected = torch.cat([ + intersected % 2, + intersected // 2 % 2, + intersected // 4 % 2, +], dim=-1).bool() + +# Extract Mesh +# O-Voxel connects dual vertices to form quads, optionally splitting them +# based on geometric features. +rec_verts, rec_faces = o_voxel.convert.flexible_dual_grid_to_mesh( + coords.cuda(), + dual_vertices.cuda(), + intersected.cuda(), + split_weight=None, # Auto-split based on min angle if None + grid_size=RES, + aabb=[[-0.5,-0.5,-0.5],[0.5,0.5,0.5]], +) + +# Save as ply +visual = trimesh.visual.ColorVisuals( + vertex_colors=base_color, +) +mesh = trimesh.Trimesh( + vertices=rec_verts.cpu(), faces=rec_faces.cpu(), visual=visual, + process=False +) +mesh.export("rec_helmet.ply") diff --git a/o-voxel/examples/render_ovox.py b/o-voxel/examples/render_ovox.py new file mode 100644 index 0000000..a4a27c0 --- /dev/null +++ b/o-voxel/examples/render_ovox.py @@ -0,0 +1,39 @@ +import torch +import numpy as np +import imageio +import o_voxel +import utils3d + +RES = 512 + +# Load data +coords, data = o_voxel.io.read("ovoxel_helmet.vxz") +position = (coords / RES - 0.5).cuda() +base_color = (data['base_color'] / 255).cuda() + +# Setup camera +extr = utils3d.extrinsics_look_at( + eye=torch.tensor([1.2, 0.5, 1.2]), + look_at=torch.tensor([0.0, 0.0, 0.0]), + up=torch.tensor([0.0, 1.0, 0.0]) +).cuda() +intr = utils3d.intrinsics_from_fov_xy( + fov_x=torch.deg2rad(torch.tensor(45.0)), + fov_y=torch.deg2rad(torch.tensor(45.0)), +).cuda() + +# Render +renderer = o_voxel.rasterize.VoxelRenderer( + rendering_options={"resolution": 512, "ssaa": 2} +) +output = renderer.render( + position=position, # Voxel centers + attrs=base_color, # Color/Opacity etc. + voxel_size=1.0/RES, + extrinsics=extr, + intrinsics=intr +) +image = np.clip( + output.attr.permute(1, 2, 0).cpu().numpy() * 255, 0, 255 +).astype(np.uint8) +imageio.imwrite("ovoxel_helmet_visualization.png", image) diff --git a/o-voxel/examples/utils.py b/o-voxel/examples/utils.py new file mode 100644 index 0000000..f9a8a6a --- /dev/null +++ b/o-voxel/examples/utils.py @@ -0,0 +1,27 @@ +import os +import requests +import tarfile +import trimesh + +HELMET_URL = "https://raw.githubusercontent.com/KhronosGroup/glTF-Sample-Models/refs/heads/main/2.0/DamagedHelmet/glTF-Binary/DamagedHelmet.glb" +CACHE_DIR = os.path.join(os.path.abspath(os.path.dirname(__file__)), "cache") + + +def download_file(url, path): + print(f"Downloading from {url} ...") + resp = requests.get(url, stream=True) + resp.raise_for_status() + + with open(path, "wb") as f: + for chunk in resp.iter_content(chunk_size=8192): + f.write(chunk) + + print(f"Saved to {path}") + + +def get_helmet() -> trimesh.Trimesh: + HELMET_PATH = os.path.join(CACHE_DIR, "helmet.glb") + if not os.path.exists(HELMET_PATH): + os.makedirs(CACHE_DIR, exist_ok=True) + download_file(HELMET_URL, HELMET_PATH) + return trimesh.load(HELMET_PATH) diff --git a/o-voxel/o_voxel/__init__.py b/o-voxel/o_voxel/__init__.py new file mode 100644 index 0000000..263b753 --- /dev/null +++ b/o-voxel/o_voxel/__init__.py @@ -0,0 +1,7 @@ +from . import ( + convert, + io, + postprocess, + rasterize, + serialize +) \ No newline at end of file diff --git a/o-voxel/o_voxel/convert/__init__.py b/o-voxel/o_voxel/convert/__init__.py new file mode 100644 index 0000000..f7b6d8c --- /dev/null +++ b/o-voxel/o_voxel/convert/__init__.py @@ -0,0 +1,2 @@ +from .flexible_dual_grid import * +from .volumetic_attr import * \ No newline at end of file diff --git a/o-voxel/o_voxel/convert/flexible_dual_grid.py b/o-voxel/o_voxel/convert/flexible_dual_grid.py new file mode 100644 index 0000000..7cf1397 --- /dev/null +++ b/o-voxel/o_voxel/convert/flexible_dual_grid.py @@ -0,0 +1,283 @@ +from typing import * +import numpy as np +import torch +from .. import _C + +__all__ = [ + "mesh_to_flexible_dual_grid", + "flexible_dual_grid_to_mesh", +] + + +def _init_hashmap(grid_size, capacity, device): + VOL = (grid_size[0] * grid_size[1] * grid_size[2]).item() + + # If the number of elements in the tensor is less than 2^32, use uint32 as the hashmap type, otherwise use uint64. + if VOL < 2**32: + hashmap_keys = torch.full((capacity,), torch.iinfo(torch.uint32).max, dtype=torch.uint32, device=device) + elif VOL < 2**64: + hashmap_keys = torch.full((capacity,), torch.iinfo(torch.uint64).max, dtype=torch.uint64, device=device) + else: + raise ValueError(f"The spatial size is too large to fit in a hashmap. Get volumn {VOL} > 2^64.") + + hashmap_vals = torch.empty((capacity,), dtype=torch.uint32, device=device) + + return hashmap_keys, hashmap_vals + + +@torch.no_grad() +def mesh_to_flexible_dual_grid( + vertices: torch.Tensor, + faces: torch.Tensor, + voxel_size: Union[float, list, tuple, np.ndarray, torch.Tensor] = None, + grid_size: Union[int, list, tuple, np.ndarray, torch.Tensor] = None, + aabb: Union[list, tuple, np.ndarray, torch.Tensor] = None, + face_weight: float = 1.0, + boundary_weight: float = 1.0, + regularization_weight: float = 0.1, + timing: bool = False, +) -> Union[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Voxelize a mesh into a sparse voxel grid. + + Args: + vertices (torch.Tensor): The vertices of the mesh. + faces (torch.Tensor): The faces of the mesh. + voxel_size (float, list, tuple, np.ndarray, torch.Tensor): The size of each voxel. + grid_size (int, list, tuple, np.ndarray, torch.Tensor): The size of the grid. + NOTE: One of voxel_size and grid_size must be provided. + aabb (list, tuple, np.ndarray, torch.Tensor): The axis-aligned bounding box of the mesh. + If not provided, it will be computed automatically. + face_weight (float): The weight of the face term in the QEF when solving the dual vertices. + boundary_weight (float): The weight of the boundary term in the QEF when solving the dual vertices. + regularization_weight (float): The weight of the regularization term in the QEF when solving the dual vertices. + timing (bool): Whether to time the voxelization process. + + Returns: + torch.Tensor: The indices of the voxels that are occupied by the mesh. + The shape of the tensor is (N, 3), where N is the number of occupied voxels. + torch.Tensor: The dual vertices of the mesh. + torch.Tensor: The intersected flag of each voxel. + """ + + # Load mesh + vertices = vertices.float() + faces = faces.int() + + # Voxelize settings + assert voxel_size is not None or grid_size is not None, "Either voxel_size or grid_size must be provided" + + if voxel_size is not None: + if isinstance(voxel_size, float): + voxel_size = [voxel_size, voxel_size, voxel_size] + if isinstance(voxel_size, (list, tuple)): + voxel_size = np.array(voxel_size) + if isinstance(voxel_size, np.ndarray): + voxel_size = torch.tensor(voxel_size, dtype=torch.float32) + assert isinstance(voxel_size, torch.Tensor), f"voxel_size must be a float, list, tuple, np.ndarray, or torch.Tensor, but got {type(voxel_size)}" + assert voxel_size.dim() == 1, f"voxel_size must be a 1D tensor, but got {voxel_size.shape}" + assert voxel_size.size(0) == 3, f"voxel_size must have 3 elements, but got {voxel_size.size(0)}" + + if grid_size is not None: + if isinstance(grid_size, int): + grid_size = [grid_size, grid_size, grid_size] + if isinstance(grid_size, (list, tuple)): + grid_size = np.array(grid_size) + if isinstance(grid_size, np.ndarray): + grid_size = torch.tensor(grid_size, dtype=torch.int32) + assert isinstance(grid_size, torch.Tensor), f"grid_size must be an int, list, tuple, np.ndarray, or torch.Tensor, but got {type(grid_size)}" + assert grid_size.dim() == 1, f"grid_size must be a 1D tensor, but got {grid_size.shape}" + assert grid_size.size(0) == 3, f"grid_size must have 3 elements, but got {grid_size.size(0)}" + + if aabb is not None: + if isinstance(aabb, (list, tuple)): + aabb = np.array(aabb) + if isinstance(aabb, np.ndarray): + aabb = torch.tensor(aabb, dtype=torch.float32) + assert isinstance(aabb, torch.Tensor), f"aabb must be a list, tuple, np.ndarray, or torch.Tensor, but got {type(aabb)}" + assert aabb.dim() == 2, f"aabb must be a 2D tensor, but got {aabb.shape}" + assert aabb.size(0) == 2, f"aabb must have 2 rows, but got {aabb.size(0)}" + assert aabb.size(1) == 3, f"aabb must have 3 columns, but got {aabb.size(1)}" + + # Auto adjust aabb + if aabb is None: + min_xyz = vertices.min(dim=0).values + max_xyz = vertices.max(dim=0).values + + if voxel_size is not None: + padding = torch.ceil((max_xyz - min_xyz) / voxel_size) * voxel_size - (max_xyz - min_xyz) + min_xyz -= padding * 0.5 + max_xyz += padding * 0.5 + if grid_size is not None: + padding = (max_xyz - min_xyz) / (grid_size - 1) + min_xyz -= padding * 0.5 + max_xyz += padding * 0.5 + + aabb = torch.stack([min_xyz, max_xyz], dim=0).float().cuda() + + # Fill voxel size or grid size + if voxel_size is None: + voxel_size = (aabb[1] - aabb[0]) / grid_size + if grid_size is None: + grid_size = ((aabb[1] - aabb[0]) / voxel_size).round().int() + + # subdivide mesh + vertices = vertices - aabb[0].reshape(1, 3) + grid_range = torch.stack([torch.zeros_like(grid_size), grid_size], dim=0).int() + + ret = _C.mesh_to_flexible_dual_grid_cpu( + vertices, + faces, + voxel_size, + grid_range, + face_weight, + boundary_weight, + regularization_weight, + timing, + ) + + return ret + + +def flexible_dual_grid_to_mesh( + coords: torch.Tensor, + dual_vertices: torch.Tensor, + intersected_flag: torch.Tensor, + split_weight: Union[torch.Tensor, None], + aabb: Union[list, tuple, np.ndarray, torch.Tensor], + voxel_size: Union[float, list, tuple, np.ndarray, torch.Tensor] = None, + grid_size: Union[int, list, tuple, np.ndarray, torch.Tensor] = None, + train: bool = False, +): + """ + Extract mesh from sparse voxel structures using flexible dual grid. + + Args: + coords (torch.Tensor): The coordinates of the voxels. + dual_vertices (torch.Tensor): The dual vertices. + intersected_flag (torch.Tensor): The intersected flag. + split_weight (torch.Tensor): The split weight of each dual quad. If None, the algorithm + will split based on minimum angle. + aabb (list, tuple, np.ndarray, torch.Tensor): The axis-aligned bounding box of the mesh. + voxel_size (float, list, tuple, np.ndarray, torch.Tensor): The size of each voxel. + grid_size (int, list, tuple, np.ndarray, torch.Tensor): The size of the grid. + NOTE: One of voxel_size and grid_size must be provided. + train (bool): Whether to use training mode. + + Returns: + vertices (torch.Tensor): The vertices of the mesh. + faces (torch.Tensor): The faces of the mesh. + """ + # Static variables + if not hasattr(flexible_dual_grid_to_mesh, "edge_neighbor_voxel_offset"): + flexible_dual_grid_to_mesh.edge_neighbor_voxel_offset = torch.tensor([ + [[0, 0, 0], [0, 0, 1], [0, 1, 1], [0, 1, 0]], # x-axis + [[0, 0, 0], [1, 0, 0], [1, 0, 1], [0, 0, 1]], # y-axis + [[0, 0, 0], [0, 1, 0], [1, 1, 0], [1, 0, 0]], # z-axis + ], dtype=torch.int, device=coords.device).unsqueeze(0) + if not hasattr(flexible_dual_grid_to_mesh, "quad_split_1"): + flexible_dual_grid_to_mesh.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=coords.device, requires_grad=False) + if not hasattr(flexible_dual_grid_to_mesh, "quad_split_2"): + flexible_dual_grid_to_mesh.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=coords.device, requires_grad=False) + if not hasattr(flexible_dual_grid_to_mesh, "quad_split_train"): + flexible_dual_grid_to_mesh.quad_split_train = torch.tensor([0, 1, 4, 1, 2, 4, 2, 3, 4, 3, 0, 4], dtype=torch.long, device=coords.device, requires_grad=False) + + # AABB + if isinstance(aabb, (list, tuple)): + aabb = np.array(aabb) + if isinstance(aabb, np.ndarray): + aabb = torch.tensor(aabb, dtype=torch.float32, device=coords.device) + assert isinstance(aabb, torch.Tensor), f"aabb must be a list, tuple, np.ndarray, or torch.Tensor, but got {type(aabb)}" + assert aabb.dim() == 2, f"aabb must be a 2D tensor, but got {aabb.shape}" + assert aabb.size(0) == 2, f"aabb must have 2 rows, but got {aabb.size(0)}" + assert aabb.size(1) == 3, f"aabb must have 3 columns, but got {aabb.size(1)}" + + # Voxel size + if voxel_size is not None: + if isinstance(voxel_size, float): + voxel_size = [voxel_size, voxel_size, voxel_size] + if isinstance(voxel_size, (list, tuple)): + voxel_size = np.array(voxel_size) + if isinstance(voxel_size, np.ndarray): + voxel_size = torch.tensor(voxel_size, dtype=torch.float32, device=coords.device) + grid_size = ((aabb[1] - aabb[0]) / voxel_size).round().int() + else: + assert grid_size is not None, "Either voxel_size or grid_size must be provided" + if isinstance(grid_size, int): + grid_size = [grid_size, grid_size, grid_size] + if isinstance(grid_size, (list, tuple)): + grid_size = np.array(grid_size) + if isinstance(grid_size, np.ndarray): + grid_size = torch.tensor(grid_size, dtype=torch.int32, device=coords.device) + voxel_size = (aabb[1] - aabb[0]) / grid_size + assert isinstance(voxel_size, torch.Tensor), f"voxel_size must be a float, list, tuple, np.ndarray, or torch.Tensor, but got {type(voxel_size)}" + assert voxel_size.dim() == 1, f"voxel_size must be a 1D tensor, but got {voxel_size.shape}" + assert voxel_size.size(0) == 3, f"voxel_size must have 3 elements, but got {voxel_size.size(0)}" + assert isinstance(grid_size, torch.Tensor), f"grid_size must be an int, list, tuple, np.ndarray, or torch.Tensor, but got {type(grid_size)}" + assert grid_size.dim() == 1, f"grid_size must be a 1D tensor, but got {grid_size.shape}" + assert grid_size.size(0) == 3, f"grid_size must have 3 elements, but got {grid_size.size(0)}" + + # Extract mesh + N = dual_vertices.shape[0] + mesh_vertices = (coords.float() + dual_vertices) / (2 * N) - 0.5 + + # Store active voxels into hashmap + hashmap = _init_hashmap(grid_size, 2 * N, device=coords.device) + _C.hashmap_insert_3d_idx_as_val_cuda(*hashmap, torch.cat([torch.zeros_like(coords[:, :1]), coords], dim=-1), *grid_size.tolist()) + + # Find connected voxels + edge_neighbor_voxel = coords.reshape(N, 1, 1, 3) + flexible_dual_grid_to_mesh.edge_neighbor_voxel_offset # (N, 3, 4, 3) + connected_voxel = edge_neighbor_voxel[intersected_flag] # (M, 4, 3) + M = connected_voxel.shape[0] + connected_voxel_hash_key = torch.cat([ + torch.zeros((M * 4, 1), dtype=torch.int, device=coords.device), + connected_voxel.reshape(-1, 3) + ], dim=1) + connected_voxel_indices = _C.hashmap_lookup_3d_cuda(*hashmap, connected_voxel_hash_key, *grid_size.tolist()).reshape(M, 4).int() + connected_voxel_valid = (connected_voxel_indices != 0xffffffff).all(dim=1) + quad_indices = connected_voxel_indices[connected_voxel_valid].int() # (L, 4) + L = quad_indices.shape[0] + + # Construct triangles + if not train: + mesh_vertices = (coords.float() + dual_vertices) * voxel_size + aabb[0].reshape(1, 3) + if split_weight is None: + # if split 1 + atempt_triangles_0 = quad_indices[:, flexible_dual_grid_to_mesh.quad_split_1] + normals0 = torch.cross(mesh_vertices[atempt_triangles_0[:, 1]] - mesh_vertices[atempt_triangles_0[:, 0]], mesh_vertices[atempt_triangles_0[:, 2]] - mesh_vertices[atempt_triangles_0[:, 0]]) + normals1 = torch.cross(mesh_vertices[atempt_triangles_0[:, 2]] - mesh_vertices[atempt_triangles_0[:, 1]], mesh_vertices[atempt_triangles_0[:, 3]] - mesh_vertices[atempt_triangles_0[:, 1]]) + align0 = (normals0 * normals1).sum(dim=1, keepdim=True).abs() + # if split 2 + atempt_triangles_1 = quad_indices[:, flexible_dual_grid_to_mesh.quad_split_2] + normals0 = torch.cross(mesh_vertices[atempt_triangles_1[:, 1]] - mesh_vertices[atempt_triangles_1[:, 0]], mesh_vertices[atempt_triangles_1[:, 2]] - mesh_vertices[atempt_triangles_1[:, 0]]) + normals1 = torch.cross(mesh_vertices[atempt_triangles_1[:, 2]] - mesh_vertices[atempt_triangles_1[:, 1]], mesh_vertices[atempt_triangles_1[:, 3]] - mesh_vertices[atempt_triangles_1[:, 1]]) + align1 = (normals0 * normals1).sum(dim=1, keepdim=True).abs() + # select split + mesh_triangles = torch.where(align0 > align1, atempt_triangles_0, atempt_triangles_1).reshape(-1, 3) + else: + split_weight_ws = split_weight[quad_indices] + split_weight_ws_02 = split_weight_ws[:, 0] * split_weight_ws[:, 2] + split_weight_ws_13 = split_weight_ws[:, 1] * split_weight_ws[:, 3] + mesh_triangles = torch.where( + split_weight_ws_02 > split_weight_ws_13, + quad_indices[:, flexible_dual_grid_to_mesh.quad_split_1], + quad_indices[:, flexible_dual_grid_to_mesh.quad_split_2] + ).reshape(-1, 3) + else: + assert split_weight is not None, "split_weight must be provided in training mode" + mesh_vertices = (coords.float() + dual_vertices) * voxel_size + aabb[0].reshape(1, 3) + quad_vs = mesh_vertices[quad_indices] + mean_v02 = (quad_vs[:, 0] + quad_vs[:, 2]) / 2 + mean_v13 = (quad_vs[:, 1] + quad_vs[:, 3]) / 2 + split_weight_ws = split_weight[quad_indices] + split_weight_ws_02 = split_weight_ws[:, 0] * split_weight_ws[:, 2] + split_weight_ws_13 = split_weight_ws[:, 1] * split_weight_ws[:, 3] + mid_vertices = ( + split_weight_ws_02 * mean_v02 + + split_weight_ws_13 * mean_v13 + ) / (split_weight_ws_02 + split_weight_ws_13) + mesh_vertices = torch.cat([mesh_vertices, mid_vertices], dim=0) + quad_indices = torch.cat([quad_indices, torch.arange(N, N + L, device='cuda').unsqueeze(1)], dim=1) + mesh_triangles = quad_indices[:, flexible_dual_grid_to_mesh.quad_split_train].reshape(-1, 3) + + return mesh_vertices, mesh_triangles diff --git a/o-voxel/o_voxel/convert/volumetic_attr.py b/o-voxel/o_voxel/convert/volumetic_attr.py new file mode 100644 index 0000000..77c71c1 --- /dev/null +++ b/o-voxel/o_voxel/convert/volumetic_attr.py @@ -0,0 +1,583 @@ +from typing import * +import io +from PIL import Image +import torch +import numpy as np +from tqdm import tqdm +import trimesh +import trimesh.visual + +from .. import _C + +__all__ = [ + "textured_mesh_to_volumetric_attr", + "blender_dump_to_volumetric_attr" +] + + +ALPHA_MODE_ENUM = { + "OPAQUE": 0, + "MASK": 1, + "BLEND": 2, +} + + +def is_power_of_two(n: int) -> bool: + return n > 0 and (n & (n - 1)) == 0 + + +def nearest_power_of_two(n: int) -> int: + if n < 1: + raise ValueError("n must be >= 1") + if is_power_of_two(n): + return n + lower = 2 ** (n.bit_length() - 1) + upper = 2 ** n.bit_length() + if n - lower < upper - n: + return lower + else: + return upper + + +def textured_mesh_to_volumetric_attr( + mesh: Union[trimesh.Scene, trimesh.Trimesh, str], + voxel_size: Union[float, list, tuple, np.ndarray, torch.Tensor] = None, + grid_size: Union[int, list, tuple, np.ndarray, torch.Tensor] = None, + aabb: Union[list, tuple, np.ndarray, torch.Tensor] = None, + mip_level_offset: float = 0.0, + verbose: bool = False, + timing: bool = False, +) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + """ + Voxelize a mesh into a sparse voxel grid with PBR properties. + + Args: + mesh (trimesh.Scene, trimesh.Trimesh, str): The input mesh. + If a string is provided, it will be loaded as a mesh using trimesh.load(). + voxel_size (float, list, tuple, np.ndarray, torch.Tensor): The size of each voxel. + grid_size (int, list, tuple, np.ndarray, torch.Tensor): The size of the grid. + NOTE: One of voxel_size and grid_size must be provided. + aabb (list, tuple, np.ndarray, torch.Tensor): The axis-aligned bounding box of the mesh. + If not provided, it will be computed automatically. + tile_size (int): The size of the tiles used for each individual voxelization. + mip_level_offset (float): The mip level offset for texture mip level selection. + verbose (bool): Whether to print the settings. + timing (bool): Whether to print the timing information. + + Returns: + torch.Tensor: The indices of the voxels that are occupied by the mesh. + Dict[str, torch.Tensor]: A dictionary containing the following keys: + - "base_color": The base color of the occupied voxels. + - "metallic": The metallic value of the occupied voxels. + - "roughness": The roughness value of the occupied voxels. + - "emissive": The emissive value of the occupied voxels. + - "alpha": The alpha value of the occupied voxels. + - "normal": The normal of the occupied voxels. + """ + + # Load mesh + if isinstance(mesh, str): + mesh = trimesh.load(mesh) + if isinstance(mesh, trimesh.Scene): + groups = mesh.dump() + if isinstance(mesh, trimesh.Trimesh): + groups = [mesh] + scene = trimesh.Scene(groups) + + # Voxelize settings + assert voxel_size is not None or grid_size is not None, "Either voxel_size or grid_size must be provided" + + if voxel_size is not None: + if isinstance(voxel_size, float): + voxel_size = [voxel_size, voxel_size, voxel_size] + if isinstance(voxel_size, (list, tuple)): + voxel_size = np.array(voxel_size) + if isinstance(voxel_size, np.ndarray): + voxel_size = torch.tensor(voxel_size, dtype=torch.float32) + assert isinstance(voxel_size, torch.Tensor), f"voxel_size must be a float, list, tuple, np.ndarray, or torch.Tensor, but got {type(voxel_size)}" + assert voxel_size.dim() == 1, f"voxel_size must be a 1D tensor, but got {voxel_size.shape}" + assert voxel_size.size(0) == 3, f"voxel_size must have 3 elements, but got {voxel_size.size(0)}" + + if grid_size is not None: + if isinstance(grid_size, int): + grid_size = [grid_size, grid_size, grid_size] + if isinstance(grid_size, (list, tuple)): + grid_size = np.array(grid_size) + if isinstance(grid_size, np.ndarray): + grid_size = torch.tensor(grid_size, dtype=torch.int32) + assert isinstance(grid_size, torch.Tensor), f"grid_size must be an int, list, tuple, np.ndarray, or torch.Tensor, but got {type(grid_size)}" + assert grid_size.dim() == 1, f"grid_size must be a 1D tensor, but got {grid_size.shape}" + assert grid_size.size(0) == 3, f"grid_size must have 3 elements, but got {grid_size.size(0)}" + + if aabb is not None: + if isinstance(aabb, (list, tuple)): + aabb = np.array(aabb) + if isinstance(aabb, np.ndarray): + aabb = torch.tensor(aabb, dtype=torch.float32) + assert isinstance(aabb, torch.Tensor), f"aabb must be a list, tuple, np.ndarray, or torch.Tensor, but got {type(aabb)}" + assert aabb.dim() == 2, f"aabb must be a 2D tensor, but got {aabb.shape}" + assert aabb.size(0) == 2, f"aabb must have 2 rows, but got {aabb.size(0)}" + assert aabb.size(1) == 3, f"aabb must have 3 columns, but got {aabb.size(1)}" + + # Auto adjust aabb + if aabb is None: + aabb = scene.bounds + min_xyz = aabb[0] + max_xyz = aabb[1] + + if voxel_size is not None: + padding = torch.ceil((max_xyz - min_xyz) / voxel_size) * voxel_size - (max_xyz - min_xyz) + min_xyz -= padding * 0.5 + max_xyz += padding * 0.5 + if grid_size is not None: + padding = (max_xyz - min_xyz) / (grid_size - 1) + min_xyz -= padding * 0.5 + max_xyz += padding * 0.5 + + aabb = torch.stack([min_xyz, max_xyz], dim=0).float() + + # Fill voxel size or grid size + if voxel_size is None: + voxel_size = (aabb[1] - aabb[0]) / grid_size + if grid_size is None: + grid_size = ((aabb[1] - aabb[0]) / voxel_size).round().int() + + grid_range = torch.stack([torch.zeros_like(grid_size), grid_size], dim=0).int() + + # Print settings + if verbose: + print(f"Voxelize settings:") + print(f" Voxel size: {voxel_size}") + print(f" Grid size: {grid_size}") + print(f" AABB: {aabb}") + + # Load Scene + scene_buffers = { + 'triangles': [], + 'normals': [], + 'uvs': [], + 'material_ids': [], + 'base_color_factor': [], + 'base_color_texture': [], + 'metallic_factor': [], + 'metallic_texture': [], + 'roughness_factor': [], + 'roughness_texture': [], + 'emissive_factor': [], + 'emissive_texture': [], + 'alpha_mode': [], + 'alpha_cutoff': [], + 'alpha_factor': [], + 'alpha_texture': [], + 'normal_texture': [], + } + for sid, (name, g) in tqdm(enumerate(scene.geometry.items()), total=len(scene.geometry), desc="Loading Scene", disable=not verbose): + if verbose: + print(f"Geometry: {name}") + print(f" Visual: {g.visual}") + print(f" Triangles: {g.triangles.shape[0]}") + print(f" Vertices: {g.vertices.shape[0]}") + print(f" Normals: {g.vertex_normals.shape[0]}") + if g.visual.material.baseColorFactor is not None: + print(f" Base color factor: {g.visual.material.baseColorFactor}") + if g.visual.material.baseColorTexture is not None: + print(f" Base color texture: {g.visual.material.baseColorTexture.size} {g.visual.material.baseColorTexture.mode}") + if g.visual.material.metallicFactor is not None: + print(f" Metallic factor: {g.visual.material.metallicFactor}") + if g.visual.material.roughnessFactor is not None: + print(f" Roughness factor: {g.visual.material.roughnessFactor}") + if g.visual.material.metallicRoughnessTexture is not None: + print(f" Metallic roughness texture: {g.visual.material.metallicRoughnessTexture.size} {g.visual.material.metallicRoughnessTexture.mode}") + if g.visual.material.emissiveFactor is not None: + print(f" Emissive factor: {g.visual.material.emissiveFactor}") + if g.visual.material.emissiveTexture is not None: + print(f" Emissive texture: {g.visual.material.emissiveTexture.size} {g.visual.material.emissiveTexture.mode}") + if g.visual.material.alphaMode is not None: + print(f" Alpha mode: {g.visual.material.alphaMode}") + if g.visual.material.alphaCutoff is not None: + print(f" Alpha cutoff: {g.visual.material.alphaCutoff}") + if g.visual.material.normalTexture is not None: + print(f" Normal texture: {g.visual.material.normalTexture.size} {g.visual.material.normalTexture.mode}") + + assert isinstance(g, trimesh.Trimesh), f"Only trimesh.Trimesh is supported, but got {type(g)}" + assert isinstance(g.visual, trimesh.visual.TextureVisuals), f"Only trimesh.visual.TextureVisuals is supported, but got {type(g.visual)}" + assert isinstance(g.visual.material, trimesh.visual.material.PBRMaterial), f"Only trimesh.visual.material.PBRMaterial is supported, but got {type(g.visual.material)}" + triangles = torch.tensor(g.triangles, dtype=torch.float32) - aabb[0].reshape(1, 1, 3) # [N, 3, 3] + normals = torch.tensor(g.vertex_normals[g.faces], dtype=torch.float32) # [N, 3, 3] + uvs = torch.tensor(g.visual.uv[g.faces], dtype=torch.float32) if g.visual.uv is not None \ + else torch.zeros(g.triangles.shape[0], 3, 2, dtype=torch.float32) # [N, 3, 2] + baseColorFactor = torch.tensor(g.visual.material.baseColorFactor / 255, dtype=torch.float32) if g.visual.material.baseColorFactor is not None \ + else torch.ones(3, dtype=torch.float32) # [3] + baseColorTexture = torch.tensor(np.array(g.visual.material.baseColorTexture.convert('RGBA'))[..., :3], dtype=torch.uint8) if g.visual.material.baseColorTexture is not None \ + else torch.tensor([]) # [H, W, 3] + metallicFactor = g.visual.material.metallicFactor if g.visual.material.metallicFactor is not None else 1.0 + metallicTexture = torch.tensor(np.array(g.visual.material.metallicRoughnessTexture.convert('RGB'))[..., 2], dtype=torch.uint8) if g.visual.material.metallicRoughnessTexture is not None \ + else torch.tensor([]) # [H, W] + roughnessFactor = g.visual.material.roughnessFactor if g.visual.material.roughnessFactor is not None else 1.0 + roughnessTexture = torch.tensor(np.array(g.visual.material.metallicRoughnessTexture.convert('RGB'))[..., 1], dtype=torch.uint8) if g.visual.material.metallicRoughnessTexture is not None \ + else torch.tensor([]) # [H, W] + emissiveFactor = torch.tensor(g.visual.material.emissiveFactor, dtype=torch.float32) if g.visual.material.emissiveFactor is not None \ + else torch.zeros(3, dtype=torch.float32) # [3] + emissiveTexture = torch.tensor(np.array(g.visual.material.emissiveTexture.convert('RGB'))[..., :3], dtype=torch.uint8) if g.visual.material.emissiveTexture is not None \ + else torch.tensor([]) # [H, W, 3] + alphaMode = ALPHA_MODE_ENUM[g.visual.material.alphaMode] if g.visual.material.alphaMode in ALPHA_MODE_ENUM else 0 + alphaCutoff = g.visual.material.alphaCutoff if g.visual.material.alphaCutoff is not None else 0.5 + alphaFactor = g.visual.material.baseColorFactor[3] / 255 if g.visual.material.baseColorFactor is not None else 1.0 + alphaTexture = torch.tensor(np.array(g.visual.material.baseColorTexture.convert('RGBA'))[..., 3], dtype=torch.uint8) if g.visual.material.baseColorTexture is not None and alphaMode != 0 \ + else torch.tensor([]) # [H, W] + normalTexture = torch.tensor(np.array(g.visual.material.normalTexture.convert('RGB'))[..., :3], dtype=torch.uint8) if g.visual.material.normalTexture is not None \ + else torch.tensor([]) # [H, W, 3] + + scene_buffers['triangles'].append(triangles) + scene_buffers['normals'].append(normals) + scene_buffers['uvs'].append(uvs) + scene_buffers['material_ids'].append(torch.full((triangles.shape[0],), sid, dtype=torch.int32)) + scene_buffers['base_color_factor'].append(baseColorFactor) + scene_buffers['base_color_texture'].append(baseColorTexture) + scene_buffers['metallic_factor'].append(metallicFactor) + scene_buffers['metallic_texture'].append(metallicTexture) + scene_buffers['roughness_factor'].append(roughnessFactor) + scene_buffers['roughness_texture'].append(roughnessTexture) + scene_buffers['emissive_factor'].append(emissiveFactor) + scene_buffers['emissive_texture'].append(emissiveTexture) + scene_buffers['alpha_mode'].append(alphaMode) + scene_buffers['alpha_cutoff'].append(alphaCutoff) + scene_buffers['alpha_factor'].append(alphaFactor) + scene_buffers['alpha_texture'].append(alphaTexture) + scene_buffers['normal_texture'].append(normalTexture) + + scene_buffers['triangles'] = torch.cat(scene_buffers['triangles'], dim=0) # [N, 3, 3] + scene_buffers['normals'] = torch.cat(scene_buffers['normals'], dim=0) # [N, 3, 3] + scene_buffers['uvs'] = torch.cat(scene_buffers['uvs'], dim=0) # [N, 3, 2] + scene_buffers['material_ids'] = torch.cat(scene_buffers['material_ids'], dim=0) # [N] + + # Voxelize + out_tuple = _C.textured_mesh_to_volumetric_attr_cpu( + voxel_size, + grid_range, + scene_buffers["triangles"], + scene_buffers["normals"], + scene_buffers["uvs"], + scene_buffers["material_ids"], + scene_buffers["base_color_factor"], + scene_buffers["base_color_texture"], + [1] * len(scene_buffers["base_color_texture"]), + [0] * len(scene_buffers["base_color_texture"]), + scene_buffers["metallic_factor"], + scene_buffers["metallic_texture"], + [1] * len(scene_buffers["metallic_texture"]), + [0] * len(scene_buffers["metallic_texture"]), + scene_buffers["roughness_factor"], + scene_buffers["roughness_texture"], + [1] * len(scene_buffers["roughness_texture"]), + [0] * len(scene_buffers["roughness_texture"]), + scene_buffers["emissive_factor"], + scene_buffers["emissive_texture"], + [1] * len(scene_buffers["emissive_texture"]), + [0] * len(scene_buffers["emissive_texture"]), + scene_buffers["alpha_mode"], + scene_buffers["alpha_cutoff"], + scene_buffers["alpha_factor"], + scene_buffers["alpha_texture"], + [1] * len(scene_buffers["alpha_texture"]), + [0] * len(scene_buffers["alpha_texture"]), + scene_buffers["normal_texture"], + [1] * len(scene_buffers["normal_texture"]), + [0] * len(scene_buffers["normal_texture"]), + mip_level_offset, + timing, + ) + + # Post process + coord = out_tuple[0] + attr = { + "base_color": torch.clamp(out_tuple[1] * 255, 0, 255).byte().reshape(-1, 3), + "metallic": torch.clamp(out_tuple[2] * 255, 0, 255).byte().reshape(-1, 1), + "roughness": torch.clamp(out_tuple[3] * 255, 0, 255).byte().reshape(-1, 1), + "emissive": torch.clamp(out_tuple[4] * 255, 0, 255).byte().reshape(-1, 3), + "alpha": torch.clamp(out_tuple[5] * 255, 0, 255).byte().reshape(-1, 1), + "normal": torch.clamp((out_tuple[6] * 0.5 + 0.5) * 255, 0, 255).byte().reshape(-1, 3), + } + + return coord, attr + + +def blender_dump_to_volumetric_attr( + dump: Dict[str, Any], + voxel_size: Union[float, list, tuple, np.ndarray, torch.Tensor] = None, + grid_size: Union[int, list, tuple, np.ndarray, torch.Tensor] = None, + aabb: Union[list, tuple, np.ndarray, torch.Tensor] = None, + mip_level_offset: float = 0.0, + verbose: bool = False, + timing: bool = False, +) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + """ + Voxelize a mesh into a sparse voxel grid with PBR properties. + + Args: + dump (Dict[str, Any]): Dumped data from a blender scene. + voxel_size (float, list, tuple, np.ndarray, torch.Tensor): The size of each voxel. + grid_size (int, list, tuple, np.ndarray, torch.Tensor): The size of the grid. + NOTE: One of voxel_size and grid_size must be provided. + aabb (list, tuple, np.ndarray, torch.Tensor): The axis-aligned bounding box of the mesh. + If not provided, it will be computed automatically. + mip_level_offset (float): The mip level offset for texture mip level selection. + verbose (bool): Whether to print the settings. + timing (bool): Whether to print the timing information. + + Returns: + torch.Tensor: The indices of the voxels that are occupied by the mesh. + Dict[str, torch.Tensor]: A dictionary containing the following keys: + - "base_color": The base color of the occupied voxels. + - "metallic": The metallic value of the occupied voxels. + - "roughness": The roughness value of the occupied voxels. + - "emissive": The emissive value of the occupied voxels. + - "alpha": The alpha value of the occupied voxels. + - "normal": The normal of the occupied voxels. + """ + # Voxelize settings + assert voxel_size is not None or grid_size is not None, "Either voxel_size or grid_size must be provided" + + if voxel_size is not None: + if isinstance(voxel_size, float): + voxel_size = [voxel_size, voxel_size, voxel_size] + if isinstance(voxel_size, (list, tuple)): + voxel_size = np.array(voxel_size) + if isinstance(voxel_size, np.ndarray): + voxel_size = torch.tensor(voxel_size, dtype=torch.float32) + assert isinstance(voxel_size, torch.Tensor), f"voxel_size must be a float, list, tuple, np.ndarray, or torch.Tensor, but got {type(voxel_size)}" + assert voxel_size.dim() == 1, f"voxel_size must be a 1D tensor, but got {voxel_size.shape}" + assert voxel_size.size(0) == 3, f"voxel_size must have 3 elements, but got {voxel_size.size(0)}" + + if grid_size is not None: + if isinstance(grid_size, int): + grid_size = [grid_size, grid_size, grid_size] + if isinstance(grid_size, (list, tuple)): + grid_size = np.array(grid_size) + if isinstance(grid_size, np.ndarray): + grid_size = torch.tensor(grid_size, dtype=torch.int32) + assert isinstance(grid_size, torch.Tensor), f"grid_size must be an int, list, tuple, np.ndarray, or torch.Tensor, but got {type(grid_size)}" + assert grid_size.dim() == 1, f"grid_size must be a 1D tensor, but got {grid_size.shape}" + assert grid_size.size(0) == 3, f"grid_size must have 3 elements, but got {grid_size.size(0)}" + + if aabb is not None: + if isinstance(aabb, (list, tuple)): + aabb = np.array(aabb) + if isinstance(aabb, np.ndarray): + aabb = torch.tensor(aabb, dtype=torch.float32) + assert isinstance(aabb, torch.Tensor), f"aabb must be a list, tuple, np.ndarray, or torch.Tensor, but got {type(aabb)}" + assert aabb.dim() == 2, f"aabb must be a 2D tensor, but got {aabb.shape}" + assert aabb.size(0) == 2, f"aabb must have 2 rows, but got {aabb.size(0)}" + assert aabb.size(1) == 3, f"aabb must have 3 columns, but got {aabb.size(1)}" + + # Auto adjust aabb + if aabb is None: + min_xyz = np.min([ + object['vertices'].min(axis=0) + for object in dump['objects'] + ], axis=0) + max_xyz = np.max([ + object['vertices'].max(axis=0) + for object in dump['objects'] + ], axis=0) + + if voxel_size is not None: + padding = torch.ceil((max_xyz - min_xyz) / voxel_size) * voxel_size - (max_xyz - min_xyz) + min_xyz -= padding * 0.5 + max_xyz += padding * 0.5 + if grid_size is not None: + padding = (max_xyz - min_xyz) / (grid_size - 1) + min_xyz -= padding * 0.5 + max_xyz += padding * 0.5 + + aabb = torch.stack([min_xyz, max_xyz], dim=0).float() + + # Fill voxel size or grid size + if voxel_size is None: + voxel_size = (aabb[1] - aabb[0]) / grid_size + if grid_size is None: + grid_size = ((aabb[1] - aabb[0]) / voxel_size).round().int() + + grid_range = torch.stack([torch.zeros_like(grid_size), grid_size], dim=0).int() + + # Print settings + if verbose: + print(f"Voxelize settings:") + print(f" Voxel size: {voxel_size}") + print(f" Grid size: {grid_size}") + print(f" AABB: {aabb}") + + # Load Scene + scene_buffers = { + 'triangles': [], + 'normals': [], + 'uvs': [], + 'material_ids': [], + 'base_color_factor': [], + 'base_color_texture': [], + 'base_color_texture_filter': [], + 'base_color_texture_wrap': [], + 'metallic_factor': [], + 'metallic_texture': [], + 'metallic_texture_filter': [], + 'metallic_texture_wrap': [], + 'roughness_factor': [], + 'roughness_texture': [], + 'roughness_texture_filter': [], + 'roughness_texture_wrap': [], + 'alpha_mode': [], + 'alpha_cutoff': [], + 'alpha_factor': [], + 'alpha_texture': [], + 'alpha_texture_filter': [], + 'alpha_texture_wrap': [], + } + + def load_texture(pack): + png_bytes = pack['image'] + image = Image.open(io.BytesIO(png_bytes)) + if image.width != image.height or not is_power_of_two(image.width): + size = nearest_power_of_two(max(image.width, image.height)) + image = image.resize((size, size), Image.LANCZOS) + texture = torch.tensor(np.array(image), dtype=torch.uint8) + filter_mode = { + 'Linear': 1, + 'Closest': 0, + 'Cubic': 1, + 'Smart': 1, + }[pack['interpolation']] + wrap_mode = { + 'REPEAT': 0, + 'EXTEND': 1, + 'CLIP': 1, + 'MIRROR': 2, + }[pack['extension']] + return texture, filter_mode, wrap_mode + + for material in dump['materials']: + baseColorFactor = torch.tensor(material['baseColorFactor'][:3], dtype=torch.float32) + if material['baseColorTexture'] is not None: + baseColorTexture, baseColorTextureFilter, baseColorTextureWrap = \ + load_texture(material['baseColorTexture']) + assert baseColorTexture.shape[2] == 3, f"Base color texture must have 3 channels, but got {baseColorTexture.shape[2]}" + else: + baseColorTexture = torch.tensor([]) + baseColorTextureFilter = 0 + baseColorTextureWrap = 0 + scene_buffers['base_color_factor'].append(baseColorFactor) + scene_buffers['base_color_texture'].append(baseColorTexture) + scene_buffers['base_color_texture_filter'].append(baseColorTextureFilter) + scene_buffers['base_color_texture_wrap'].append(baseColorTextureWrap) + + metallicFactor = material['metallicFactor'] + if material['metallicTexture'] is not None: + metallicTexture, metallicTextureFilter, metallicTextureWrap = \ + load_texture(material['metallicTexture']) + assert metallicTexture.dim() == 2, f"Metallic roughness texture must have 2 dimensions, but got {metallicTexture.dim()}" + else: + metallicTexture = torch.tensor([]) + metallicTextureFilter = 0 + metallicTextureWrap = 0 + scene_buffers['metallic_factor'].append(metallicFactor) + scene_buffers['metallic_texture'].append(metallicTexture) + scene_buffers['metallic_texture_filter'].append(metallicTextureFilter) + scene_buffers['metallic_texture_wrap'].append(metallicTextureWrap) + + roughnessFactor = material['roughnessFactor'] + if material['roughnessTexture'] is not None: + roughnessTexture, roughnessTextureFilter, roughnessTextureWrap = \ + load_texture(material['roughnessTexture']) + assert roughnessTexture.dim() == 2, f"Metallic roughness texture must have 2 dimensions, but got {roughnessTexture.dim()}" + else: + roughnessTexture = torch.tensor([]) + roughnessTextureFilter = 0 + roughnessTextureWrap = 0 + scene_buffers['roughness_factor'].append(roughnessFactor) + scene_buffers['roughness_texture'].append(roughnessTexture) + scene_buffers['roughness_texture_filter'].append(roughnessTextureFilter) + scene_buffers['roughness_texture_wrap'].append(roughnessTextureWrap) + + alphaMode = ALPHA_MODE_ENUM[material['alphaMode']] + alphaCutoff = material['alphaCutoff'] + alphaFactor = material['alphaFactor'] + if material['alphaTexture'] is not None: + alphaTexture, alphaTextureFilter, alphaTextureWrap = \ + load_texture(material['alphaTexture']) + assert alphaTexture.dim() == 2, f"Alpha texture must have 2 dimensions, but got {alphaTexture.dim()}" + else: + alphaTexture = torch.tensor([]) + alphaTextureFilter = 0 + alphaTextureWrap = 0 + scene_buffers['alpha_mode'].append(alphaMode) + scene_buffers['alpha_cutoff'].append(alphaCutoff) + scene_buffers['alpha_factor'].append(alphaFactor) + scene_buffers['alpha_texture'].append(alphaTexture) + scene_buffers['alpha_texture_filter'].append(alphaTextureFilter) + scene_buffers['alpha_texture_wrap'].append(alphaTextureWrap) + + for object in dump['objects']: + triangles = torch.tensor(object['vertices'][object['faces']], dtype=torch.float32).reshape(-1, 3, 3) - aabb[0].reshape(1, 1, 3) + normails = torch.tensor(object['normals'], dtype=torch.float32) + uvs = torch.tensor(object['uvs'], dtype=torch.float32) if object['uvs'] is not None else torch.zeros(triangles.shape[0], 3, 2, dtype=torch.float32) + material_id = torch.tensor(object['mat_ids'], dtype=torch.int32) + scene_buffers['triangles'].append(triangles) + scene_buffers['normals'].append(normails) + scene_buffers['uvs'].append(uvs) + scene_buffers['material_ids'].append(material_id) + + scene_buffers['triangles'] = torch.cat(scene_buffers['triangles'], dim=0) # [N, 3, 3] + scene_buffers['normals'] = torch.cat(scene_buffers['normals'], dim=0) # [N, 3, 3] + scene_buffers['uvs'] = torch.cat(scene_buffers['uvs'], dim=0) # [N, 3, 2] + scene_buffers['material_ids'] = torch.cat(scene_buffers['material_ids'], dim=0) # [N] + + scene_buffers['uvs'][:, :, 1] = 1 - scene_buffers['uvs'][:, :, 1] # Flip v coordinate + + # Voxelize + out_tuple = _C.textured_mesh_to_volumetric_attr_cpu( + voxel_size, + grid_range, + scene_buffers["triangles"], + scene_buffers["normals"], + scene_buffers["uvs"], + scene_buffers["material_ids"], + scene_buffers["base_color_factor"], + scene_buffers["base_color_texture"], + scene_buffers["base_color_texture_filter"], + scene_buffers["base_color_texture_wrap"], + scene_buffers["metallic_factor"], + scene_buffers["metallic_texture"], + scene_buffers["metallic_texture_filter"], + scene_buffers["metallic_texture_wrap"], + scene_buffers["roughness_factor"], + scene_buffers["roughness_texture"], + scene_buffers["roughness_texture_filter"], + scene_buffers["roughness_texture_wrap"], + [torch.zeros(3, dtype=torch.float32) for _ in range(len(scene_buffers["base_color_texture"]))], + [torch.tensor([]) for _ in range(len(scene_buffers["base_color_texture"]))], + [0] * len(scene_buffers["base_color_texture"]), + [0] * len(scene_buffers["base_color_texture"]), + scene_buffers["alpha_mode"], + scene_buffers["alpha_cutoff"], + scene_buffers["alpha_factor"], + scene_buffers["alpha_texture"], + scene_buffers["alpha_texture_filter"], + scene_buffers["alpha_texture_wrap"], + [torch.tensor([]) for _ in range(len(scene_buffers["base_color_texture"]))], + [0] * len(scene_buffers["base_color_texture"]), + [0] * len(scene_buffers["base_color_texture"]), + mip_level_offset, + timing, + ) + + # Post process + coord = out_tuple[0] + attr = { + "base_color": torch.clamp(out_tuple[1] * 255, 0, 255).byte().reshape(-1, 3), + "metallic": torch.clamp(out_tuple[2] * 255, 0, 255).byte().reshape(-1, 1), + "roughness": torch.clamp(out_tuple[3] * 255, 0, 255).byte().reshape(-1, 1), + "emissive": torch.clamp(out_tuple[4] * 255, 0, 255).byte().reshape(-1, 3), + "alpha": torch.clamp(out_tuple[5] * 255, 0, 255).byte().reshape(-1, 1), + "normal": torch.clamp((out_tuple[6] * 0.5 + 0.5) * 255, 0, 255).byte().reshape(-1, 3), + } + + return coord, attr \ No newline at end of file diff --git a/o-voxel/o_voxel/io/__init__.py b/o-voxel/o_voxel/io/__init__.py new file mode 100644 index 0000000..f29541b --- /dev/null +++ b/o-voxel/o_voxel/io/__init__.py @@ -0,0 +1,45 @@ +from typing import Dict, Union +import torch +from .ply import * +from .npz import * +from .vxz import * + + +def read(file_path: str) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + """ + Read a file containing voxels. + + Args: + file_path: Path to the file. + + Returns: + torch.Tensor: the coordinates of the voxels. + Dict[str, torch.Tensor]: the attributes of the voxels. + """ + if file_path.endswith('.npz'): + return read_npz(file_path) + elif file_path.endswith('.ply'): + return read_ply(file_path) + elif file_path.endswith('.vxz'): + return read_vxz(file_path) + else: + raise ValueError(f"Unsupported file type {file_path}") + + +def write(file_path: str, coord: torch.Tensor, attr: Dict[str, torch.Tensor], **kwargs): + """ + Write a file containing voxels. + + Args: + file_path: Path to the file. + coord: the coordinates of the voxels. + attr: the attributes of the voxels. + """ + if file_path.endswith('.npz'): + write_npz(file_path, coord, attr, **kwargs) + elif file_path.endswith('.ply'): + write_ply(file_path, coord, attr, **kwargs) + elif file_path.endswith('.vxz'): + write_vxz(file_path, coord, attr, **kwargs) + else: + raise ValueError(f"Unsupported file type {file_path}") diff --git a/o-voxel/o_voxel/io/npz.py b/o-voxel/o_voxel/io/npz.py new file mode 100644 index 0000000..9009045 --- /dev/null +++ b/o-voxel/o_voxel/io/npz.py @@ -0,0 +1,43 @@ +from typing import * +import torch +import numpy as np + + +__all__ = [ + "read_npz", + "write_npz", +] + + +def read_npz(file) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + """ + Read a NPZ file containing voxels. + + Args: + file_path: Path or file object from which to read the NPZ file. + + Returns: + torch.Tensor: the coordinates of the voxels. + Dict[str, torch.Tensor]: the attributes of the voxels. + """ + data = np.load(file) + coord = torch.from_numpy(data['coord']).int() + attr = {k: torch.from_numpy(v) for k, v in data.items() if k!= 'coord'} + return coord, attr + + +def write_npz(file, coord: torch.Tensor, attr: Dict[str, torch.Tensor], compress=True): + """ + Write a NPZ file containing voxels. + + Args: + file_path: Path or file object to which to write the NPZ file. + coord: the coordinates of the voxels. + attr: the attributes of the voxels. + """ + data = {'coord': coord.cpu().numpy().astype(np.uint16)} + data.update({k: v.cpu().numpy() for k, v in attr.items()}) + if compress: + np.savez_compressed(file, **data) + else: + np.savez(file, **data) diff --git a/o-voxel/o_voxel/io/ply.py b/o-voxel/o_voxel/io/ply.py new file mode 100644 index 0000000..c8e7fb8 --- /dev/null +++ b/o-voxel/o_voxel/io/ply.py @@ -0,0 +1,72 @@ +from typing import * +import io +import torch +import numpy as np +import plyfile + + +__all__ = [ + "read_ply", + "write_ply", +] + + +DTYPE_MAP = { + torch.uint8: 'u1', + torch.uint16: 'u2', + torch.uint32: 'u4', + torch.int8: 'i1', + torch.int16: 'i2', + torch.int32: 'i4', + torch.float32: 'f4', + torch.float64: 'f8' +} + + +def read_ply(file) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + """ + Read a PLY file containing voxels. + + Args: + file: Path or file-like object of the PLY file. + + Returns: + torch.Tensor: the coordinates of the voxels. + Dict[str, torch.Tensor]: the attributes of the voxels. + """ + plydata = plyfile.PlyData.read(file) + xyz = np.stack([plydata.elements[0][k] for k in ['x', 'y', 'z']], axis=1) + coord = np.round(xyz).astype(int) + coord = torch.from_numpy(coord) + + attr_keys = [k for k in plydata.elements[0].data.dtype.names if k not in ['x', 'y', 'z']] + attr_names = ['_'.join(k.split('_')[:-1]) for k in attr_keys] + attr_chs = [sum([1 for k in attr_keys if k.startswith(f'{name}_')]) for name in attr_names] + + attr = {} + for i, name in enumerate(attr_names): + attr[name] = np.stack([plydata.elements[0][f'{name}_{j}'] for j in range(attr_chs[i])], axis=1) + attr = {k: torch.from_numpy(v) for k, v in attr.items()} + + return coord, attr + + +def write_ply(file, coord: torch.Tensor, attr: Dict[str, torch.Tensor]): + """ + Write a PLY file containing voxels. + + Args: + file: Path or file-like object of the PLY file. + coord: the coordinates of the voxels. + attr: the attributes of the voxels. + """ + dtypes = [('x', 'f4'), ('y', 'f4'), ('z', 'f4')] + for k, v in attr.items(): + for j in range(v.shape[-1]): + assert v.dtype in DTYPE_MAP, f"Unsupported data type {v.dtype} for attribute {k}" + dtypes.append((f'{k}_{j}', DTYPE_MAP[v.dtype])) + data = np.empty(len(coord), dtype=dtypes) + all_chs = np.concatenate([coord.cpu().numpy().astype(np.float32)] + [v.cpu().numpy() for v in attr.values()], axis=1) + data[:] = list(map(tuple, all_chs)) + plyfile.PlyData([plyfile.PlyElement.describe(data, 'vertex')]).write(file) + \ No newline at end of file diff --git a/o-voxel/o_voxel/io/vxz.py b/o-voxel/o_voxel/io/vxz.py new file mode 100644 index 0000000..bc34cda --- /dev/null +++ b/o-voxel/o_voxel/io/vxz.py @@ -0,0 +1,365 @@ +from typing import * +import os +import json +import struct +import torch +import numpy as np +import zlib +import lzma +import zstandard +from concurrent.futures import ThreadPoolExecutor +from ..serialize import encode_seq, decode_seq +from .. import _C + + +__all__ = [ + "read_vxz", + "read_vxz_info", + "write_vxz", +] + + +""" +VXZ format + +Header: +- file type (3 bytes) - 'VXZ' +- version (1 byte) - 0 +- binary start offset (4 bytes) +- structure (json) - +{ + "num_voxel": int, + "chunk_size": int, + "filter": str, + "compression": str, + "compression_level": int, + "raw_size": int, + "compressed_size": int, + "compress_ratio": float, + "attr_interleave": str, + "attr": [ + {"name": str, "chs": int}, + ... + ] + "chunks": [ + { + "ptr": [offset, length], # offset from global binary start + "svo": [offset, length], # offset from this chunk start + "attr": [offset, length], # offset from this chunk start + }, + ... + ] +} +- binary data +""" + +DEFAULT_COMPRESION_LEVEL = { + 'none': 0, + 'deflate': 9, + 'lzma': 9, + 'zstd': 22, +} + + +def _compress(data: bytes, algo: Literal['none', 'deflate', 'lzma', 'zstd'], level: int) -> bytes: + if algo == 'none': + return data + if level is None: + level = DEFAULT_COMPRESION_LEVEL[algo] + if algo == 'deflate': + compresser = zlib.compressobj(level, wbits=-15) + return compresser.compress(data) + compresser.flush() + if algo == 'lzma': + compresser = lzma.LZMACompressor(format=lzma.FORMAT_RAW, filters=[{'id': lzma.FILTER_LZMA2, 'preset': level}]) + return compresser.compress(data) + compresser.flush() + if algo == 'zstd': + compresser = zstandard.ZstdCompressor(level=level, write_checksum=False, write_content_size=True, threads=-1) + return compresser.compress(data) + raise ValueError(f"Invalid compression algorithm: {algo}") + + +def _decompress(data: bytes, algo: Literal['none', 'deflate', 'lzma', 'zstd'], level: int) -> bytes: + if algo == 'none': + return data + if level is None: + level = DEFAULT_COMPRESION_LEVEL[algo] + if algo == 'deflate': + decompresser = zlib.decompressobj(wbits=-15) + return decompresser.decompress(data) + decompresser.flush() + if algo == 'lzma': + decompresser = lzma.LZMADecompressor(format=lzma.FORMAT_RAW, filters=[{'id': lzma.FILTER_LZMA2, 'preset': level}]) + return decompresser.decompress(data) + if algo == 'zstd': + decompresser = zstandard.ZstdDecompressor(format=zstandard.FORMAT_ZSTD1) + return decompresser.decompress(data) + raise ValueError(f"Invalid compression algorithm: {algo}") + + +def read_vxz_info(file) -> Dict: + """ + Read the header of a VXZ file without decompressing the binary data. + + Args: + file_path: Path or file-like object to the VXZ file. + + Returns: + Dict: the header of the VXZ file. + """ + if isinstance(file, str): + with open(file, 'rb') as f: + file_data = f.read() + else: + file_data = file.read() + + assert file_data[:3] == b'VXZ', "Invalid file type" + version = file_data[3] + assert version == 0, "Invalid file version" + + bin_start = struct.unpack('>I', file_data[4:8])[0] + structure_data = json.loads(file_data[8:bin_start].decode()) + return structure_data + + +def read_vxz(file, num_threads: int = -1) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + """ + Read a VXZ file containing voxels. + + Args: + file_path: Path or file-like object to the VXZ file. + num_threads: the number of threads to use for reading the file. + + Returns: + torch.Tensor: the coordinates of the voxels. + Dict[str, torch.Tensor]: the attributes of the voxels. + """ + if isinstance(file, str): + with open(file, 'rb') as f: + file_data = f.read() + else: + file_data = file.read() + + num_threads = num_threads if num_threads > 0 else os.cpu_count() + + # Parse header + assert file_data[:3] == b'VXZ', "Invalid file type" + version = file_data[3] + assert version == 0, "Invalid file version" + + bin_start = struct.unpack('>I', file_data[4:8])[0] + structure_data = json.loads(file_data[8:bin_start].decode()) + bin_data = file_data[bin_start:] + + # Decode chunks + chunk_size = structure_data['chunk_size'] + chunk_depth = np.log2(chunk_size) + assert chunk_depth.is_integer(), f"Chunk size must be a power of 2, got {chunk_size}" + chunk_depth = int(chunk_depth) + + def worker(chunk_info): + decompressed = {} + chunk_data = bin_data[chunk_info['ptr'][0]:chunk_info['ptr'][0]+chunk_info['ptr'][1]] + for k, v in chunk_info.items(): + if k in ['ptr', 'idx']: + continue + decompressed[k] = np.frombuffer(_decompress(chunk_data[v[0]:v[0]+v[1]], structure_data['compression'], structure_data['compression_level']), dtype=np.uint8) + svo = torch.tensor(np.frombuffer(decompressed['svo'], dtype=np.uint8)) + morton_code = _C.decode_sparse_voxel_octree_cpu(svo, chunk_depth) + coord = decode_seq(morton_code.int()).cpu() + + # deinterleave attributes + if structure_data['attr_interleave'] == 'none': + all_attr = [] + for k, chs in structure_data['attr']: + for i in range(chs): + all_attr.append(torch.tensor(decompressed[f'{k}_{i}'])) + all_attr = torch.stack(all_attr, dim=1) + elif structure_data['attr_interleave'] == 'as_is': + all_attr = [] + for k, chs in structure_data['attr']: + all_attr.append(torch.tensor(decompressed[k].reshape(-1, chs))) + all_attr = torch.cat(all_attr, dim=1) + elif structure_data['attr_interleave'] == 'all': + all_chs = sum(chs for k, chs in structure_data['attr']) + all_attr = decompressed['attr'].reshape(-1, all_chs) + + # unfilter + if structure_data['filter'] == 'none': + pass + elif structure_data['filter'] == 'parent': + all_attr = _C.decode_sparse_voxel_octree_attr_parent_cpu(svo, chunk_depth, all_attr) + elif structure_data['filter'] == 'neighbor': + all_attr = _C.decode_sparse_voxel_octree_attr_neighbor_cpu(coord, chunk_size, all_attr) + + # final + attr = {} + ch = 0 + for k, chs in structure_data['attr']: + attr[k] = all_attr[:, ch:ch+chs] + ch += chs + return { + 'coord': coord, + 'attr': attr, + } + + if num_threads == 1: + chunks = [worker(info) for info in structure_data['chunks']] + else: + with ThreadPoolExecutor(max_workers=num_threads) as executor: + chunks = list(executor.map(worker, structure_data['chunks'])) + + # Combine chunks + coord = [] + attr = {k: [] for k, _ in structure_data['attr']} + for info, chunk in zip(structure_data['chunks'], chunks): + coord.append(chunk['coord'] + torch.tensor([[info['idx'][0] * chunk_size, info['idx'][1] * chunk_size, info['idx'][2] * chunk_size]]).int()) + for k, v in chunk['attr'].items(): + attr[k].append(v) + coord = torch.cat(coord, dim=0) + for k, v in attr.items(): + attr[k] = torch.cat(v, dim=0) + return coord, attr + + +def write_vxz( + file, + coord: torch.Tensor, + attr: Dict[str, torch.Tensor], + chunk_size: int = 256, + filter: Literal['none', 'parent', 'neighbor'] = 'none', + compression: Literal['none', 'deflate', 'lzma', 'zstd'] = 'lzma', + compression_level: Optional[int] = None, + attr_interleave: Literal['none', 'as_is', 'all'] = 'as_is', + num_threads: int = -1, +): + """ + Write a VXZ file containing voxels. + + Args: + file: Path or file-like object to the VXZ file. + coord: the coordinates of the voxels. + attr: the attributes of the voxels. + chunk_size: the size of each chunk. + filter: the filter to apply to the voxels. + compression: the compression algorithm to use. + compression_level: the level of compression. + attr_interleave: how to interleave the attributes. + num_threads: the number of threads to use for compression. + """ + # Check + for k, v in attr.items(): + assert coord.shape[0] == v.shape[0], f"Number of coordinates and attributes do not match for key {k}" + assert v.dtype == torch.uint8, f"Attributes must be uint8, got {v.dtype} for key {k}" + assert attr_interleave in ['none', 'as_is', 'all'], f"Invalid attr_interleave value: {attr_interleave}" + + compression_level = compression_level or DEFAULT_COMPRESION_LEVEL[compression] + num_threads = num_threads if num_threads > 0 else os.cpu_count() + + file_info = { + 'num_voxel': coord.shape[0], + 'chunk_size': chunk_size, + 'filter': filter, + 'compression': compression, + 'compression_level': compression_level, + 'raw_size': sum([coord.numel() * 4] + [v.numel() for v in attr.values()]), + 'compressed_size': 0, + 'compress_ratio': 0.0, + 'attr_interleave': attr_interleave, + 'attr': [[k, v.shape[1]] for k, v in attr.items()], + 'chunks': [], + } + bin_data = b'' + + # Split into chunks + chunk_depth = np.log2(chunk_size) + assert chunk_depth.is_integer(), f"Chunk size must be a power of 2, got {chunk_size}" + chunk_depth = int(chunk_depth) + + chunk_coord = coord // chunk_size + coord = coord % chunk_size + unique_chunk_coord, inverse = torch.unique(chunk_coord, dim=0, return_inverse=True) + + chunks = [] + for idx, chunk_xyz in enumerate(unique_chunk_coord.tolist()): + chunk_mask = (inverse == idx) + chunks.append({ + 'idx': chunk_xyz, + 'coord': coord[chunk_mask], + 'attr': {k: v[chunk_mask] for k, v in attr.items()}, + }) + + # Compress each chunk + with ThreadPoolExecutor(max_workers=num_threads) as executor: + def worker(chunk): + ## compress to binary + coord = chunk['coord'] + morton_code = encode_seq(coord) + sorted_idx = morton_code.argsort().cpu() + coord = coord.cpu()[sorted_idx] + morton_code = morton_code.cpu()[sorted_idx] + attr = torch.cat([v.cpu()[sorted_idx] for v in chunk['attr'].values()], dim=1) + svo = _C.encode_sparse_voxel_octree_cpu(morton_code, chunk_depth) + svo_bytes = _compress(svo.numpy().tobytes(), compression, compression_level) + + # filter + if filter == 'none': + attr = attr.numpy() + elif filter == 'parent': + attr = _C.encode_sparse_voxel_octree_attr_parent_cpu(svo, chunk_depth, attr).numpy() + elif filter == 'neighbor': + attr = _C.encode_sparse_voxel_octree_attr_neighbor_cpu(coord, chunk_size, attr).numpy() + + # interleave attributes + attr_bytes = {} + if attr_interleave == 'none': + ch = 0 + for k, chs in file_info['attr']: + for i in range(chs): + attr_bytes[f'{k}_{i}'] = _compress(attr[:, ch].tobytes(), compression, compression_level) + ch += 1 + elif attr_interleave == 'as_is': + ch = 0 + for k, chs in file_info['attr']: + attr_bytes[k] = _compress(attr[:, ch:ch+chs].tobytes(), compression, compression_level) + ch += chs + elif attr_interleave == 'all': + attr_bytes['attr'] = _compress(attr.tobytes(), compression, compression_level) + + ## buffer for each chunk + chunk_info = {'idx': chunk['idx']} + bin_data = b'' + + ### svo + chunk_info['svo'] = [len(bin_data), len(svo_bytes)] + bin_data += svo_bytes + + ### attr + for k, v in attr_bytes.items(): + chunk_info[k] = [len(bin_data), len(v)] + bin_data += v + + return chunk_info, bin_data + + chunks = list(executor.map(worker, chunks)) + + for chunk_info, chunk_data in chunks: + chunk_info['ptr'] = [len(bin_data), len(chunk_data)] + bin_data += chunk_data + file_info['chunks'].append(chunk_info) + + file_info['compressed_size'] = len(bin_data) + file_info['compress_ratio'] = file_info['raw_size'] / file_info['compressed_size'] + + # File parts + structure_data = json.dumps(file_info).encode() + header = b'VXZ\x00' + struct.pack('>I', len(structure_data) + 8) + + # Write to file + if isinstance(file, str): + with open(file, 'wb') as f: + f.write(header) + f.write(structure_data) + f.write(bin_data) + else: + file.write(header) + file.write(structure_data) + file.write(bin_data) diff --git a/o-voxel/o_voxel/postprocess.py b/o-voxel/o_voxel/postprocess.py new file mode 100644 index 0000000..1ce8227 --- /dev/null +++ b/o-voxel/o_voxel/postprocess.py @@ -0,0 +1,331 @@ +from typing import * +from tqdm import tqdm +import numpy as np +import torch +import cv2 +from PIL import Image +import trimesh +import trimesh.visual +from flex_gemm.ops.grid_sample import grid_sample_3d +import nvdiffrast.torch as dr +import cumesh + + +def to_glb( + vertices: torch.Tensor, + faces: torch.Tensor, + attr_volume: torch.Tensor, + coords: torch.Tensor, + attr_layout: Dict[str, slice], + aabb: Union[list, tuple, np.ndarray, torch.Tensor], + voxel_size: Union[float, list, tuple, np.ndarray, torch.Tensor] = None, + grid_size: Union[int, list, tuple, np.ndarray, torch.Tensor] = None, + decimation_target: int = 1000000, + texture_size: int = 2048, + remesh: bool = False, + remesh_band: float = 1, + remesh_project: float = 0.9, + mesh_cluster_threshold_cone_half_angle_rad=np.radians(90.0), + mesh_cluster_refine_iterations=0, + mesh_cluster_global_iterations=1, + mesh_cluster_smooth_strength=1, + verbose: bool = False, + use_tqdm: bool = False, +): + """ + Convert an extracted mesh to a GLB file. + Performs cleaning, optional remeshing, UV unwrapping, and texture baking from a volume. + + Args: + vertices: (N, 3) tensor of vertex positions + faces: (M, 3) tensor of vertex indices + attr_volume: (L, C) features of a sprase tensor for attribute interpolation + coords: (L, 3) tensor of coordinates for each voxel + attr_layout: dictionary of slice objects for each attribute + aabb: (2, 3) tensor of minimum and maximum coordinates of the volume + voxel_size: (3,) tensor of size of each voxel + grid_size: (3,) tensor of number of voxels in each dimension + decimation_target: target number of vertices for mesh simplification + texture_size: size of the texture for baking + remesh: whether to perform remeshing + remesh_band: size of the remeshing band + remesh_project: projection factor for remeshing + mesh_cluster_threshold_cone_half_angle_rad: threshold for cone-based clustering in uv unwrapping + mesh_cluster_refine_iterations: number of iterations for refining clusters in uv unwrapping + mesh_cluster_global_iterations: number of global iterations for clustering in uv unwrapping + mesh_cluster_smooth_strength: strength of smoothing for clustering in uv unwrapping + verbose: whether to print verbose messages + use_tqdm: whether to use tqdm to display progress bar + """ + # --- Input Normalization (AABB, Voxel Size, Grid Size) --- + if isinstance(aabb, (list, tuple)): + aabb = np.array(aabb) + if isinstance(aabb, np.ndarray): + aabb = torch.tensor(aabb, dtype=torch.float32, device=coords.device) + assert isinstance(aabb, torch.Tensor), f"aabb must be a list, tuple, np.ndarray, or torch.Tensor, but got {type(aabb)}" + assert aabb.dim() == 2, f"aabb must be a 2D tensor, but got {aabb.shape}" + assert aabb.size(0) == 2, f"aabb must have 2 rows, but got {aabb.size(0)}" + assert aabb.size(1) == 3, f"aabb must have 3 columns, but got {aabb.size(1)}" + + # Calculate grid dimensions based on AABB and voxel size + if voxel_size is not None: + if isinstance(voxel_size, float): + voxel_size = [voxel_size, voxel_size, voxel_size] + if isinstance(voxel_size, (list, tuple)): + voxel_size = np.array(voxel_size) + if isinstance(voxel_size, np.ndarray): + voxel_size = torch.tensor(voxel_size, dtype=torch.float32, device=coords.device) + grid_size = ((aabb[1] - aabb[0]) / voxel_size).round().int() + else: + assert grid_size is not None, "Either voxel_size or grid_size must be provided" + if isinstance(grid_size, int): + grid_size = [grid_size, grid_size, grid_size] + if isinstance(grid_size, (list, tuple)): + grid_size = np.array(grid_size) + if isinstance(grid_size, np.ndarray): + grid_size = torch.tensor(grid_size, dtype=torch.int32, device=coords.device) + voxel_size = (aabb[1] - aabb[0]) / grid_size + + # Assertions for dimensions + assert isinstance(voxel_size, torch.Tensor) + assert voxel_size.dim() == 1 and voxel_size.size(0) == 3 + assert isinstance(grid_size, torch.Tensor) + assert grid_size.dim() == 1 and grid_size.size(0) == 3 + + if use_tqdm: + pbar = tqdm(total=6, desc="Extracting GLB") + if verbose: + print(f"Original mesh: {vertices.shape[0]} vertices, {faces.shape[0]} faces") + + # Move data to GPU + vertices = vertices.cuda() + faces = faces.cuda() + + # Initialize CUDA mesh handler + mesh = cumesh.CuMesh() + mesh.init(vertices, faces) + + # --- Initial Mesh Cleaning --- + # Fills holes as much as we can before processing + mesh.fill_holes(max_hole_perimeter=3e-2) + if verbose: + print(f"After filling holes: {mesh.num_vertices} vertices, {mesh.num_faces} faces") + vertices, faces = mesh.read() + if use_tqdm: + pbar.update(1) + + # Build BVH for the current mesh to guide remeshing + if use_tqdm: + pbar.set_description("Building BVH") + if verbose: + print(f"Building BVH for current mesh...", end='', flush=True) + bvh = cumesh.cuBVH(vertices, faces) + if use_tqdm: + pbar.update(1) + if verbose: + print("Done") + + if use_tqdm: + pbar.set_description("Cleaning mesh") + if verbose: + print("Cleaning mesh...") + + # --- Branch 1: Standard Pipeline (Simplification & Cleaning) --- + if not remesh: + # Step 1: Aggressive simplification (3x target) + mesh.simplify(decimation_target * 3, verbose=verbose) + if verbose: + print(f"After inital simplification: {mesh.num_vertices} vertices, {mesh.num_faces} faces") + + # Step 2: Clean up topology (duplicates, non-manifolds, isolated parts) + mesh.remove_duplicate_faces() + mesh.repair_non_manifold_edges() + mesh.remove_small_connected_components(1e-5) + mesh.fill_holes(max_hole_perimeter=3e-2) + if verbose: + print(f"After initial cleanup: {mesh.num_vertices} vertices, {mesh.num_faces} faces") + + # Step 3: Final simplification to target count + mesh.simplify(decimation_target, verbose=verbose) + if verbose: + print(f"After final simplification: {mesh.num_vertices} vertices, {mesh.num_faces} faces") + + # Step 4: Final Cleanup loop + mesh.remove_duplicate_faces() + mesh.repair_non_manifold_edges() + mesh.remove_small_connected_components(1e-5) + mesh.fill_holes(max_hole_perimeter=3e-2) + if verbose: + print(f"After final cleanup: {mesh.num_vertices} vertices, {mesh.num_faces} faces") + + # Step 5: Unify face orientations + mesh.unify_face_orientations() + + # --- Branch 2: Remeshing Pipeline --- + else: + center = aabb.mean(dim=0) + scale = (aabb[1] - aabb[0]).max().item() + resolution = grid_size.max().item() + + # Perform Dual Contouring remeshing (rebuilds topology) + mesh.init(*cumesh.remeshing.remesh_narrow_band_dc( + vertices, faces, + center = center, + scale = (resolution + 3 * remesh_band) / resolution * scale, + resolution = resolution, + band = remesh_band, + project_back = remesh_project, # Snaps vertices back to original surface + verbose = verbose, + bvh = bvh, + )) + if verbose: + print(f"After remeshing: {mesh.num_vertices} vertices, {mesh.num_faces} faces") + + # Simplify and clean the remeshed result (similar logic to above) + mesh.simplify(decimation_target, verbose=verbose) + if verbose: + print(f"After simplifying: {mesh.num_vertices} vertices, {mesh.num_faces} faces") + + if use_tqdm: + pbar.update(1) + if verbose: + print("Done") + + + # --- UV Parameterization --- + if use_tqdm: + pbar.set_description("Parameterizing new mesh") + if verbose: + print("Parameterizing new mesh...") + + out_vertices, out_faces, out_uvs, out_vmaps = mesh.uv_unwrap( + compute_charts_kwargs={ + "threshold_cone_half_angle_rad": mesh_cluster_threshold_cone_half_angle_rad, + "refine_iterations": mesh_cluster_refine_iterations, + "global_iterations": mesh_cluster_global_iterations, + "smooth_strength": mesh_cluster_smooth_strength, + }, + return_vmaps=True, + verbose=verbose, + ) + out_vertices = out_vertices.cuda() + out_faces = out_faces.cuda() + out_uvs = out_uvs.cuda() + out_vmaps = out_vmaps.cuda() + mesh.compute_vertex_normals() + out_normals = mesh.read_vertex_normals()[out_vmaps] + + if use_tqdm: + pbar.update(1) + if verbose: + print("Done") + + # --- Texture Baking (Attribute Sampling) --- + if use_tqdm: + pbar.set_description("Sampling attributes") + if verbose: + print("Sampling attributes...", end='', flush=True) + + # Setup differentiable rasterizer context + ctx = dr.RasterizeCudaContext() + # Prepare UV coordinates for rasterization (rendering in UV space) + uvs_rast = torch.cat([out_uvs * 2 - 1, torch.zeros_like(out_uvs[:, :1]), torch.ones_like(out_uvs[:, :1])], dim=-1).unsqueeze(0) + rast = torch.zeros((1, texture_size, texture_size, 4), device='cuda', dtype=torch.float32) + + # Rasterize in chunks to save memory + for i in range(0, out_faces.shape[0], 100000): + rast_chunk, _ = dr.rasterize( + ctx, uvs_rast, out_faces[i:i+100000], + resolution=[texture_size, texture_size], + ) + mask_chunk = rast_chunk[..., 3:4] > 0 + rast_chunk[..., 3:4] += i # Store face ID in alpha channel + rast = torch.where(mask_chunk, rast_chunk, rast) + + # Mask of valid pixels in texture + mask = rast[0, ..., 3] > 0 + + # Interpolate 3D positions in UV space (finding 3D coord for every texel) + pos = dr.interpolate(out_vertices.unsqueeze(0), rast, out_faces)[0][0] + valid_pos = pos[mask] + + # Map these positions back to the *original* high-res mesh to get accurate attributes + # This corrects geometric errors introduced by simplification/remeshing + _, face_id, uvw = bvh.unsigned_distance(valid_pos, return_uvw=True) + orig_tri_verts = vertices[faces[face_id.long()]] # (N_new, 3, 3) + valid_pos = (orig_tri_verts * uvw.unsqueeze(-1)).sum(dim=1) + + # Trilinear sampling from the attribute volume (Color, Material props) + attrs = torch.zeros(texture_size, texture_size, attr_volume.shape[1], device='cuda') + attrs[mask] = grid_sample_3d( + attr_volume, + torch.cat([torch.zeros_like(coords[:, :1]), coords], dim=-1), + shape=torch.Size([1, attr_volume.shape[1], *grid_size.tolist()]), + grid=((valid_pos - aabb[0]) / voxel_size).reshape(1, -1, 3), + mode='trilinear', + ) + if use_tqdm: + pbar.update(1) + if verbose: + print("Done") + + # --- Texture Post-Processing & Material Construction --- + if use_tqdm: + pbar.set_description("Finalizing mesh") + if verbose: + print("Finalizing mesh...", end='', flush=True) + + mask = mask.cpu().numpy() + + # Extract channels based on layout (BaseColor, Metallic, Roughness, Alpha) + base_color = np.clip(attrs[..., attr_layout['base_color']].cpu().numpy() * 255, 0, 255).astype(np.uint8) + metallic = np.clip(attrs[..., attr_layout['metallic']].cpu().numpy() * 255, 0, 255).astype(np.uint8) + roughness = np.clip(attrs[..., attr_layout['roughness']].cpu().numpy() * 255, 0, 255).astype(np.uint8) + alpha = np.clip(attrs[..., attr_layout['alpha']].cpu().numpy() * 255, 0, 255).astype(np.uint8) + alpha_mode = 'OPAQUE' + + # Inpainting: fill gaps (dilation) to prevent black seams at UV boundaries + mask_inv = (~mask).astype(np.uint8) + base_color = cv2.inpaint(base_color, mask_inv, 3, cv2.INPAINT_TELEA) + metallic = cv2.inpaint(metallic, mask_inv, 1, cv2.INPAINT_TELEA)[..., None] + roughness = cv2.inpaint(roughness, mask_inv, 1, cv2.INPAINT_TELEA)[..., None] + alpha = cv2.inpaint(alpha, mask_inv, 1, cv2.INPAINT_TELEA)[..., None] + + # Create PBR material + # Standard PBR packs Metallic and Roughness into Blue and Green channels + material = trimesh.visual.material.PBRMaterial( + baseColorTexture=Image.fromarray(np.concatenate([base_color, alpha], axis=-1)), + baseColorFactor=np.array([255, 255, 255, 255], dtype=np.uint8), + metallicRoughnessTexture=Image.fromarray(np.concatenate([np.zeros_like(metallic), roughness, metallic], axis=-1)), + metallicFactor=1.0, + roughnessFactor=1.0, + alphaMode=alpha_mode, + doubleSided=True if not remesh else False, + ) + + # --- Coordinate System Conversion & Final Object --- + vertices_np = out_vertices.cpu().numpy() + faces_np = out_faces.cpu().numpy() + uvs_np = out_uvs.cpu().numpy() + normals_np = out_normals.cpu().numpy() + + # Swap Y and Z axes, invert Y (common conversion for GLB compatibility) + vertices_np[:, 1], vertices_np[:, 2] = vertices_np[:, 2], -vertices_np[:, 1] + normals_np[:, 1], normals_np[:, 2] = normals_np[:, 2], -normals_np[:, 1] + uvs_np[:, 1] = 1 - uvs_np[:, 1] # Flip UV V-coordinate + + textured_mesh = trimesh.Trimesh( + vertices=vertices_np, + faces=faces_np, + vertex_normals=normals_np, + process=False, + visual=trimesh.visual.TextureVisuals(uv=uvs_np, material=material) + ) + + if use_tqdm: + pbar.update(1) + pbar.close() + if verbose: + print("Done") + + return textured_mesh \ No newline at end of file diff --git a/o-voxel/o_voxel/rasterize.py b/o-voxel/o_voxel/rasterize.py new file mode 100644 index 0000000..c27134a --- /dev/null +++ b/o-voxel/o_voxel/rasterize.py @@ -0,0 +1,111 @@ +import torch +import torch.nn.functional as F +from easydict import EasyDict as edict +from . import _C + + +def intrinsics_to_projection( + intrinsics: torch.Tensor, + near: float, + far: float, + ) -> torch.Tensor: + """ + OpenCV intrinsics to OpenGL perspective matrix + + Args: + intrinsics (torch.Tensor): [3, 3] OpenCV intrinsics matrix + near (float): near plane to clip + far (float): far plane to clip + Returns: + (torch.Tensor): [4, 4] OpenGL perspective matrix + """ + fx, fy = intrinsics[0, 0], intrinsics[1, 1] + cx, cy = intrinsics[0, 2], intrinsics[1, 2] + ret = torch.zeros((4, 4), dtype=intrinsics.dtype, device=intrinsics.device) + ret[0, 0] = 2 * fx + ret[1, 1] = 2 * fy + ret[0, 2] = 2 * cx - 1 + ret[1, 2] = - 2 * cy + 1 + ret[2, 2] = far / (far - near) + ret[2, 3] = near * far / (near - far) + ret[3, 2] = 1. + return ret + + +class VoxelRenderer: + """ + Renderer for the Voxel representation. + + Args: + rendering_options (dict): Rendering options. + """ + + def __init__(self, rendering_options={}) -> None: + self.rendering_options = edict({ + "resolution": None, + "near": 0.1, + "far": 10.0, + "ssaa": 1, + }) + self.rendering_options.update(rendering_options) + + def render( + self, + position: torch.Tensor, + attrs: torch.Tensor, + voxel_size: float, + extrinsics: torch.Tensor, + intrinsics: torch.Tensor, + ) -> edict: + """ + Render the octree. + + Args: + position (torch.Tensor): (N, 3) xyz positions + attrs (torch.Tensor): (N, C) attributes + voxel_size (float): voxel size + extrinsics (torch.Tensor): (4, 4) camera extrinsics + intrinsics (torch.Tensor): (3, 3) camera intrinsics + + Returns: + edict containing: + attr (torch.Tensor): (C, H, W) rendered color + depth (torch.Tensor): (H, W) rendered depth + alpha (torch.Tensor): (H, W) rendered alpha + """ + resolution = self.rendering_options["resolution"] + near = self.rendering_options["near"] + far = self.rendering_options["far"] + ssaa = self.rendering_options["ssaa"] + + view = extrinsics + perspective = intrinsics_to_projection(intrinsics, near, far) + camera = torch.inverse(view)[:3, 3] + focalx = intrinsics[0, 0] + focaly = intrinsics[1, 1] + args = ( + position, + attrs, + voxel_size, + view.T.contiguous(), + (perspective @ view).T.contiguous(), + camera, + 0.5 / focalx, + 0.5 / focaly, + resolution * ssaa, + resolution * ssaa, + ) + color, depth, alpha = _C.rasterize_voxels_cuda(*args) + + if ssaa > 1: + color = F.interpolate(color[None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze() + depth = F.interpolate(depth[None, None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze() + alpha = F.interpolate(alpha[None, None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze() + + ret = edict({ + 'attr': color, + 'depth': depth, + 'alpha': alpha, + }) + return ret + \ No newline at end of file diff --git a/o-voxel/o_voxel/serialize.py b/o-voxel/o_voxel/serialize.py new file mode 100644 index 0000000..452b5dc --- /dev/null +++ b/o-voxel/o_voxel/serialize.py @@ -0,0 +1,68 @@ +from typing import * +import torch +from . import _C + + +@torch.no_grad() +def encode_seq(coords: torch.Tensor, permute: List[int] = [0, 1, 2], mode: Literal['z_order', 'hilbert'] = 'z_order') -> torch.Tensor: + """ + Encodes 3D coordinates into a 30-bit code. + + Args: + coords: a tensor of shape [N, 3] containing the 3D coordinates. + permute: the permutation of the coordinates. + mode: the encoding mode to use. + """ + assert coords.shape[-1] == 3 and coords.ndim == 2, "Input coordinates must be of shape [N, 3]" + x = coords[:, permute[0]].int() + y = coords[:, permute[1]].int() + z = coords[:, permute[2]].int() + if mode == 'z_order': + if coords.device.type == 'cpu': + return _C.z_order_encode_cpu(x, y, z) + elif coords.device.type == 'cuda': + return _C.z_order_encode_cuda(x, y, z) + else: + raise ValueError(f"Unsupported device type: {coords.device.type}") + elif mode == 'hilbert': + if coords.device.type == 'cpu': + return _C.hilbert_encode_cpu(x, y, z) + elif coords.device.type == 'cuda': + return _C.hilbert_encode_cuda(x, y, z) + else: + raise ValueError(f"Unsupported device type: {coords.device.type}") + else: + raise ValueError(f"Unknown encoding mode: {mode}") + + +@torch.no_grad() +def decode_seq(code: torch.Tensor, permute: List[int] = [0, 1, 2], mode: Literal['z_order', 'hilbert'] = 'z_order') -> torch.Tensor: + """ + Decodes a 30-bit code into 3D coordinates. + + Args: + code: a tensor of shape [N] containing the 30-bit code. + permute: the permutation of the coordinates. + mode: the decoding mode to use. + """ + assert code.ndim == 1, "Input code must be of shape [N]" + if mode == 'z_order': + if code.device.type == 'cpu': + coords = _C.z_order_decode_cpu(code) + elif code.device.type == 'cuda': + coords = _C.z_order_decode_cuda(code) + else: + raise ValueError(f"Unsupported device type: {code.device.type}") + elif mode == 'hilbert': + if code.device.type == 'cpu': + coords = _C.hilbert_decode_cpu(code) + elif code.device.type == 'cuda': + coords = _C.hilbert_decode_cuda(code) + else: + raise ValueError(f"Unsupported device type: {code.device.type}") + else: + raise ValueError(f"Unknown decoding mode: {mode}") + x = coords[permute.index(0)] + y = coords[permute.index(1)] + z = coords[permute.index(2)] + return torch.stack([x, y, z], dim=-1) diff --git a/o-voxel/pyproject.toml b/o-voxel/pyproject.toml new file mode 100644 index 0000000..6b13d43 --- /dev/null +++ b/o-voxel/pyproject.toml @@ -0,0 +1,34 @@ +[build-system] +requires = [ + "setuptools>=64", + "wheel", + "torch", + "numpy", + "plyfile", + "trimesh", + "tqdm", + "zstandard", + "easydict" +] +build-backend = "setuptools.build_meta" + + +[project] +name = "o_voxel" +version = "0.0.1" +description = "All about voxel." +requires-python = ">=3.8" +authors = [ + { name = "Jianfeng Xiang", email = "belljig@outlook.com" } +] +dependencies = [ + "torch", + "numpy", + "plyfile", + "trimesh", + "tqdm", + "zstandard", + "easydict", + "cumesh @ git+https://github.com/JeffreyXiang/CuMesh.git", + "flex_gemm @ git+https://github.com/JeffreyXiang/FlexGEMM.git", +] diff --git a/o-voxel/setup.py b/o-voxel/setup.py new file mode 100644 index 0000000..91cb5ce --- /dev/null +++ b/o-voxel/setup.py @@ -0,0 +1,67 @@ +from setuptools import setup +from torch.utils.cpp_extension import CUDAExtension, BuildExtension, IS_HIP_EXTENSION +import os + +ROOT = os.path.dirname(os.path.abspath(__file__)) +BUILD_TARGET = os.environ.get("BUILD_TARGET", "auto") + +if BUILD_TARGET == "auto": + if IS_HIP_EXTENSION: + IS_HIP = True + else: + IS_HIP = False +else: + if BUILD_TARGET == "cuda": + IS_HIP = False + elif BUILD_TARGET == "rocm": + IS_HIP = True + +if not IS_HIP: + cc_flag = [] +else: + archs = os.getenv("GPU_ARCHS", "native").split(";") + cc_flag = [f"--offload-arch={arch}" for arch in archs] + +setup( + name="o_voxel", + packages=[ + 'o_voxel', + 'o_voxel.convert', + 'o_voxel.io', + ], + ext_modules=[ + CUDAExtension( + name="o_voxel._C", + sources=[ + # Hashmap functions + "src/hash/hash.cu", + # Convert functions + "src/convert/flexible_dual_grid.cpp", + "src/convert/volumetic_attr.cpp", + ## Serialization functions + "src/serialize/api.cu", + "src/serialize/hilbert.cu", + "src/serialize/z_order.cu", + # IO functions + "src/io/svo.cpp", + "src/io/filter_parent.cpp", + "src/io/filter_neighbor.cpp", + # Rasterization functions + "src/rasterize/rasterize.cu", + + # main + "src/ext.cpp", + ], + include_dirs=[ + os.path.join(ROOT, "third_party/eigen"), + ], + extra_compile_args={ + "cxx": ["-O3", "-std=c++17"], + "nvcc": ["-O3","-std=c++17"] + cc_flag, + } + ) + ], + cmdclass={ + 'build_ext': BuildExtension + } +) diff --git a/o-voxel/src/convert/api.h b/o-voxel/src/convert/api.h new file mode 100644 index 0000000..b70551c --- /dev/null +++ b/o-voxel/src/convert/api.h @@ -0,0 +1,122 @@ +/* + * O-Voxel Convertion API + * + * Copyright (C) 2025, Jianfeng XIANG + * All rights reserved. + * + * Licensed under The MIT License [see LICENSE for details] + * + * Written by Jianfeng XIANG + */ + +#pragma once +#include + + +/** + * Extract flexible dual grid from a triangle mesh. + * + * @param vertices: Tensor of shape (N, 3) containing vertex positions. + * @param faces: Tensor of shape (M, 3) containing triangle vertex indices. + * @param voxel_size: Tensor of shape (3,) containing the voxel size in each dimension. + * @param grid_range: Tensor of shape (2, 3) containing the minimum and maximum coordinates of the grid range. + * @param face_weight: Weight for the face edges in the QEM computation. + * @param boundary_weight: Weight for the boundary edges in the QEM computation. + * @param regularization_weight: Regularization factor to apply to the QEM matrices. + * @param timing: Boolean flag to indicate whether to print timing information. + * + * @return a tuple ((x, y, z), vertices, intersected, faces) containing the remeshed vertices and the corresponding voxel grid. + */ +std::tuple mesh_to_flexible_dual_grid_cpu( + const torch::Tensor& vertices, + const torch::Tensor& faces, + const torch::Tensor& voxel_size, + const torch::Tensor& grid_range, + float face_weight, + float boundary_weight, + float regularization_weight, + bool timing +); + + +/** + * Voxelizes a triangle mesh with PBR materials + * + * @param voxel_size [3] tensor containing the size of a voxel + * @param grid_range [6] tensor containing the size of the grid + * @param vertices [N_tri, 3, 3] array containing the triangle vertices + * @param normals [N_tri, 3, 3] array containing the triangle vertex normals + * @param uvs [N_tri, 3, 2] tensor containing the texture coordinates + * @param materialIds [N_tri] tensor containing the material ids + * @param baseColorFactor list of [3] tensor containing the base color factor + * @param baseColorTexture list of [H, W, 3] tensor containing the base color texture + * @param baseColorTextureFilter list of int indicating the base color texture filter (0: NEAREST, 1: LINEAR) + * @param baseColorTextureWrap list of int indicating the base color texture wrap (0: REPEAT, 1: CLAMP_TO_EDGE, 2: MIRRORED_REPEAT) + * @param metallicFactor list of float containing the metallic factor + * @param metallicTexture list of [H, W] tensor containing the metallic texture + * @param metallicTextureFilter list of int indicating the metallic texture filter (0: NEAREST, 1: LINEAR) + * @param metallicTextureWrap list of int indicating the metallic texture wrap (0: REPEAT, 1: CLAMP_TO_EDGE, 2: MIRRORED_REPEAT) + * @param roughnessFactor list of float containing the roughness factor + * @param roughnessTexture list of [H, W] tensor containing the roughness texture + * @param roughnessTextureFilter list of int indicating the roughness texture filter (0: NEAREST, 1: LINEAR) + * @param roughnessTextureWrap list of int indicating the roughness texture wrap (0: REPEAT, 1: CLAMP_TO_EDGE, 2: MIRRORED_REPEAT) + * @param emissiveFactor list of [3] tensor containing the emissive factor + * @param emissiveTexture list of [H, W, 3] tensor containing the emissive texture + * @param emissiveTextureFilter list of int indicating the emissive texture filter (0: NEAREST, 1: LINEAR) + * @param emissiveTextureWrap list of int indicating the emissive texture wrap (0: REPEAT, 1: CLAMP_TO_EDGE, 2: MIRRORED_REPEAT) + * @param alphaMode list of int indicating the alpha mode (0: OPAQUE, 1: MASK, 2: BLEND) + * @param alphaCutoff list of float containing the alpha cutoff + * @param alphaFactor list of float containing the alpha factor + * @param alphaTexture list of [H, W] tensor containing the alpha texture + * @param alphaTextureFilter list of int indicating the alpha texture filter (0: NEAREST, 1: LINEAR) + * @param alphaTextureWrap list of int indicating the alpha texture wrap (0: REPEAT, 1: CLAMP_TO_EDGE, 2: MIRRORED_REPEAT) + * @param normalTexture list of [H, W, 3] tensor containing the normal texture + * @param normalTextureFilter list of int indicating the normal texture filter (0: NEAREST, 1: LINEAR) + * @param normalTextureWrap list of int indicating the normal texture wrap (0: REPEAT, 1: CLAMP_TO_EDGE, 2: MIRRORED_REPEAT) + * @param mipLevelOffset float indicating the mip level offset for texture mipmap + * + * @return tuple containing: + * - coords: tensor of shape [N, 3] containing the voxel coordinates + * - out_baseColor: tensor of shape [N, 3] containing the base color of each voxel + * - out_metallic: tensor of shape [N, 1] containing the metallic of each voxel + * - out_roughness: tensor of shape [N, 1] containing the roughness of each voxel + * - out_emissive: tensor of shape [N, 3] containing the emissive of each voxel + * - out_alpha: tensor of shape [N, 1] containing the alpha of each voxel + * - out_normal: tensor of shape [N, 3] containing the normal of each voxel + */ +std::tuple +textured_mesh_to_volumetric_attr_cpu( + const torch::Tensor& voxel_size, + const torch::Tensor& grid_range, + const torch::Tensor& vertices, + const torch::Tensor& normals, + const torch::Tensor& uvs, + const torch::Tensor& materialIds, + const std::vector& baseColorFactor, + const std::vector& baseColorTexture, + const std::vector& baseColorTextureFilter, + const std::vector& baseColorTextureWrap, + const std::vector& metallicFactor, + const std::vector& metallicTexture, + const std::vector& metallicTextureFilter, + const std::vector& metallicTextureWrap, + const std::vector& roughnessFactor, + const std::vector& roughnessTexture, + const std::vector& roughnessTextureFilter, + const std::vector& roughnessTextureWrap, + const std::vector& emissiveFactor, + const std::vector& emissiveTexture, + const std::vector& emissiveTextureFilter, + const std::vector& emissiveTextureWrap, + const std::vector& alphaMode, + const std::vector& alphaCutoff, + const std::vector& alphaFactor, + const std::vector& alphaTexture, + const std::vector& alphaTextureFilter, + const std::vector& alphaTextureWrap, + const std::vector& normalTexture, + const std::vector& normalTextureFilter, + const std::vector& normalTextureWrap, + const float mipLevelOffset, + const bool timing +); diff --git a/o-voxel/src/convert/flexible_dual_grid.cpp b/o-voxel/src/convert/flexible_dual_grid.cpp new file mode 100644 index 0000000..ad89edc --- /dev/null +++ b/o-voxel/src/convert/flexible_dual_grid.cpp @@ -0,0 +1,775 @@ +#include +#include +#include +#include +#include + +#include "api.h" + + +constexpr size_t kInvalidIndex = std::numeric_limits::max(); + + +struct float3 {float x, y, z; float& operator[](int i) {return (&x)[i];}}; +struct int3 {int x, y, z; int& operator[](int i) {return (&x)[i];}}; +struct int4 {int x, y, z, w; int& operator[](int i) {return (&x)[i];}}; +struct bool3 {bool x, y, z; bool& operator[](int i) {return (&x)[i];}}; + + +template +static inline U lerp(const T& a, const T& b, const T& t, const U& val_a, const U& val_b) { + if (a == b) return val_a; // Avoid divide by zero + T alpha = (t - a) / (b - a); + return (1 - alpha) * val_a + alpha * val_b; +} + + +template +static auto get_or_default(const Map& map, const Key& key, const Default& default_val) -> typename Map::mapped_type { + auto it = map.find(key); + return (it != map.end()) ? it->second : default_val; +} + + +// 3D voxel coordinate +struct VoxelCoord { + int x, y, z; + + int& operator[](int i) { + return (&x)[i]; + } + + bool operator==(const VoxelCoord& other) const { + return x == other.x && y == other.y && z == other.z; + } +}; + +// Hash function for VoxelCoord to use in unordered_map +namespace std { +template <> +struct hash { + size_t operator()(const VoxelCoord& v) const { + const std::size_t p1 = 73856093; + const std::size_t p2 = 19349663; + const std::size_t p3 = 83492791; + return (std::size_t)(v.x) * p1 ^ (std::size_t)(v.y) * p2 ^ (std::size_t)(v.z) * p3; + } +}; +} + + +void intersect_qef( + const Eigen::Vector3f& voxel_size, + const Eigen::Vector3i& grid_min, + const Eigen::Vector3i& grid_max, + const std::vector& triangles, // 3 vertices per triangle + std::unordered_map& hash_table, // Hash table for voxel lookup + std::vector& voxels, // Output: Voxel coordinates + std::vector& means, // Output: Mean vertex positions for each voxel + std::vector& cnt, // Output: Number of intersections for each voxel + std::vector& intersected, // Output: Whether edge of voxel intersects with triangle + std::vector& qefs // Output: QEF matrices for each voxel +) { + const size_t N_tri = triangles.size() / 3; + + for (size_t i = 0; i < N_tri; ++i) { + const Eigen::Vector3f& v0 = triangles[i * 3 + 0]; + const Eigen::Vector3f& v1 = triangles[i * 3 + 1]; + const Eigen::Vector3f& v2 = triangles[i * 3 + 2]; + + // Compute edge vectors and face normal + Eigen::Vector3f e0 = v1 - v0; + Eigen::Vector3f e1 = v2 - v1; + Eigen::Vector3f n = e0.cross(e1).normalized(); + Eigen::Vector4f plane; + plane << n.x(), n.y(), n.z(), -n.dot(v0); + auto Q = plane * plane.transpose(); + + // Scan-line algorithm to find intersections with the voxel grid from three directions + /* + t0 + | \ + | t1 + | / + t2 + */ + auto scan_line_fill = [&] (const int ax2) { + int ax0 = (ax2 + 1) % 3; + int ax1 = (ax2 + 2) % 3; + + // Canonical question + std::array t = { + Eigen::Vector3d(v0[ax0], v0[ax1], v0[ax2]), + Eigen::Vector3d(v1[ax0], v1[ax1], v1[ax2]), + Eigen::Vector3d(v2[ax0], v2[ax1], v2[ax2]) + }; + std::sort(t.begin(), t.end(), [](const Eigen::Vector3d& a, const Eigen::Vector3d& b) { return a.y() < b.y(); }); + + // Scan-line algorithm + int start = std::clamp(int(t[0].y() / voxel_size[ax1]), grid_min[ax1], grid_max[ax1] - 1); + int mid = std::clamp(int(t[1].y() / voxel_size[ax1]), grid_min[ax1], grid_max[ax1] - 1); + int end = std::clamp(int(t[2].y() / voxel_size[ax1]), grid_min[ax1], grid_max[ax1] - 1); + + auto scan_line_half = [&] (const int row_start, const int row_end, const Eigen::Vector3d t0, const Eigen::Vector3d t1, const Eigen::Vector3d t2) { + /* + t0 + | \ + t3-t4 + | \ + t1---t2 + */ + for (int y_idx = row_start; y_idx < row_end; ++y_idx) { + double y = (y_idx + 1) * voxel_size[ax1]; + Eigen::Vector2d t3 = lerp(t0.y(), t1.y(), y, Eigen::Vector2d(t0.x(), t0.z()), Eigen::Vector2d(t1.x(), t1.z())); + Eigen::Vector2d t4 = lerp(t0.y(), t2.y(), y, Eigen::Vector2d(t0.x(), t0.z()), Eigen::Vector2d(t2.x(), t2.z())); + if (t3.x() > t4.x()) std::swap(t3, t4); + int line_start = std::clamp(int(t3.x() / voxel_size[ax0]), grid_min[ax0], grid_max[ax0] - 1); + int line_end = std::clamp(int(t4.x() / voxel_size[ax0]), grid_min[ax0], grid_max[ax0] - 1); + for (int x_idx = line_start; x_idx < line_end; ++x_idx) { + double x = (x_idx + 1) * voxel_size[ax0]; + double z = lerp(t3.x(), t4.x(), x, t3.y(), t4.y()); + int z_idx = int(z / voxel_size[ax2]); + if (z_idx >= grid_min[ax2] && z_idx < grid_max[ax2]) { + // For 4-connected voxels + for (int dx = 0; dx < 2; ++dx) { + for (int dy = 0; dy < 2; ++dy) { + VoxelCoord coord; + coord[ax0] = x_idx + dx; coord[ax1] = y_idx + dy; coord[ax2] = z_idx; + Eigen::Vector3d intersect; + intersect[ax0] = x; intersect[ax1] = y; intersect[ax2] = z; + auto kv = hash_table.find(coord); + if (kv == hash_table.end()) { + hash_table[coord] = voxels.size(); + voxels.push_back({coord.x, coord.y, coord.z}); + means.push_back(intersect.cast()); + cnt.push_back(1); + intersected.push_back({false, false, false}); + qefs.push_back(Q); + if (dx == 0 && dy == 0) + intersected.back()[ax2] = true; + } + else { + auto i = kv->second; + means[i] += intersect.cast(); + cnt[i] += 1; + if (dx == 0 && dy == 0) + intersected[i][ax2] = true; + qefs[i] += Q; + } + } + } + } + } + } + }; + scan_line_half(start, mid, t[0], t[1], t[2]); + scan_line_half(mid, end, t[2], t[1], t[0]); + }; + scan_line_fill(0); + scan_line_fill(1); + scan_line_fill(2); + } +} + + +void face_qef( + const Eigen::Vector3f& voxel_size, + const Eigen::Vector3i& grid_min, + const Eigen::Vector3i& grid_max, + const std::vector& triangles, // 3 vertices per triangle + std::unordered_map& hash_table, // Hash table for voxel lookup + std::vector& qefs // Output: QEF matrices for each voxel +) { + const size_t N_tri = triangles.size() / 3; + + for (size_t i = 0; i < N_tri; ++i) { + const Eigen::Vector3f& v0 = triangles[i * 3 + 0]; + const Eigen::Vector3f& v1 = triangles[i * 3 + 1]; + const Eigen::Vector3f& v2 = triangles[i * 3 + 2]; + + // Compute edge vectors and face normal + Eigen::Vector3f e0 = v1 - v0; + Eigen::Vector3f e1 = v2 - v1; + Eigen::Vector3f e2 = v0 - v2; + Eigen::Vector3f n = e0.cross(e1).normalized(); + Eigen::Vector4f plane; + plane << n.x(), n.y(), n.z(), -n.dot(v0); + auto Q = plane * plane.transpose(); + + // Compute triangle bounding box in voxel coordinates + Eigen::Vector3f bb_min_f = v0.cwiseMin(v1).cwiseMin(v2).cwiseQuotient(voxel_size); + Eigen::Vector3f bb_max_f = v0.cwiseMax(v1).cwiseMax(v2).cwiseQuotient(voxel_size); + + Eigen::Vector3i bb_min(std::max(static_cast(bb_min_f.x()), grid_min.x()), + std::max(static_cast(bb_min_f.y()), grid_min.y()), + std::max(static_cast(bb_min_f.z()), grid_min.z())); + Eigen::Vector3i bb_max(std::min(static_cast(bb_max_f.x() + 1), grid_max.x()), + std::min(static_cast(bb_max_f.y() + 1), grid_max.y()), + std::min(static_cast(bb_max_f.z() + 1), grid_max.z())); + + // Plane test setup + Eigen::Vector3f c( + n.x() > 0.0f ? voxel_size.x() : 0.0f, + n.y() > 0.0f ? voxel_size.y() : 0.0f, + n.z() > 0.0f ? voxel_size.z() : 0.0f + ); + float d1 = n.dot(c - v0); + float d2 = n.dot(voxel_size - c - v0); + + // XY plane projection test setup + int mul_xy = (n.z() < 0.0f) ? -1 : 1; + Eigen::Vector2f n_xy_e0(-mul_xy * e0.y(), mul_xy * e0.x()); + Eigen::Vector2f n_xy_e1(-mul_xy * e1.y(), mul_xy * e1.x()); + Eigen::Vector2f n_xy_e2(-mul_xy * e2.y(), mul_xy * e2.x()); + + float d_xy_e0 = -n_xy_e0.dot(v0.head<2>()) + n_xy_e0.cwiseMax(0.0f).dot(voxel_size.head<2>()); + float d_xy_e1 = -n_xy_e1.dot(v1.head<2>()) + n_xy_e1.cwiseMax(0.0f).dot(voxel_size.head<2>()); + float d_xy_e2 = -n_xy_e2.dot(v2.head<2>()) + n_xy_e2.cwiseMax(0.0f).dot(voxel_size.head<2>()); + + // YZ plane projection test setup + int mul_yz = (n.x() < 0.0f) ? -1 : 1; + Eigen::Vector2f n_yz_e0(-mul_yz * e0.z(), mul_yz * e0.y()); + Eigen::Vector2f n_yz_e1(-mul_yz * e1.z(), mul_yz * e1.y()); + Eigen::Vector2f n_yz_e2(-mul_yz * e2.z(), mul_yz * e2.y()); + + float d_yz_e0 = -n_yz_e0.dot(Eigen::Vector2f(v0.y(), v0.z())) + n_yz_e0.cwiseMax(0.0f).dot(Eigen::Vector2f(voxel_size.y(), voxel_size.z())); + float d_yz_e1 = -n_yz_e1.dot(Eigen::Vector2f(v1.y(), v1.z())) + n_yz_e1.cwiseMax(0.0f).dot(Eigen::Vector2f(voxel_size.y(), voxel_size.z())); + float d_yz_e2 = -n_yz_e2.dot(Eigen::Vector2f(v2.y(), v2.z())) + n_yz_e2.cwiseMax(0.0f).dot(Eigen::Vector2f(voxel_size.y(), voxel_size.z())); + + // ZX plane projection test setup + int mul_zx = (n.y() < 0.0f) ? -1 : 1; + Eigen::Vector2f n_zx_e0(-mul_zx * e0.x(), mul_zx * e0.z()); + Eigen::Vector2f n_zx_e1(-mul_zx * e1.x(), mul_zx * e1.z()); + Eigen::Vector2f n_zx_e2(-mul_zx * e2.x(), mul_zx * e2.z()); + + float d_zx_e0 = -n_zx_e0.dot(Eigen::Vector2f(v0.z(), v0.x())) + n_zx_e0.cwiseMax(0.0f).dot(Eigen::Vector2f(voxel_size.z(), voxel_size.x())); + float d_zx_e1 = -n_zx_e1.dot(Eigen::Vector2f(v1.z(), v1.x())) + n_zx_e1.cwiseMax(0.0f).dot(Eigen::Vector2f(voxel_size.z(), voxel_size.x())); + float d_zx_e2 = -n_zx_e2.dot(Eigen::Vector2f(v2.z(), v2.x())) + n_zx_e2.cwiseMax(0.0f).dot(Eigen::Vector2f(voxel_size.z(), voxel_size.x())); + + // Loop over candidate voxels inside bounding box + for (int z = bb_min.z(); z < bb_max.z(); ++z) { + for (int y = bb_min.y(); y < bb_max.y(); ++y) { + for (int x = bb_min.x(); x < bb_max.x(); ++x) { + // Voxel center + Eigen::Vector3f p = voxel_size.cwiseProduct(Eigen::Vector3f(x, y, z)); + + // Plane through box test + float nDOTp = n.dot(p); + if (((nDOTp + d1) * (nDOTp + d2)) > 0.0f) continue; + + // XY projection test + Eigen::Vector2f p_xy(p.x(), p.y()); + if (n_xy_e0.dot(p_xy) + d_xy_e0 < 0) continue; + if (n_xy_e1.dot(p_xy) + d_xy_e1 < 0) continue; + if (n_xy_e2.dot(p_xy) + d_xy_e2 < 0) continue; + + // YZ projection test + Eigen::Vector2f p_yz(p.y(), p.z()); + if (n_yz_e0.dot(p_yz) + d_yz_e0 < 0) continue; + if (n_yz_e1.dot(p_yz) + d_yz_e1 < 0) continue; + if (n_yz_e2.dot(p_yz) + d_yz_e2 < 0) continue; + + // ZX projection test + Eigen::Vector2f p_zx(p.z(), p.x()); + if (n_zx_e0.dot(p_zx) + d_zx_e0 < 0) continue; + if (n_zx_e1.dot(p_zx) + d_zx_e1 < 0) continue; + if (n_zx_e2.dot(p_zx) + d_zx_e2 < 0) continue; + + // Passed all tests — mark voxel + auto coord = VoxelCoord{x, y, z}; + auto kv = hash_table.find(coord); + if (kv != hash_table.end()) { + qefs[kv->second] += Q; + } + } + } + } + } +} + + +void boundry_qef( + const Eigen::Vector3f& voxel_size, + const Eigen::Vector3i& grid_min, + const Eigen::Vector3i& grid_max, + const std::vector& boundries, // 2 vertices per segment + const float boundary_weight, // Weight for boundary edges + std::unordered_map& hash_table, // Hash table for voxel lookup + std::vector& qefs // Output: QEF matrices for each voxel +) { + for (size_t i = 0; i < boundries.size() / 2; ++i) { + const Eigen::Vector3f& v0 = boundries[i * 2 + 0]; + const Eigen::Vector3f& v1 = boundries[i * 2 + 1]; + + // Calculate the QEF for the edge (boundary) defined by v0 and v1 + Eigen::Vector3d dir(v1.x() - v0.x(), v1.y() - v0.y(), v1.z() - v0.z()); + double segment_length = dir.norm(); + if (segment_length < 1e-6d) continue; // Skip degenerate edges (zero-length) + dir.normalize(); // unit direction vector + + // Projection matrix orthogonal to the direction: I - d d^T + Eigen::Matrix3f A = Eigen::Matrix3f::Identity() - (dir * dir.transpose()).cast(); + + // b = -A * v0 + Eigen::Vector3f b = -A * v0; + + // c = v0^T * A * v0 + float c = v0.transpose() * A * v0; + + // Now pack this into a 4x4 QEF matrix + Eigen::Matrix4f Q = Eigen::Matrix4f::Zero(); + Q.block<3, 3>(0, 0) = A; + Q.block<3, 1>(0, 3) = b; + Q.block<1, 3>(3, 0) = b.transpose(); + Q(3, 3) = c; + + // DDA Traversal logic directly inside the function + + // Starting and ending voxel coordinates + Eigen::Vector3i v0_voxel = (v0.cwiseQuotient(voxel_size)).array().floor().cast(); + Eigen::Vector3i v1_voxel = (v1.cwiseQuotient(voxel_size)).array().floor().cast(); + + // Determine step direction for each axis based on the line direction + Eigen::Vector3i step = (dir.array() > 0).select(Eigen::Vector3i(1, 1, 1), Eigen::Vector3i(-1, -1, -1)); + + Eigen::Vector3d tMax, tDelta; + for (int axis = 0; axis < 3; ++axis) { + if (dir[axis] == 0.0d) { + tMax[axis] = std::numeric_limits::infinity(); + tDelta[axis] = std::numeric_limits::infinity(); + } else { + float voxel_border = voxel_size[axis] * (v0_voxel[axis] + (step[axis] > 0 ? 1 : 0)); + tMax[axis] = (voxel_border - v0[axis]) / dir[axis]; + tDelta[axis] = voxel_size[axis] / std::abs(dir[axis]); + } + } + + // Current voxel position + Eigen::Vector3i current = v0_voxel; + + // Store the voxel we start at + std::vector voxels; + voxels.push_back({current.x(), current.y(), current.z()}); + + // Traverse the voxels + while (true) { + int axis; + if (tMax.x() < tMax.y()) { + axis = (tMax.x() < tMax.z()) ? 0 : 2; + } else { + axis = (tMax.y() < tMax.z()) ? 1 : 2; + } + + if (tMax[axis] > segment_length) break; + + current[axis] += step[axis]; + tMax[axis] += tDelta[axis]; + + voxels.push_back({current.x(), current.y(), current.z()}); + } + + // Accumulate QEF for each voxel passed through + for (const auto& coord : voxels) { + // Make sure the voxel is within bounds + if ((coord.x < grid_min.x() || coord.x >= grid_max.x()) || + (coord.y < grid_min.y() || coord.y >= grid_max.y()) || + (coord.z < grid_min.z() || coord.z >= grid_max.z())) continue; + if (!hash_table.count(coord)) continue; // Skip if voxel not in hash table + + // Accumulate the QEF for this voxel + qefs[hash_table[coord]] += boundary_weight * Q; // Scale by boundary weight + } + } +} + + +std::array quad_to_2tri( + const std::vector& vertices, + const int4& quad_indices +) { + int ia = quad_indices.x; + int ib = quad_indices.y; + int ic = quad_indices.z; + int id = quad_indices.w; + + Eigen::Vector3f a(vertices[ia].x, vertices[ia].y, vertices[ia].z); + Eigen::Vector3f b(vertices[ib].x, vertices[ib].y, vertices[ib].z); + Eigen::Vector3f c(vertices[ic].x, vertices[ic].y, vertices[ic].z); + Eigen::Vector3f d(vertices[id].x, vertices[id].y, vertices[id].z); + + // diagonal AC + Eigen::Vector3f n_abc = (b - a).cross(c - a).normalized(); + Eigen::Vector3f n_acd = (c - a).cross(d - a).normalized(); + float angle_ac = std::acos(std::clamp(n_abc.dot(n_acd), -1.0f, 1.0f)); + + // diagonal BD + Eigen::Vector3f n_abd = (b - a).cross(d - a).normalized(); + Eigen::Vector3f n_bcd = (c - b).cross(d - b).normalized(); + float angle_bd = std::acos(std::clamp(n_abd.dot(n_bcd), -1.0f, 1.0f)); + + if (angle_ac <= angle_bd) { + return {int3{ia, ib, ic}, int3{ia, ic, id}}; + } else { + return {int3{ia, ib, id}, int3{ib, ic, id}}; + } +} + + +void face_from_dual_vertices( + const std::unordered_map& hash_table, + const std::vector& voxels, + const std::vector& dual_vertices, + const std::vector& intersected, + std::vector& face_indices +) { + for (int i = 0; i < dual_vertices.size(); ++i) { + int3 coord = voxels[i]; + bool3 is_intersected = intersected[i]; + + // Check existence of neighboring 6 voxels + size_t neigh_indices[6] = { + get_or_default(hash_table, VoxelCoord{coord.x + 1, coord.y, coord.z}, kInvalidIndex), + get_or_default(hash_table, VoxelCoord{coord.x, coord.y + 1, coord.z}, kInvalidIndex), + get_or_default(hash_table, VoxelCoord{coord.x + 1, coord.y + 1, coord.z}, kInvalidIndex), + get_or_default(hash_table, VoxelCoord{coord.x, coord.y, coord.z + 1}, kInvalidIndex), + get_or_default(hash_table, VoxelCoord{coord.x + 1, coord.y, coord.z + 1}, kInvalidIndex), + get_or_default(hash_table, VoxelCoord{coord.x, coord.y + 1, coord.z + 1}, kInvalidIndex) + }; + + // xy-plane + if (is_intersected[2] && neigh_indices[0] != kInvalidIndex && neigh_indices[1] != kInvalidIndex && neigh_indices[2] != kInvalidIndex) { + int4 quad_indices{i, neigh_indices[0], neigh_indices[2], neigh_indices[1]}; + auto tri_indices = quad_to_2tri(dual_vertices, quad_indices); + face_indices.insert(face_indices.end(), tri_indices.begin(), tri_indices.end()); + } + // yz-plane + if (is_intersected[0] && neigh_indices[1] != kInvalidIndex && neigh_indices[3] != kInvalidIndex && neigh_indices[5] != kInvalidIndex) { + int4 quad_indices{i, neigh_indices[1], neigh_indices[5], neigh_indices[3]}; + auto tri_indices = quad_to_2tri(dual_vertices, quad_indices); + face_indices.insert(face_indices.end(), tri_indices.begin(), tri_indices.end()); + } + // xz-plane + if (is_intersected[1] && neigh_indices[0] != kInvalidIndex && neigh_indices[3] != kInvalidIndex && neigh_indices[4] != kInvalidIndex) { + int4 quad_indices{i, neigh_indices[0], neigh_indices[4], neigh_indices[3]}; + auto tri_indices = quad_to_2tri(dual_vertices, quad_indices); + face_indices.insert(face_indices.end(), tri_indices.begin(), tri_indices.end()); + } + } +} + +/** + * Extract flexible dual grid from a triangle mesh. + * + * @param vertices: Tensor of shape (N, 3) containing vertex positions. + * @param faces: Tensor of shape (M, 3) containing triangle vertex indices. + * @param voxel_size: Tensor of shape (3,) containing the voxel size in each dimension. + * @param grid_range: Tensor of shape (2, 3) containing the minimum and maximum coordinates of the grid range. + * @param face_weight: Weight for the face edges in the QEF computation. + * @param boundary_weight: Weight for the boundary edges in the QEF computation. + * @param regularization_weight: Regularization factor to apply to the QEF matrices. + * @param timing: Boolean flag to indicate whether to print timing information. + * + * @return a tuple ((x, y, z), vertices, intersected, faces) containing the remeshed vertices and the corresponding voxel grid. + */ +std::tuple mesh_to_flexible_dual_grid_cpu( + const torch::Tensor& vertices, + const torch::Tensor& faces, + const torch::Tensor& voxel_size, + const torch::Tensor& grid_range, + float face_weight, + float boundary_weight, + float regularization_weight, + bool timing +) { + const int F = faces.size(0); + const float* v_ptr = vertices.data_ptr(); + const int* f_ptr = faces.data_ptr(); + const float* voxel_size_ptr = voxel_size.data_ptr(); + const int* grid_range_ptr = grid_range.data_ptr(); + clock_t start, end; + std::unordered_map hash_table; + std::vector voxels; // Voxel coordinates + std::vector means; // Mean vertex positions for each voxel + std::vector cnt; // Number of intersections for each voxel + std::vector intersected; // Indicate whether edges of voxels intersect with surface + std::vector qefs; // QEF matrices for each voxel + + // Convert tensors to Eigen types + Eigen::Vector3f e_voxel_size(voxel_size_ptr[0], voxel_size_ptr[1], voxel_size_ptr[2]); + Eigen::Vector3i e_grid_min(grid_range_ptr[0], grid_range_ptr[1], grid_range_ptr[2]); + Eigen::Vector3i e_grid_max(grid_range_ptr[3], grid_range_ptr[4], grid_range_ptr[5]); + + // Intersect QEF computation + start = clock(); + std::vector triangles; + triangles.reserve(F * 3); + for (int f = 0; f < F; ++f) { + for (int v = 0; v < 3; ++v) { + triangles.push_back(Eigen::Vector3f( + v_ptr[f_ptr[f * 3 + v] * 3 + 0], + v_ptr[f_ptr[f * 3 + v] * 3 + 1], + v_ptr[f_ptr[f * 3 + v] * 3 + 2] + )); + } + } + intersect_qef(e_voxel_size, e_grid_min, e_grid_max, triangles, hash_table, voxels, means, cnt, intersected, qefs); + end = clock(); + if (timing) std::cout << "Intersect QEF computation took " << double(end - start) / CLOCKS_PER_SEC << " seconds." << std::endl; + + // Face QEF computation + if (face_weight > 0.0f) { + start = clock(); + face_qef(e_voxel_size, e_grid_min, e_grid_max, triangles, hash_table, qefs); + end = clock(); + if (timing) std::cout << "Face QEF computation took " << double(end - start) / CLOCKS_PER_SEC << " seconds." << std::endl; + } + + // Boundary QEF computation + if (boundary_weight > 0.0f) { + start = clock(); + std::map, int> edge_count; + for (int f = 0; f < F; ++f) { + for (int v0 = 0; v0 < 3; ++v0) { + int e0 = f_ptr[f * 3 + v0]; + int e1 = f_ptr[f * 3 + (v0 + 1) % 3]; + if (e0 > e1) std::swap(e0, e1); + edge_count[std::make_pair(e0, e1)]++; + } + } + std::vector boundries; + for (const auto& e : edge_count) { + if (e.second == 1) { + int v0 = e.first.first; + int v1 = e.first.second; + boundries.push_back(Eigen::Vector3f( + v_ptr[v0 * 3 + 0], + v_ptr[v0 * 3 + 1], + v_ptr[v0 * 3 + 2] + )); + boundries.push_back(Eigen::Vector3f( + v_ptr[v1 * 3 + 0], + v_ptr[v1 * 3 + 1], + v_ptr[v1 * 3 + 2] + )); + } + } + boundry_qef(e_voxel_size, e_grid_min, e_grid_max, boundries, boundary_weight, hash_table, qefs); + end = clock(); + if (timing) std::cout << "Boundary QEF computation took " << double(end - start) / CLOCKS_PER_SEC << " seconds." << std::endl; + } + + // Solve the QEF system to obtain final dual vertices + start = clock(); + std::vector dual_vertices(voxels.size()); + for (int i = 0; i < voxels.size(); ++i) { + int3 coord = voxels[i]; + Eigen::Matrix4f Q = qefs[i]; + float min_corner[3] = { + coord.x * e_voxel_size.x(), + coord.y * e_voxel_size.y(), + coord.z * e_voxel_size.z() + }; + float max_corner[3] = { + (coord.x + 1) * e_voxel_size.x(), + (coord.y + 1) * e_voxel_size.y(), + (coord.z + 1) * e_voxel_size.z() + }; + + // Add regularization term + if (regularization_weight > 0.0f) { + Eigen::Vector3f p = means[i] / cnt[i]; + + // Construct the QEF matrix for this vertex + Eigen::Matrix4f Qreg = Eigen::Matrix4f::Zero(); + Qreg.topLeftCorner<3,3>() = Eigen::Matrix3f::Identity(); + Qreg.block<3,1>(0,3) = -p; + Qreg.block<1,3>(3,0) = -p.transpose(); + Qreg(3,3) = p.dot(p); + + Q += regularization_weight * cnt[i] * Qreg; // Scale by regularization weight + } + + // Solve unconstrained + Eigen::Matrix3f A = Q.topLeftCorner<3, 3>(); + Eigen::Vector3f b = -Q.block<3, 1>(0, 3); + Eigen::Vector3f v_new = A.colPivHouseholderQr().solve(b); + + if (!( + v_new.x() >= min_corner[0] && v_new.x() <= max_corner[0] && + v_new.y() >= min_corner[1] && v_new.y() <= max_corner[1] && + v_new.z() >= min_corner[2] && v_new.z() <= max_corner[2] + )) { + // Starting enumeration of constraints + float best = std::numeric_limits::infinity(); + + // Solve single-constraint + auto solve_single_constraint = [&](int fixed_axis) { + int ax1 = (fixed_axis + 1) % 3; + int ax2 = (fixed_axis + 2) % 3; + + Eigen::Matrix2f A; + Eigen::Matrix2f B; + Eigen::Vector2f q, b, x; + + A << Q(ax1, ax1), Q(ax1, ax2), + Q(ax2, ax1), Q(ax2, ax2); + B << Q(ax1, fixed_axis), Q(ax1, 3), + Q(ax2, fixed_axis), Q(ax2, 3); + auto Asol = A.colPivHouseholderQr(); + + // if lower bound + q << min_corner[fixed_axis], 1; + b = -B * q; + x = Asol.solve(b); + if ( + x.x() >= min_corner[ax1] && x.x() <= max_corner[ax1] && + x.y() >= min_corner[ax2] && x.y() <= max_corner[ax2] + ) { + Eigen::Vector4f p; + p[fixed_axis] = min_corner[fixed_axis]; + p[ax1] = x.x(); + p[ax2] = x.y(); + p[3] = 1.0f; + float err = p.transpose() * Q * p; + if (err < best) { + best = err; + v_new << p[0], p[1], p[2]; + } + } + + // if upper bound + q << max_corner[fixed_axis], 1; + b = -B * q; + x = Asol.solve(b); + if ( + x.x() >= min_corner[ax1] && x.x() <= max_corner[ax1] && + x.y() >= min_corner[ax2] && x.y() <= max_corner[ax2] + ) { + Eigen::Vector4f p; + p[fixed_axis] = max_corner[fixed_axis]; + p[ax1] = x.x(); + p[ax2] = x.y(); + p[3] = 1.0f; + float err = p.transpose() * Q * p; + if (err < best) { + best = err; + v_new << p[0], p[1], p[2]; + } + } + }; + solve_single_constraint(0); // fix x + solve_single_constraint(1); // fix y + solve_single_constraint(2); // fix z + + // Solve two-constraint + auto solve_two_constraint = [&](int free_axis) { + int ax1 = (free_axis + 1) % 3; + int ax2 = (free_axis + 2) % 3; + + float a, x; + Eigen::Vector3f b, q; + + a = Q(free_axis, free_axis); + b << Q(free_axis, ax1), Q(free_axis, ax2), Q(free_axis, 3); + + // if lower-lower bound + q << min_corner[ax1], min_corner[ax2], 1; + x = -(b.dot(q)) / a; + if (x >= min_corner[free_axis] && x <= max_corner[free_axis]) { + Eigen::Vector4f p; + p[free_axis] = x; + p[ax1] = min_corner[ax1]; + p[ax2] = min_corner[ax2]; + p[3] = 1.0f; + float err = p.transpose() * Q * p; + if (err < best) { + best = err; + v_new << p[0], p[1], p[2]; + } + } + + // if lower-upper bound + q << min_corner[ax1], max_corner[ax2], 1; + x = -(b.dot(q)) / a; + if (x >= min_corner[free_axis] && x <= max_corner[free_axis]) { + Eigen::Vector4f p; + p[free_axis] = x; + p[ax1] = min_corner[ax1]; + p[ax2] = max_corner[ax2]; + p[3] = 1.0f; + float err = p.transpose() * Q * p; + if (err < best) { + best = err; + v_new << p[0], p[1], p[2]; + } + } + + // if upper-lower bound + q << max_corner[ax1], min_corner[ax2], 1; + x = -(b.dot(q)) / a; + if (x >= min_corner[free_axis] && x <= max_corner[free_axis]) { + Eigen::Vector4f p; + p[free_axis] = x; + p[ax1] = max_corner[ax1]; + p[ax2] = min_corner[ax2]; + p[3] = 1.0f; + float err = p.transpose() * Q * p; + if (err < best) { + best = err; + v_new << p[0], p[1], p[2]; + } + } + + // if upper-upper bound + q << max_corner[ax1], max_corner[ax2], 1; + x = -(b.dot(q)) / a; + if (x >= min_corner[free_axis] && x <= max_corner[free_axis]) { + Eigen::Vector4f p; + p[free_axis] = x; + p[ax1] = max_corner[ax1]; + p[ax2] = max_corner[ax2]; + p[3] = 1.0f; + float err = p.transpose() * Q * p; + if (err < best) { + best = err; + v_new << p[0], p[1], p[2]; + } + } + }; + solve_two_constraint(0); // free x + solve_two_constraint(1); // free y + solve_two_constraint(2); // free z + + // Solve three-constraint + for (int x_constraint = 0; x_constraint < 2; ++x_constraint) { + for (int y_constraint = 0; y_constraint < 2; ++y_constraint) { + for (int z_constraint = 0; z_constraint < 2; ++z_constraint) { + Eigen::Vector4f p; + p[0] = x_constraint ? min_corner[0] : max_corner[0]; + p[1] = y_constraint ? min_corner[1] : max_corner[1]; + p[2] = z_constraint ? min_corner[2] : max_corner[2]; + p[3] = 1.0f; + + float err = p.transpose() * Q * p; + if (err < best) { + best = err; + v_new << p[0], p[1], p[2]; + } + } + } + } + } + + // Store the dual vertex and voxel grid coordinates + dual_vertices[i] = float3{v_new.x(), v_new.y(), v_new.z()}; + } + end = clock(); + if (timing) std::cout << "Dual vertices computation took " << double(end - start) / CLOCKS_PER_SEC << " seconds." << std::endl; + + return std::make_tuple( + torch::from_blob(voxels.data(), {int(voxels .size()), 3}, torch::kInt32).clone(), + torch::from_blob(dual_vertices.data(), {int(dual_vertices.size()), 3}, torch::kFloat32).clone(), + torch::from_blob(intersected.data(), {int(intersected.size()), 3}, torch::kBool).clone() + ); +} + diff --git a/o-voxel/src/convert/volumetic_attr.cpp b/o-voxel/src/convert/volumetic_attr.cpp new file mode 100644 index 0000000..64c3d09 --- /dev/null +++ b/o-voxel/src/convert/volumetic_attr.cpp @@ -0,0 +1,872 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "api.h" + + +constexpr size_t kInvalidIndex = std::numeric_limits::max(); + + +static bool is_power_of_two(int n) { + return n > 0 && (n & (n - 1)) == 0; +} + + +template +static inline U lerp(const T& a, const T& b, const T& t, const U& val_a, const U& val_b) { + if (a == b) return val_a; // Avoid divide by zero + T alpha = (t - a) / (b - a); + return (1 - alpha) * val_a + alpha * val_b; +} + + +template +static auto get_or_default(const Map& map, const Key& key, const Default& default_val) -> typename Map::mapped_type { + auto it = map.find(key); + return (it != map.end()) ? it->second : default_val; +} + + +// 3D voxel coordinate +struct VoxelCoord { + int x, y, z; + + int& operator[](int i) { + return (&x)[i]; + } + + bool operator==(const VoxelCoord& other) const { + return x == other.x && y == other.y && z == other.z; + } +}; + +// Hash function for VoxelCoord to use in unordered_map +namespace std { +template <> +struct hash { + size_t operator()(const VoxelCoord& v) const { + const std::size_t p1 = 73856093; + const std::size_t p2 = 19349663; + const std::size_t p3 = 83492791; + return (std::size_t)(v.x) * p1 ^ (std::size_t)(v.y) * p2 ^ (std::size_t)(v.z) * p3; + } +}; +} + + +/** + * Compute the Normal Tangent and Bitangent vectors for a triangle. + * + * @param v0 The first vertex of the triangle. + * @param v1 The second vertex of the triangle. + * @param v2 The third vertex of the triangle. + * @param uv0 The texture coordinates of the first vertex. + * @param uv1 The texture coordinates of the second vertex. + * @param uv2 The texture coordinates of the third vertex. + * + * @return A tuple containing: + * - t The tangent vector. + * - b The bitangent vector. + * - n The normal vector. + * - mip_length The norms of the partial derivatives of the 3D coordinates with respect to the 2D texture coordinates. + */ +static std::tuple compute_TBN( + const Eigen::Vector3f& v0, + const Eigen::Vector3f& v1, + const Eigen::Vector3f& v2, + const Eigen::Vector2f& uv0, + const Eigen::Vector2f& uv1, + const Eigen::Vector2f& uv2 +) { + Eigen::Vector3f e1 = v1 - v0; + Eigen::Vector3f e2 = v2 - v0; + Eigen::Vector2f duv1 = uv1 - uv0; + Eigen::Vector2f duv2 = uv2 - uv0; + Eigen::Vector3f n = e1.cross(e2).normalized(); + + float det = duv1.x() * duv2.y() - duv1.y() * duv2.x(); + if (fabs(det) < 1e-6) { + // Use default + Eigen::Vector3f t(1.0f, 0.0f, 0.0f); + Eigen::Vector3f b(0.0f, 1.0f, 0.0f); + Eigen::Vector2f mip_length(1e6, 1e6); + return std::make_tuple(t, b, n, mip_length); + } + + float invDet = 1.0f / det; + Eigen::Vector3f t = (duv2.y() * e1 - duv1.y() * e2); + Eigen::Vector3f b = (duv1.x() * e2 - duv2.x() * e1); + float t_norm = t.norm(); + float b_norm = b.norm(); + t = t / t_norm; + b = b / b_norm; + Eigen::Vector2f mip_length(invDet * t_norm, invDet * b_norm); + + return std::make_tuple(t, b, n, mip_length); +} + + +/** + * Project a point onto a triangle defined by three vertices. + * + * @param p The point to project. + * @param a The first vertex of the triangle. + * @param b The second vertex of the triangle. + * @param c The third vertex of the triangle. + * @param n The normal of the triangle. + * + * @return The projected point represented as barycentric coordinates (u, v, w) and distance from the plane. + */ +static Eigen::Vector4f project_onto_triangle( + const Eigen::Vector3f& p, + const Eigen::Vector3f& a, + const Eigen::Vector3f& b, + const Eigen::Vector3f& c, + const Eigen::Vector3f& n +) { + float d = (p - a).dot(n); + + Eigen::Vector3f p_proj = p - d * n; + Eigen::Vector3f ab = b - a; + Eigen::Vector3f ac = c - a; + Eigen::Vector3f ap = p_proj - a; + + float d00 = ab.dot(ab); + float d01 = ab.dot(ac); + float d11 = ac.dot(ac); + float d20 = ap.dot(ab); + float d21 = ap.dot(ac); + + float denom = d00 * d11 - d01 * d01; + float v = (d11 * d20 - d01 * d21) / denom; + float w = (d00 * d21 - d01 * d20) / denom; + float u = 1.0f - v - w; + + return Eigen::Vector4f(u, v, w, d); +} + + +static inline int wrap_texcoord(const int& x, const int& W, const int& filter) { + if (filter == 0) { // REPEAT + return (x % W + W) % W; + } else if (filter == 1) { // CLAMP_TO_EDGE + return std::max(0, std::min(x, W - 1)); + } else if (filter == 2) { // MIRROR_REPEAT + int period = 2 * W; + int x_mod = (x % period + period) % period; + return (x_mod < W) ? x_mod : (period - x_mod - 1); + } else { + // Default to repeat + return (x % W + W) % W; + } +} + + +static std::vector> build_mipmaps( + const uint8_t* texture, + const int& H, const int& W, const int& C +) { + if (H != W || !is_power_of_two(H)) { + throw std::invalid_argument("Texture width and height must be equal and a power of two."); + } + std::vector> mipmaps; + const uint8_t* cur_map = texture; + int cur_H = H; + int cur_W = W; + int next_H = cur_H >> 1; + int next_W = cur_W >> 1; + while (next_H > 0 && next_W > 0) { + std::vector next_map(next_H * next_W * C); + for (int y = 0; y < next_H; y++) { + for (int x = 0; x < next_W; x++) { + for (int c = 0; c < C; c++) { + size_t sum = 0; + size_t xx = static_cast(x) << 1; + size_t yy = static_cast(y) << 1; + sum += cur_map[yy * static_cast(cur_W) * C + xx * C + c]; + sum += cur_map[(yy + 1) * static_cast(cur_W) * C + xx * C + c]; + sum += cur_map[yy * static_cast(cur_W) * C + (xx + 1) * C + c]; + sum += cur_map[(yy + 1) * static_cast(cur_W) * C + (xx + 1) * C + c]; + next_map[y * next_W * C + x * C + c] = static_cast(sum / 4); + } + } + } + mipmaps.push_back(std::move(next_map)); + cur_map = mipmaps.back().data(); + cur_H = next_H; + cur_W = next_W; + next_H = cur_H >> 1; + next_W = cur_W >> 1; + } + return mipmaps; +} + + +static void sample_texture( + const uint8_t* texture, + const int& H, const int& W, const int& C, + const float& u, const float& v, + const int& filter, const int& wrap, + float* color +) { + float x = u * W; + float y = (1 - v) * H; + if (filter == 0) { // NEAREST + int x_int = floorf(x); + int y_int = floorf(y); + x_int = wrap_texcoord(x_int, W, wrap); + y_int = wrap_texcoord(y_int, H, wrap); + for (int c = 0; c < C; c++) { + color[c] = texture[y_int * W * C + x_int * C + c] / 255.0f; + } + } + else { // LINEAR + int x_low = floorf(x - 0.5); + int x_high = x_low + 1; + int y_low = floorf(y - 0.5); + int y_high = y_low + 1; + float w_x = x - x_low - 0.5; + float w_y = y - y_low - 0.5; + x_low = wrap_texcoord(x_low, W, wrap); + x_high = wrap_texcoord(x_high, W, wrap); + y_low = wrap_texcoord(y_low, H, wrap); + y_high = wrap_texcoord(y_high, H, wrap); + for (int c = 0; c < C; c++) { + color[c] = (1 - w_x) * (1 - w_y) * texture[y_low * W * C + x_low * C + c] + + w_x * (1 - w_y) * texture[y_low * W * C + x_high * C + c] + + (1 - w_x) * w_y * texture[y_high * W * C + x_low * C + c] + + w_x * w_y * texture[y_high * W * C + x_high * C + c]; + color[c] /= 255.0f; + } + } +} + + +static void sample_texture_mipmap( + const uint8_t* texture, + const int& H, const int& W, const int& C, + const std::vector>& mipmaps, + const float& u, const float& v, const float& mip_length, const float& mipLevelOffset, + const int& filter, const int& wrap, + float* color +) { + if (filter == 0) { // NEAREST + sample_texture(texture, H, W, C, u, v, filter, wrap, color); + } + else { // LINEAR + float mip_level = std::log2(mip_length * H) + mipLevelOffset; + if (!std::isfinite(mip_level) || mip_level <= 0 || mipmaps.empty()) { + sample_texture(texture, H, W, C, u, v, filter, wrap, color); + } + else if (mip_level >= mipmaps.size()) { + sample_texture(mipmaps[mipmaps.size() - 1].data(), H >> mipmaps.size(), W >> mipmaps.size(), C, u, v, filter, wrap, color); + } + else { + int lower_mip_level = std::floor(mip_level); + int upper_mip_level = lower_mip_level + 1; + float mip_frac = mip_level - lower_mip_level; + const uint8_t* lower_mip_ptr = lower_mip_level == 0 ? texture : mipmaps[lower_mip_level - 1].data(); + const uint8_t* upper_mip_ptr = mipmaps[upper_mip_level - 1].data(); + int lower_mip_H = H >> lower_mip_level; + int lower_mip_W = W >> lower_mip_level; + int upper_mip_H = H >> upper_mip_level; + int upper_mip_W = W >> upper_mip_level; + std::vector lower_mip_sample(C); + std::vector upper_mip_sample(C); + sample_texture(lower_mip_ptr, lower_mip_H, lower_mip_W, C, u, v, filter, wrap, lower_mip_sample.data()); + sample_texture(upper_mip_ptr, upper_mip_H, upper_mip_W, C, u, v, filter, wrap, upper_mip_sample.data()); + for (int c = 0; c < C; c++) { + color[c] = (1 - mip_frac) * lower_mip_sample[c] + mip_frac * upper_mip_sample[c]; + } + } + } +} + + +static std::tuple, std::vector, std::vector, std::vector, std::vector, std::vector, std::vector> +voxelize_trimesh_pbr_impl( + const float* voxel_size, + const int* grid_range, + const int N_tri, + const float* vertices, + const float* normals, + const float* uvs, + const int* materialIds, + const std::vector baseColorFactor, + const std::vector baseColorTexture, + const std::vector H_bcTex, const std::vector W_bcTex, + const std::vector baseColorTextureFilter, + const std::vector baseColorTextureWrap, + const std::vector metallicFactor, + const std::vector metallicTexture, + const std::vector H_mtlTex, const std::vector W_mtlTex, + const std::vector metallicTextureFilter, + const std::vector metallicTextureWrap, + const std::vector roughnessFactor, + const std::vector roughnessTexture, + const std::vector H_rghTex, const std::vector W_rghTex, + const std::vector roughnessTextureFilter, + const std::vector roughnessTextureWrap, + const std::vector emissiveFactor, + const std::vector emissiveTexture, + const std::vector H_emTex, const std::vector W_emTex, + const std::vector emissiveTextureFilter, + const std::vector emissiveTextureWrap, + const std::vector alphaMode, + const std::vector alphaCutoff, + const std::vector alphaFactor, + const std::vector alphaTexture, + const std::vector H_aTex, const std::vector W_aTex, + const std::vector alphaTextureFilter, + const std::vector alphaTextureWrap, + const std::vector normalTexture, + const std::vector H_nTex, const std::vector W_nTex, + const std::vector normalTextureFilter, + const std::vector normalTextureWrap, + const float mipLevelOffset, + const bool timing +) { + clock_t start, end; + + // Common variables used in the voxelization process + Eigen::Vector3f delta_p(voxel_size[0], voxel_size[1], voxel_size[2]); + Eigen::Vector3i grid_min(grid_range[0], grid_range[1], grid_range[2]); + Eigen::Vector3i grid_max(grid_range[3], grid_range[4], grid_range[5]); + + // Construct Mipmaps + start = clock(); + std::vector>> baseColorMipmaps(baseColorTexture.size()); + std::vector>> metallicMipmaps(metallicTexture.size()); + std::vector>> roughnessMipmaps(roughnessTexture.size()); + std::vector>> emissiveMipmaps(emissiveTexture.size()); + std::vector>> alphaMipmaps(alphaTexture.size()); + std::vector>> normalMipmaps(normalTexture.size()); + for (size_t i = 0; i < baseColorTexture.size(); i++) { + if (baseColorTexture[i] != nullptr && baseColorTextureFilter[i] != 0) { + baseColorMipmaps[i] = build_mipmaps(baseColorTexture[i], H_bcTex[i], W_bcTex[i], 3); + } + } + for (size_t i = 0; i < metallicTexture.size(); i++) { + if (metallicTexture[i] != nullptr && metallicTextureFilter[i] != 0) { + metallicMipmaps[i] = build_mipmaps(metallicTexture[i], H_mtlTex[i], W_mtlTex[i], 1); + } + } + for (size_t i = 0; i < roughnessTexture.size(); i++) { + if (roughnessTexture[i] != nullptr && roughnessTextureFilter[i] != 0) { + roughnessMipmaps[i] = build_mipmaps(roughnessTexture[i], H_rghTex[i], W_rghTex[i], 1); + } + } + for (size_t i = 0; i < emissiveTexture.size(); i++) { + if (emissiveTexture[i] != nullptr && emissiveTextureFilter[i] != 0) { + emissiveMipmaps[i] = build_mipmaps(emissiveTexture[i], H_emTex[i], W_emTex[i], 3); + } + } + for (size_t i = 0; i < alphaTexture.size(); i++) { + if (alphaTexture[i] != nullptr && alphaTextureFilter[i] != 0) { + alphaMipmaps[i] = build_mipmaps(alphaTexture[i], H_aTex[i], W_aTex[i], 1); + } + } + for (size_t i = 0; i < normalTexture.size(); i++) { + if (normalTexture[i] != nullptr && normalTextureFilter[i] != 0) { + normalMipmaps[i] = build_mipmaps(normalTexture[i], H_nTex[i], W_nTex[i], 3); + } + } + end = clock(); + if (timing) std::cout << "Mipmaps construction took " << double(end - start) / CLOCKS_PER_SEC << " seconds." << std::endl; + + // Buffers + std::unordered_map hash_table; + std::vector coords; + std::vector buf_weights; + std::vector buf_baseColors; + std::vector buf_metallics; + std::vector buf_roughnesses; + std::vector buf_emissives; + std::vector buf_alphas; + std::vector buf_normals; + + // Enumerate all triangles + start = clock(); + for (size_t tid = 0; tid < N_tri; tid++) { + // COMPUTE COMMON TRIANGLE PROPERTIES + // Move vertices to origin using bbox + size_t ptr = tid * 9; + Eigen::Vector3f v0(vertices[ptr], vertices[ptr + 1], vertices[ptr + 2]); + Eigen::Vector3f v1(vertices[ptr + 3], vertices[ptr + 4], vertices[ptr + 5]); + Eigen::Vector3f v2(vertices[ptr + 6], vertices[ptr + 7], vertices[ptr + 8]); + // Normals + Eigen::Vector3f n0(normals[ptr], normals[ptr + 1], normals[ptr + 2]); + Eigen::Vector3f n1(normals[ptr + 3], normals[ptr + 4], normals[ptr + 5]); + Eigen::Vector3f n2(normals[ptr + 6], normals[ptr + 7], normals[ptr + 8]); + // UV vectors + ptr = tid * 6; + Eigen::Vector2f uv0(uvs[ptr], uvs[ptr + 1]); + Eigen::Vector2f uv1(uvs[ptr + 2], uvs[ptr + 3]); + Eigen::Vector2f uv2(uvs[ptr + 4], uvs[ptr + 5]); + // TBN + auto tbn = compute_TBN(v0, v1, v2, uv0, uv1, uv2); + Eigen::Vector3f t = std::get<0>(tbn); + Eigen::Vector3f b = std::get<1>(tbn); + Eigen::Vector3f n = std::get<2>(tbn); + Eigen::Vector2f v_mip_length = std::get<3>(tbn); + float mip_length = delta_p.maxCoeff() / std::sqrt(v_mip_length.x() * v_mip_length.y()); + // Material ID + int mid = materialIds[tid]; + + // Find intersected voxel for each triangle + std::unordered_set intersected_voxels; + // Scan-line algorithm to find intersections with the voxel grid from three directions + /* + t0 + | \ + | t1 + | / + t2 + */ + auto scan_line_fill = [&] (const int ax2) { + int ax0 = (ax2 + 1) % 3; + int ax1 = (ax2 + 2) % 3; + + // Canonical question + std::array t = { + Eigen::Vector3d(v0[ax0], v0[ax1], v0[ax2]), + Eigen::Vector3d(v1[ax0], v1[ax1], v1[ax2]), + Eigen::Vector3d(v2[ax0], v2[ax1], v2[ax2]) + }; + std::sort(t.begin(), t.end(), [](const Eigen::Vector3d& a, const Eigen::Vector3d& b) { return a.y() < b.y(); }); + + // Scan-line algorithm + int start = std::clamp(int(t[0].y() / voxel_size[ax1]), grid_min[ax1], grid_max[ax1] - 1); + int mid = std::clamp(int(t[1].y() / voxel_size[ax1]), grid_min[ax1], grid_max[ax1] - 1); + int end = std::clamp(int(t[2].y() / voxel_size[ax1]), grid_min[ax1], grid_max[ax1] - 1); + + auto scan_line_half = [&] (const int row_start, const int row_end, const Eigen::Vector3d t0, const Eigen::Vector3d t1, const Eigen::Vector3d t2) { + /* + t0 + | \ + t3-t4 + | \ + t1---t2 + */ + for (int y_idx = row_start; y_idx < row_end; ++y_idx) { + double y = (y_idx + 1) * voxel_size[ax1]; + Eigen::Vector2d t3 = lerp(t0.y(), t1.y(), y, Eigen::Vector2d(t0.x(), t0.z()), Eigen::Vector2d(t1.x(), t1.z())); + Eigen::Vector2d t4 = lerp(t0.y(), t2.y(), y, Eigen::Vector2d(t0.x(), t0.z()), Eigen::Vector2d(t2.x(), t2.z())); + if (t3.x() > t4.x()) std::swap(t3, t4); + int line_start = std::clamp(int(t3.x() / voxel_size[ax0]), grid_min[ax0], grid_max[ax0] - 1); + int line_end = std::clamp(int(t4.x() / voxel_size[ax0]), grid_min[ax0], grid_max[ax0] - 1); + for (int x_idx = line_start; x_idx < line_end; ++x_idx) { + double x = (x_idx + 1) * voxel_size[ax0]; + double z = lerp(t3.x(), t4.x(), x, t3.y(), t4.y()); + int z_idx = int(z / voxel_size[ax2]); + if (z_idx >= grid_min[ax2] && z_idx < grid_max[ax2]) { + // For 4-connected voxels + for (int dx = 0; dx < 2; ++dx) { + for (int dy = 0; dy < 2; ++dy) { + VoxelCoord coord; + coord[ax0] = x_idx + dx; coord[ax1] = y_idx + dy; coord[ax2] = z_idx; + intersected_voxels.insert(coord); + } + } + } + } + } + }; + scan_line_half(start, mid, t[0], t[1], t[2]); + scan_line_half(mid, end, t[2], t[1], t[0]); + }; + scan_line_fill(0); + scan_line_fill(1); + scan_line_fill(2); + + // For all intersected voxels, ample texture and write to voxel grid + for (auto voxel : intersected_voxels) { + int x = voxel.x; + int y = voxel.y; + int z = voxel.z; + + // Compute barycentric coordinates and weight + Eigen::Vector4f barycentric = project_onto_triangle( + Eigen::Vector3f((x + 0.5f) * delta_p.x(), (y + 0.5f) * delta_p.y(), (z + 0.5f) * delta_p.z()), + v0, v1, v2, n + ); + Eigen::Vector2f uv = { + barycentric.x() * uv0.x() + barycentric.y() * uv1.x() + barycentric.z() * uv2.x(), + barycentric.x() * uv0.y() + barycentric.y() * uv1.y() + barycentric.z() * uv2.y() + }; + Eigen::Vector3f int_n = { + barycentric.x() * n0.x() + barycentric.y() * n1.x() + barycentric.z() * n2.x(), + barycentric.x() * n0.y() + barycentric.y() * n1.y() + barycentric.z() * n2.y(), + barycentric.x() * n0.z() + barycentric.y() * n1.z() + barycentric.z() * n2.z() + }; + float weight = 1 - barycentric.w(); + + /// base color + float baseColor[3] = {1, 1, 1}; + if (baseColorTexture[mid]) { + sample_texture_mipmap( + baseColorTexture[mid], + H_bcTex[mid], W_bcTex[mid], 3, + baseColorMipmaps[mid], + uv.x(), uv.y(), mip_length, mipLevelOffset, + baseColorTextureFilter[mid], baseColorTextureWrap[mid], + baseColor + ); + } + baseColor[0] *= baseColorFactor[mid][0]; + baseColor[1] *= baseColorFactor[mid][1]; + baseColor[2] *= baseColorFactor[mid][2]; + + /// metallic + float metallic = 1.0f; + if (metallicTexture[mid]) { + sample_texture_mipmap( + metallicTexture[mid], + H_mtlTex[mid], W_mtlTex[mid], 1, + metallicMipmaps[mid], + uv.x(), uv.y(), mip_length, mipLevelOffset, + metallicTextureFilter[mid], metallicTextureWrap[mid], + &metallic + ); + } + metallic *= metallicFactor[mid]; + + /// roughness + float roughness = 1.0f; + if (roughnessTexture[mid]) { + sample_texture_mipmap( + roughnessTexture[mid], + H_rghTex[mid], W_rghTex[mid], 1, + roughnessMipmaps[mid], + uv.x(), uv.y(), mip_length, mipLevelOffset, + roughnessTextureFilter[mid], roughnessTextureWrap[mid], + &roughness + ); + } + roughness *= roughnessFactor[mid]; + + /// emissive + float emissive[3] = {1, 1, 1}; + if (emissiveTexture[mid]) { + sample_texture_mipmap( + emissiveTexture[mid], + H_emTex[mid], W_emTex[mid], 3, + roughnessMipmaps[mid], + uv.x(), uv.y(), mip_length, mipLevelOffset, + emissiveTextureFilter[mid], emissiveTextureWrap[mid], + emissive + ); + } + emissive[0] *= emissiveFactor[mid][0]; + emissive[1] *= emissiveFactor[mid][1]; + emissive[2] *= emissiveFactor[mid][2]; + + /// alpha + float alpha = 1.0f; + if (alphaMode[mid] != 0) { + if (alphaTexture[mid]) { + sample_texture_mipmap( + alphaTexture[mid], + H_aTex[mid], W_aTex[mid], 1, + alphaMipmaps[mid], + uv.x(), uv.y(), mip_length, mipLevelOffset, + alphaTextureFilter[mid], alphaTextureWrap[mid], + &alpha + ); + } + alpha *= alphaFactor[mid]; + if (alphaMode[mid] == 1) { // MASK + alpha = alpha < alphaCutoff[mid] ? 0.0f : 1.0f; + } + } + + /// normal + float normal[3] = {int_n.x(), int_n.y(), int_n.z()}; + if (normalTexture[mid]) { + sample_texture_mipmap( + normalTexture[mid], + H_nTex[mid], W_nTex[mid], 3, + normalMipmaps[mid], + uv.x(), uv.y(), mip_length, mipLevelOffset, + normalTextureFilter[mid], normalTextureWrap[mid], + normal + ); + normal[0] = normal[0] * 2 - 1; + normal[1] = normal[1] * 2 - 1; + normal[2] = normal[2] * 2 - 1; + Eigen::Vector3f _n = (normal[0] * t + normal[1] * b + normal[2] * int_n).normalized(); + normal[0] = _n.x(); + normal[1] = _n.y(); + normal[2] = _n.z(); + } + + // Write to voxel grid + auto coord = VoxelCoord{x-grid_min.x(), y-grid_min.y(), z-grid_min.z()}; + auto kv = hash_table.find(coord); + if (kv == hash_table.end()) { + hash_table[coord] = coords.size(); + coords.push_back({coord.x, coord.y, coord.z}); + buf_weights.push_back(weight); + buf_baseColors.push_back(Eigen::Vector3f(baseColor[0], baseColor[1], baseColor[2]) * weight); + buf_metallics.push_back(metallic * weight); + buf_roughnesses.push_back(roughness * weight); + buf_emissives.push_back(Eigen::Vector3f(emissive[0], emissive[1], emissive[2]) * weight); + buf_alphas.push_back(alpha * weight); + buf_normals.push_back(Eigen::Vector3f(normal[0], normal[1], normal[2]) * weight); + } + else { + auto i = kv->second; + buf_weights[i] += weight; + buf_baseColors[i] += Eigen::Vector3f(baseColor[0], baseColor[1], baseColor[2]) * weight; + buf_metallics[i] += metallic * weight; + buf_roughnesses[i] += roughness * weight; + buf_emissives[i] += Eigen::Vector3f(emissive[0], emissive[1], emissive[2]) * weight; + buf_alphas[i] += alpha * weight; + buf_normals[i] += Eigen::Vector3f(normal[0], normal[1], normal[2]) * weight; + } + } + } + end = clock(); + if (timing) std::cout << "Voxelization took " << double(end - start) / CLOCKS_PER_SEC << " seconds." << std::endl; + + // Normalize buffers + start = clock(); + std::vector out_coord(coords.size() * 3); + std::vector out_baseColor(coords.size() * 3); + std::vector out_metallic(coords.size()); + std::vector out_roughness(coords.size()); + std::vector out_emissive(coords.size() * 3); + std::vector out_alpha(coords.size()); + std::vector out_normal(coords.size() * 3); + for (int i = 0; i < coords.size(); i++) { + out_coord[i * 3 + 0] = coords[i].x; + out_coord[i * 3 + 1] = coords[i].y; + out_coord[i * 3 + 2] = coords[i].z; + out_baseColor[i * 3 + 0] = buf_baseColors[i].x() / buf_weights[i]; + out_baseColor[i * 3 + 1] = buf_baseColors[i].y() / buf_weights[i]; + out_baseColor[i * 3 + 2] = buf_baseColors[i].z() / buf_weights[i]; + out_metallic[i] = buf_metallics[i] / buf_weights[i]; + out_roughness[i] = buf_roughnesses[i] / buf_weights[i]; + out_emissive[i * 3 + 0] = buf_emissives[i].x() / buf_weights[i]; + out_emissive[i * 3 + 1] = buf_emissives[i].y() / buf_weights[i]; + out_emissive[i * 3 + 2] = buf_emissives[i].z() / buf_weights[i]; + out_alpha[i] = buf_alphas[i] / buf_weights[i]; + out_normal[i * 3 + 0] = buf_normals[i].x() / buf_weights[i]; + out_normal[i * 3 + 1] = buf_normals[i].y() / buf_weights[i]; + out_normal[i * 3 + 2] = buf_normals[i].z() / buf_weights[i]; + } + end = clock(); + if (timing) std::cout << "Normalization took " << double(end - start) / CLOCKS_PER_SEC << " seconds." << std::endl; + + return std::make_tuple( + std::move(out_coord), + std::move(out_baseColor), + std::move(out_metallic), + std::move(out_roughness), + std::move(out_emissive), + std::move(out_alpha), + std::move(out_normal) + ); +} + + +std::tuple +textured_mesh_to_volumetric_attr_cpu( + const torch::Tensor& voxel_size, + const torch::Tensor& grid_range, + const torch::Tensor& vertices, + const torch::Tensor& normals, + const torch::Tensor& uvs, + const torch::Tensor& materialIds, + const std::vector& baseColorFactor, + const std::vector& baseColorTexture, + const std::vector& baseColorTextureFilter, + const std::vector& baseColorTextureWrap, + const std::vector& metallicFactor, + const std::vector& metallicTexture, + const std::vector& metallicTextureFilter, + const std::vector& metallicTextureWrap, + const std::vector& roughnessFactor, + const std::vector& roughnessTexture, + const std::vector& roughnessTextureFilter, + const std::vector& roughnessTextureWrap, + const std::vector& emissiveFactor, + const std::vector& emissiveTexture, + const std::vector& emissiveTextureFilter, + const std::vector& emissiveTextureWrap, + const std::vector& alphaMode, + const std::vector& alphaCutoff, + const std::vector& alphaFactor, + const std::vector& alphaTexture, + const std::vector& alphaTextureFilter, + const std::vector& alphaTextureWrap, + const std::vector& normalTexture, + const std::vector& normalTextureFilter, + const std::vector& normalTextureWrap, + const float mipLevelOffset, + const bool timing +) { + auto N_mat = baseColorFactor.size(); + int N_tri = vertices.size(0); + + // Get the size of the input tensors + std::vector baseColorFactor_ptrs(N_mat); + std::vector baseColorTexture_ptrs(N_mat); + std::vector H_bcTex(N_mat), W_bcTex(N_mat); + std::vector metallicFactor_vec(N_mat); + std::vector metallicTexture_ptrs(N_mat); + std::vector H_mtlTex(N_mat), W_mtlTex(N_mat); + std::vector roughnessFactor_vec(N_mat); + std::vector roughnessTexture_ptrs(N_mat); + std::vector H_rghTex(N_mat), W_rghTex(N_mat); + std::vector emissiveFactor_ptrs(N_mat); + std::vector emissiveTexture_ptrs(N_mat); + std::vector H_emTex(N_mat), W_emTex(N_mat); + std::vector alphaMode_vec(N_mat); + std::vector alphaCutoff_vec(N_mat); + std::vector alphaFactor_vec(N_mat); + std::vector alphaTexture_ptrs(N_mat); + std::vector H_aTex(N_mat), W_aTex(N_mat); + std::vector normalTexture_ptrs(N_mat); + std::vector H_nTex(N_mat), W_nTex(N_mat); + + for (int i = 0; i < N_mat; ++i) { + baseColorFactor_ptrs[i] = baseColorFactor[i].contiguous().data_ptr(); + if (baseColorTexture[i].numel() > 0) { + baseColorTexture_ptrs[i] = baseColorTexture[i].contiguous().data_ptr(); + H_bcTex[i] = baseColorTexture[i].size(0); + W_bcTex[i] = baseColorTexture[i].size(1); + } + else { + baseColorTexture_ptrs[i] = nullptr; + H_bcTex[i] = 0; + W_bcTex[i] = 0; + } + metallicFactor_vec[i] = metallicFactor[i]; + if (metallicTexture[i].numel() > 0) { + metallicTexture_ptrs[i] = metallicTexture[i].contiguous().data_ptr(); + H_mtlTex[i] = metallicTexture[i].size(0); + W_mtlTex[i] = metallicTexture[i].size(1); + } + else { + metallicTexture_ptrs[i] = nullptr; + H_mtlTex[i] = 0; + W_mtlTex[i] = 0; + } + roughnessFactor_vec[i] = roughnessFactor[i]; + if (roughnessTexture[i].numel() > 0) { + roughnessTexture_ptrs[i] = roughnessTexture[i].contiguous().data_ptr(); + H_rghTex[i] = roughnessTexture[i].size(0); + W_rghTex[i] = roughnessTexture[i].size(1); + } + else { + roughnessTexture_ptrs[i] = nullptr; + H_rghTex[i] = 0; + W_rghTex[i] = 0; + } + emissiveFactor_ptrs[i] = emissiveFactor[i].contiguous().data_ptr(); + if (emissiveTexture[i].numel() > 0) { + emissiveTexture_ptrs[i] = emissiveTexture[i].contiguous().data_ptr(); + H_emTex[i] = emissiveTexture[i].size(0); + W_emTex[i] = emissiveTexture[i].size(1); + } + else { + emissiveTexture_ptrs[i] = nullptr; + H_emTex[i] = 0; + W_emTex[i] = 0; + } + alphaMode_vec[i] = alphaMode[i]; + alphaCutoff_vec[i] = alphaCutoff[i]; + alphaFactor_vec[i] = alphaFactor[i]; + if (alphaTexture[i].numel() > 0) { + alphaTexture_ptrs[i] = alphaTexture[i].contiguous().data_ptr(); + H_aTex[i] = alphaTexture[i].size(0); + W_aTex[i] = alphaTexture[i].size(1); + } + else { + alphaTexture_ptrs[i] = nullptr; + H_aTex[i] = 0; + W_aTex[i] = 0; + } + if (normalTexture[i].numel() > 0) { + normalTexture_ptrs[i] = normalTexture[i].contiguous().data_ptr(); + H_nTex[i] = normalTexture[i].size(0); + W_nTex[i] = normalTexture[i].size(1); + } + else { + normalTexture_ptrs[i] = nullptr; + H_nTex[i] = 0; + W_nTex[i] = 0; + } + } + + auto outputs = voxelize_trimesh_pbr_impl( + voxel_size.contiguous().data_ptr(), + grid_range.contiguous().data_ptr(), + N_tri, + vertices.contiguous().data_ptr(), + normals.contiguous().data_ptr(), + uvs.contiguous().data_ptr(), + materialIds.contiguous().data_ptr(), + baseColorFactor_ptrs, + baseColorTexture_ptrs, + H_bcTex, W_bcTex, + baseColorTextureFilter, baseColorTextureWrap, + metallicFactor_vec, + metallicTexture_ptrs, + H_mtlTex, W_mtlTex, + metallicTextureFilter, metallicTextureWrap, + roughnessFactor_vec, + roughnessTexture_ptrs, + H_rghTex, W_rghTex, + roughnessTextureFilter, roughnessTextureWrap, + emissiveFactor_ptrs, + emissiveTexture_ptrs, + H_emTex, W_emTex, + emissiveTextureFilter, emissiveTextureWrap, + alphaMode_vec, + alphaCutoff_vec, + alphaFactor_vec, + alphaTexture_ptrs, + H_aTex, W_aTex, + alphaTextureFilter, alphaTextureWrap, + normalTexture_ptrs, + H_nTex, W_nTex, + normalTextureFilter, normalTextureWrap, + mipLevelOffset, + timing + ); + + std::vector coords_vec = std::get<0>(outputs); + std::vector baseColors_vec = std::get<1>(outputs); + std::vector metallics_vec = std::get<2>(outputs); + std::vector roughnesses_vec = std::get<3>(outputs); + std::vector emissives_vec = std::get<4>(outputs); + std::vector alphas_vec = std::get<5>(outputs); + std::vector normals_vec = std::get<6>(outputs); + + // Create output tensors + auto out_coords = torch::from_blob(coords_vec.data(), {static_cast(coords_vec.size() / 3), 3}, torch::kInt32).clone(); + auto out_baseColors = torch::from_blob(baseColors_vec.data(), {static_cast(baseColors_vec.size() / 3), 3}, torch::kFloat32).clone(); + auto out_metallics = torch::from_blob(metallics_vec.data(), {static_cast(metallics_vec.size())}, torch::kFloat32).clone(); + auto out_roughnesses = torch::from_blob(roughnesses_vec.data(), {static_cast(roughnesses_vec.size())}, torch::kFloat32).clone(); + auto out_emissives = torch::from_blob(emissives_vec.data(), {static_cast(emissives_vec.size() / 3), 3}, torch::kFloat32).clone(); + auto out_alphas = torch::from_blob(alphas_vec.data(), {static_cast(alphas_vec.size())}, torch::kFloat32).clone(); + auto out_normals = torch::from_blob(normals_vec.data(), {static_cast(normals_vec.size() / 3), 3}, torch::kFloat32).clone(); + + return std::make_tuple( + out_coords, + out_baseColors, + out_metallics, + out_roughnesses, + out_emissives, + out_alphas, + out_normals + ); +} + diff --git a/o-voxel/src/ext.cpp b/o-voxel/src/ext.cpp new file mode 100644 index 0000000..e2ac946 --- /dev/null +++ b/o-voxel/src/ext.cpp @@ -0,0 +1,37 @@ +#include +#include "hash/api.h" +#include "convert/api.h" +#include "io/api.h" +#include "serialize/api.h" +#include "rasterize/api.h" + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + // Hash functions + m.def("hashmap_insert_cuda", &hashmap_insert_cuda); + m.def("hashmap_lookup_cuda", &hashmap_lookup_cuda); + m.def("hashmap_insert_3d_cuda", &hashmap_insert_3d_cuda); + m.def("hashmap_lookup_3d_cuda", &hashmap_lookup_3d_cuda); + m.def("hashmap_insert_3d_idx_as_val_cuda", &hashmap_insert_3d_idx_as_val_cuda); + // Convert functions + m.def("mesh_to_flexible_dual_grid_cpu", &mesh_to_flexible_dual_grid_cpu, py::call_guard()); + m.def("textured_mesh_to_volumetric_attr_cpu", &textured_mesh_to_volumetric_attr_cpu, py::call_guard()); + // Serialization functions + m.def("z_order_encode_cpu", &z_order_encode_cpu, py::call_guard()); + m.def("z_order_decode_cpu", &z_order_decode_cpu, py::call_guard()); + m.def("hilbert_encode_cpu", &hilbert_encode_cpu, py::call_guard()); + m.def("hilbert_decode_cpu", &hilbert_decode_cpu, py::call_guard()); + m.def("z_order_encode_cuda", &z_order_encode_cuda, py::call_guard()); + m.def("z_order_decode_cuda", &z_order_decode_cuda, py::call_guard()); + m.def("hilbert_encode_cuda", &hilbert_encode_cuda, py::call_guard()); + m.def("hilbert_decode_cuda", &hilbert_decode_cuda, py::call_guard()); + // IO functions + m.def("encode_sparse_voxel_octree_cpu", &encode_sparse_voxel_octree_cpu, py::call_guard()); + m.def("decode_sparse_voxel_octree_cpu", &decode_sparse_voxel_octree_cpu, py::call_guard()); + m.def("encode_sparse_voxel_octree_attr_parent_cpu", &encode_sparse_voxel_octree_attr_parent_cpu, py::call_guard()); + m.def("decode_sparse_voxel_octree_attr_parent_cpu", &decode_sparse_voxel_octree_attr_parent_cpu, py::call_guard()); + m.def("encode_sparse_voxel_octree_attr_neighbor_cpu", &encode_sparse_voxel_octree_attr_neighbor_cpu, py::call_guard()); + m.def("decode_sparse_voxel_octree_attr_neighbor_cpu", &decode_sparse_voxel_octree_attr_neighbor_cpu, py::call_guard()); + // Rasterization functions + m.def("rasterize_voxels_cuda", &rasterize_voxels_cuda); +} \ No newline at end of file diff --git a/o-voxel/src/hash/api.h b/o-voxel/src/hash/api.h new file mode 100644 index 0000000..7693667 --- /dev/null +++ b/o-voxel/src/hash/api.h @@ -0,0 +1,111 @@ +/* + * Hashmap + * + * Copyright (C) 2025, Jianfeng XIANG + * All rights reserved. + * + * Licensed under The MIT License [see LICENSE for details] + * + * Written by Jianfeng XIANG + */ + +#pragma once +#include + + +#define BLOCK_SIZE 256 + + +/** + * Insert keys into the hashmap + * + * @param hashmap_keys [N] uint32/uint64 tensor containing the hashmap keys + * @param hashmap_values [N] uint32/uint64 tensor containing the hashmap values + * @param keys [M] uint32/uint64 tensor containing the keys to be inserted + * @param values [M] uint32/uint64 tensor containing the values to be inserted + */ +void hashmap_insert_cuda( + torch::Tensor& hashmap_keys, + torch::Tensor& hashmap_values, + const torch::Tensor& keys, + const torch::Tensor& values +); + + +/** + * Lookup keys in the hashmap + * + * @param hashmap_keys [N] uint32/uint64 tensor containing the hashmap keys + * @param hashmap_values [N] uint32/uint64 tensor containing the hashmap values + * @param keys [M] uint32/uint64 tensor containing the keys to be looked up + * @return [M] uint32/uint64 tensor containing the values of the keys + */ +torch::Tensor hashmap_lookup_cuda( + const torch::Tensor& hashmap_keys, + const torch::Tensor& hashmap_values, + const torch::Tensor& keys +); + + +/** + * Insert 3D coordinates into the hashmap + * + * @param hashmap_keys [N] uint32/uint64 tensor containing the hashmap keys + * @param hashmap_values [N] uint32/uint64 tensor containing the hashmap values + * @param coords [M, 4] int32 tensor containing the keys to be inserted + * @param values [M] uint32/uint64 tensor containing the values to be inserted + * @param W the number of width dimensions + * @param H the number of height dimensions + * @param D the number of depth dimensions + */ +void hashmap_insert_3d_cuda( + torch::Tensor& hashmap_keys, + torch::Tensor& hashmap_values, + const torch::Tensor& coords, + const torch::Tensor& values, + int W, + int H, + int D +); + + +/** + * Lookup 3D coordinates in the hashmap + * + * @param hashmap_keys [N] uint32/uint64 tensor containing the hashmap keys + * @param hashmap_values [N] uint32/uint64 tensor containing the hashmap values + * @param coords [M, 4] int32 tensor containing the keys to be looked up + * @param W the number of width dimensions + * @param H the number of height dimensions + * @param D the number of depth dimensions + * + * @return [M] uint32/uint64 tensor containing the values of the keys + */ +torch::Tensor hashmap_lookup_3d_cuda( + const torch::Tensor& hashmap_keys, + const torch::Tensor& hashmap_values, + const torch::Tensor& coords, + int W, + int H, + int D +); + + +/** + * Insert 3D coordinates into the hashmap using index as value + * + * @param hashmap_keys [N] uint32/uint64 tensor containing the hashmap keys + * @param hashmap_values [N] uint32/uint64 tensor containing the hashmap values + * @param coords [M, 4] int32 tensor containing the keys to be inserted + * @param W the number of width dimensions + * @param H the number of height dimensions + * @param D the number of depth dimensions + */ +void hashmap_insert_3d_idx_as_val_cuda( + torch::Tensor& hashmap_keys, + torch::Tensor& hashmap_values, + const torch::Tensor& coords, + int W, + int H, + int D +); diff --git a/o-voxel/src/hash/hash.cu b/o-voxel/src/hash/hash.cu new file mode 100644 index 0000000..c56d223 --- /dev/null +++ b/o-voxel/src/hash/hash.cu @@ -0,0 +1,446 @@ +#include +#include +#include + +#include "api.h" +#include "hash.cuh" + + +template +static __global__ void hashmap_insert_cuda_kernel( + const size_t N, + const size_t M, + K* __restrict__ hashmap_keys, + V* __restrict__ hashmap_values, + const K* __restrict__ keys, + const V* __restrict__ values +) { + size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x; + if (thread_id < M) + { + K key = keys[thread_id]; + V value = values[thread_id]; + linear_probing_insert(hashmap_keys, hashmap_values, key, value, N); + } +} + + +template +static void dispatch_hashmap_insert_cuda( + torch::Tensor& hashmap_keys, + torch::Tensor& hashmap_values, + const torch::Tensor& keys, + const torch::Tensor& values +) { + hashmap_insert_cuda_kernel<<< + (keys.size(0) + BLOCK_SIZE - 1) / BLOCK_SIZE, + BLOCK_SIZE + >>>( + hashmap_keys.size(0), + keys.size(0), + hashmap_keys.data_ptr(), + hashmap_values.data_ptr(), + keys.data_ptr(), + values.data_ptr() + ); +} + + +/** + * Insert keys into the hashmap + * + * @param hashmap_keys [N] uint32/uint64 tensor containing the hashmap keys + * @param hashmap_values [N] uint32/uint64 tensor containing the hashmap values + * @param keys [M] uint32/uint64 tensor containing the keys to be inserted + * @param values [M] uint32/uint64 tensor containing the values to be inserted + */ +void hashmap_insert_cuda( + torch::Tensor& hashmap_keys, + torch::Tensor& hashmap_values, + const torch::Tensor& keys, + const torch::Tensor& values +) { + // Dispatch to 32-bit or 64-bit kernel + if (hashmap_keys.dtype() == torch::kUInt32 && hashmap_values.dtype() == torch::kUInt32) { + TORCH_CHECK(keys.dtype() == torch::kUInt32, "Keys must be uint32"); + TORCH_CHECK(values.dtype() == torch::kUInt32, "Values must be uint32"); + dispatch_hashmap_insert_cuda(hashmap_keys, hashmap_values, keys, values); + } + else if (hashmap_keys.dtype() == torch::kUInt32 && hashmap_values.dtype() == torch::kUInt64) { + TORCH_CHECK(keys.dtype() == torch::kUInt32, "Keys must be uint32"); + TORCH_CHECK(values.dtype() == torch::kUInt64, "Values must be uint64"); + dispatch_hashmap_insert_cuda(hashmap_keys, hashmap_values, keys, values); + } + else if (hashmap_keys.dtype() == torch::kUInt64 && hashmap_values.dtype() == torch::kUInt32) { + TORCH_CHECK(keys.dtype() == torch::kUInt64, "Keys must be uint64"); + TORCH_CHECK(values.dtype() == torch::kUInt32, "Values must be uint32"); + dispatch_hashmap_insert_cuda(hashmap_keys, hashmap_values, keys, values); + } + else if (hashmap_keys.dtype() == torch::kUInt64 && hashmap_values.dtype() == torch::kUInt64) { + TORCH_CHECK(keys.dtype() == torch::kUInt64, "Keys must be uint64"); + TORCH_CHECK(values.dtype() == torch::kUInt64, "Values must be uint64"); + dispatch_hashmap_insert_cuda(hashmap_keys, hashmap_values, keys, values); + } + else { + TORCH_CHECK(false, "Unsupported data type"); + } +} + + +template +static __global__ void hashmap_lookup_cuda_kernel( + const size_t N, + const size_t M, + const K * __restrict__ hashmap_keys, + const V * __restrict__ hashmap_values, + const K * __restrict__ keys, + V * __restrict__ values +) { + size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x; + if (thread_id < M) { + K key = keys[thread_id]; + values[thread_id] = linear_probing_lookup(hashmap_keys, hashmap_values, key, N); + } +} + + +template +static void dispatch_hashmap_lookup_cuda( + const torch::Tensor& hashmap_keys, + const torch::Tensor& hashmap_values, + const torch::Tensor& keys, + torch::Tensor& values +) { + hashmap_lookup_cuda_kernel<<< + (keys.size(0) + BLOCK_SIZE - 1) / BLOCK_SIZE, + BLOCK_SIZE + >>>( + hashmap_keys.size(0), + keys.size(0), + hashmap_keys.data_ptr(), + hashmap_values.data_ptr(), + keys.data_ptr(), + values.data_ptr() + ); +} + + +/** + * Lookup keys in the hashmap + * + * @param hashmap_keys [N] uint32/uint64 tensor containing the hashmap keys + * @param hashmap_values [N] uint32/uint64 tensor containing the hashmap values + * @param keys [M] uint32/uint64 tensor containing the keys to be looked up + * @return [M] uint32/uint64 tensor containing the values of the keys + */ +torch::Tensor hashmap_lookup_cuda( + const torch::Tensor& hashmap_keys, + const torch::Tensor& hashmap_values, + const torch::Tensor& keys +) { + // Allocate output tensor + auto output = torch::empty({keys.size(0)}, torch::dtype(hashmap_values.dtype()).device(hashmap_values.device())); + + // Dispatch to 32-bit or 64-bit kernel + if (hashmap_keys.dtype() == torch::kUInt32 && hashmap_values.dtype() == torch::kUInt32) { + TORCH_CHECK(keys.dtype() == torch::kUInt32, "Keys must be uint32"); + TORCH_CHECK(output.dtype() == torch::kUInt32, "Output must be uint32"); + dispatch_hashmap_lookup_cuda(hashmap_keys, hashmap_values, keys, output); + } + else if (hashmap_keys.dtype() == torch::kUInt32 && hashmap_values.dtype() == torch::kUInt64) { + TORCH_CHECK(keys.dtype() == torch::kUInt32, "Keys must be uint32"); + TORCH_CHECK(output.dtype() == torch::kUInt64, "Output must be uint64"); + dispatch_hashmap_lookup_cuda(hashmap_keys, hashmap_values, keys, output); + } + else if (hashmap_keys.dtype() == torch::kUInt64 && hashmap_values.dtype() == torch::kUInt32) { + TORCH_CHECK(keys.dtype() == torch::kUInt64, "Keys must be uint64"); + TORCH_CHECK(output.dtype() == torch::kUInt32, "Output must be uint32"); + dispatch_hashmap_lookup_cuda(hashmap_keys, hashmap_values, keys, output); + } + else if (hashmap_keys.dtype() == torch::kUInt64 && hashmap_values.dtype() == torch::kUInt64) { + TORCH_CHECK(keys.dtype() == torch::kUInt64, "Keys must be uint64"); + TORCH_CHECK(output.dtype() == torch::kUInt64, "Output must be uint64"); + dispatch_hashmap_lookup_cuda(hashmap_keys, hashmap_values, keys, output); + } + else { + TORCH_CHECK(false, "Unsupported data type"); + } + + return output; +} + + +template +static __global__ void hashmap_insert_3d_cuda_kernel( + const size_t N, + const size_t M, + const int W, + const int H, + const int D, + K* __restrict__ hashmap_keys, + V* __restrict__ hashmap_values, + const int32_t* __restrict__ coords, + const V* __restrict__ values +) { + size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x; + if (thread_id < M) { + int4 coord = reinterpret_cast(coords)[thread_id]; + int b = coord.x; + int x = coord.y; + int y = coord.z; + int z = coord.w; + size_t flat_idx = (size_t)b * W * H * D + (size_t)x * H * D + (size_t)y * D + z; + K key = static_cast(flat_idx); + V value = values[thread_id]; + linear_probing_insert(hashmap_keys, hashmap_values, key, value, N); + } +} + + +template +static void dispatch_hashmap_insert_3d_cuda( + torch::Tensor& hashmap_keys, + torch::Tensor& hashmap_values, + const torch::Tensor& coords, + const torch::Tensor& values, + int W, int H, int D +) { + hashmap_insert_3d_cuda_kernel<<< + (coords.size(0) + BLOCK_SIZE - 1) / BLOCK_SIZE, + BLOCK_SIZE + >>>( + hashmap_keys.size(0), + coords.size(0), + W, H, D, + hashmap_keys.data_ptr(), + hashmap_values.data_ptr(), + coords.data_ptr(), + values.data_ptr() + ); +} + + +/** + * Insert 3D coordinates into the hashmap + * + * @param hashmap_keys [N] uint32/uint64 tensor containing the hashmap keys + * @param hashmap_values [N] uint32/uint64 tensor containing the hashmap values + * @param coords [M, 4] int32 tensor containing the keys to be inserted + * @param values [M] uint32/uint64 tensor containing the values to be inserted + * @param W the number of width dimensions + * @param H the number of height dimensions + * @param D the number of depth dimensions + */ +void hashmap_insert_3d_cuda( + torch::Tensor& hashmap_keys, + torch::Tensor& hashmap_values, + const torch::Tensor& coords, + const torch::Tensor& values, + int W, + int H, + int D +) { + TORCH_CHECK(coords.dtype() == torch::kInt32, "Coords must be int32"); + + // Dispatch to 32-bit or 64-bit kernel + if (hashmap_keys.dtype() == torch::kUInt32 && hashmap_values.dtype() == torch::kUInt32) { + TORCH_CHECK(values.dtype() == torch::kUInt32, "Values must be uint32"); + dispatch_hashmap_insert_3d_cuda(hashmap_keys, hashmap_values, coords, values, W, H, D); + } + else if (hashmap_keys.dtype() == torch::kUInt32 && hashmap_values.dtype() == torch::kUInt64) { + TORCH_CHECK(values.dtype() == torch::kUInt64, "Values must be uint64"); + dispatch_hashmap_insert_3d_cuda(hashmap_keys, hashmap_values, coords, values, W, H, D); + } + else if (hashmap_keys.dtype() == torch::kUInt64 && hashmap_values.dtype() == torch::kUInt32) { + TORCH_CHECK(values.dtype() == torch::kUInt32, "Values must be uint32"); + dispatch_hashmap_insert_3d_cuda(hashmap_keys, hashmap_values, coords, values, W, H, D); + } + else if (hashmap_keys.dtype() == torch::kUInt64 && hashmap_values.dtype() == torch::kUInt64) { + TORCH_CHECK(values.dtype() == torch::kUInt64, "Values must be uint64"); + dispatch_hashmap_insert_3d_cuda(hashmap_keys, hashmap_values, coords, values, W, H, D); + } + else { + TORCH_CHECK(false, "Unsupported data type"); + } +} + + +template +static __global__ void hashmap_lookup_3d_cuda_kernel( + const size_t N, + const size_t M, + const int W, + const int H, + const int D, + const K* __restrict__ hashmap_keys, + const V* __restrict__ hashmap_values, + const int32_t* __restrict__ coords, + V* __restrict__ values +) { + const size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x; + if (thread_id < M) { + int4 coord = reinterpret_cast(coords)[thread_id]; + int b = coord.x; + int x = coord.y; + int y = coord.z; + int z = coord.w; + if (x < 0 || x >= W || y < 0 || y >= H || z < 0 || z >= D) { + values[thread_id] = std::numeric_limits::max(); + return; + } + size_t flat_idx = (size_t)b * W * H * D + (size_t)x * H * D + (size_t)y * D + z; + K key = static_cast(flat_idx); + values[thread_id] = linear_probing_lookup(hashmap_keys, hashmap_values, key, N); + } +} + + +template +static void dispatch_hashmap_lookup_3d_cuda( + const torch::Tensor& hashmap_keys, + const torch::Tensor& hashmap_values, + const torch::Tensor& coords, + torch::Tensor& values, + int W, int H, int D +) { + hashmap_lookup_3d_cuda_kernel<<< + (coords.size(0) + BLOCK_SIZE - 1) / BLOCK_SIZE, + BLOCK_SIZE + >>>( + hashmap_keys.size(0), + coords.size(0), + W, H, D, + hashmap_keys.data_ptr(), + hashmap_values.data_ptr(), + coords.data_ptr(), + values.data_ptr() + ); +} + + +/** + * Lookup 3D coordinates in the hashmap + * + * @param hashmap_keys [N] uint32/uint64 tensor containing the hashmap keys + * @param hashmap_values [N] uint32/uint64 tensor containing the hashmap values + * @param coords [M, 4] int32 tensor containing the keys to be looked up + * @param W the number of width dimensions + * @param H the number of height dimensions + * @param D the number of depth dimensions + * + * @return [M] uint32/uint64 tensor containing the values of the keys + */ +torch::Tensor hashmap_lookup_3d_cuda( + const torch::Tensor& hashmap_keys, + const torch::Tensor& hashmap_values, + const torch::Tensor& coords, + int W, + int H, + int D +) { + // Allocate output tensor + auto output = torch::empty({coords.size(0)}, torch::dtype(hashmap_values.dtype()).device(hashmap_values.device())); + + // Dispatch to 32-bit or 64-bit kernel + if (hashmap_keys.dtype() == torch::kUInt32 && hashmap_values.dtype() == torch::kUInt32) { + dispatch_hashmap_lookup_3d_cuda(hashmap_keys, hashmap_values, coords, output, W, H, D); + } + else if (hashmap_keys.dtype() == torch::kUInt32 && hashmap_values.dtype() == torch::kUInt64) { + dispatch_hashmap_lookup_3d_cuda(hashmap_keys, hashmap_values, coords, output, W, H, D); + } + else if (hashmap_keys.dtype() == torch::kUInt64 && hashmap_values.dtype() == torch::kUInt32) { + dispatch_hashmap_lookup_3d_cuda(hashmap_keys, hashmap_values, coords, output, W, H, D); + } + else if (hashmap_keys.dtype() == torch::kUInt64 && hashmap_values.dtype() == torch::kUInt64) { + dispatch_hashmap_lookup_3d_cuda(hashmap_keys, hashmap_values, coords, output, W, H, D); + } + else { + TORCH_CHECK(false, "Unsupported data type"); + } + + return output; +} + + +template +static __global__ void hashmap_insert_3d_idx_as_val_cuda_kernel( + const size_t N, + const size_t M, + const int W, + const int H, + const int D, + K* __restrict__ hashmap, + V* __restrict__ values, + const int32_t* __restrict__ coords +) { + const size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x; + if (thread_id < M) { + int4 coord = reinterpret_cast(coords)[thread_id]; + int b = coord.x; + int x = coord.y; + int y = coord.z; + int z = coord.w; + size_t flat_idx = (size_t)b * W * H * D + (size_t)x * H * D + (size_t)y * D + z; + K key = static_cast(flat_idx); + V value = static_cast(thread_id); + linear_probing_insert(hashmap, values, key, value, N); + } +} + + +template +static void dispatch_hashmap_insert_3d_idx_as_val_cuda( + torch::Tensor& hashmap_keys, + torch::Tensor& hashmap_values, + const torch::Tensor& coords, + int W, int H, int D +) { + hashmap_insert_3d_idx_as_val_cuda_kernel<<< + (coords.size(0) + BLOCK_SIZE - 1) / BLOCK_SIZE, + BLOCK_SIZE + >>>( + hashmap_keys.size(0), + coords.size(0), + W, H, D, + hashmap_keys.data_ptr(), + hashmap_values.data_ptr(), + coords.data_ptr() + ); +} + + +/** + * Insert 3D coordinates into the hashmap using index as value + * + * @param hashmap_keys [N] uint32/uint64 tensor containing the hashmap keys + * @param hashmap_values [N] uint32/uint64 tensor containing the hashmap values + * @param coords [M, 4] int32 tensor containing the keys to be inserted + * @param W the number of width dimensions + * @param H the number of height dimensions + * @param D the number of depth dimensions + */ +void hashmap_insert_3d_idx_as_val_cuda( + torch::Tensor& hashmap_keys, + torch::Tensor& hashmap_values, + const torch::Tensor& coords, + int W, + int H, + int D +) { + // Dispatch to 32-bit or 64-bit kernel + if (hashmap_keys.dtype() == torch::kUInt32 && hashmap_values.dtype() == torch::kUInt32) { + dispatch_hashmap_insert_3d_idx_as_val_cuda(hashmap_keys, hashmap_values, coords, W, H, D); + } + else if (hashmap_keys.dtype() == torch::kUInt32 && hashmap_values.dtype() == torch::kUInt64) { + dispatch_hashmap_insert_3d_idx_as_val_cuda(hashmap_keys, hashmap_values, coords, W, H, D); + } + else if (hashmap_keys.dtype() == torch::kUInt64 && hashmap_values.dtype() == torch::kUInt32) { + dispatch_hashmap_insert_3d_idx_as_val_cuda(hashmap_keys, hashmap_values, coords, W, H, D); + } + else if (hashmap_keys.dtype() == torch::kUInt64 && hashmap_values.dtype() == torch::kUInt64) { + dispatch_hashmap_insert_3d_idx_as_val_cuda(hashmap_keys, hashmap_values, coords, W, H, D); + } + else { + TORCH_CHECK(false, "Unsupported data type"); + } +} \ No newline at end of file diff --git a/o-voxel/src/hash/hash.cuh b/o-voxel/src/hash/hash.cuh new file mode 100644 index 0000000..b7bcf1b --- /dev/null +++ b/o-voxel/src/hash/hash.cuh @@ -0,0 +1,87 @@ +// 32 bit Murmur3 hash +__forceinline__ __device__ size_t hash(uint32_t k, size_t N) { + k ^= k >> 16; + k *= 0x85ebca6b; + k ^= k >> 13; + k *= 0xc2b2ae35; + k ^= k >> 16; + return k % N; +} + + +// 64 bit Murmur3 hash +__forceinline__ __device__ size_t hash(uint64_t k, size_t N) { + k ^= k >> 33; + k *= 0xff51afd7ed558ccdULL; + k ^= k >> 33; + k *= 0xc4ceb9fe1a85ec53ULL; + k ^= k >> 33; + return k % N; +} + + +template +__forceinline__ __device__ void linear_probing_insert( + K* hashmap_keys, + V* hashmap_values, + const K key, + const V value, + const size_t N +) { + size_t slot = hash(key, N); + while (true) { + K prev = atomicCAS(&hashmap_keys[slot], std::numeric_limits::max(), key); + if (prev == std::numeric_limits::max() || prev == key) { + hashmap_values[slot] = value; + return; + } + slot = slot + 1; + if (slot >= N) slot = 0; + } +} + + +template +__forceinline__ __device__ void linear_probing_insert( + uint64_t* hashmap_keys, + V* hashmap_values, + const uint64_t key, + const V value, + const size_t N +) { + size_t slot = hash(key, N); + while (true) { + uint64_t prev = atomicCAS( + reinterpret_cast(&hashmap_keys[slot]), + static_cast(std::numeric_limits::max()), + static_cast(key) + ); + if (prev == std::numeric_limits::max() || prev == key) { + hashmap_values[slot] = value; + return; + } + slot = (slot + 1) % N; + } +} + + +template +__forceinline__ __device__ V linear_probing_lookup( + const K* hashmap_keys, + const V* hashmap_values, + const K key, + const size_t N +) { + size_t slot = hash(key, N); + while (true) { + K prev = hashmap_keys[slot]; + if (prev == std::numeric_limits::max()) { + return std::numeric_limits::max(); + } + if (prev == key) { + return hashmap_values[slot]; + } + slot = slot + 1; + if (slot >= N) slot = 0; + } +} diff --git a/o-voxel/src/io/api.h b/o-voxel/src/io/api.h new file mode 100644 index 0000000..116eb17 --- /dev/null +++ b/o-voxel/src/io/api.h @@ -0,0 +1,109 @@ +/* + * Efficient Sparse Voxel storage as Sparse Voxel Zip files (.svz) + * + * Copyright (C) 2025, Jianfeng XIANG + * All rights reserved. + * + * Licensed under The MIT License [see LICENSE for details] + * + * Written by Jianfeng XIANG + */ + +#pragma once +#include +#include + + +/** + * Encode a list of sparse voxel morton codes into a sparse voxel octree + * NOTE: The input indices must be sorted in ascending order + * + * @param codes [N] uint32 tensor containing the morton codes + * @param depth The depth of the sparse voxel octree + * + * @return uint8 tensor containing the sparse voxel octree + */ +torch::Tensor encode_sparse_voxel_octree_cpu( + const torch::Tensor& codes, + const uint32_t depth +); + + +/** + * Decode a sparse voxel octree into a list of sparse voxel morton codes + * + * @param octree uint8 tensor containing the sparse voxel octree + * @param depth The depth of the sparse voxel octree + * + * @return [N] uint32 tensor containing the morton codes + * The codes are sorted in ascending order + */ +torch::Tensor decode_sparse_voxel_octree_cpu( + const torch::Tensor& octree, + const uint32_t depth +); + + + +/** + * Encode the attribute of a sparse voxel octree into deltas from its parent node. + * + * @param octree uint8 tensor containing the sparse voxel octree + * @param depth The depth of the sparse voxel octree + * @param attr [N, C] tensor containing the attribute of each sparse voxel + * + * @return uint8 tensor containing the deltas + */ +torch::Tensor encode_sparse_voxel_octree_attr_parent_cpu( + const torch::Tensor& octree, + const uint32_t depth, + const torch::Tensor& attr +); + + +/** + * Decode the attribute of a sparse voxel octree from its parent node and its deltas. + * + * @param octree uint8 tensor containing the sparse voxel octree + * @param depth The depth of the sparse voxel octree + * @param delta uint8 tensor containing the deltas + * + * @return [N, C] tensor containing the attribute of each sparse voxel + */ +torch::Tensor decode_sparse_voxel_octree_attr_parent_cpu( + const torch::Tensor& octree, + const uint32_t depth, + const torch::Tensor& delta +); + + +/** + * Encode the attribute of a sparse voxel octree into deltas from its neighbors. + * + * @param coord [N, 3] tensor containing the coordinates of each sparse voxel + * @param res The resolution of the sparse voxel grid + * @param attr [N, C] tensor containing the attribute of each sparse voxel + * + * @return uint8 tensor containing the deltas + */ +torch::Tensor encode_sparse_voxel_octree_attr_neighbor_cpu( + const torch::Tensor& coord, + const uint32_t res, + const torch::Tensor& attr +); + + +/** + * Decode the attribute of a sparse voxel octree from its neighbors and deltas. + * + * @param coord [N, 3] tensor containing the coordinates of each sparse voxel + * @param res The resolution of the sparse voxel grid + * @param delta [N, C] tensor containing the deltas + * + * @return [N, C] tensor containing the attribute of each sparse voxel + */ +torch::Tensor decode_sparse_voxel_octree_attr_neighbor_cpu( + const torch::Tensor& coord, + const uint32_t res, + const torch::Tensor& delta +); diff --git a/o-voxel/src/io/filter_neighbor.cpp b/o-voxel/src/io/filter_neighbor.cpp new file mode 100644 index 0000000..3af406e --- /dev/null +++ b/o-voxel/src/io/filter_neighbor.cpp @@ -0,0 +1,178 @@ +#include +#include "api.h" + +#include +#include +#include + + +/** + * Encode the attribute of a sparse voxel octree into deltas from its neighbors. + * + * @param coord [N, 3] tensor containing the coordinates of each sparse voxel + * @param res The resolution of the sparse voxel grid + * @param attr [N, C] tensor containing the attribute of each sparse voxel + * + * @return uint8 tensor containing the deltas + */ +torch::Tensor encode_sparse_voxel_octree_attr_neighbor_cpu( + const torch::Tensor& coord, + const uint32_t res, + const torch::Tensor& attr +) { + size_t N = coord.size(0); + size_t C = attr.size(1); + int* coord_data = coord.data_ptr(); + uint8_t* attr_data = attr.data_ptr(); + std::vector buffer(res * res * res * (C + 1), 0); + + // Densify the coordinates + for (int i = 0; i < N; i++) { + int x = coord_data[i * 3 + 0]; + int y = coord_data[i * 3 + 1]; + int z = coord_data[i * 3 + 2]; + int ptr = (z * res * res + y * res + x) * (C + 1); + buffer[ptr + C] = 1; + for (int c = 0; c < C; c++) { + buffer[ptr + c] = attr_data[i * C + c]; + } + } + + // Compute the deltas + for (int z = res-1; z >= 0; z--) { + for (int y = res-1; y >= 0; y--) { + for (int x = res-1; x >= 0; x--) { + int ptr = (z * res * res + y * res + x) * (C + 1); + int neignbor_ptr = -1; + int tmp_ptr; + if (!buffer[ptr + C]) continue; + // x + tmp_ptr = (z * res * res + y * res + (x - 1)) * (C + 1); + if (x > 0 && buffer[tmp_ptr + C]) neignbor_ptr = tmp_ptr; + // y + tmp_ptr = (z * res * res + (y - 1) * res + x) * (C + 1); + if (y > 0 && buffer[tmp_ptr + C]) neignbor_ptr = tmp_ptr; + // z + tmp_ptr = ((z - 1) * res * res + y * res + x) * (C + 1); + if (z > 0 && buffer[tmp_ptr + C]) neignbor_ptr = tmp_ptr; + // xy + tmp_ptr = (z * res * res + (y - 1) * res + (x - 1)) * (C + 1); + if (y > 0 && x > 0 && buffer[tmp_ptr + C]) neignbor_ptr = tmp_ptr; + // xz + tmp_ptr = ((z - 1) * res * res + y * res + (x - 1)) * (C + 1); + if (z > 0 && x > 0 && buffer[tmp_ptr + C]) neignbor_ptr = tmp_ptr; + // yz + tmp_ptr = ((z - 1) * res * res + (y - 1) * res + x) * (C + 1); + if (z > 0 && y > 0 && buffer[tmp_ptr + C]) neignbor_ptr = tmp_ptr; + // xyz + tmp_ptr = ((z - 1) * res * res + (y - 1) * res + (x - 1)) * (C + 1); + if (z > 0 && y > 0 && x > 0 && buffer[tmp_ptr + C]) neignbor_ptr = tmp_ptr; + if (neignbor_ptr >= 0) { + for (int c = 0; c < C; c++) { + buffer[ptr + c] -= buffer[neignbor_ptr + c]; + } + } + } + } + } + + // Pack the deltas into a uint8 tensor + torch::Tensor delta = torch::zeros({N, C}, torch::dtype(torch::kUInt8)); + uint8_t* delta_data = delta.data_ptr(); + for (int i = 0; i < N; i++) { + int x = coord_data[i * 3 + 0]; + int y = coord_data[i * 3 + 1]; + int z = coord_data[i * 3 + 2]; + int ptr = (z * res * res + y * res + x) * (C + 1); + for (int c = 0; c < C; c++) { + delta_data[i * C + c] = buffer[ptr + c]; + } + } + return delta; +} + + +/** + * Decode the attribute of a sparse voxel octree from its neighbors and deltas. + * + * @param coord [N, 3] tensor containing the coordinates of each sparse voxel + * @param res The resolution of the sparse voxel grid + * @param delta [N, C] tensor containing the deltas + * + * @return [N, C] tensor containing the attribute of each sparse voxel + */ +torch::Tensor decode_sparse_voxel_octree_attr_neighbor_cpu( + const torch::Tensor& coord, + const uint32_t res, + const torch::Tensor& delta +) { + size_t N = coord.size(0); + size_t C = delta.size(1); + int* coord_data = coord.data_ptr(); + uint8_t* delta_data = delta.data_ptr(); + std::vector buffer(res * res * res * (C + 1), 0); + + // Densify the coordinates + for (int i = 0; i < N; i++) { + int x = coord_data[i * 3 + 0]; + int y = coord_data[i * 3 + 1]; + int z = coord_data[i * 3 + 2]; + int ptr = (z * res * res + y * res + x) * (C + 1); + buffer[ptr + C] = 1; + for (int c = 0; c < C; c++) { + buffer[ptr + c] = delta_data[i * C + c]; + } + } + + // Reconstruct the attribute + for (int z = 0; z < res; z++) { + for (int y = 0; y < res; y++) { + for (int x = 0; x < res; x++) { + int ptr = (z * res * res + y * res + x) * (C + 1); + int neignbor_ptr = -1; + int tmp_ptr; + if (!buffer[ptr + C]) continue; + // x + tmp_ptr = (z * res * res + y * res + (x - 1)) * (C + 1); + if (x > 0 && buffer[tmp_ptr + C]) neignbor_ptr = tmp_ptr; + // y + tmp_ptr = (z * res * res + (y - 1) * res + x) * (C + 1); + if (y > 0 && buffer[tmp_ptr + C]) neignbor_ptr = tmp_ptr; + // z + tmp_ptr = ((z - 1) * res * res + y * res + x) * (C + 1); + if (z > 0 && buffer[tmp_ptr + C]) neignbor_ptr = tmp_ptr; + // xy + tmp_ptr = (z * res * res + (y - 1) * res + (x - 1)) * (C + 1); + if (y > 0 && x > 0 && buffer[tmp_ptr + C]) neignbor_ptr = tmp_ptr; + // xz + tmp_ptr = ((z - 1) * res * res + y * res + (x - 1)) * (C + 1); + if (z > 0 && x > 0 && buffer[tmp_ptr + C]) neignbor_ptr = tmp_ptr; + // yz + tmp_ptr = ((z - 1) * res * res + (y - 1) * res + x) * (C + 1); + if (z > 0 && y > 0 && buffer[tmp_ptr + C]) neignbor_ptr = tmp_ptr; + // xyz + tmp_ptr = ((z - 1) * res * res + (y - 1) * res + (x - 1)) * (C + 1); + if (z > 0 && y > 0 && x > 0 && buffer[tmp_ptr + C]) neignbor_ptr = tmp_ptr; + if (neignbor_ptr >= 0) { + for (int c = 0; c < C; c++) { + buffer[ptr + c] += buffer[neignbor_ptr + c]; + } + } + } + } + } + + // Pack the attribute into a uint8 tensor + torch::Tensor attr = torch::zeros({N, C}, torch::dtype(torch::kUInt8)); + uint8_t* attr_data = attr.data_ptr(); + for (int i = 0; i < N; i++) { + int x = coord_data[i * 3 + 0]; + int y = coord_data[i * 3 + 1]; + int z = coord_data[i * 3 + 2]; + int ptr = (z * res * res + y * res + x) * (C + 1); + for (int c = 0; c < C; c++) { + attr_data[i * C + c] = buffer[ptr + c]; + } + } + return attr; +} diff --git a/o-voxel/src/io/filter_parent.cpp b/o-voxel/src/io/filter_parent.cpp new file mode 100644 index 0000000..0beadaf --- /dev/null +++ b/o-voxel/src/io/filter_parent.cpp @@ -0,0 +1,165 @@ +#include +#include "api.h" +#include "lut.h" + +#include +#include +#include + + +std::vector encode_recursive( + const uint8_t* svo, + const uint32_t depth, + const uint8_t* attr, + const size_t C, + uint32_t& svo_ptr, + uint32_t& attr_ptr, + uint32_t& delta_ptr, + uint32_t self_delta_ptr, + uint32_t cur_depth, + uint8_t* delta +) { + std::vector node_attr(C, 0); + if (cur_depth == depth) { + // Leaf node + for (size_t i = 0; i < C; i++) { + node_attr[i] = attr[attr_ptr + i]; + if (self_delta_ptr != 0 || cur_depth == 0) { + delta[self_delta_ptr + i] = node_attr[i]; + } + } + attr_ptr += C; + } + else { + // Internal node + uint8_t node = svo[svo_ptr]; + uint32_t child_delta_ptr = delta_ptr; + uint8_t cnt = lut_1cnt[node]; + svo_ptr++; + delta_ptr += C * (cnt - 1); + for (uint8_t i = 0; i < cnt; i++) { + auto child_attr = encode_recursive( + svo, depth, attr, C, svo_ptr, attr_ptr, delta_ptr, i == cnt-1 ? 0 : child_delta_ptr+i*C, cur_depth+1, delta + ); + for (size_t j = 0; j < C; j++) { + if (i == 0) { + node_attr[j] = child_attr[j]; + } + else { + delta[child_delta_ptr + (i-1)*C + j] = child_attr[j] - delta[child_delta_ptr + (i-1)*C + j]; + } + } + } + if (self_delta_ptr != 0 || cur_depth == 0) { + for (size_t i = 0; i < C; i++) { + delta[self_delta_ptr + i] = node_attr[i]; + } + } + } + return node_attr; +} + + +/** + * Encode the attribute of a sparse voxel octree into deltas from its parent node. + * + * @param octree uint8 tensor containing the sparse voxel octree + * @param depth The depth of the sparse voxel octree + * @param attr [N, C] tensor containing the attribute of each sparse voxel + * + * @return uint8 tensor containing the deltas + */ +torch::Tensor encode_sparse_voxel_octree_attr_parent_cpu( + const torch::Tensor& octree, + const uint32_t depth, + const torch::Tensor& attr +) { + size_t N_leaf = attr.size(0); + size_t N_node = octree.size(0); + size_t C = attr.size(1); + uint8_t* octree_data = octree.data_ptr(); + uint8_t* attr_data = attr.data_ptr(); + + torch::Tensor delta = torch::zeros({N_leaf, C}, torch::kUInt8); + uint32_t svo_ptr = 0; + uint32_t attr_ptr = 0; + uint32_t delta_ptr = C; + encode_recursive(octree_data, depth, attr_data, C, svo_ptr, attr_ptr, delta_ptr, 0, 0, delta.data_ptr()); + + return delta; +} + + +void decode_recursive( + const uint8_t* svo, + const uint32_t depth, + const uint8_t* delta, + const size_t C, + uint32_t& svo_ptr, + uint32_t& attr_ptr, + uint32_t& delta_ptr, + uint32_t cur_depth, + uint8_t* cur_attr, + uint8_t* attr +) { + if (cur_depth == depth) { + // Leaf node + for (size_t i = 0; i < C; i++) { + attr[attr_ptr + i] = cur_attr[i]; + } + attr_ptr += C; + } + else { + // Internal node + uint8_t node = svo[svo_ptr]; + uint32_t child_delta_ptr = delta_ptr; + std::vector child_attr(cur_attr, cur_attr + C); + uint8_t cnt = lut_1cnt[node]; + svo_ptr++; + delta_ptr += C * (cnt - 1); + for (uint8_t i = 0; i < cnt; i++) { + for (size_t j = 0; j < C; j++) { + if (i > 0) { + child_attr[j] += delta[child_delta_ptr + (i-1)*C + j]; + } + } + decode_recursive( + svo, depth, delta, C, svo_ptr, attr_ptr, delta_ptr, cur_depth+1, child_attr.data(), attr + ); + } + } +} + + +/** + * Decode the attribute of a sparse voxel octree from its parent node and its deltas. + * + * @param octree uint8 tensor containing the sparse voxel octree + * @param depth The depth of the sparse voxel octree + * @param delta uint8 tensor containing the deltas + * + * @return [N, C] tensor containing the attribute of each sparse voxel + */ +torch::Tensor decode_sparse_voxel_octree_attr_parent_cpu( + const torch::Tensor& octree, + const uint32_t depth, + const torch::Tensor& delta +) { + size_t N_node = octree.size(0); + size_t N_leaf = delta.size(0); + size_t C = delta.size(1); + uint8_t* octree_data = octree.data_ptr(); + uint8_t* delta_data = delta.data_ptr(); + + torch::Tensor attr = torch::zeros({N_leaf, C}, torch::kUInt8); + uint32_t svo_ptr = 0; + uint32_t attr_ptr = 0; + uint32_t delta_ptr = C; + + // Recursively decode the attribute + decode_recursive( + octree_data, depth, delta_data, C, svo_ptr, attr_ptr, delta_ptr, 0, delta_data, attr.data_ptr() + ); + + return attr; +} diff --git a/o-voxel/src/io/lut.h b/o-voxel/src/io/lut.h new file mode 100644 index 0000000..d790ae9 --- /dev/null +++ b/o-voxel/src/io/lut.h @@ -0,0 +1,19 @@ +#pragma once + +#include + +// np.array([bin(i).count('1') for i in range(256)], dtype=np.uint8) +uint8_t lut_1cnt[256] = { + 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, 1, 2, 2, 3, 2, 3, + 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, + 3, 4, 4, 5, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 1, 2, + 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, 2, 3, 3, 4, 3, 4, 4, 5, + 3, 4, 4, 5, 4, 5, 5, 6, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, + 5, 6, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, 1, 2, 2, 3, + 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, + 4, 5, 4, 5, 5, 6, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, 2, 3, 3, 4, 3, 4, + 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, + 5, 6, 6, 7, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, 4, 5, + 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8 +}; diff --git a/o-voxel/src/io/svo.cpp b/o-voxel/src/io/svo.cpp new file mode 100644 index 0000000..6775284 --- /dev/null +++ b/o-voxel/src/io/svo.cpp @@ -0,0 +1,138 @@ +#include +#include "api.h" + +#include +#include + + +/** + * Encode a list of sparse voxel morton codes into a sparse voxel octree + * NOTE: The input indices must be sorted in ascending order + * + * @param codes [N] uint32 tensor containing the morton codes + * @param depth The depth of the sparse voxel octree + * + * @return uint8 tensor containing the sparse voxel octree + */ +torch::Tensor encode_sparse_voxel_octree_cpu( + const torch::Tensor& codes, + const uint32_t depth +) { + size_t N_leaf = codes.size(0); + int* codes_data = codes.data_ptr(); + + std::vector svo; + std::vector stack(depth-1); + std::vector insert_stack(depth); + std::vector stack_ptr(depth); + uint32_t code, insert_from; + + // Root node + svo.push_back(0); + stack_ptr[0] = 0; + + // Iterate over all codes and encode them into SVO + for (int i = 0; i < N_leaf; i++) { + code = codes_data[i]; + + // Convert code to insert stack (3bit per level) + for (uint32_t j = 0; j < depth; j++) { + insert_stack[j] = (code >> (3*(depth-1-j))) & 0x7; + } + + // Compare insert stack to stack to determine which level to insert + if (i == 0) { + // First code, insert at level 0 + insert_from = 0; + } + else { + // Compare insert stack to stack + for (insert_from = 0; insert_from < depth-1; insert_from++) { + if (insert_stack[insert_from] != stack[insert_from]) { + break; + } + } + } + + // Insert new nodes from insert_from to depth-1 + for (uint32_t j = insert_from; j < depth; j++) { + // Add new node to SVO + if (j > insert_from) { + svo.push_back(0); + stack_ptr[j] = svo.size()-1; + } + // Update parent pointers + svo[stack_ptr[j]] |= (1 << insert_stack[j]); + // Update stack + if (j < depth-1) { + stack[j] = insert_stack[j]; + } + } + } + + // Convert SVO to tensor + torch::Tensor svo_tensor = torch::from_blob(svo.data(), {svo.size()}, torch::kUInt8).clone(); + return svo_tensor; +} + + +void decode_sparse_voxel_octree_cpu_recursive( + const uint8_t* svo, + const uint32_t depth, + uint32_t& ptr, + std::vector& stack, + std::vector& codes +) { + uint8_t node = svo[ptr]; + if (stack.size() == depth-1) { + // Leaf node, add code to list + uint32_t code = 0; + for (uint32_t i = 0; i < depth-1; i++) { + code |= (static_cast(stack[i]) << (3*(depth-1-i))); + } + for (uint8_t i = 0; i < 8; i++) { + if (node & (1 << i)) { + code = (code & ~0x7) | i; + codes.push_back(code); + } + } + ptr++; + } + else { + // Internal node, recurse + ptr++; + for (uint8_t i = 0; i < 8; i++) { + if (node & (1 << i)) { + stack.push_back(i); + decode_sparse_voxel_octree_cpu_recursive(svo, depth, ptr, stack, codes); + stack.pop_back(); + } + } + } +} + + +/** + * Decode a sparse voxel octree into a list of sparse voxel morton codes + * + * @param octree uint8 tensor containing the sparse voxel octree + * @param depth The depth of the sparse voxel octree + * + * @return [N] uint32 tensor containing the morton codes + * The codes are sorted in ascending order + */ +torch::Tensor decode_sparse_voxel_octree_cpu( + const torch::Tensor& octree, + const uint32_t depth +) { + uint8_t* octree_data = octree.data_ptr(); + std::vector codes; + std::vector stack; + stack.reserve(depth-2); + uint32_t ptr = 0; + // Decode SVO into list of codes + decode_sparse_voxel_octree_cpu_recursive(octree_data, depth, ptr, stack, codes); + // Convert codes to tensor + torch::Tensor codes_tensor = torch::from_blob(codes.data(), {codes.size()}, torch::kInt32).clone(); + return codes_tensor; +} diff --git a/o-voxel/src/rasterize/api.h b/o-voxel/src/rasterize/api.h new file mode 100644 index 0000000..26212fd --- /dev/null +++ b/o-voxel/src/rasterize/api.h @@ -0,0 +1,47 @@ +/* + * Sparse Voxel Rasterizer + * + * Copyright (C) 2025, Jianfeng XIANG + * All rights reserved. + * + * Licensed under The MIT License [see LICENSE for details] + * + * Written by Jianfeng XIANG + */ + +#pragma once +#include + + +/** + * Rasterize a sparse voxel octree with CUDA backend + * + * @param positions Tensor of shape (N, 3) containing the positions of the octree nodes in [0, 1]^3 + * @param attrs Tensor of shape (N, 1) containing the attributes of the octree nodes + * @param voxel_size Float containing the size of the voxels + * @param viewmatrix Tensor of shape (4, 4) containing the view matrix + * @param projmatrix Tensor of shape (4, 4) containing the projection matrix + * @param campos Tensor of shape (3) containing the camera position + * @param tan_fovx Float containing the tangent of the horizontal field of view + * @param tan_fovy Float containing the tangent of the vertical field of view + * @param image_height Integer containing the image height + * @param image_width Integer containing the image width + * + * @return A tuple containing: + * - Tensor of shape (C, H, W) containing the output color + * - Tensor of shape (H, W) containing the output depth + * - Tensor of shape (H, W) containing the output alpha + */ +std::tuple +rasterize_voxels_cuda( + const torch::Tensor& positions, + const torch::Tensor& attrs, + const float voxel_size, + const torch::Tensor& viewmatrix, + const torch::Tensor& projmatrix, + const torch::Tensor& campos, + const float tan_fovx, + const float tan_fovy, + const int image_height, + const int image_width +); diff --git a/o-voxel/src/rasterize/auxiliary.h b/o-voxel/src/rasterize/auxiliary.h new file mode 100644 index 0000000..f4abfbc --- /dev/null +++ b/o-voxel/src/rasterize/auxiliary.h @@ -0,0 +1,285 @@ +#pragma once +#include "config.h" + + +#define BLOCK_SIZE (BLOCK_X * BLOCK_Y) + + +__forceinline__ __device__ float ndc2Pix(float v, int S) +{ + return ((v + 1.0) * S - 1.0) * 0.5; +} + + +__forceinline__ __device__ void getRect(const int4 bbox, uint2& rect_min, uint2& rect_max, dim3 grid) +{ + rect_min = { + min(grid.x, max((int)0, (int)((bbox.x) / BLOCK_X))), + min(grid.y, max((int)0, (int)((bbox.y) / BLOCK_Y))) + }; + rect_max = { + min(grid.x, max((int)0, (int)((bbox.z + BLOCK_X - 1) / BLOCK_X))), + min(grid.y, max((int)0, (int)((bbox.w + BLOCK_Y - 1) / BLOCK_Y))) + }; +} + + +__forceinline__ __device__ float3 normalize(const float3& v) +{ + float inv_norm = 1.0f / sqrt(v.x * v.x + v.y * v.y + v.z * v.z); + return { v.x * inv_norm, v.y * inv_norm, v.z * inv_norm }; +} + + +__forceinline__ __device__ float3 transformPoint4x3(const float3& p, const float* matrix) +{ + float3 transformed = { + matrix[0] * p.x + matrix[4] * p.y + matrix[8] * p.z + matrix[12], + matrix[1] * p.x + matrix[5] * p.y + matrix[9] * p.z + matrix[13], + matrix[2] * p.x + matrix[6] * p.y + matrix[10] * p.z + matrix[14], + }; + return transformed; +} + + +__forceinline__ __device__ float4 transformPoint4x4(const float3& p, const float* matrix) +{ + float4 transformed = { + matrix[0] * p.x + matrix[4] * p.y + matrix[8] * p.z + matrix[12], + matrix[1] * p.x + matrix[5] * p.y + matrix[9] * p.z + matrix[13], + matrix[2] * p.x + matrix[6] * p.y + matrix[10] * p.z + matrix[14], + matrix[3] * p.x + matrix[7] * p.y + matrix[11] * p.z + matrix[15] + }; + return transformed; +} + + +__forceinline__ __device__ float3 transformVec4x3(const float3& p, const float* matrix) +{ + float3 transformed = { + matrix[0] * p.x + matrix[4] * p.y + matrix[8] * p.z, + matrix[1] * p.x + matrix[5] * p.y + matrix[9] * p.z, + matrix[2] * p.x + matrix[6] * p.y + matrix[10] * p.z, + }; + return transformed; +} + + +__forceinline__ __device__ float3 transformVec4x3Transpose(const float3& p, const float* matrix) +{ + float3 transformed = { + matrix[0] * p.x + matrix[1] * p.y + matrix[2] * p.z, + matrix[4] * p.x + matrix[5] * p.y + matrix[6] * p.z, + matrix[8] * p.x + matrix[9] * p.y + matrix[10] * p.z, + }; + return transformed; +} + + +__forceinline__ __device__ float dnormvdz(float3 v, float3 dv) +{ + float sum2 = v.x * v.x + v.y * v.y + v.z * v.z; + float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2); + float dnormvdz = (-v.x * v.z * dv.x - v.y * v.z * dv.y + (sum2 - v.z * v.z) * dv.z) * invsum32; + return dnormvdz; +} + + +__forceinline__ __device__ float3 dnormvdv(float3 v, float3 dv) +{ + float sum2 = v.x * v.x + v.y * v.y + v.z * v.z; + float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2); + + float3 dnormvdv; + dnormvdv.x = ((+sum2 - v.x * v.x) * dv.x - v.y * v.x * dv.y - v.z * v.x * dv.z) * invsum32; + dnormvdv.y = (-v.x * v.y * dv.x + (sum2 - v.y * v.y) * dv.y - v.z * v.y * dv.z) * invsum32; + dnormvdv.z = (-v.x * v.z * dv.x - v.y * v.z * dv.y + (sum2 - v.z * v.z) * dv.z) * invsum32; + return dnormvdv; +} + + +__forceinline__ __device__ float4 dnormvdv(float4 v, float4 dv) +{ + float sum2 = v.x * v.x + v.y * v.y + v.z * v.z + v.w * v.w; + float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2); + + float4 vdv = { v.x * dv.x, v.y * dv.y, v.z * dv.z, v.w * dv.w }; + float vdv_sum = vdv.x + vdv.y + vdv.z + vdv.w; + float4 dnormvdv; + dnormvdv.x = ((sum2 - v.x * v.x) * dv.x - v.x * (vdv_sum - vdv.x)) * invsum32; + dnormvdv.y = ((sum2 - v.y * v.y) * dv.y - v.y * (vdv_sum - vdv.y)) * invsum32; + dnormvdv.z = ((sum2 - v.z * v.z) * dv.z - v.z * (vdv_sum - vdv.z)) * invsum32; + dnormvdv.w = ((sum2 - v.w * v.w) * dv.w - v.w * (vdv_sum - vdv.w)) * invsum32; + return dnormvdv; +} + + +__forceinline__ __device__ float sigmoid(float x) +{ + return 1.0f / (1.0f + expf(-x)); +} + + +__forceinline__ __device__ bool in_frustum(int idx, + const float3& p_orig, + const float* viewmatrix, + const float* projmatrix, + float3& p_view) +{ + // Bring points to screen space + float4 p_hom = transformPoint4x4(p_orig, projmatrix); + float p_w = 1.0f / (p_hom.w + 0.0000001f); + float3 p_proj = { p_hom.x * p_w, p_hom.y * p_w, p_hom.z * p_w }; + p_view = transformPoint4x3(p_orig, viewmatrix); + + if (p_view.z <= 0.2f)// || ((p_proj.x < -1.3 || p_proj.x > 1.3 || p_proj.y < -1.3 || p_proj.y > 1.3))) + { + return false; + } + return true; +} + + +__forceinline__ __device__ uint32_t expandBits(uint32_t v) +{ + v = (v * 0x00010001u) & 0xFF0000FFu; + v = (v * 0x00000101u) & 0x0F00F00Fu; + v = (v * 0x00000011u) & 0xC30C30C3u; + v = (v * 0x00000005u) & 0x49249249u; + return v; +} + + +__forceinline__ __device__ int2 project(const float3& p, const float* matrix, const int& width, const int& height) +{ + float3 p_hom = { + matrix[0] * p.x + matrix[4] * p.y + matrix[8] * p.z + matrix[12], + matrix[1] * p.x + matrix[5] * p.y + matrix[9] * p.z + matrix[13], + matrix[3] * p.x + matrix[7] * p.y + matrix[11] * p.z + matrix[15] + }; + float p_w = 1.0f / (p_hom.z + 0.0000001f); + return { (int)((p_hom.x * p_w + 1.0f) * 0.5f * width), (int)((p_hom.y * p_w + 1.0f) * 0.5f * height) }; +} + + +#define GET_BBOX_FIRST(A, B, C) \ +vertex.x = point.x A half_scale.x; \ +vertex.y = point.y B half_scale.y; \ +vertex.z = point.z C half_scale.z; \ +p_screen = project(vertex, projmatrix, width, height); \ +bbox.x = p_screen.x; \ +bbox.y = p_screen.y; \ +bbox.z = p_screen.x + 1; \ +bbox.w = p_screen.y + 1; + +#define GET_BBOX_OTHER(A, B, C) \ +vertex.x = point.x A half_scale.x; \ +vertex.y = point.y B half_scale.y; \ +vertex.z = point.z C half_scale.z; \ +p_screen = project(vertex, projmatrix, width, height); \ +bbox.x = min(bbox.x, p_screen.x); \ +bbox.y = min(bbox.y, p_screen.y); \ +bbox.z = max(bbox.z, p_screen.x + 1); \ +bbox.w = max(bbox.w, p_screen.y + 1); + + +__forceinline__ __device__ int4 get_bbox( + const float3& point, + const float3& scale, + const float* projmatrix, + const int& width, + const int& height +) { + float3 half_scale = { scale.x * 0.5f, scale.y * 0.5f, scale.z * 0.5f }; + float3 vertex; + int2 p_screen; + int4 bbox; + + GET_BBOX_FIRST(-, -, -); + GET_BBOX_OTHER(+, -, -); + GET_BBOX_OTHER(-, +, -); + GET_BBOX_OTHER(+, +, -); + GET_BBOX_OTHER(-, -, +); + GET_BBOX_OTHER(+, -, +); + GET_BBOX_OTHER(-, +, +); + GET_BBOX_OTHER(+, +, +); + + bbox.x = max(0, bbox.x); + bbox.y = max(0, bbox.y); + bbox.z = min(width, bbox.z); + bbox.w = min(height, bbox.w); + if (bbox.x >= bbox.z || bbox.y >= bbox.w) // bbox is empty + return { 0, 0, 0, 0 }; + return bbox; +} + + +// Fast ray-box intersection, returns the intersection distance +__forceinline__ __device__ float2 get_ray_voxel_intersection( + const float3& ray_origin, + const float3& ray_direction, + const float3& voxel_min, + const float3& voxel_max +) { + // Careful with the division by zero + float3 inv_direction; + inv_direction.x = ray_direction.x == 0.0f ? 1e10f : 1.0f / ray_direction.x; + inv_direction.y = ray_direction.y == 0.0f ? 1e10f : 1.0f / ray_direction.y; + inv_direction.z = ray_direction.z == 0.0f ? 1e10f : 1.0f / ray_direction.z; + float3 t0 = { + (voxel_min.x - ray_origin.x) * inv_direction.x, + (voxel_min.y - ray_origin.y) * inv_direction.y, + (voxel_min.z - ray_origin.z) * inv_direction.z + }; + float3 t1 = { + (voxel_max.x - ray_origin.x) * inv_direction.x, + (voxel_max.y - ray_origin.y) * inv_direction.y, + (voxel_max.z - ray_origin.z) * inv_direction.z + }; + float3 tmin = { + min(t0.x, t1.x), + min(t0.y, t1.y), + min(t0.z, t1.z) + }; + float3 tmax = { + max(t0.x, t1.x), + max(t0.y, t1.y), + max(t0.z, t1.z) + }; + float tmin_max = max(tmin.x, max(tmin.y, tmin.z)); + float tmax_min = min(tmax.x, min(tmax.y, tmax.z)); + return { tmin_max, tmax_min }; +} + + +__forceinline__ __device__ float3 getRayDir( + const uint2& pix, + const int& width, + const int& height, + const float& tan_fovx, + const float& tan_fovy, + const float* viewmatrix +) { + float x = (2.0f * (pix.x + 0.5f) / width - 1.0f) * tan_fovx; + float y = (2.0f * (pix.y + 0.5f) / height - 1.0f) * tan_fovy; + float3 ray_dir = { + viewmatrix[0] * x + viewmatrix[1] * y + viewmatrix[2], + viewmatrix[4] * x + viewmatrix[5] * y + viewmatrix[6], + viewmatrix[8] * x + viewmatrix[9] * y + viewmatrix[10] + }; + return normalize(ray_dir); +} + + +#ifdef DEBUG +#define CHECK_CUDA(...) __VA_ARGS__; {\ +auto ret = cudaDeviceSynchronize(); \ +if (ret != cudaSuccess) { \ +std::cerr << "\n[CUDA ERROR] in " << __FILE__ << "\nLine " << __LINE__ << ": " << cudaGetErrorString(ret); \ +throw std::runtime_error(cudaGetErrorString(ret)); \ +}} +#define DEBUG_PRINT(...) printf(__VA_ARGS__) +#else +#define CHECK_CUDA(...) __VA_ARGS__ +#define DEBUG_PRINT(...) +#endif \ No newline at end of file diff --git a/o-voxel/src/rasterize/config.h b/o-voxel/src/rasterize/config.h new file mode 100644 index 0000000..0bf4bab --- /dev/null +++ b/o-voxel/src/rasterize/config.h @@ -0,0 +1,5 @@ +#pragma once + +#define BLOCK_X 8 +#define BLOCK_Y 8 +#define MEM_ALIGNMENT 128 diff --git a/o-voxel/src/rasterize/rasterize.cu b/o-voxel/src/rasterize/rasterize.cu new file mode 100644 index 0000000..cc7291d --- /dev/null +++ b/o-voxel/src/rasterize/rasterize.cu @@ -0,0 +1,396 @@ +#include + +#include +#include "cuda_runtime.h" + +#include +namespace cg = cooperative_groups; + +#include "config.h" +#include "auxiliary.h" +#include "api.h" + + +/** + * Preprocess input 3D points + */ +static __global__ void preprocess( + const int num_nodes, + const float* positions, + const float voxel_size, + const float* viewmatrix, + const float* projmatrix, + const int width, + const int height, + const dim3 grid, + int4* bboxes, + float* depths, + uint32_t* tiles_touched +) { + auto idx = cg::this_grid().thread_rank(); + if (idx >= num_nodes) + return; + + // Initialize bboxes and touched tiles to 0. If this isn't changed, + // this voxel will not be processed further. + bboxes[idx] = { 0, 0, 0, 0 }; + tiles_touched[idx] = 0; + + // Perform near culling, quit if outside. + float3 p_orig = { + positions[3 * idx], + positions[3 * idx + 1], + positions[3 * idx + 2] + }; + float3 p_view; + if (!in_frustum(idx, p_orig, viewmatrix, projmatrix, p_view)) + return; + + // Project 8 vertices of the voxel to screen space to find the + // bounding box of the projected points. + float3 scale = { voxel_size, voxel_size, voxel_size }; + int4 bbox = get_bbox(p_orig, scale, projmatrix, width, height); + uint2 rect_min, rect_max; + getRect(bbox, rect_min, rect_max, grid); + if ((rect_max.x - rect_min.x) * (rect_max.y - rect_min.y) == 0) + return; + + // Store some useful helper data for the next steps. + depths[idx] = p_view.z; + bboxes[idx] = bbox; + tiles_touched[idx] = (rect_max.y - rect_min.y) * (rect_max.x - rect_min.x); +} + + +/** + * Generates one key/value pair for all voxel / tile overlaps. + * Run once per voxel (1:N mapping). + * + * @param P Number of points. + * @param grid Grid size. + * @param depths Depths of points. + * @param offsets Offsets for writing keys/values. + * @param bboxes Bounding boxes of voxels. + * @param keys_unsorted Unsorted keys. + * @param values_unsorted Unsorted values. + */ +static __global__ void duplicateWithKeys( + int P, dim3 grid, + const float* depths, + const int64_t* offsets, + const int4* bboxes, + int64_t* keys_unsorted, + uint32_t* values_unsorted +) { + auto idx = cg::this_grid().thread_rank(); + if (idx >= P) + return; + + // Generate no key/value pair for invisible voxels + if (bboxes[idx].w > 0) + { + // Find this voxel's offset in buffer for writing keys/values. + int64_t off = (idx == 0) ? 0 : offsets[idx - 1]; + uint2 rect_min, rect_max; + getRect(bboxes[idx], rect_min, rect_max, grid); + + // For each tile that the bounding rect overlaps, emit a + // key/value pair. The key is | tile ID | depth |, + // and the value is the ID of the voxel. Sorting the values + // with this key yields voxel IDs in a list, such that they + // are first sorted by tile and then by depth. + for (int y = rect_min.y; y < rect_max.y; y++) + { + for (int x = rect_min.x; x < rect_max.x; x++) + { + int64_t key = y * grid.x + x; + key <<= 32; + key |= *((uint32_t*)&depths[idx]); + keys_unsorted[off] = key; + values_unsorted[off] = idx; + off++; + } + } + } +} + + +/** + * Check keys to see if it is at the start/end of one tile's range in the full sorted list. If yes, write start/end of this tile. + * + * @param L Number of points. + * @param point_list_keys List of keys. + * @param ranges Ranges of tiles. + */ +static __global__ void identifyTileRanges(int L, int64_t* point_list_keys, uint2* ranges) +{ + auto idx = cg::this_grid().thread_rank(); + if (idx >= L) + return; + + // Read tile ID from key. Update start/end of tile range if at limit. + int64_t key = point_list_keys[idx]; + uint32_t currtile = key >> 32; + if (idx == 0) + ranges[currtile].x = 0; + else + { + uint32_t prevtile = point_list_keys[idx - 1] >> 32; + if (currtile != prevtile) + { + ranges[prevtile].y = idx; + ranges[currtile].x = idx; + } + } + if (idx == L - 1) + ranges[currtile].y = L; +} + + +/** + * Main rasterization method. Collaboratively works on one tile per + * block, each thread treats one pixel. Alternates between fetching + * and rasterizing data. + * + * @param ranges Ranges of voxel instances for each tile. + * @param point_list List of voxel instances. + * @param C Number of channels. + * @param W Width of the image. + * @param H Height of the image. + * @param cam_pos Camera position. + * @param tan_fovx Tangent of the horizontal field of view. + * @param tan_fovy Tangent of the vertical field of view. + * @param viewmatrix View matrix. + * @param positions Centers of voxels. + * @param attrs Attributes of voxels. + * @param voxel_size Size of voxels. + * @param out_color Output color. + * @param out_depth Output depth. + * @param out_alpha Output alpha. + */ +static __global__ void __launch_bounds__(BLOCK_X * BLOCK_Y) +render( + const uint2* ranges, + const uint32_t* point_list, + const int C, + const int W, + const int H, + const float* cam_pos, + const float tan_fovx, + const float tan_fovy, + const float* viewmatrix, + const float* positions, + const float* attrs, + const float voxel_size, + float* out_color, + float* out_depth, + float* out_alpha +) { + // Identify current tile and associated min/max pixel range. + auto block = cg::this_thread_block(); + uint32_t horizontal_blocks = (W + BLOCK_X - 1) / BLOCK_X; + uint2 pix_min = { block.group_index().x * BLOCK_X, block.group_index().y * BLOCK_Y }; + uint2 pix_max = { min(pix_min.x + BLOCK_X, W), min(pix_min.y + BLOCK_Y , H) }; + uint2 pix = { pix_min.x + block.thread_index().x, pix_min.y + block.thread_index().y }; + uint32_t pix_id = W * pix.y + pix.x; + + // Get ray direction and origin for this pixel. + float3 ray_dir = getRayDir(pix, W, H, tan_fovx, tan_fovy, viewmatrix); + + // Check if this thread is associated with a valid pixel or outside. + bool inside = pix.x < W&& pix.y < H; + // Done threads can help with fetching, but don't rasterize + bool done = !inside; + + // Load start/end range of IDs to process in bit sorted list. + uint2 range = ranges[block.group_index().y * horizontal_blocks + block.group_index().x]; + const int rounds = ((range.y - range.x + BLOCK_SIZE - 1) / BLOCK_SIZE); + int toDo = range.y - range.x; + + // Allocate storage for batches of collectively fetched data. + __shared__ int collected_id[BLOCK_SIZE]; + __shared__ float3 collected_xyz[BLOCK_SIZE]; + + // Initialize helper variables + int hit = -1; + float D; + + // Iterate over batches until all done or range is complete + for (int i = 0; i < rounds; i++, toDo -= BLOCK_SIZE) + { + // End if entire block votes that it is done rasterizing + int num_done = __syncthreads_count(done); + if (num_done == BLOCK_SIZE) + break; + + // Collectively fetch per-voxel data from global to shared + int progress = i * BLOCK_SIZE + block.thread_rank(); + if (range.x + progress < range.y) + { + int coll_id = point_list[range.x + progress]; + collected_id[block.thread_rank()] = coll_id; + collected_xyz[block.thread_rank()] = { + positions[3 * coll_id], + positions[3 * coll_id + 1], + positions[3 * coll_id + 2] + }; + } + block.sync(); + + // Iterate over current batch + for (int j = 0; !done && j < min(BLOCK_SIZE, toDo); j++) + { + // Get ray-voxel intersection + float3 p = collected_xyz[j]; + float3 scale = { voxel_size, voxel_size, voxel_size }; + float3 voxel_min = { p.x - 0.5f * scale.x, p.y - 0.5f * scale.y, p.z - 0.5f * scale.z }; + float3 voxel_max = { p.x + 0.5f * scale.x, p.y + 0.5f * scale.y, p.z + 0.5f * scale.z }; + float2 itsc = get_ray_voxel_intersection(*(float3*)cam_pos, ray_dir, voxel_min, voxel_max); + float itsc_dist = (itsc.y >= itsc.x) ? itsc.y - itsc.x : -1.0f; + if (itsc_dist <= 0.0f) + continue; + + hit = collected_id[j]; + D = itsc.x; + done = true; + } + } + + // All threads that treat valid pixel write out their final + // rendering data to the frame and auxiliary buffers. + if (inside) + { + for (int ch = 0; ch < C; ch++) + if (hit >= 0) out_color[ch * H * W + pix_id] = attrs[hit * C + ch]; + out_depth[pix_id] = D; + out_alpha[pix_id] = hit >= 0 ? 1.0f : 0.0f; + } +} + +void forward( + const int num_nodes, + const int num_channels, + const int width, + const int height, + const float* positions, + const float* attrs, + const float voxel_size, + const float* viewmatrix, + const float* projmatrix, + const float* campos, + const float tan_fovx, + const float tan_fovy, + float* out_color, + float* out_depth, + float* out_alpha +) { + // Parrallel config (2D grid of 2D blocks) + dim3 grid((width + BLOCK_X - 1) / BLOCK_X, (height + BLOCK_Y - 1) / BLOCK_Y, 1); + dim3 block(BLOCK_X, BLOCK_Y, 1); + + // Run preprocessing kernel + auto pt_bboxes = torch::zeros({num_nodes, 4}, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA)); + auto pt_depths = torch::zeros({num_nodes}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)); + auto pt_tiles_touched = torch::zeros({num_nodes}, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA)); + preprocess<<<(num_nodes+255)/256, 256>>>( + num_nodes, positions, voxel_size, viewmatrix, projmatrix, width, height, grid, + reinterpret_cast(pt_bboxes.data_ptr()), + pt_depths.data_ptr(), + reinterpret_cast(pt_tiles_touched.data_ptr()) + ); + + // Compute prefix sum over full list of touched tile counts by voxels + // E.g., [2, 3, 0, 2, 1] -> [2, 5, 5, 7, 8] + auto pt_offsets = torch::cumsum(pt_tiles_touched, 0); + + // Retrieve total number of voxel instances to launch + int num_rendered = pt_offsets[num_nodes - 1].item(); + if (num_rendered == 0) return; + + // For each instance to be rendered, produce adequate [ tile | depth ] key + auto pt_keys_unsorted = torch::zeros({num_rendered}, torch::TensorOptions().dtype(torch::kInt64).device(torch::kCUDA)); + auto pt_indices_unsorted = torch::zeros({num_rendered}, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA)); + duplicateWithKeys<<<(num_nodes+255)/256, 256>>>( + num_nodes, grid, + pt_depths.data_ptr(), + pt_offsets.data_ptr(), + reinterpret_cast(pt_bboxes.data_ptr()), + pt_keys_unsorted.data_ptr(), + reinterpret_cast(pt_indices_unsorted.data_ptr()) + ); + + // Sort complete list of (duplicated) voxel indices by keys + auto pt_sorted = torch::sort(pt_keys_unsorted, 0); + auto pt_keys = std::get<0>(pt_sorted); + auto pt_order = std::get<1>(pt_sorted); + auto pt_indices = torch::index_select(pt_indices_unsorted, 0, pt_order); + + // Identify start and end of per-tile workloads in sorted list + auto tile_ranges = torch::zeros({grid.x * grid.y, 2}, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA)); + identifyTileRanges<<<(num_rendered+255)/256, 256>>>( + num_rendered, + pt_keys.data_ptr(), + reinterpret_cast(tile_ranges.data_ptr()) + ); + + // Let each tile blend its range of voxels independently in parallel + render<<>>( + reinterpret_cast(tile_ranges.data_ptr()), + reinterpret_cast(pt_indices.data_ptr()), + num_channels, width, height, + campos, tan_fovx, tan_fovy, viewmatrix, + positions, attrs, voxel_size, + out_color, out_depth, out_alpha + ); +} + + +std::tuple +rasterize_voxels_cuda( + const torch::Tensor& positions, + const torch::Tensor& attrs, + const float voxel_size, + const torch::Tensor& viewmatrix, + const torch::Tensor& projmatrix, + const torch::Tensor& campos, + const float tan_fovx, + const float tan_fovy, + const int image_height, + const int image_width +) { + // Sizes + const int P = positions.size(0); + const int C = attrs.size(1); + const int H = image_height; + const int W = image_width; + + // Types + torch::TensorOptions float_opts = torch::TensorOptions().dtype(torch::kFloat32).device(positions.device()); + torch::TensorOptions byte_opts = torch::TensorOptions().dtype(torch::kUInt8).device(positions.device()); + + // Allocate output tensors + torch::Tensor out_color = torch::zeros({C, H, W}, float_opts); + torch::Tensor out_depth = torch::zeros({H, W}, float_opts); + torch::Tensor out_alpha = torch::zeros({H, W}, float_opts); + + // Call Forward + if (P > 0) { + forward( + P, C, W, H, + positions.contiguous().data_ptr(), + attrs.contiguous().data_ptr(), + voxel_size, + viewmatrix.contiguous().data_ptr(), + projmatrix.contiguous().data_ptr(), + campos.contiguous().data_ptr(), + tan_fovx, tan_fovy, + out_color.contiguous().data_ptr(), + out_depth.contiguous().data_ptr(), + out_alpha.contiguous().data_ptr() + ); + } + + return std::make_tuple( + out_color, out_depth, out_alpha + ); +} diff --git a/o-voxel/src/serialize/api.cu b/o-voxel/src/serialize/api.cu new file mode 100644 index 0000000..8a973bc --- /dev/null +++ b/o-voxel/src/serialize/api.cu @@ -0,0 +1,180 @@ +#include +#include "api.h" +#include "z_order.h" +#include "hilbert.h" + + +torch::Tensor +z_order_encode_cpu( + const torch::Tensor& x, + const torch::Tensor& y, + const torch::Tensor& z +) { + // Allocate output tensor + torch::Tensor codes = torch::empty_like(x, torch::dtype(torch::kInt32)); + + // Call CUDA kernel + CPU::z_order_encode( + x.size(0), + reinterpret_cast(x.contiguous().data_ptr()), + reinterpret_cast(y.contiguous().data_ptr()), + reinterpret_cast(z.contiguous().data_ptr()), + reinterpret_cast(codes.data_ptr()) + ); + + return codes; +} + + +std::tuple +z_order_decode_cpu( + const torch::Tensor& codes +) { + // Allocate output tensors + torch::Tensor x = torch::empty_like(codes, torch::dtype(torch::kInt32)); + torch::Tensor y = torch::empty_like(codes, torch::dtype(torch::kInt32)); + torch::Tensor z = torch::empty_like(codes, torch::dtype(torch::kInt32)); + + // Call CUDA kernel + CPU::z_order_decode( + codes.size(0), + reinterpret_cast(codes.contiguous().data_ptr()), + reinterpret_cast(x.data_ptr()), + reinterpret_cast(y.data_ptr()), + reinterpret_cast(z.data_ptr()) + ); + + return std::make_tuple(x, y, z); +} + + +torch::Tensor +hilbert_encode_cpu( + const torch::Tensor& x, + const torch::Tensor& y, + const torch::Tensor& z +) { + // Allocate output tensor + torch::Tensor codes = torch::empty_like(x); + + // Call CUDA kernel + CPU::hilbert_encode( + x.size(0), + reinterpret_cast(x.contiguous().data_ptr()), + reinterpret_cast(y.contiguous().data_ptr()), + reinterpret_cast(z.contiguous().data_ptr()), + reinterpret_cast(codes.data_ptr()) + ); + + return codes; +} + + +std::tuple +hilbert_decode_cpu( + const torch::Tensor& codes +) { + // Allocate output tensors + torch::Tensor x = torch::empty_like(codes); + torch::Tensor y = torch::empty_like(codes); + torch::Tensor z = torch::empty_like(codes); + + // Call CUDA kernel + CPU::hilbert_decode( + codes.size(0), + reinterpret_cast(codes.contiguous().data_ptr()), + reinterpret_cast(x.data_ptr()), + reinterpret_cast(y.data_ptr()), + reinterpret_cast(z.data_ptr()) + ); + + return std::make_tuple(x, y, z); +} + + +torch::Tensor +z_order_encode_cuda( + const torch::Tensor& x, + const torch::Tensor& y, + const torch::Tensor& z +) { + // Allocate output tensor + torch::Tensor codes = torch::empty_like(x, torch::dtype(torch::kInt32)); + + // Call CUDA kernel + CUDA::z_order_encode<<<(x.size(0) + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE>>>( + x.size(0), + reinterpret_cast(x.contiguous().data_ptr()), + reinterpret_cast(y.contiguous().data_ptr()), + reinterpret_cast(z.contiguous().data_ptr()), + reinterpret_cast(codes.data_ptr()) + ); + + return codes; +} + + +std::tuple +z_order_decode_cuda( + const torch::Tensor& codes +) { + // Allocate output tensors + torch::Tensor x = torch::empty_like(codes, torch::dtype(torch::kInt32)); + torch::Tensor y = torch::empty_like(codes, torch::dtype(torch::kInt32)); + torch::Tensor z = torch::empty_like(codes, torch::dtype(torch::kInt32)); + + // Call CUDA kernel + CUDA::z_order_decode<<<(codes.size(0) + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE>>>( + codes.size(0), + reinterpret_cast(codes.contiguous().data_ptr()), + reinterpret_cast(x.data_ptr()), + reinterpret_cast(y.data_ptr()), + reinterpret_cast(z.data_ptr()) + ); + + return std::make_tuple(x, y, z); +} + + +torch::Tensor +hilbert_encode_cuda( + const torch::Tensor& x, + const torch::Tensor& y, + const torch::Tensor& z +) { + // Allocate output tensor + torch::Tensor codes = torch::empty_like(x); + + // Call CUDA kernel + CUDA::hilbert_encode<<<(x.size(0) + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE>>>( + x.size(0), + reinterpret_cast(x.contiguous().data_ptr()), + reinterpret_cast(y.contiguous().data_ptr()), + reinterpret_cast(z.contiguous().data_ptr()), + reinterpret_cast(codes.data_ptr()) + ); + + return codes; +} + + +std::tuple +hilbert_decode_cuda( + const torch::Tensor& codes +) { + // Allocate output tensors + torch::Tensor x = torch::empty_like(codes); + torch::Tensor y = torch::empty_like(codes); + torch::Tensor z = torch::empty_like(codes); + + // Call CUDA kernel + CUDA::hilbert_decode<<<(codes.size(0) + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE>>>( + codes.size(0), + reinterpret_cast(codes.contiguous().data_ptr()), + reinterpret_cast(x.data_ptr()), + reinterpret_cast(y.data_ptr()), + reinterpret_cast(z.data_ptr()) + ); + + return std::make_tuple(x, y, z); +} diff --git a/o-voxel/src/serialize/api.h b/o-voxel/src/serialize/api.h new file mode 100644 index 0000000..70da702 --- /dev/null +++ b/o-voxel/src/serialize/api.h @@ -0,0 +1,136 @@ +/* + * Serialize a voxel grid + * + * Copyright (C) 2025, Jianfeng XIANG + * All rights reserved. + * + * Licensed under The MIT License [see LICENSE for details] + * + * Written by Jianfeng XIANG + */ + +#pragma once +#include + + +#define BLOCK_SIZE 256 + + +/** + * Z-order encode 3D points + * + * @param x [N] tensor containing the x coordinates + * @param y [N] tensor containing the y coordinates + * @param z [N] tensor containing the z coordinates + * + * @return [N] tensor containing the z-order encoded values + */ +torch::Tensor +z_order_encode_cuda( + const torch::Tensor& x, + const torch::Tensor& y, + const torch::Tensor& z +); + + +/** + * Z-order decode 3D points + * + * @param codes [N] tensor containing the z-order encoded values + * + * @return 3 tensors [N] containing the x, y, z coordinates + */ +std::tuple +z_order_decode_cuda( + const torch::Tensor& codes +); + + +/** + * Hilbert encode 3D points + * + * @param x [N] tensor containing the x coordinates + * @param y [N] tensor containing the y coordinates + * @param z [N] tensor containing the z coordinates + * + * @return [N] tensor containing the Hilbert encoded values + */ +torch::Tensor +hilbert_encode_cuda( + const torch::Tensor& x, + const torch::Tensor& y, + const torch::Tensor& z +); + + +/** + * Hilbert decode 3D points + * + * @param codes [N] tensor containing the Hilbert encoded values + * + * @return 3 tensors [N] containing the x, y, z coordinates + */ +std::tuple +hilbert_decode_cuda( + const torch::Tensor& codes +); + + +/** + * Z-order encode 3D points + * + * @param x [N] tensor containing the x coordinates + * @param y [N] tensor containing the y coordinates + * @param z [N] tensor containing the z coordinates + * + * @return [N] tensor containing the z-order encoded values + */ +torch::Tensor +z_order_encode_cpu( + const torch::Tensor& x, + const torch::Tensor& y, + const torch::Tensor& z +); + + +/** + * Z-order decode 3D points + * + * @param codes [N] tensor containing the z-order encoded values + * + * @return 3 tensors [N] containing the x, y, z coordinates + */ +std::tuple +z_order_decode_cpu( + const torch::Tensor& codes +); + + +/** + * Hilbert encode 3D points + * + * @param x [N] tensor containing the x coordinates + * @param y [N] tensor containing the y coordinates + * @param z [N] tensor containing the z coordinates + * + * @return [N] tensor containing the Hilbert encoded values + */ +torch::Tensor +hilbert_encode_cpu( + const torch::Tensor& x, + const torch::Tensor& y, + const torch::Tensor& z +); + + +/** + * Hilbert decode 3D points + * + * @param codes [N] tensor containing the Hilbert encoded values + * + * @return 3 tensors [N] containing the x, y, z coordinates + */ +std::tuple +hilbert_decode_cpu( + const torch::Tensor& codes +); diff --git a/o-voxel/src/serialize/hilbert.cu b/o-voxel/src/serialize/hilbert.cu new file mode 100644 index 0000000..715b056 --- /dev/null +++ b/o-voxel/src/serialize/hilbert.cu @@ -0,0 +1,230 @@ +#include +#include + +#include +namespace cg = cooperative_groups; + +#include "hilbert.h" + + +// Expands a 10-bit integer into 30 bits by inserting 2 zeros after each bit. +static __host__ __device__ uint32_t expandBits(uint32_t v) +{ + v = (v * 0x00010001u) & 0xFF0000FFu; + v = (v * 0x00000101u) & 0x0F00F00Fu; + v = (v * 0x00000011u) & 0xC30C30C3u; + v = (v * 0x00000005u) & 0x49249249u; + return v; +} + + +// Removes 2 zeros after each bit in a 30-bit integer. +static __host__ __device__ uint32_t extractBits(uint32_t v) +{ + v = v & 0x49249249; + v = (v ^ (v >> 2)) & 0x030C30C3u; + v = (v ^ (v >> 4)) & 0x0300F00Fu; + v = (v ^ (v >> 8)) & 0x030000FFu; + v = (v ^ (v >> 16)) & 0x000003FFu; + return v; +} + + +__host__ void CPU::hilbert_encode( + size_t N, + const uint32_t* x, + const uint32_t* y, + const uint32_t* z, + uint32_t* codes +) { + for (size_t thread_id = 0; thread_id < N; thread_id++) { + uint32_t point[3] = {x[thread_id], y[thread_id], z[thread_id]}; + + uint32_t m = 1 << 9, q, p, t; + + // Inverse undo excess work + q = m; + while (q > 1) { + p = q - 1; + for (int i = 0; i < 3; i++) { + if (point[i] & q) { + point[0] ^= p; // invert + } else { + t = (point[0] ^ point[i]) & p; + point[0] ^= t; + point[i] ^= t; + } + } + q >>= 1; + } + + // Gray encode + for (int i = 1; i < 3; i++) { + point[i] ^= point[i - 1]; + } + t = 0; + q = m; + while (q > 1) { + if (point[2] & q) { + t ^= q - 1; + } + q >>= 1; + } + for (int i = 0; i < 3; i++) { + point[i] ^= t; + } + + // Convert to 3D Hilbert code + uint32_t xx = expandBits(point[0]); + uint32_t yy = expandBits(point[1]); + uint32_t zz = expandBits(point[2]); + + codes[thread_id] = xx * 4 + yy * 2 + zz; + } +} + + +__host__ void CPU::hilbert_decode( + size_t N, + const uint32_t* codes, + uint32_t* x, + uint32_t* y, + uint32_t* z +) { + for (size_t thread_id = 0; thread_id < N; thread_id++) { + uint32_t point[3]; + point[0] = extractBits(codes[thread_id] >> 2); + point[1] = extractBits(codes[thread_id] >> 1); + point[2] = extractBits(codes[thread_id]); + + uint32_t m = 2 << 9, q, p, t; + + // Gray decode by H ^ (H/2) + t = point[2] >> 1; + for (int i = 2; i > 0; i--) { + point[i] ^= point[i - 1]; + } + point[0] ^= t; + + // Undo excess work + q = 2; + while (q != m) { + p = q - 1; + for (int i = 2; i >= 0; i--) { + if (point[i] & q) { + point[0] ^= p; + } else { + t = (point[0] ^ point[i]) & p; + point[0] ^= t; + point[i] ^= t; + } + } + q <<= 1; + } + + x[thread_id] = point[0]; + y[thread_id] = point[1]; + z[thread_id] = point[2]; + } +} + + +__global__ void CUDA::hilbert_encode( + size_t N, + const uint32_t* x, + const uint32_t* y, + const uint32_t* z, + uint32_t* codes +) { + size_t thread_id = cg::this_grid().thread_rank(); + if (thread_id >= N) return; + + uint32_t point[3] = {x[thread_id], y[thread_id], z[thread_id]}; + + uint32_t m = 1 << 9, q, p, t; + + // Inverse undo excess work + q = m; + while (q > 1) { + p = q - 1; + for (int i = 0; i < 3; i++) { + if (point[i] & q) { + point[0] ^= p; // invert + } else { + t = (point[0] ^ point[i]) & p; + point[0] ^= t; + point[i] ^= t; + } + } + q >>= 1; + } + + // Gray encode + for (int i = 1; i < 3; i++) { + point[i] ^= point[i - 1]; + } + t = 0; + q = m; + while (q > 1) { + if (point[2] & q) { + t ^= q - 1; + } + q >>= 1; + } + for (int i = 0; i < 3; i++) { + point[i] ^= t; + } + + // Convert to 3D Hilbert code + uint32_t xx = expandBits(point[0]); + uint32_t yy = expandBits(point[1]); + uint32_t zz = expandBits(point[2]); + + codes[thread_id] = xx * 4 + yy * 2 + zz; +} + + +__global__ void CUDA::hilbert_decode( + size_t N, + const uint32_t* codes, + uint32_t* x, + uint32_t* y, + uint32_t* z +) { + size_t thread_id = cg::this_grid().thread_rank(); + if (thread_id >= N) return; + + uint32_t point[3]; + point[0] = extractBits(codes[thread_id] >> 2); + point[1] = extractBits(codes[thread_id] >> 1); + point[2] = extractBits(codes[thread_id]); + + uint32_t m = 2 << 9, q, p, t; + + // Gray decode by H ^ (H/2) + t = point[2] >> 1; + for (int i = 2; i > 0; i--) { + point[i] ^= point[i - 1]; + } + point[0] ^= t; + + // Undo excess work + q = 2; + while (q != m) { + p = q - 1; + for (int i = 2; i >= 0; i--) { + if (point[i] & q) { + point[0] ^= p; + } else { + t = (point[0] ^ point[i]) & p; + point[0] ^= t; + point[i] ^= t; + } + } + q <<= 1; + } + + x[thread_id] = point[0]; + y[thread_id] = point[1]; + z[thread_id] = point[2]; +} diff --git a/o-voxel/src/serialize/hilbert.h b/o-voxel/src/serialize/hilbert.h new file mode 100644 index 0000000..55f5bc0 --- /dev/null +++ b/o-voxel/src/serialize/hilbert.h @@ -0,0 +1,74 @@ +#pragma once + +namespace CUDA { +/** + * Hilbert encode 3D points + * + * @param x [N] tensor containing the x coordinates + * @param y [N] tensor containing the y coordinates + * @param z [N] tensor containing the z coordinates + * + * @return [N] tensor containing the z-order encoded values + */ +__global__ void hilbert_encode( + size_t N, + const uint32_t* x, + const uint32_t* y, + const uint32_t* z, + uint32_t* codes +); + + +/** + * Hilbert decode 3D points + * + * @param codes [N] tensor containing the z-order encoded values + * @param x [N] tensor containing the x coordinates + * @param y [N] tensor containing the y coordinates + * @param z [N] tensor containing the z coordinates + */ +__global__ void hilbert_decode( + size_t N, + const uint32_t* codes, + uint32_t* x, + uint32_t* y, + uint32_t* z +); +} // namespace CUDA + + +namespace CPU { +/** + * Hilbert encode 3D points + * + * @param x [N] tensor containing the x coordinates + * @param y [N] tensor containing the y coordinates + * @param z [N] tensor containing the z coordinates + * + * @return [N] tensor containing the z-order encoded values + */ +__host__ void hilbert_encode( + size_t N, + const uint32_t* x, + const uint32_t* y, + const uint32_t* z, + uint32_t* codes +); + + +/** + * Hilbert decode 3D points + * + * @param codes [N] tensor containing the z-order encoded values + * @param x [N] tensor containing the x coordinates + * @param y [N] tensor containing the y coordinates + * @param z [N] tensor containing the z coordinates + */ +__host__ void hilbert_decode( + size_t N, + const uint32_t* codes, + uint32_t* x, + uint32_t* y, + uint32_t* z +); +} // namespace CPU diff --git a/o-voxel/src/serialize/z_order.cu b/o-voxel/src/serialize/z_order.cu new file mode 100644 index 0000000..d4ae118 --- /dev/null +++ b/o-voxel/src/serialize/z_order.cu @@ -0,0 +1,97 @@ +#include +#include + +#include +namespace cg = cooperative_groups; + +#include "z_order.h" + + +// Expands a 10-bit integer into 30 bits by inserting 2 zeros after each bit. +static __host__ __device__ uint32_t expandBits(uint32_t v) +{ + v = (v * 0x00010001u) & 0xFF0000FFu; + v = (v * 0x00000101u) & 0x0F00F00Fu; + v = (v * 0x00000011u) & 0xC30C30C3u; + v = (v * 0x00000005u) & 0x49249249u; + return v; +} + + +// Removes 2 zeros after each bit in a 30-bit integer. +static __host__ __device__ uint32_t extractBits(uint32_t v) +{ + v = v & 0x49249249; + v = (v ^ (v >> 2)) & 0x030C30C3u; + v = (v ^ (v >> 4)) & 0x0300F00Fu; + v = (v ^ (v >> 8)) & 0x030000FFu; + v = (v ^ (v >> 16)) & 0x000003FFu; + return v; +} + + +__host__ void CPU::z_order_encode( + size_t N, + const uint32_t* x, + const uint32_t* y, + const uint32_t* z, + uint32_t* codes +) { + for (size_t thread_id = 0; thread_id < N; thread_id++) { + uint32_t xx = expandBits(x[thread_id]); + uint32_t yy = expandBits(y[thread_id]); + uint32_t zz = expandBits(z[thread_id]); + + codes[thread_id] = xx * 4 + yy * 2 + zz; + } +} + + +__host__ void CPU::z_order_decode( + size_t N, + const uint32_t* codes, + uint32_t* x, + uint32_t* y, + uint32_t* z +) { + for (size_t thread_id = 0; thread_id < N; thread_id++) { + x[thread_id] = extractBits(codes[thread_id] >> 2); + y[thread_id] = extractBits(codes[thread_id] >> 1); + z[thread_id] = extractBits(codes[thread_id]); + } +} + + + +__global__ void CUDA::z_order_encode( + size_t N, + const uint32_t* x, + const uint32_t* y, + const uint32_t* z, + uint32_t* codes +) { + size_t thread_id = cg::this_grid().thread_rank(); + if (thread_id >= N) return; + + uint32_t xx = expandBits(x[thread_id]); + uint32_t yy = expandBits(y[thread_id]); + uint32_t zz = expandBits(z[thread_id]); + + codes[thread_id] = xx * 4 + yy * 2 + zz; +} + + +__global__ void CUDA::z_order_decode( + size_t N, + const uint32_t* codes, + uint32_t* x, + uint32_t* y, + uint32_t* z +) { + size_t thread_id = cg::this_grid().thread_rank(); + if (thread_id >= N) return; + + x[thread_id] = extractBits(codes[thread_id] >> 2); + y[thread_id] = extractBits(codes[thread_id] >> 1); + z[thread_id] = extractBits(codes[thread_id]); +} diff --git a/o-voxel/src/serialize/z_order.h b/o-voxel/src/serialize/z_order.h new file mode 100644 index 0000000..5ae9c22 --- /dev/null +++ b/o-voxel/src/serialize/z_order.h @@ -0,0 +1,74 @@ +#pragma once + +namespace CUDA { +/** + * Z-order encode 3D points + * + * @param x [N] tensor containing the x coordinates + * @param y [N] tensor containing the y coordinates + * @param z [N] tensor containing the z coordinates + * + * @return [N] tensor containing the z-order encoded values + */ +__global__ void z_order_encode( + size_t N, + const uint32_t* x, + const uint32_t* y, + const uint32_t* z, + uint32_t* codes +); + + +/** + * Z-order decode 3D points + * + * @param codes [N] tensor containing the z-order encoded values + * @param x [N] tensor containing the x coordinates + * @param y [N] tensor containing the y coordinates + * @param z [N] tensor containing the z coordinates + */ +__global__ void z_order_decode( + size_t N, + const uint32_t* codes, + uint32_t* x, + uint32_t* y, + uint32_t* z +); +} // namespace CUDA + + +namespace CPU { +/** + * Z-order encode 3D points + * + * @param x [N] tensor containing the x coordinates + * @param y [N] tensor containing the y coordinates + * @param z [N] tensor containing the z coordinates + * + * @return [N] tensor containing the z-order encoded values + */ +__host__ void z_order_encode( + size_t N, + const uint32_t* x, + const uint32_t* y, + const uint32_t* z, + uint32_t* codes +); + + +/** + * Z-order decode 3D points + * + * @param codes [N] tensor containing the z-order encoded values + * @param x [N] tensor containing the x coordinates + * @param y [N] tensor containing the y coordinates + * @param z [N] tensor containing the z coordinates + */ +__host__ void z_order_decode( + size_t N, + const uint32_t* codes, + uint32_t* x, + uint32_t* y, + uint32_t* z +); +} // namespace CPU diff --git a/o-voxel/third_party/eigen b/o-voxel/third_party/eigen new file mode 160000 index 0000000..21e4582 --- /dev/null +++ b/o-voxel/third_party/eigen @@ -0,0 +1 @@ +Subproject commit 21e4582d1739107337a03460c81412981130373e diff --git a/setup.sh b/setup.sh new file mode 100644 index 0000000..e09e2b1 --- /dev/null +++ b/setup.sh @@ -0,0 +1,139 @@ +# Read Arguments +TEMP=`getopt -o h --long help,new-env,basic,flash-attn,cumesh,o-voxel,flexgemm,nvdiffrast,nvdiffrec -n 'setup.sh' -- "$@"` + +eval set -- "$TEMP" + +HELP=false +NEW_ENV=false +BASIC=false +FLASHATTN=false +CUMESH=false +OVOXEL=false +FLEXGEMM=false +NVDIFFRAST=false +NVDIFFREC=false +ERROR=false + + +if [ "$#" -eq 1 ] ; then + HELP=true +fi + +while true ; do + case "$1" in + -h|--help) HELP=true ; shift ;; + --new-env) NEW_ENV=true ; shift ;; + --basic) BASIC=true ; shift ;; + --flash-attn) FLASHATTN=true ; shift ;; + --cumesh) CUMESH=true ; shift ;; + --o-voxel) OVOXEL=true ; shift ;; + --flexgemm) FLEXGEMM=true ; shift ;; + --nvdiffrast) NVDIFFRAST=true ; shift ;; + --nvdiffrec) NVDIFFREC=true ; shift ;; + --) shift ; break ;; + *) ERROR=true ; break ;; + esac +done + +if [ "$ERROR" = true ] ; then + echo "Error: Invalid argument" + HELP=true +fi + +if [ "$HELP" = true ] ; then + echo "Usage: setup.sh [OPTIONS]" + echo "Options:" + echo " -h, --help Display this help message" + echo " --new-env Create a new conda environment" + echo " --basic Install basic dependencies" + echo " --flash-attn Install flash-attention" + echo " --cumesh Install cumesh" + echo " --o-voxel Install o-voxel" + echo " --flexgemm Install flexgemm" + echo " --nvdiffrast Install nvdiffrast" + echo " --nvdiffrec Install nvdiffrec" + return +fi + +# Get system information +WORKDIR=$(pwd) +if command -v nvidia-smi > /dev/null; then + PLATFORM="cuda" +elif command -v rocminfo > /dev/null; then + PLATFORM="hip" +else + echo "Error: No supported GPU found" + exit 1 +fi + +if [ "$NEW_ENV" = true ] ; then + conda create -n trellis2 python=3.10 + conda activate trellis2 + if [ "$PLATFORM" = "cuda" ] ; then + pip install torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/cu124 + elif [ "$PLATFORM" = "hip" ] ; then + pip install torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/rocm6.2.4 + fi +fi + +if [ "$BASIC" = true ] ; then + pip install imageio imageio-ffmpeg tqdm easydict opencv-python-headless ninja trimesh transformers gradio==6.0.1 tensorboard pandas lpips zstandard + pip install git+https://github.com/EasternJournalist/utils3d.git@9a4eb15e4021b67b12c460c7057d642626897ec8 + sudo apt install -y libjpeg-dev + pip install pillow-simd + pip install kornia timm +fi + +if [ "$FLASHATTN" = true ] ; then + if [ "$PLATFORM" = "cuda" ] ; then + pip install flash-attn==2.7.3 + elif [ "$PLATFORM" = "hip" ] ; then + echo "[FLASHATTN] Prebuilt binaries not found. Building from source..." + mkdir -p /tmp/extensions + git clone --recursive https://github.com/ROCm/flash-attention.git /tmp/extensions/flash-attention + cd /tmp/extensions/flash-attention + git checkout tags/v2.7.3-cktile + GPU_ARCHS=gfx942 python setup.py install #MI300 series + cd $WORKDIR + else + echo "[FLASHATTN] Unsupported platform: $PLATFORM" + fi +fi + +if [ "$NVDIFFRAST" = true ] ; then + if [ "$PLATFORM" = "cuda" ] ; then + mkdir -p /tmp/extensions + git clone -b v0.4.0 https://github.com/NVlabs/nvdiffrast.git /tmp/extensions/nvdiffrast + pip install /tmp/extensions/nvdiffrast --no-build-isolation + else + echo "[NVDIFFRAST] Unsupported platform: $PLATFORM" + fi +fi + +if [ "$NVDIFFREC" = true ] ; then + if [ "$PLATFORM" = "cuda" ] ; then + mkdir -p /tmp/extensions + git clone -b renderutils https://github.com/JeffreyXiang/nvdiffrec.git /tmp/extensions/nvdiffrec + pip install /tmp/extensions/nvdiffrec --no-build-isolation + else + echo "[NVDIFFREC] Unsupported platform: $PLATFORM" + fi +fi + +if [ "$CUMESH" = true ] ; then + mkdir -p /tmp/extensions + git clone https://github.com/JeffreyXiang/CuMesh.git /tmp/extensions/CuMesh --recursive + pip install /tmp/extensions/CuMesh --no-build-isolation +fi + +if [ "$FLEXGEMM" = true ] ; then + mkdir -p /tmp/extensions + git clone https://github.com/JeffreyXiang/FlexGEMM.git /tmp/extensions/FlexGEMM --recursive + pip install /tmp/extensions/FlexGEMM --no-build-isolation +fi + +if [ "$OVOXEL" = true ] ; then + mkdir -p /tmp/extensions + cp -r o-voxel /tmp/extensions/o-voxel + pip install /tmp/extensions/o-voxel --no-build-isolation +fi diff --git a/trellis2/__init__.py b/trellis2/__init__.py new file mode 100644 index 0000000..20d240a --- /dev/null +++ b/trellis2/__init__.py @@ -0,0 +1,6 @@ +from . import models +from . import modules +from . import pipelines +from . import renderers +from . import representations +from . import utils diff --git a/trellis2/datasets/__init__.py b/trellis2/datasets/__init__.py new file mode 100644 index 0000000..b8f7d94 --- /dev/null +++ b/trellis2/datasets/__init__.py @@ -0,0 +1,46 @@ +import importlib + +__attributes = { + 'FlexiDualGridDataset': 'flexi_dual_grid', + 'SparseVoxelPbrDataset':'sparse_voxel_pbr', + + 'SparseStructureLatent': 'sparse_structure_latent', + 'TextConditionedSparseStructureLatent': 'sparse_structure_latent', + 'ImageConditionedSparseStructureLatent': 'sparse_structure_latent', + + 'SLat': 'structured_latent', + 'ImageConditionedSLat': 'structured_latent', + 'SLatShape': 'structured_latent_shape', + 'ImageConditionedSLatShape': 'structured_latent_shape', + 'SLatPbr': 'structured_latent_svpbr', + 'ImageConditionedSLatPbr': 'structured_latent_svpbr', +} + +__submodules = [] + +__all__ = list(__attributes.keys()) + __submodules + +def __getattr__(name): + if name not in globals(): + if name in __attributes: + module_name = __attributes[name] + module = importlib.import_module(f".{module_name}", __name__) + globals()[name] = getattr(module, name) + elif name in __submodules: + module = importlib.import_module(f".{name}", __name__) + globals()[name] = module + else: + raise AttributeError(f"module {__name__} has no attribute {name}") + return globals()[name] + + +# For Pylance +if __name__ == '__main__': + from .flexi_dual_grid import FlexiDualGridDataset + from .sparse_voxel_pbr import SparseVoxelPbrDataset + + from .sparse_structure_latent import SparseStructureLatent, ImageConditionedSparseStructureLatent + from .structured_latent import SLat, ImageConditionedSLat + from .structured_latent_shape import SLatShape, ImageConditionedSLatShape + from .structured_latent_svpbr import SLatPbr, ImageConditionedSLatPbr + \ No newline at end of file diff --git a/trellis2/datasets/components.py b/trellis2/datasets/components.py new file mode 100644 index 0000000..6c593ce --- /dev/null +++ b/trellis2/datasets/components.py @@ -0,0 +1,192 @@ +from typing import * +import json +from abc import abstractmethod +import os +import json +import torch +import numpy as np +import pandas as pd +from PIL import Image +from torch.utils.data import Dataset + + +class StandardDatasetBase(Dataset): + """ + Base class for standard datasets. + + Args: + roots (str): paths to the dataset + """ + + def __init__(self, + roots: str, + ): + super().__init__() + try: + self.roots = json.loads(roots) + root_type = 'obj' + except: + self.roots = roots.split(',') + root_type = 'list' + self.instances = [] + self.metadata = pd.DataFrame() + + self._stats = {} + if root_type == 'obj': + for key, root in self.roots.items(): + self._stats[key] = {} + metadata = pd.DataFrame(columns=['sha256']).set_index('sha256') + for _, r in root.items(): + metadata = metadata.combine_first(pd.read_csv(os.path.join(r, 'metadata.csv')).set_index('sha256')) + self._stats[key]['Total'] = len(metadata) + metadata, stats = self.filter_metadata(metadata) + self._stats[key].update(stats) + self.instances.extend([(root, sha256) for sha256 in metadata.index.values]) + self.metadata = pd.concat([self.metadata, metadata]) + else: + for root in self.roots: + key = os.path.basename(root) + self._stats[key] = {} + metadata = pd.read_csv(os.path.join(root, 'metadata.csv')) + self._stats[key]['Total'] = len(metadata) + metadata, stats = self.filter_metadata(metadata) + self._stats[key].update(stats) + self.instances.extend([(root, sha256) for sha256 in metadata['sha256'].values]) + metadata.set_index('sha256', inplace=True) + self.metadata = pd.concat([self.metadata, metadata]) + + @abstractmethod + def filter_metadata(self, metadata: pd.DataFrame) -> Tuple[pd.DataFrame, Dict[str, int]]: + pass + + @abstractmethod + def get_instance(self, root, instance: str) -> Dict[str, Any]: + pass + + def __len__(self): + return len(self.instances) + + def __getitem__(self, index) -> Dict[str, Any]: + try: + root, instance = self.instances[index] + return self.get_instance(root, instance) + except Exception as e: + print(f'Error loading {instance}: {e}') + return self.__getitem__(np.random.randint(0, len(self))) + + def __str__(self): + lines = [] + lines.append(self.__class__.__name__) + lines.append(f' - Total instances: {len(self)}') + lines.append(f' - Sources:') + for key, stats in self._stats.items(): + lines.append(f' - {key}:') + for k, v in stats.items(): + lines.append(f' - {k}: {v}') + return '\n'.join(lines) + + +class ImageConditionedMixin: + def __init__(self, roots, *, image_size=518, **kwargs): + self.image_size = image_size + super().__init__(roots, **kwargs) + + def filter_metadata(self, metadata): + metadata, stats = super().filter_metadata(metadata) + metadata = metadata[metadata['cond_rendered'].notna()] + stats['Cond rendered'] = len(metadata) + return metadata, stats + + def get_instance(self, root, instance): + pack = super().get_instance(root, instance) + + image_root = os.path.join(root['render_cond'], instance) + with open(os.path.join(image_root, 'transforms.json')) as f: + metadata = json.load(f) + n_views = len(metadata['frames']) + view = np.random.randint(n_views) + metadata = metadata['frames'][view] + + image_path = os.path.join(image_root, metadata['file_path']) + image = Image.open(image_path) + + alpha = np.array(image.getchannel(3)) + bbox = np.array(alpha).nonzero() + bbox = [bbox[1].min(), bbox[0].min(), bbox[1].max(), bbox[0].max()] + center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2] + hsize = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2 + aug_hsize = hsize + aug_center_offset = [0, 0] + aug_center = [center[0] + aug_center_offset[0], center[1] + aug_center_offset[1]] + aug_bbox = [int(aug_center[0] - aug_hsize), int(aug_center[1] - aug_hsize), int(aug_center[0] + aug_hsize), int(aug_center[1] + aug_hsize)] + image = image.crop(aug_bbox) + + image = image.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS) + alpha = image.getchannel(3) + image = image.convert('RGB') + image = torch.tensor(np.array(image)).permute(2, 0, 1).float() / 255.0 + alpha = torch.tensor(np.array(alpha)).float() / 255.0 + image = image * alpha.unsqueeze(0) + pack['cond'] = image + + return pack + + +class MultiImageConditionedMixin: + def __init__(self, roots, *, image_size=518, max_image_cond_view = 4, **kwargs): + self.image_size = image_size + self.max_image_cond_view = max_image_cond_view + super().__init__(roots, **kwargs) + + def filter_metadata(self, metadata): + metadata, stats = super().filter_metadata(metadata) + metadata = metadata[metadata['cond_rendered'].notna()] + stats['Cond rendered'] = len(metadata) + return metadata, stats + + def get_instance(self, root, instance): + pack = super().get_instance(root, instance) + + image_root = os.path.join(root['render_cond'], instance) + with open(os.path.join(image_root, 'transforms.json')) as f: + metadata = json.load(f) + + n_views = len(metadata['frames']) + n_sample_views = np.random.randint(1, self.max_image_cond_view+1) + + assert n_views >= n_sample_views, f'Not enough views to sample {n_sample_views} unique images.' + + sampled_views = np.random.choice(n_views, size=n_sample_views, replace=False) + + cond_images = [] + for v in sampled_views: + frame_info = metadata['frames'][v] + image_path = os.path.join(image_root, frame_info['file_path']) + image = Image.open(image_path) + + alpha = np.array(image.getchannel(3)) + bbox = np.array(alpha).nonzero() + bbox = [bbox[1].min(), bbox[0].min(), bbox[1].max(), bbox[0].max()] + center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2] + hsize = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2 + aug_hsize = hsize + aug_center = center + aug_bbox = [ + int(aug_center[0] - aug_hsize), + int(aug_center[1] - aug_hsize), + int(aug_center[0] + aug_hsize), + int(aug_center[1] + aug_hsize), + ] + + img = image.crop(aug_bbox) + img = img.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS) + alpha = img.getchannel(3) + img = img.convert('RGB') + img = torch.tensor(np.array(img)).permute(2, 0, 1).float() / 255.0 + alpha = torch.tensor(np.array(alpha)).float() / 255.0 + img = img * alpha.unsqueeze(0) + + cond_images.append(img) + + pack['cond'] = [torch.stack(cond_images, dim=0)] # (V,3,H,W) + return pack diff --git a/trellis2/datasets/flexi_dual_grid.py b/trellis2/datasets/flexi_dual_grid.py new file mode 100644 index 0000000..f870d83 --- /dev/null +++ b/trellis2/datasets/flexi_dual_grid.py @@ -0,0 +1,173 @@ +import os +import numpy as np +import pickle +import torch +import utils3d +from .components import StandardDatasetBase +from ..modules import sparse as sp +from ..renderers import MeshRenderer +from ..representations import Mesh +from ..utils.data_utils import load_balanced_group_indices +import o_voxel + + +class FlexiDualGridVisMixin: + @torch.no_grad() + def visualize_sample(self, x: dict): + mesh = x['mesh'] + + renderer = MeshRenderer({'near': 1, 'far': 3}) + renderer.rendering_options.resolution = 512 + renderer.rendering_options.ssaa = 4 + + # Build camera + yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2] + yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4) + yaws = [y + yaws_offset for y in yaws] + pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)] + + exts = [] + ints = [] + for yaw, pitch in zip(yaws, pitch): + orig = torch.tensor([ + np.sin(yaw) * np.cos(pitch), + np.cos(yaw) * np.cos(pitch), + np.sin(pitch), + ]).float().cuda() * 2 + fov = torch.deg2rad(torch.tensor(30)).cuda() + extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda()) + intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov) + exts.append(extrinsics) + ints.append(intrinsics) + + # Build each representation + images = [] + for m in mesh: + image = torch.zeros(3, 1024, 1024).cuda() + tile = [2, 2] + for j, (ext, intr) in enumerate(zip(exts, ints)): + image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = \ + renderer.render(m.cuda(), ext, intr)['normal'] + images.append(image) + images = torch.stack(images) + + return images + + +class FlexiDualGridDataset(FlexiDualGridVisMixin, StandardDatasetBase): + """ + Flexible Dual Grid Dataset + + Args: + roots (str): path to the dataset + resolution (int): resolution of the voxel grid + min_aesthetic_score (float): minimum aesthetic score of the instances to be included in the dataset + """ + + def __init__( + self, + roots, + resolution: int = 1024, + max_active_voxels: int = 1000000, + max_num_faces: int = None, + min_aesthetic_score: float = 5.0, + ): + self.resolution = resolution + self.min_aesthetic_score = min_aesthetic_score + self.max_active_voxels = max_active_voxels + self.max_num_faces = max_num_faces + self.value_range = (0, 1) + + super().__init__(roots) + + self.loads = [self.metadata.loc[sha256, f'dual_grid_size'] for _, sha256 in self.instances] + + def __str__(self): + lines = [ + super().__str__(), + f' - Resolution: {self.resolution}', + ] + return '\n'.join(lines) + + def filter_metadata(self, metadata): + stats = {} + metadata = metadata[metadata[f'dual_grid_converted'] == True] + stats['Dual Grid Converted'] = len(metadata) + if self.min_aesthetic_score is not None: + metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score] + stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata) + metadata = metadata[metadata[f'dual_grid_size'] <= self.max_active_voxels] + stats[f'Active Voxels <= {self.max_active_voxels}'] = len(metadata) + if self.max_num_faces is not None: + metadata = metadata[metadata['num_faces'] <= self.max_num_faces] + stats[f'Faces <= {self.max_num_faces}'] = len(metadata) + return metadata, stats + + def read_mesh(self, root, instance): + with open(os.path.join(root, f'{instance}.pickle'), 'rb') as f: + dump = pickle.load(f) + start = 0 + vertices = [] + faces = [] + for obj in dump['objects']: + if obj['vertices'].size == 0 or obj['faces'].size == 0: + continue + vertices.append(obj['vertices']) + faces.append(obj['faces'] + start) + start += len(obj['vertices']) + vertices = torch.from_numpy(np.concatenate(vertices, axis=0)).float() + faces = torch.from_numpy(np.concatenate(faces, axis=0)).long() + vertices_min = vertices.min(dim=0)[0] + vertices_max = vertices.max(dim=0)[0] + center = (vertices_min + vertices_max) / 2 + scale = 0.99999 / (vertices_max - vertices_min).max() + vertices = (vertices - center) * scale + assert torch.all(vertices >= -0.5) and torch.all(vertices <= 0.5), 'vertices out of range' + return {'mesh': [Mesh(vertices=vertices, faces=faces)]} + + def read_dual_grid(self, root, instance): + coords, attr = o_voxel.io.read_vxz(os.path.join(root, f'{instance}.vxz'), num_threads=4) + vertices = sp.SparseTensor( + (attr['vertices'] / 255.0).float(), + torch.cat([torch.zeros_like(coords[:, 0:1]), coords], dim=-1), + ) + intersected = vertices.replace(torch.cat([ + attr['intersected'] % 2, + attr['intersected'] // 2 % 2, + attr['intersected'] // 4 % 2, + ], dim=-1).bool()) + return {'vertices': vertices, 'intersected': intersected} + + def get_instance(self, root, instance): + mesh = self.read_mesh(root['mesh_dump'], instance) + dual_grid = self.read_dual_grid(root['dual_grid'], instance) + return {**mesh, **dual_grid} + + @staticmethod + def collate_fn(batch, split_size=None): + if split_size is None: + group_idx = [list(range(len(batch)))] + else: + group_idx = load_balanced_group_indices([b['vertices'].feats.shape[0] for b in batch], split_size) + packs = [] + for group in group_idx: + sub_batch = [batch[i] for i in group] + pack = {} + + keys = [k for k in sub_batch[0].keys()] + for k in keys: + if isinstance(sub_batch[0][k], torch.Tensor): + pack[k] = torch.stack([b[k] for b in sub_batch]) + elif isinstance(sub_batch[0][k], sp.SparseTensor): + pack[k] = sp.sparse_cat([b[k] for b in sub_batch], dim=0) + elif isinstance(sub_batch[0][k], list): + pack[k] = sum([b[k] for b in sub_batch], []) + else: + pack[k] = [b[k] for b in sub_batch] + + packs.append(pack) + + if split_size is None: + return packs[0] + return packs + \ No newline at end of file diff --git a/trellis2/datasets/sparse_structure_latent.py b/trellis2/datasets/sparse_structure_latent.py new file mode 100644 index 0000000..498e115 --- /dev/null +++ b/trellis2/datasets/sparse_structure_latent.py @@ -0,0 +1,160 @@ +import os +import json +from typing import * +import numpy as np +import torch +from ..representations import Voxel +from ..renderers import VoxelRenderer +from .components import StandardDatasetBase, ImageConditionedMixin +from .. import models +from ..utils.render_utils import yaw_pitch_r_fov_to_extrinsics_intrinsics + + +class SparseStructureLatentVisMixin: + def __init__( + self, + *args, + pretrained_ss_dec: str = 'JeffreyXiang/TRELLIS-image-large/ckpts/ss_dec_conv3d_16l8_fp16.json', + ss_dec_path: Optional[str] = None, + ss_dec_ckpt: Optional[str] = None, + **kwargs + ): + super().__init__(*args, **kwargs) + self.ss_dec = None + self.pretrained_ss_dec = pretrained_ss_dec + self.ss_dec_path = ss_dec_path + self.ss_dec_ckpt = ss_dec_ckpt + + def _loading_ss_dec(self): + if self.ss_dec is not None: + return + if self.ss_dec_path is not None: + cfg = json.load(open(os.path.join(self.ss_dec_path, 'config.json'), 'r')) + decoder = getattr(models, cfg['models']['decoder']['name'])(**cfg['models']['decoder']['args']) + ckpt_path = os.path.join(self.ss_dec_path, 'ckpts', f'decoder_{self.ss_dec_ckpt}.pt') + decoder.load_state_dict(torch.load(ckpt_path, map_location='cpu', weights_only=True)) + else: + decoder = models.from_pretrained(self.pretrained_ss_dec) + self.ss_dec = decoder.cuda().eval() + + def _delete_ss_dec(self): + del self.ss_dec + self.ss_dec = None + + @torch.no_grad() + def decode_latent(self, z, batch_size=4): + self._loading_ss_dec() + ss = [] + if self.normalization: + z = z * self.std.to(z.device) + self.mean.to(z.device) + for i in range(0, z.shape[0], batch_size): + ss.append(self.ss_dec(z[i:i+batch_size])) + ss = torch.cat(ss, dim=0) + self._delete_ss_dec() + return ss + + @torch.no_grad() + def visualize_sample(self, x_0: Union[torch.Tensor, dict]): + x_0 = x_0 if isinstance(x_0, torch.Tensor) else x_0['x_0'] + x_0 = self.decode_latent(x_0.cuda()) + + renderer = VoxelRenderer() + renderer.rendering_options.resolution = 512 + renderer.rendering_options.ssaa = 4 + + # build camera + yaw = [0, np.pi/2, np.pi, 3*np.pi/2] + yaw_offset = -16 / 180 * np.pi + yaw = [y + yaw_offset for y in yaw] + pitch = [20 / 180 * np.pi for _ in range(4)] + exts, ints = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaw, pitch, 2, 30) + + images = [] + + # Build each representation + x_0 = x_0.cuda() + for i in range(x_0.shape[0]): + coords = torch.nonzero(x_0[i, 0] > 0, as_tuple=False) + resolution = x_0.shape[-1] + color = coords / resolution + rep = Voxel( + origin=[-0.5, -0.5, -0.5], + voxel_size=1/resolution, + coords=coords, + attrs=color, + layout={ + 'color': slice(0, 3), + } + ) + image = torch.zeros(3, 1024, 1024).cuda() + tile = [2, 2] + for j, (ext, intr) in enumerate(zip(exts, ints)): + res = renderer.render(rep, ext, intr, colors_overwrite=color) + image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['color'] + images.append(image) + + return torch.stack(images) + + +class SparseStructureLatent(SparseStructureLatentVisMixin, StandardDatasetBase): + """ + Sparse structure latent dataset + + Args: + roots (str): path to the dataset + min_aesthetic_score (float): minimum aesthetic score + normalization (dict): normalization stats + pretrained_ss_dec (str): name of the pretrained sparse structure decoder + ss_dec_path (str): path to the sparse structure decoder, if given, will override the pretrained_ss_dec + ss_dec_ckpt (str): name of the sparse structure decoder checkpoint + """ + def __init__(self, + roots: str, + *, + min_aesthetic_score: float = 5.0, + normalization: Optional[dict] = None, + pretrained_ss_dec: str = 'JeffreyXiang/TRELLIS-image-large/ckpts/ss_dec_conv3d_16l8_fp16', + ss_dec_path: Optional[str] = None, + ss_dec_ckpt: Optional[str] = None, + ): + self.min_aesthetic_score = min_aesthetic_score + self.normalization = normalization + self.value_range = (0, 1) + + super().__init__( + roots, + pretrained_ss_dec=pretrained_ss_dec, + ss_dec_path=ss_dec_path, + ss_dec_ckpt=ss_dec_ckpt, + ) + + if self.normalization is not None: + self.mean = torch.tensor(self.normalization['mean']).reshape(-1, 1, 1, 1) + self.std = torch.tensor(self.normalization['std']).reshape(-1, 1, 1, 1) + + def filter_metadata(self, metadata): + stats = {} + metadata = metadata[metadata['ss_latent_encoded'] == True] + stats['With latent'] = len(metadata) + metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score] + stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata) + return metadata, stats + + def get_instance(self, root, instance): + latent = np.load(os.path.join(root['ss_latent'], f'{instance}.npz')) + z = torch.tensor(latent['z']).float() + if self.normalization is not None: + z = (z - self.mean) / self.std + + pack = { + 'x_0': z, + } + return pack + + +class ImageConditionedSparseStructureLatent(ImageConditionedMixin, SparseStructureLatent): + """ + Image-conditioned sparse structure dataset + """ + pass + \ No newline at end of file diff --git a/trellis2/datasets/sparse_voxel_pbr.py b/trellis2/datasets/sparse_voxel_pbr.py new file mode 100644 index 0000000..2149b5b --- /dev/null +++ b/trellis2/datasets/sparse_voxel_pbr.py @@ -0,0 +1,298 @@ +import os +import io +from typing import Union +import numpy as np +import pickle +import torch +from PIL import Image +import o_voxel +import utils3d +from .components import StandardDatasetBase +from ..modules import sparse as sp +from ..renderers import VoxelRenderer +from ..representations.mesh import Voxel, MeshWithPbrMaterial, TextureFilterMode, TextureWrapMode, AlphaMode, PbrMaterial, Texture + +from ..utils.data_utils import load_balanced_group_indices +from ..utils.mesh_utils import subdivide_to_size + + +def is_power_of_two(n: int) -> bool: + return n > 0 and (n & (n - 1)) == 0 + + +def nearest_power_of_two(n: int) -> int: + if n < 1: + raise ValueError("n must be >= 1") + if is_power_of_two(n): + return n + lower = 2 ** (n.bit_length() - 1) + upper = 2 ** n.bit_length() + if n - lower < upper - n: + return lower + else: + return upper + + +class SparseVoxelPbrVisMixin: + @torch.no_grad() + def visualize_sample(self, x: Union[sp.SparseTensor, dict]): + x = x if isinstance(x, sp.SparseTensor) else x['x'] + + renderer = VoxelRenderer() + renderer.rendering_options.resolution = 512 + renderer.rendering_options.ssaa = 4 + + # Build camera + yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2] + yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4) + yaws = [y + yaws_offset for y in yaws] + pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)] + + exts = [] + ints = [] + for yaw, pitch in zip(yaws, pitch): + orig = torch.tensor([ + np.sin(yaw) * np.cos(pitch), + np.cos(yaw) * np.cos(pitch), + np.sin(pitch), + ]).float().cuda() * 2 + fov = torch.deg2rad(torch.tensor(30)).cuda() + extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda()) + intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov) + exts.append(extrinsics) + ints.append(intrinsics) + + images = {k: [] for k in self.layout} + + # Build each representation + x = x.cuda() + for i in range(x.shape[0]): + rep = Voxel( + origin=[-0.5, -0.5, -0.5], + voxel_size=1/self.resolution, + coords=x[i].coords[:, 1:].contiguous(), + attrs=attr, + layout={ + 'color': slice(0, 3), + } + ) + for k in self.layout: + image = torch.zeros(3, 1024, 1024).cuda() + tile = [2, 2] + for j, (ext, intr) in enumerate(zip(exts, ints)): + attr = x[i].feats[:, self.layout[k]].expand(-1, 3) + res = renderer.render(rep, ext, intr) + image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['color'] + images[k].append(image) + + for k in self.layout: + images[k] = torch.stack(images[k]) + + return images + + +class SparseVoxelPbrDataset(SparseVoxelPbrVisMixin, StandardDatasetBase): + """ + Sparse Voxel PBR dataset. + + Args: + roots (str): path to the dataset + resolution (int): resolution of the voxel grid + min_aesthetic_score (float): minimum aesthetic score of the instances to be included in the dataset + """ + + def __init__( + self, + roots, + resolution: int = 1024, + max_active_voxels: int = 1000000, + max_num_faces: int = None, + min_aesthetic_score: float = 5.0, + attrs: list[str] = ['base_color', 'metallic', 'roughness', 'emissive', 'alpha'], + with_mesh: bool = True, + ): + self.resolution = resolution + self.min_aesthetic_score = min_aesthetic_score + self.max_active_voxels = max_active_voxels + self.max_num_faces = max_num_faces + self.with_mesh = with_mesh + self.value_range = (-1, 1) + self.channels = { + 'base_color': 3, + 'metallic': 1, + 'roughness': 1, + 'emissive': 3, + 'alpha': 1, + } + self.layout = {} + start = 0 + for attr in attrs: + self.layout[attr] = slice(start, start + self.channels[attr]) + start += self.channels[attr] + + super().__init__(roots) + + self.loads = [self.metadata.loc[sha256, f'num_pbr_voxels'] for _, sha256 in self.instances] + + def __str__(self): + lines = [ + super().__str__(), + f' - Resolution: {self.resolution}', + f' - Attributes: {list(self.layout.keys())}', + ] + return '\n'.join(lines) + + def filter_metadata(self, metadata): + stats = {} + metadata = metadata[metadata['pbr_voxelized'] == True] + stats['PBR Voxelized'] = len(metadata) + if self.min_aesthetic_score is not None: + metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score] + stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata) + metadata = metadata[metadata['num_pbr_voxels'] <= self.max_active_voxels] + stats[f'Active voxels <= {self.max_active_voxels}'] = len(metadata) + if self.max_num_faces is not None: + metadata = metadata[metadata['num_faces'] <= self.max_num_faces] + stats[f'Faces <= {self.max_num_faces}'] = len(metadata) + return metadata, stats + + @staticmethod + def _texture_from_dump(pack) -> Texture: + png_bytes = pack['image'] + image = Image.open(io.BytesIO(png_bytes)) + if image.width != image.height or not is_power_of_two(image.width): + size = nearest_power_of_two(max(image.width, image.height)) + image = image.resize((size, size), Image.LANCZOS) + texture = torch.tensor(np.array(image) / 255.0, dtype=torch.float32).reshape(image.height, image.width, -1) + filter_mode = { + 'Linear': TextureFilterMode.LINEAR, + 'Closest': TextureFilterMode.CLOSEST, + 'Cubic': TextureFilterMode.LINEAR, + 'Smart': TextureFilterMode.LINEAR, + }[pack['interpolation']] + wrap_mode = { + 'REPEAT': TextureWrapMode.REPEAT, + 'EXTEND': TextureWrapMode.CLAMP_TO_EDGE, + 'CLIP': TextureWrapMode.CLAMP_TO_EDGE, + 'MIRROR': TextureWrapMode.MIRRORED_REPEAT, + }[pack['extension']] + return Texture(texture, filter_mode=filter_mode, wrap_mode=wrap_mode) + + def read_mesh_with_texture(self, root, instance): + with open(os.path.join(root, f'{instance}.pickle'), 'rb') as f: + dump = pickle.load(f) + + # Fix dump alpha map + for mat in dump['materials']: + if mat['alphaTexture'] is not None and mat['alphaMode'] == 'OPAQUE': + mat['alphaMode'] = 'BLEND' + + # process material + materials = [] + for mat in dump['materials']: + materials.append(PbrMaterial( + base_color_texture=self._texture_from_dump(mat['baseColorTexture']) if mat['baseColorTexture'] is not None else None, + base_color_factor=mat['baseColorFactor'], + metallic_texture=self._texture_from_dump(mat['metallicTexture']) if mat['metallicTexture'] is not None else None, + metallic_factor=mat['metallicFactor'], + roughness_texture=self._texture_from_dump(mat['roughnessTexture']) if mat['roughnessTexture'] is not None else None, + roughness_factor=mat['roughnessFactor'], + alpha_texture=self._texture_from_dump(mat['alphaTexture']) if mat['alphaTexture'] is not None else None, + alpha_factor=mat['alphaFactor'], + alpha_mode={ + 'OPAQUE': AlphaMode.OPAQUE, + 'MASK': AlphaMode.MASK, + 'BLEND': AlphaMode.BLEND, + }[mat['alphaMode']], + alpha_cutoff=mat['alphaCutoff'], + )) + materials.append(PbrMaterial( + base_color_factor=[0.8, 0.8, 0.8], + alpha_factor=1.0, + metallic_factor=0.0, + roughness_factor=0.5, + alpha_mode=AlphaMode.OPAQUE, + alpha_cutoff=0.5, + )) # append default material + + # process mesh + start = 0 + vertices = [] + faces = [] + material_ids = [] + uv_coords = [] + for obj in dump['objects']: + if obj['vertices'].size == 0 or obj['faces'].size == 0: + continue + vertices.append(obj['vertices']) + faces.append(obj['faces'] + start) + obj['mat_ids'][obj['mat_ids'] == -1] = len(materials) - 1 + material_ids.append(obj['mat_ids']) + uv_coords.append(obj['uvs'] if obj['uvs'] is not None else np.zeros((obj['faces'].shape[0], 3, 2), dtype=np.float32)) + start += len(obj['vertices']) + + vertices = torch.from_numpy(np.concatenate(vertices, axis=0)).float() + faces = torch.from_numpy(np.concatenate(faces, axis=0)).long() + material_ids = torch.from_numpy(np.concatenate(material_ids, axis=0)).long() + uv_coords = torch.from_numpy(np.concatenate(uv_coords, axis=0)).float() + + # Normalize vertices + vertices_min = vertices.min(dim=0)[0] + vertices_max = vertices.max(dim=0)[0] + center = (vertices_min + vertices_max) / 2 + scale = 0.99999 / (vertices_max - vertices_min).max() + vertices = (vertices - center) * scale + assert torch.all(vertices >= -0.5) and torch.all(vertices <= 0.5), 'vertices out of range' + + return {'mesh': [MeshWithPbrMaterial( + vertices=vertices, + faces=faces, + material_ids=material_ids, + uv_coords=uv_coords, + materials=materials, + )]} + + def read_pbr_voxel(self, root, instance): + coords, attr = o_voxel.io.read_vxz(os.path.join(root, f'{instance}.vxz'), num_threads=4) + feats = torch.concat([attr[k] for k in self.layout], dim=-1) / 255.0 * 2 - 1 + x = sp.SparseTensor( + feats.float(), + torch.cat([torch.zeros_like(coords[:, 0:1]), coords], dim=-1), + ) + return {'x': x} + + def get_instance(self, root, instance): + if self.with_mesh: + mesh = self.read_mesh_with_texture(root['pbr_dump'], instance) + pbr_voxel = self.read_pbr_voxel(root['pbr_voxel'], instance) + return {**mesh, **pbr_voxel} + else: + return self.read_pbr_voxel(root['pbr_voxel'], instance) + + @staticmethod + def collate_fn(batch, split_size=None): + if split_size is None: + group_idx = [list(range(len(batch)))] + else: + group_idx = load_balanced_group_indices([b['x'].feats.shape[0] for b in batch], split_size) + packs = [] + for group in group_idx: + sub_batch = [batch[i] for i in group] + pack = {} + + keys = [k for k in sub_batch[0].keys()] + for k in keys: + if isinstance(sub_batch[0][k], torch.Tensor): + pack[k] = torch.stack([b[k] for b in sub_batch]) + elif isinstance(sub_batch[0][k], sp.SparseTensor): + pack[k] = sp.sparse_cat([b[k] for b in sub_batch], dim=0) + elif isinstance(sub_batch[0][k], list): + pack[k] = sum([b[k] for b in sub_batch], []) + else: + pack[k] = [b[k] for b in sub_batch] + + packs.append(pack) + + if split_size is None: + return packs[0] + return packs diff --git a/trellis2/datasets/structured_latent.py b/trellis2/datasets/structured_latent.py new file mode 100644 index 0000000..2e18a27 --- /dev/null +++ b/trellis2/datasets/structured_latent.py @@ -0,0 +1,210 @@ +import json +import os +from typing import * +import numpy as np +import torch +import utils3d.torch +from .components import StandardDatasetBase, ImageConditionedMixin +from ..modules.sparse.basic import SparseTensor +from .. import models +from ..utils.render_utils import get_renderer +from ..utils.data_utils import load_balanced_group_indices + + +class SLatVisMixin: + def __init__( + self, + *args, + pretrained_slat_dec: str = 'JeffreyXiang/TRELLIS-image-large/ckpts/slat_dec_gs_swin8_B_64l8gs32_fp16', + slat_dec_path: Optional[str] = None, + slat_dec_ckpt: Optional[str] = None, + **kwargs + ): + super().__init__(*args, **kwargs) + self.slat_dec = None + self.pretrained_slat_dec = pretrained_slat_dec + self.slat_dec_path = slat_dec_path + self.slat_dec_ckpt = slat_dec_ckpt + + def _loading_slat_dec(self): + if self.slat_dec is not None: + return + if self.slat_dec_path is not None: + cfg = json.load(open(os.path.join(self.slat_dec_path, 'config.json'), 'r')) + decoder = getattr(models, cfg['models']['decoder']['name'])(**cfg['models']['decoder']['args']) + ckpt_path = os.path.join(self.slat_dec_path, 'ckpts', f'decoder_{self.slat_dec_ckpt}.pt') + decoder.load_state_dict(torch.load(ckpt_path, map_location='cpu', weights_only=True)) + else: + decoder = models.from_pretrained(self.pretrained_slat_dec) + self.slat_dec = decoder.cuda().eval() + + def _delete_slat_dec(self): + del self.slat_dec + self.slat_dec = None + + @torch.no_grad() + def decode_latent(self, z, batch_size=4): + self._loading_slat_dec() + reps = [] + if self.normalization is not None: + z = z * self.std.to(z.device) + self.mean.to(z.device) + for i in range(0, z.shape[0], batch_size): + reps.append(self.slat_dec(z[i:i+batch_size])) + reps = sum(reps, []) + self._delete_slat_dec() + return reps + + @torch.no_grad() + def visualize_sample(self, x_0: Union[SparseTensor, dict]): + x_0 = x_0 if isinstance(x_0, SparseTensor) else x_0['x_0'] + reps = self.decode_latent(x_0.cuda()) + + # Build camera + yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2] + yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4) + yaws = [y + yaws_offset for y in yaws] + pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)] + + exts = [] + ints = [] + for yaw, pitch in zip(yaws, pitch): + orig = torch.tensor([ + np.sin(yaw) * np.cos(pitch), + np.cos(yaw) * np.cos(pitch), + np.sin(pitch), + ]).float().cuda() * 2 + fov = torch.deg2rad(torch.tensor(40)).cuda() + extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda()) + intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov) + exts.append(extrinsics) + ints.append(intrinsics) + + renderer = get_renderer(reps[0]) + images = [] + for representation in reps: + image = torch.zeros(3, 1024, 1024).cuda() + tile = [2, 2] + for j, (ext, intr) in enumerate(zip(exts, ints)): + res = renderer.render(representation, ext, intr) + image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['color'] + images.append(image) + images = torch.stack(images) + + return images + + +class SLat(SLatVisMixin, StandardDatasetBase): + """ + structured latent V2 dataset + + Args: + roots (str): path to the dataset + min_aesthetic_score (float): minimum aesthetic score + max_tokens (int): maximum number of tokens + latent_key (str): key of the latent to be used + normalization (dict): normalization stats + pretrained_slat_dec (str): name of the pretrained slat decoder + slat_dec_path (str): path to the slat decoder, if given, will override the pretrained_slat_dec + slat_dec_ckpt (str): name of the slat decoder checkpoint + """ + def __init__(self, + roots: str, + *, + min_aesthetic_score: float = 5.0, + max_tokens: int = 32768, + latent_key: str = 'shape_latent', + normalization: Optional[dict] = None, + pretrained_slat_dec: str = 'JeffreyXiang/TRELLIS-image-large/ckpts/slat_dec_gs_swin8_B_64l8gs32_fp16', + slat_dec_path: Optional[str] = None, + slat_dec_ckpt: Optional[str] = None, + ): + self.normalization = normalization + self.min_aesthetic_score = min_aesthetic_score + self.max_tokens = max_tokens + self.latent_key = latent_key + self.value_range = (0, 1) + + super().__init__( + roots, + pretrained_slat_dec=pretrained_slat_dec, + slat_dec_path=slat_dec_path, + slat_dec_ckpt=slat_dec_ckpt, + ) + + self.loads = [self.metadata.loc[sha256, f'{latent_key}_tokens'] for _, sha256 in self.instances] + + if self.normalization is not None: + self.mean = torch.tensor(self.normalization['mean']).reshape(1, -1) + self.std = torch.tensor(self.normalization['std']).reshape(1, -1) + + def filter_metadata(self, metadata): + stats = {} + metadata = metadata[metadata[f'{self.latent_key}_encoded'] == True] + stats['With latent'] = len(metadata) + metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score] + stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata) + metadata = metadata[metadata[f'{self.latent_key}_tokens'] <= self.max_tokens] + stats[f'Num tokens <= {self.max_tokens}'] = len(metadata) + return metadata, stats + + def get_instance(self, root, instance): + data = np.load(os.path.join(root[self.latent_key], f'{instance}.npz')) + coords = torch.tensor(data['coords']).int() + feats = torch.tensor(data['feats']).float() + if self.normalization is not None: + feats = (feats - self.mean) / self.std + return { + 'coords': coords, + 'feats': feats, + } + + @staticmethod + def collate_fn(batch, split_size=None): + if split_size is None: + group_idx = [list(range(len(batch)))] + else: + group_idx = load_balanced_group_indices([b['coords'].shape[0] for b in batch], split_size) + packs = [] + for group in group_idx: + sub_batch = [batch[i] for i in group] + pack = {} + coords = [] + feats = [] + layout = [] + start = 0 + for i, b in enumerate(sub_batch): + coords.append(torch.cat([torch.full((b['coords'].shape[0], 1), i, dtype=torch.int32), b['coords']], dim=-1)) + feats.append(b['feats']) + layout.append(slice(start, start + b['coords'].shape[0])) + start += b['coords'].shape[0] + coords = torch.cat(coords) + feats = torch.cat(feats) + pack['x_0'] = SparseTensor( + coords=coords, + feats=feats, + ) + pack['x_0']._shape = torch.Size([len(group), *sub_batch[0]['feats'].shape[1:]]) + pack['x_0'].register_spatial_cache('layout', layout) + + # collate other data + keys = [k for k in sub_batch[0].keys() if k not in ['coords', 'feats']] + for k in keys: + if isinstance(sub_batch[0][k], torch.Tensor): + pack[k] = torch.stack([b[k] for b in sub_batch]) + elif isinstance(sub_batch[0][k], list): + pack[k] = sum([b[k] for b in sub_batch], []) + else: + pack[k] = [b[k] for b in sub_batch] + + packs.append(pack) + + if split_size is None: + return packs[0] + return packs + + +class ImageConditionedSLat(ImageConditionedMixin, SLat): + """ + Image conditioned structured latent dataset + """ + pass diff --git a/trellis2/datasets/structured_latent_shape.py b/trellis2/datasets/structured_latent_shape.py new file mode 100644 index 0000000..e4a7d88 --- /dev/null +++ b/trellis2/datasets/structured_latent_shape.py @@ -0,0 +1,96 @@ +import os +import json +from typing import * +import numpy as np +import torch +from .. import models +from .components import ImageConditionedMixin +from ..modules.sparse import SparseTensor +from .structured_latent import SLatVisMixin, SLat +from ..utils.render_utils import get_renderer, yaw_pitch_r_fov_to_extrinsics_intrinsics + + +class SLatShapeVisMixin(SLatVisMixin): + def _loading_slat_dec(self): + if self.slat_dec is not None: + return + if self.slat_dec_path is not None: + cfg = json.load(open(os.path.join(self.slat_dec_path, 'config.json'), 'r')) + cfg['models']['decoder']['args']['resolution'] = self.resolution + decoder = getattr(models, cfg['models']['decoder']['name'])(**cfg['models']['decoder']['args']) + ckpt_path = os.path.join(self.slat_dec_path, 'ckpts', f'decoder_{self.slat_dec_ckpt}.pt') + decoder.load_state_dict(torch.load(ckpt_path, map_location='cpu', weights_only=True)) + else: + decoder = models.from_pretrained(self.pretrained_slat_dec) + self.slat_dec = decoder.cuda().eval() + + @torch.no_grad() + def visualize_sample(self, x_0: Union[SparseTensor, dict]): + x_0 = x_0 if isinstance(x_0, SparseTensor) else x_0['x_0'] + reps = self.decode_latent(x_0.cuda()) + + # build camera + yaw = [0, np.pi/2, np.pi, 3*np.pi/2] + yaw_offset = -16 / 180 * np.pi + yaw = [y + yaw_offset for y in yaw] + pitch = [20 / 180 * np.pi for _ in range(4)] + exts, ints = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaw, pitch, 2, 30) + + # render + renderer = get_renderer(reps[0]) + images = [] + for representation in reps: + image = torch.zeros(3, 1024, 1024).cuda() + tile = [2, 2] + for j, (ext, intr) in enumerate(zip(exts, ints)): + res = renderer.render(representation, ext, intr) + image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['normal'] + images.append(image) + images = torch.stack(images) + return images + + +class SLatShape(SLatShapeVisMixin, SLat): + """ + structured latent for shape generation + + Args: + roots (str): path to the dataset + resolution (int): resolution of the shape + min_aesthetic_score (float): minimum aesthetic score + max_tokens (int): maximum number of tokens + latent_key (str): key of the latent to be used + normalization (dict): normalization stats + pretrained_slat_dec (str): name of the pretrained slat decoder + slat_dec_path (str): path to the slat decoder, if given, will override the pretrained_slat_dec + slat_dec_ckpt (str): name of the slat decoder checkpoint + """ + def __init__(self, + roots: str, + *, + resolution: int, + min_aesthetic_score: float = 5.0, + max_tokens: int = 32768, + normalization: Optional[dict] = None, + pretrained_slat_dec: str = 'JeffreyXiang/TRELLIS.2-4B/ckpts/shape_dec_next_dc_f16c32_fp16', + slat_dec_path: Optional[str] = None, + slat_dec_ckpt: Optional[str] = None, + ): + super().__init__( + roots, + min_aesthetic_score=min_aesthetic_score, + max_tokens=max_tokens, + latent_key='shape_latent', + normalization=normalization, + pretrained_slat_dec=pretrained_slat_dec, + slat_dec_path=slat_dec_path, + slat_dec_ckpt=slat_dec_ckpt, + ) + self.resolution = resolution + + +class ImageConditionedSLatShape(ImageConditionedMixin, SLatShape): + """ + Image conditioned structured latent for shape generation + """ + pass diff --git a/trellis2/datasets/structured_latent_svpbr.py b/trellis2/datasets/structured_latent_svpbr.py new file mode 100644 index 0000000..4c6711e --- /dev/null +++ b/trellis2/datasets/structured_latent_svpbr.py @@ -0,0 +1,273 @@ +import os +import json +from typing import * +import numpy as np +import torch +from .. import models +from .components import StandardDatasetBase, ImageConditionedMixin +from ..modules.sparse import SparseTensor, sparse_cat +from ..representations import MeshWithVoxel +from ..utils.data_utils import load_balanced_group_indices +from ..utils.render_utils import get_renderer, yaw_pitch_r_fov_to_extrinsics_intrinsics + + +class SLatPbrVisMixin: + def __init__( + self, + *args, + pretrained_pbr_slat_dec: str = 'JeffreyXiang/TRELLIS.2-4B/ckpts/tex_dec_next_dc_f16c32_fp16', + pbr_slat_dec_path: Optional[str] = None, + pbr_slat_dec_ckpt: Optional[str] = None, + pretrained_shape_slat_dec: str = 'JeffreyXiang/TRELLIS.2-4B/ckpts/shape_dec_next_dc_f16c32_fp16', + shape_slat_dec_path: Optional[str] = None, + shape_slat_dec_ckpt: Optional[str] = None, + **kwargs + ): + super().__init__(*args, **kwargs) + self.pbr_slat_dec = None + self.pretrained_pbr_slat_dec = pretrained_pbr_slat_dec + self.pbr_slat_dec_path = pbr_slat_dec_path + self.pbr_slat_dec_ckpt = pbr_slat_dec_ckpt + self.shape_slat_dec = None + self.pretrained_shape_slat_dec = pretrained_shape_slat_dec + self.shape_slat_dec_path = shape_slat_dec_path + self.shape_slat_dec_ckpt = shape_slat_dec_ckpt + + def _loading_slat_dec(self): + if self.pbr_slat_dec is not None and self.shape_slat_dec is not None: + return + if self.pbr_slat_dec_path is not None: + cfg = json.load(open(os.path.join(self.pbr_slat_dec_path, 'config.json'), 'r')) + decoder = getattr(models, cfg['models']['decoder']['name'])(**cfg['models']['decoder']['args']) + ckpt_path = os.path.join(self.pbr_slat_dec_path, 'ckpts', f'decoder_{self.pbr_slat_dec_ckpt}.pt') + decoder.load_state_dict(torch.load(ckpt_path, map_location='cpu', weights_only=True)) + else: + decoder = models.from_pretrained(self.pretrained_pbr_slat_dec) + self.pbr_slat_dec = decoder.cuda().eval() + + if self.shape_slat_dec_path is not None: + cfg = json.load(open(os.path.join(self.shape_slat_dec_path, 'config.json'), 'r')) + cfg['models']['decoder']['args']['resolution'] = self.resolution + decoder = getattr(models, cfg['models']['decoder']['name'])(**cfg['models']['decoder']['args']) + ckpt_path = os.path.join(self.shape_slat_dec_path, 'ckpts', f'decoder_{self.shape_slat_dec_ckpt}.pt') + decoder.load_state_dict(torch.load(ckpt_path, map_location='cpu', weights_only=True)) + else: + decoder = models.from_pretrained(self.pretrained_shape_slat_dec) + self.shape_slat_dec = decoder.cuda().eval() + + def _delete_slat_dec(self): + del self.pbr_slat_dec + self.pbr_slat_dec = None + del self.shape_slat_dec + self.shape_slat_dec = None + + @torch.no_grad() + def decode_latent(self, z, shape_z, batch_size=4): + self._loading_slat_dec() + reps = [] + if self.shape_slat_normalization is not None: + shape_z = shape_z * self.shape_slat_std.to(z.device) + self.shape_slat_mean.to(z.device) + if self.pbr_slat_normalization is not None: + z = z * self.pbr_slat_std.to(z.device) + self.pbr_slat_mean.to(z.device) + for i in range(0, z.shape[0], batch_size): + mesh, subs = self.shape_slat_dec(shape_z[i:i+batch_size], return_subs=True) + vox = self.pbr_slat_dec(z[i:i+batch_size], guide_subs=subs) + reps.extend([ + MeshWithVoxel( + m.vertices, m.faces, + origin = [-0.5, -0.5, -0.5], + voxel_size = 1 / self.resolution, + coords = v.coords[:, 1:], + attrs = v.feats, + voxel_shape = torch.Size([*v.shape, *v.spatial_shape]), + layout = self.layout, + ) + for m, v in zip(mesh, vox) + ]) + self._delete_slat_dec() + return reps + + @torch.no_grad() + def visualize_sample(self, sample: dict): + shape_z = sample['concat_cond'].cuda() + z = sample['x_0'].cuda() + reps = self.decode_latent(z, shape_z) + + # build camera + yaw = [0, np.pi/2, np.pi, 3*np.pi/2] + yaw_offset = -16 / 180 * np.pi + yaw = [y + yaw_offset for y in yaw] + pitch = [20 / 180 * np.pi for _ in range(4)] + exts, ints = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaw, pitch, 2, 30) + + # render + renderer = get_renderer(reps[0]) + images = {k: [] for k in self.layout} + for representation in reps: + image = {k: torch.zeros(3, 1024, 1024).cuda() for k in self.layout} + tile = [2, 2] + for j, (ext, intr) in enumerate(zip(exts, ints)): + res = renderer.render(representation, ext, intr, return_types=['attr']) + for k in self.layout: + image[k][:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res[k] + for k in self.layout: + images[k].append(image[k]) + for k in self.layout: + images[k] = torch.stack(images[k], dim=0) + return images + + +class SLatPbr(SLatPbrVisMixin, StandardDatasetBase): + """ + structured latent for sparse voxel pbr dataset + + Args: + roots (str): path to the dataset + latent_key (str): key of the latent to be used + min_aesthetic_score (float): minimum aesthetic score + normalization (dict): normalization stats + resolution (int): resolution of decoded sparse voxel + attrs (list): attributes to be decoded + pretained_slat_dec (str): name of the pretrained slat decoder + slat_dec_path (str): path to the slat decoder, if given, will override the pretrained_slat_dec + slat_dec_ckpt (str): name of the slat decoder checkpoint + """ + def __init__(self, + roots: str, + *, + resolution: int, + min_aesthetic_score: float = 5.0, + max_tokens: int = 32768, + full_pbr: bool = False, + pbr_slat_normalization: Optional[dict] = None, + shape_slat_normalization: Optional[dict] = None, + attrs: list[str] = ['base_color', 'metallic', 'roughness', 'emissive', 'alpha'], + pretrained_pbr_slat_dec: str = 'JeffreyXiang/TRELLIS.2-4B/ckpts/tex_dec_next_dc_f16c32_fp16', + pbr_slat_dec_path: Optional[str] = None, + pbr_slat_dec_ckpt: Optional[str] = None, + pretrained_shape_slat_dec: str = 'JeffreyXiang/TRELLIS.2-4B/ckpts/shape_dec_next_dc_f16c32_fp16', + shape_slat_dec_path: Optional[str] = None, + shape_slat_dec_ckpt: Optional[str] = None, + **kwargs + ): + self.resolution = resolution + self.pbr_slat_normalization = pbr_slat_normalization + self.shape_slat_normalization = shape_slat_normalization + self.min_aesthetic_score = min_aesthetic_score + self.max_tokens = max_tokens + self.full_pbr = full_pbr + self.value_range = (-1, 1) + + super().__init__( + roots, + pretrained_pbr_slat_dec=pretrained_pbr_slat_dec, + pbr_slat_dec_path=pbr_slat_dec_path, + pbr_slat_dec_ckpt=pbr_slat_dec_ckpt, + pretrained_shape_slat_dec=pretrained_shape_slat_dec, + shape_slat_dec_path=shape_slat_dec_path, + shape_slat_dec_ckpt=shape_slat_dec_ckpt, + **kwargs + ) + + self.loads = [self.metadata.loc[sha256, 'pbr_latent_tokens'] for _, sha256 in self.instances] + + if self.pbr_slat_normalization is not None: + self.pbr_slat_mean = torch.tensor(self.pbr_slat_normalization['mean']).reshape(1, -1) + self.pbr_slat_std = torch.tensor(self.pbr_slat_normalization['std']).reshape(1, -1) + + if self.shape_slat_normalization is not None: + self.shape_slat_mean = torch.tensor(self.shape_slat_normalization['mean']).reshape(1, -1) + self.shape_slat_std = torch.tensor(self.shape_slat_normalization['std']).reshape(1, -1) + + self.attrs = attrs + self.channels = { + 'base_color': 3, + 'metallic': 1, + 'roughness': 1, + 'emissive': 3, + 'alpha': 1, + } + self.layout = {} + start = 0 + for attr in attrs: + self.layout[attr] = slice(start, start + self.channels[attr]) + start += self.channels[attr] + + def filter_metadata(self, metadata): + stats = {} + metadata = metadata[metadata['pbr_latent_encoded'] == True] + stats['With PBR latent'] = len(metadata) + metadata = metadata[metadata['shape_latent_encoded'] == True] + stats['With shape latent'] = len(metadata) + metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score] + stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata) + metadata = metadata[metadata['pbr_latent_tokens'] <= self.max_tokens] + stats[f'Num tokens <= {self.max_tokens}'] = len(metadata) + if self.full_pbr: + metadata = metadata[metadata['num_basecolor_tex'] > 0] + metadata = metadata[metadata['num_metallic_tex'] > 0] + metadata = metadata[metadata['num_roughness_tex'] > 0] + stats['Full PBR'] = len(metadata) + return metadata, stats + + def get_instance(self, root, instance): + # PBR latent + data = np.load(os.path.join(root['pbr_latent'], f'{instance}.npz')) + coords = torch.tensor(data['coords']).int() + coords = torch.cat([torch.zeros_like(coords)[:, :1], coords], dim=1) + feats = torch.tensor(data['feats']).float() + if self.pbr_slat_normalization is not None: + feats = (feats - self.pbr_slat_mean) / self.pbr_slat_std + pbr_z = SparseTensor(feats, coords) + + # Shape latent + data = np.load(os.path.join(root['shape_latent'], f'{instance}.npz')) + coords = torch.tensor(data['coords']).int() + coords = torch.cat([torch.zeros_like(coords)[:, :1], coords], dim=1) + feats = torch.tensor(data['feats']).float() + if self.shape_slat_normalization is not None: + feats = (feats - self.shape_slat_mean) / self.shape_slat_std + shape_z = SparseTensor(feats, coords) + + assert torch.equal(shape_z.coords, pbr_z.coords), \ + f"Shape latent and PBR latent have different coordinates: {shape_z.coords.shape} vs {pbr_z.coords.shape}" + + return { + 'x_0': pbr_z, + 'concat_cond': shape_z, + } + + @staticmethod + def collate_fn(batch, split_size=None): + if split_size is None: + group_idx = [list(range(len(batch)))] + else: + group_idx = load_balanced_group_indices([b['x_0'].feats.shape[0] for b in batch], split_size) + packs = [] + for group in group_idx: + sub_batch = [batch[i] for i in group] + pack = {} + + keys = [k for k in sub_batch[0].keys()] + for k in keys: + if isinstance(sub_batch[0][k], torch.Tensor): + pack[k] = torch.stack([b[k] for b in sub_batch]) + elif isinstance(sub_batch[0][k], SparseTensor): + pack[k] = sparse_cat([b[k] for b in sub_batch], dim=0) + elif isinstance(sub_batch[0][k], list): + pack[k] = sum([b[k] for b in sub_batch], []) + else: + pack[k] = [b[k] for b in sub_batch] + + packs.append(pack) + + if split_size is None: + return packs[0] + return packs + + +class ImageConditionedSLatPbr(ImageConditionedMixin, SLatPbr): + """ + Image conditioned structured latent dataset + """ + pass diff --git a/trellis2/models/__init__.py b/trellis2/models/__init__.py new file mode 100644 index 0000000..d4fed03 --- /dev/null +++ b/trellis2/models/__init__.py @@ -0,0 +1,78 @@ +import importlib + +__attributes = { + # Sparse Structure + 'SparseStructureEncoder': 'sparse_structure_vae', + 'SparseStructureDecoder': 'sparse_structure_vae', + 'SparseStructureFlowModel': 'sparse_structure_flow', + + # SLat Generation + 'SLatFlowModel': 'structured_latent_flow', + 'ElasticSLatFlowModel': 'structured_latent_flow', + + # SC-VAEs + 'SparseUnetVaeEncoder': 'sc_vaes.sparse_unet_vae', + 'SparseUnetVaeDecoder': 'sc_vaes.sparse_unet_vae', + 'FlexiDualGridVaeEncoder': 'sc_vaes.fdg_vae', + 'FlexiDualGridVaeDecoder': 'sc_vaes.fdg_vae' +} + +__submodules = [] + +__all__ = list(__attributes.keys()) + __submodules + +def __getattr__(name): + if name not in globals(): + if name in __attributes: + module_name = __attributes[name] + module = importlib.import_module(f".{module_name}", __name__) + globals()[name] = getattr(module, name) + elif name in __submodules: + module = importlib.import_module(f".{name}", __name__) + globals()[name] = module + else: + raise AttributeError(f"module {__name__} has no attribute {name}") + return globals()[name] + + +def from_pretrained(path: str, **kwargs): + """ + Load a model from a pretrained checkpoint. + + Args: + path: The path to the checkpoint. Can be either local path or a Hugging Face model name. + NOTE: config file and model file should take the name f'{path}.json' and f'{path}.safetensors' respectively. + **kwargs: Additional arguments for the model constructor. + """ + import os + import json + from safetensors.torch import load_file + is_local = os.path.exists(f"{path}.json") and os.path.exists(f"{path}.safetensors") + + if is_local: + config_file = f"{path}.json" + model_file = f"{path}.safetensors" + else: + from huggingface_hub import hf_hub_download + path_parts = path.split('/') + repo_id = f'{path_parts[0]}/{path_parts[1]}' + model_name = '/'.join(path_parts[2:]) + config_file = hf_hub_download(repo_id, f"{model_name}.json") + model_file = hf_hub_download(repo_id, f"{model_name}.safetensors") + + with open(config_file, 'r') as f: + config = json.load(f) + model = __getattr__(config['name'])(**config['args'], **kwargs) + model.load_state_dict(load_file(model_file), strict=False) + + return model + + +# For Pylance +if __name__ == '__main__': + from .sparse_structure_vae import SparseStructureEncoder, SparseStructureDecoder + from .sparse_structure_flow import SparseStructureFlowModel + from .structured_latent_flow import SLatFlowModel, ElasticSLatFlowModel + + from .sc_vaes.sparse_unet_vae import SparseUnetVaeEncoder, SparseUnetVaeDecoder + from .sc_vaes.fdg_vae import FlexiDualGridVaeEncoder, FlexiDualGridVaeDecoder diff --git a/trellis2/models/sc_vaes/fdg_vae.py b/trellis2/models/sc_vaes/fdg_vae.py new file mode 100644 index 0000000..c9b5b07 --- /dev/null +++ b/trellis2/models/sc_vaes/fdg_vae.py @@ -0,0 +1,110 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +from ...modules import sparse as sp +from .sparse_unet_vae import ( + SparseResBlock3d, + SparseConvNeXtBlock3d, + + SparseResBlockDownsample3d, + SparseResBlockUpsample3d, + SparseResBlockS2C3d, + SparseResBlockC2S3d, +) +from .sparse_unet_vae import ( + SparseUnetVaeEncoder, + SparseUnetVaeDecoder, +) +from ...representations import Mesh +from o_voxel.convert import flexible_dual_grid_to_mesh + + +class FlexiDualGridVaeEncoder(SparseUnetVaeEncoder): + def __init__( + self, + model_channels: List[int], + latent_channels: int, + num_blocks: List[int], + block_type: List[str], + down_block_type: List[str], + block_args: List[Dict[str, Any]], + use_fp16: bool = False, + ): + super().__init__( + 6, + model_channels, + latent_channels, + num_blocks, + block_type, + down_block_type, + block_args, + use_fp16, + ) + + def forward(self, vertices: sp.SparseTensor, intersected: sp.SparseTensor, sample_posterior=False, return_raw=False): + x = vertices.replace(torch.cat([ + vertices.feats - 0.5, + intersected.feats.float() - 0.5, + ], dim=1)) + return super().forward(x, sample_posterior, return_raw) + + +class FlexiDualGridVaeDecoder(SparseUnetVaeDecoder): + def __init__( + self, + resolution: int, + model_channels: List[int], + latent_channels: int, + num_blocks: List[int], + block_type: List[str], + up_block_type: List[str], + block_args: List[Dict[str, Any]], + voxel_margin: float = 0.5, + use_fp16: bool = False, + ): + self.resolution = resolution + self.voxel_margin = voxel_margin + + super().__init__( + 7, + model_channels, + latent_channels, + num_blocks, + block_type, + up_block_type, + block_args, + use_fp16, + ) + + def set_resolution(self, resolution: int) -> None: + self.resolution = resolution + + def forward(self, x: sp.SparseTensor, gt_intersected: sp.SparseTensor = None, **kwargs): + decoded = super().forward(x, **kwargs) + if self.training: + h, subs_gt, subs = decoded + vertices = h.replace((1 + 2 * self.voxel_margin) * F.sigmoid(h.feats[..., 0:3]) - self.voxel_margin) + intersected_logits = h.replace(h.feats[..., 3:6]) + quad_lerp = h.replace(F.softplus(h.feats[..., 6:7])) + mesh = [Mesh(flexible_dual_grid_to_mesh( + h.coords[:, 1:], v.feats, i.feats, q.feats, + aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]], + grid_size=self.resolution, + train=True + )) for v, i, q in zip(vertices, gt_intersected, quad_lerp)] + return mesh, vertices, intersected_logits, subs_gt, subs + else: + out_list = list(decoded) if isinstance(decoded, tuple) else [decoded] + h = out_list[0] + vertices = h.replace((1 + 2 * self.voxel_margin) * F.sigmoid(h.feats[..., 0:3]) - self.voxel_margin) + intersected = h.replace(h.feats[..., 3:6] > 0) + quad_lerp = h.replace(F.softplus(h.feats[..., 6:7])) + mesh = [Mesh(*flexible_dual_grid_to_mesh( + h.coords[:, 1:], v.feats, i.feats, q.feats, + aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]], + grid_size=self.resolution, + train=False + )) for v, i, q in zip(vertices, intersected, quad_lerp)] + out_list[0] = mesh + return out_list[0] if len(out_list) == 1 else tuple(out_list) diff --git a/trellis2/models/sc_vaes/sparse_unet_vae.py b/trellis2/models/sc_vaes/sparse_unet_vae.py new file mode 100644 index 0000000..b9902a1 --- /dev/null +++ b/trellis2/models/sc_vaes/sparse_unet_vae.py @@ -0,0 +1,522 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +from ...modules.utils import convert_module_to_f16, convert_module_to_f32, zero_module +from ...modules import sparse as sp +from ...modules.norm import LayerNorm32 + + +class SparseResBlock3d(nn.Module): + def __init__( + self, + channels: int, + out_channels: Optional[int] = None, + downsample: bool = False, + upsample: bool = False, + resample_mode: Literal['nearest', 'spatial2channel'] = 'nearest', + use_checkpoint: bool = False, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.downsample = downsample + self.upsample = upsample + self.resample_mode = resample_mode + self.use_checkpoint = use_checkpoint + + assert not (downsample and upsample), "Cannot downsample and upsample at the same time" + + self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6) + if resample_mode == 'nearest': + self.conv1 = sp.SparseConv3d(channels, self.out_channels, 3) + elif resample_mode =='spatial2channel' and not self.downsample: + self.conv1 = sp.SparseConv3d(channels, self.out_channels * 8, 3) + elif resample_mode =='spatial2channel' and self.downsample: + self.conv1 = sp.SparseConv3d(channels, self.out_channels // 8, 3) + self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3)) + if resample_mode == 'nearest': + self.skip_connection = sp.SparseLinear(channels, self.out_channels) if channels != self.out_channels else nn.Identity() + elif resample_mode =='spatial2channel' and self.downsample: + self.skip_connection = lambda x: x.replace(x.feats.reshape(x.feats.shape[0], out_channels, channels * 8 // out_channels).mean(dim=-1)) + elif resample_mode =='spatial2channel' and not self.downsample: + self.skip_connection = lambda x: x.replace(x.feats.repeat_interleave(out_channels // (channels // 8), dim=1)) + self.updown = None + if self.downsample: + if resample_mode == 'nearest': + self.updown = sp.SparseDownsample(2) + elif resample_mode =='spatial2channel': + self.updown = sp.SparseSpatial2Channel(2) + elif self.upsample: + self.to_subdiv = sp.SparseLinear(channels, 8) + if resample_mode == 'nearest': + self.updown = sp.SparseUpsample(2) + elif resample_mode =='spatial2channel': + self.updown = sp.SparseChannel2Spatial(2) + + def _updown(self, x: sp.SparseTensor, subdiv: sp.SparseTensor = None) -> sp.SparseTensor: + if self.downsample: + x = self.updown(x) + elif self.upsample: + x = self.updown(x, subdiv.replace(subdiv.feats > 0)) + return x + + def _forward(self, x: sp.SparseTensor) -> sp.SparseTensor: + subdiv = None + if self.upsample: + subdiv = self.to_subdiv(x) + h = x.replace(self.norm1(x.feats)) + h = h.replace(F.silu(h.feats)) + if self.resample_mode == 'spatial2channel': + h = self.conv1(h) + h = self._updown(h, subdiv) + x = self._updown(x, subdiv) + if self.resample_mode == 'nearest': + h = self.conv1(h) + h = h.replace(self.norm2(h.feats)) + h = h.replace(F.silu(h.feats)) + h = self.conv2(h) + h = h + self.skip_connection(x) + if self.upsample: + return h, subdiv + return h + + def forward(self, x: sp.SparseTensor) -> sp.SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) + else: + return self._forward(x) + + +class SparseResBlockDownsample3d(nn.Module): + def __init__( + self, + channels: int, + out_channels: Optional[int] = None, + use_checkpoint: bool = False, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_checkpoint = use_checkpoint + + self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6) + self.conv1 = sp.SparseConv3d(channels, self.out_channels, 3) + self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3)) + self.skip_connection = sp.SparseLinear(channels, self.out_channels) if channels != self.out_channels else nn.Identity() + self.updown = sp.SparseDownsample(2) + + def _forward(self, x: sp.SparseTensor) -> sp.SparseTensor: + h = x.replace(self.norm1(x.feats)) + h = h.replace(F.silu(h.feats)) + h = self.updown(h) + x = self.updown(x) + h = self.conv1(h) + h = h.replace(self.norm2(h.feats)) + h = h.replace(F.silu(h.feats)) + h = self.conv2(h) + h = h + self.skip_connection(x) + return h + + def forward(self, x: sp.SparseTensor) -> sp.SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) + else: + return self._forward(x) + + +class SparseResBlockUpsample3d(nn.Module): + def __init__( + self, + channels: int, + out_channels: Optional[int] = None, + use_checkpoint: bool = False, + pred_subdiv: bool = True, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_checkpoint = use_checkpoint + self.pred_subdiv = pred_subdiv + + self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6) + self.conv1 = sp.SparseConv3d(channels, self.out_channels, 3) + self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3)) + self.skip_connection = sp.SparseLinear(channels, self.out_channels) if channels != self.out_channels else nn.Identity() + if self.pred_subdiv: + self.to_subdiv = sp.SparseLinear(channels, 8) + self.updown = sp.SparseUpsample(2) + + def _forward(self, x: sp.SparseTensor, subdiv: sp.SparseTensor = None) -> sp.SparseTensor: + if self.pred_subdiv: + subdiv = self.to_subdiv(x) + h = x.replace(self.norm1(x.feats)) + h = h.replace(F.silu(h.feats)) + subdiv_binarized = subdiv.replace(subdiv.feats > 0) if subdiv is not None else None + h = self.updown(h, subdiv_binarized) + x = self.updown(x, subdiv_binarized) + h = self.conv1(h) + h = h.replace(self.norm2(h.feats)) + h = h.replace(F.silu(h.feats)) + h = self.conv2(h) + h = h + self.skip_connection(x) + if self.pred_subdiv: + return h, subdiv + else: + return h + + def forward(self, x: sp.SparseTensor) -> sp.SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) + else: + return self._forward(x) + + +class SparseResBlockS2C3d(nn.Module): + def __init__( + self, + channels: int, + out_channels: Optional[int] = None, + use_checkpoint: bool = False, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_checkpoint = use_checkpoint + + self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6) + self.conv1 = sp.SparseConv3d(channels, self.out_channels // 8, 3) + self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3)) + self.skip_connection = lambda x: x.replace(x.feats.reshape(x.feats.shape[0], out_channels, channels * 8 // out_channels).mean(dim=-1)) + self.updown = sp.SparseSpatial2Channel(2) + + def _forward(self, x: sp.SparseTensor) -> sp.SparseTensor: + h = x.replace(self.norm1(x.feats)) + h = h.replace(F.silu(h.feats)) + h = self.conv1(h) + h = self.updown(h) + x = self.updown(x) + h = h.replace(self.norm2(h.feats)) + h = h.replace(F.silu(h.feats)) + h = self.conv2(h) + h = h + self.skip_connection(x) + return h + + def forward(self, x: sp.SparseTensor) -> sp.SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) + else: + return self._forward(x) + + +class SparseResBlockC2S3d(nn.Module): + def __init__( + self, + channels: int, + out_channels: Optional[int] = None, + use_checkpoint: bool = False, + pred_subdiv: bool = True, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_checkpoint = use_checkpoint + self.pred_subdiv = pred_subdiv + + self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6) + self.conv1 = sp.SparseConv3d(channels, self.out_channels * 8, 3) + self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3)) + self.skip_connection = lambda x: x.replace(x.feats.repeat_interleave(out_channels // (channels // 8), dim=1)) + if pred_subdiv: + self.to_subdiv = sp.SparseLinear(channels, 8) + self.updown = sp.SparseChannel2Spatial(2) + + def _forward(self, x: sp.SparseTensor, subdiv: sp.SparseTensor = None) -> sp.SparseTensor: + if self.pred_subdiv: + subdiv = self.to_subdiv(x) + h = x.replace(self.norm1(x.feats)) + h = h.replace(F.silu(h.feats)) + h = self.conv1(h) + subdiv_binarized = subdiv.replace(subdiv.feats > 0) if subdiv is not None else None + h = self.updown(h, subdiv_binarized) + x = self.updown(x, subdiv_binarized) + h = h.replace(self.norm2(h.feats)) + h = h.replace(F.silu(h.feats)) + h = self.conv2(h) + h = h + self.skip_connection(x) + if self.pred_subdiv: + return h, subdiv + else: + return h + + def forward(self, x: sp.SparseTensor, subdiv: sp.SparseTensor = None) -> sp.SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, subdiv, use_reentrant=False) + else: + return self._forward(x, subdiv) + + +class SparseConvNeXtBlock3d(nn.Module): + def __init__( + self, + channels: int, + mlp_ratio: float = 4.0, + use_checkpoint: bool = False, + ): + super().__init__() + self.channels = channels + self.use_checkpoint = use_checkpoint + + self.norm = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.conv = sp.SparseConv3d(channels, channels, 3) + self.mlp = nn.Sequential( + nn.Linear(channels, int(channels * mlp_ratio)), + nn.SiLU(), + zero_module(nn.Linear(int(channels * mlp_ratio), channels)), + ) + + def _forward(self, x: sp.SparseTensor) -> sp.SparseTensor: + h = self.conv(x) + h = h.replace(self.norm(h.feats)) + h = h.replace(self.mlp(h.feats)) + return h + x + + def forward(self, x: sp.SparseTensor) -> sp.SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) + else: + return self._forward(x) + + +class SparseUnetVaeEncoder(nn.Module): + """ + Sparse Swin Transformer Unet VAE model. + """ + def __init__( + self, + in_channels: int, + model_channels: List[int], + latent_channels: int, + num_blocks: List[int], + block_type: List[str], + down_block_type: List[str], + block_args: List[Dict[str, Any]], + use_fp16: bool = False, + ): + super().__init__() + self.in_channels = in_channels + self.model_channels = model_channels + self.num_blocks = num_blocks + self.dtype = torch.float16 if use_fp16 else torch.float32 + self.dtype = torch.float16 if use_fp16 else torch.float32 + + self.input_layer = sp.SparseLinear(in_channels, model_channels[0]) + self.to_latent = sp.SparseLinear(model_channels[-1], 2 * latent_channels) + + self.blocks = nn.ModuleList([]) + for i in range(len(num_blocks)): + self.blocks.append(nn.ModuleList([])) + for j in range(num_blocks[i]): + self.blocks[-1].append( + globals()[block_type[i]]( + model_channels[i], + **block_args[i], + ) + ) + if i < len(num_blocks) - 1: + self.blocks[-1].append( + globals()[down_block_type[i]]( + model_channels[i], + model_channels[i+1], + **block_args[i], + ) + ) + + self.initialize_weights() + if use_fp16: + self.convert_to_fp16() + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + self.blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + self.blocks.apply(convert_module_to_f32) + + def initialize_weights(self) -> None: + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + def forward(self, x: sp.SparseTensor, sample_posterior=False, return_raw=False): + h = self.input_layer(x) + h = h.type(self.dtype) + for i, res in enumerate(self.blocks): + for j, block in enumerate(res): + h = block(h) + h = h.type(x.dtype) + h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) + h = self.to_latent(h) + + # Sample from the posterior distribution + mean, logvar = h.feats.chunk(2, dim=-1) + if sample_posterior: + std = torch.exp(0.5 * logvar) + z = mean + std * torch.randn_like(std) + else: + z = mean + z = h.replace(z) + + if return_raw: + return z, mean, logvar + else: + return z + + +class SparseUnetVaeDecoder(nn.Module): + """ + Sparse Swin Transformer Unet VAE model. + """ + def __init__( + self, + out_channels: int, + model_channels: List[int], + latent_channels: int, + num_blocks: List[int], + block_type: List[str], + up_block_type: List[str], + block_args: List[Dict[str, Any]], + use_fp16: bool = False, + pred_subdiv: bool = True, + ): + super().__init__() + self.out_channels = out_channels + self.model_channels = model_channels + self.num_blocks = num_blocks + self.use_fp16 = use_fp16 + self.pred_subdiv = pred_subdiv + self.dtype = torch.float16 if use_fp16 else torch.float32 + self.low_vram = False + + self.output_layer = sp.SparseLinear(model_channels[-1], out_channels) + self.from_latent = sp.SparseLinear(latent_channels, model_channels[0]) + + self.blocks = nn.ModuleList([]) + for i in range(len(num_blocks)): + self.blocks.append(nn.ModuleList([])) + for j in range(num_blocks[i]): + self.blocks[-1].append( + globals()[block_type[i]]( + model_channels[i], + **block_args[i], + ) + ) + if i < len(num_blocks) - 1: + self.blocks[-1].append( + globals()[up_block_type[i]]( + model_channels[i], + model_channels[i+1], + pred_subdiv=pred_subdiv, + **block_args[i], + ) + ) + + self.initialize_weights() + if use_fp16: + self.convert_to_fp16() + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + self.blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + self.blocks.apply(convert_module_to_f32) + + def initialize_weights(self) -> None: + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + def forward(self, x: sp.SparseTensor, guide_subs: Optional[List[sp.SparseTensor]] = None, return_subs: bool = False) -> sp.SparseTensor: + assert guide_subs is None or self.pred_subdiv == False, "Only decoders with pred_subdiv=False can be used with guide_subs" + assert return_subs == False or self.pred_subdiv == True, "Only decoders with pred_subdiv=True can be used with return_subs" + + h = self.from_latent(x) + h = h.type(self.dtype) + subs_gt = [] + subs = [] + for i, res in enumerate(self.blocks): + for j, block in enumerate(res): + if i < len(self.blocks) - 1 and j == len(res) - 1: + if self.pred_subdiv: + if self.training: + subs_gt.append(h.get_spatial_cache('subdivision')) + h, sub = block(h) + subs.append(sub) + else: + h = block(h, subdiv=guide_subs[i] if guide_subs is not None else None) + else: + h = block(h) + h = h.type(x.dtype) + h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) + h = self.output_layer(h) + if self.training and self.pred_subdiv: + return h, subs_gt, subs + else: + if return_subs: + return h, subs + else: + return h + + def upsample(self, x: sp.SparseTensor, upsample_times: int) -> torch.Tensor: + assert self.pred_subdiv == True, "Only decoders with pred_subdiv=True can be used with upsampling" + + h = self.from_latent(x) + h = h.type(self.dtype) + for i, res in enumerate(self.blocks): + if i == upsample_times: + return h.coords + for j, block in enumerate(res): + if i < len(self.blocks) - 1 and j == len(res) - 1: + h, sub = block(h) + else: + h = block(h) + \ No newline at end of file diff --git a/trellis2/models/sparse_elastic_mixin.py b/trellis2/models/sparse_elastic_mixin.py new file mode 100644 index 0000000..66d204c --- /dev/null +++ b/trellis2/models/sparse_elastic_mixin.py @@ -0,0 +1,24 @@ +from contextlib import contextmanager +from typing import * +import math +from ..modules import sparse as sp +from ..utils.elastic_utils import ElasticModuleMixin + + +class SparseTransformerElasticMixin(ElasticModuleMixin): + def _get_input_size(self, x: sp.SparseTensor, *args, **kwargs): + return x.feats.shape[0] + + @contextmanager + def with_mem_ratio(self, mem_ratio=1.0): + if mem_ratio == 1.0: + yield 1.0 + return + num_blocks = len(self.blocks) + num_checkpoint_blocks = min(math.ceil((1 - mem_ratio) * num_blocks) + 1, num_blocks) + exact_mem_ratio = 1 - (num_checkpoint_blocks - 1) / num_blocks + for i in range(num_blocks): + self.blocks[i].use_checkpoint = i < num_checkpoint_blocks + yield exact_mem_ratio + for i in range(num_blocks): + self.blocks[i].use_checkpoint = False diff --git a/trellis2/models/sparse_structure_flow.py b/trellis2/models/sparse_structure_flow.py new file mode 100644 index 0000000..6c97665 --- /dev/null +++ b/trellis2/models/sparse_structure_flow.py @@ -0,0 +1,247 @@ +from typing import * +from functools import partial +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from ..modules.utils import convert_module_to, manual_cast, str_to_dtype +from ..modules.transformer import AbsolutePositionEmbedder, ModulatedTransformerCrossBlock +from ..modules.attention import RotaryPositionEmbedder + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + Args: + t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + dim: the dimension of the output. + max_period: controls the minimum frequency of the embeddings. + + Returns: + an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + half = dim // 2 + freqs = torch.exp( + -np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + + +class SparseStructureFlowModel(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + model_channels: int, + cond_channels: int, + out_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4, + pe_mode: Literal["ape", "rope"] = "ape", + rope_freq: Tuple[float, float] = (1.0, 10000.0), + dtype: str = 'float32', + use_checkpoint: bool = False, + share_mod: bool = False, + initialization: str = 'vanilla', + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + **kwargs + ): + super().__init__() + self.resolution = resolution + self.in_channels = in_channels + self.model_channels = model_channels + self.cond_channels = cond_channels + self.out_channels = out_channels + self.num_blocks = num_blocks + self.num_heads = num_heads or model_channels // num_head_channels + self.mlp_ratio = mlp_ratio + self.pe_mode = pe_mode + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.initialization = initialization + self.qk_rms_norm = qk_rms_norm + self.qk_rms_norm_cross = qk_rms_norm_cross + self.dtype = str_to_dtype(dtype) + + self.t_embedder = TimestepEmbedder(model_channels) + if share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(model_channels, 6 * model_channels, bias=True) + ) + + if pe_mode == "ape": + pos_embedder = AbsolutePositionEmbedder(model_channels, 3) + coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [resolution] * 3], indexing='ij') + coords = torch.stack(coords, dim=-1).reshape(-1, 3) + pos_emb = pos_embedder(coords) + self.register_buffer("pos_emb", pos_emb) + elif pe_mode == "rope": + pos_embedder = RotaryPositionEmbedder(self.model_channels // self.num_heads, 3) + coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [resolution] * 3], indexing='ij') + coords = torch.stack(coords, dim=-1).reshape(-1, 3) + rope_phases = pos_embedder(coords) + self.register_buffer("rope_phases", rope_phases) + + if pe_mode != "rope": + self.rope_phases = None + + self.input_layer = nn.Linear(in_channels, model_channels) + + self.blocks = nn.ModuleList([ + ModulatedTransformerCrossBlock( + model_channels, + cond_channels, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + attn_mode='full', + use_checkpoint=self.use_checkpoint, + use_rope=(pe_mode == "rope"), + rope_freq=rope_freq, + share_mod=share_mod, + qk_rms_norm=self.qk_rms_norm, + qk_rms_norm_cross=self.qk_rms_norm_cross, + ) + for _ in range(num_blocks) + ]) + + self.out_layer = nn.Linear(model_channels, out_channels) + + self.initialize_weights() + self.convert_to(self.dtype) + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to(self, dtype: torch.dtype) -> None: + """ + Convert the torso of the model to the specified dtype. + """ + self.dtype = dtype + self.blocks.apply(partial(convert_module_to, dtype=dtype)) + + def initialize_weights(self) -> None: + if self.initialization == 'vanilla': + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in DiT blocks: + if self.share_mod: + nn.init.constant_(self.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.adaLN_modulation[-1].bias, 0) + else: + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.out_layer.weight, 0) + nn.init.constant_(self.out_layer.bias, 0) + + elif self.initialization == 'scaled': + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, std=np.sqrt(2.0 / (5.0 * self.model_channels))) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + # Scaled init for to_out and ffn2 + def _scaled_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, std=1.0 / np.sqrt(5 * self.num_blocks * self.model_channels)) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + for block in self.blocks: + block.self_attn.to_out.apply(_scaled_init) + block.cross_attn.to_out.apply(_scaled_init) + block.mlp.mlp[2].apply(_scaled_init) + + # Initialize input layer to make the initial representation have variance 1 + nn.init.normal_(self.input_layer.weight, std=1.0 / np.sqrt(self.in_channels)) + nn.init.zeros_(self.input_layer.bias) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in DiT blocks: + if self.share_mod: + nn.init.constant_(self.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.adaLN_modulation[-1].bias, 0) + else: + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.out_layer.weight, 0) + nn.init.constant_(self.out_layer.bias, 0) + + def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: + assert [*x.shape] == [x.shape[0], self.in_channels, *[self.resolution] * 3], \ + f"Input shape mismatch, got {x.shape}, expected {[x.shape[0], self.in_channels, *[self.resolution] * 3]}" + + h = x.view(*x.shape[:2], -1).permute(0, 2, 1).contiguous() + + h = self.input_layer(h) + if self.pe_mode == "ape": + h = h + self.pos_emb[None] + t_emb = self.t_embedder(t) + if self.share_mod: + t_emb = self.adaLN_modulation(t_emb) + t_emb = manual_cast(t_emb, self.dtype) + h = manual_cast(h, self.dtype) + cond = manual_cast(cond, self.dtype) + for block in self.blocks: + h = block(h, t_emb, cond, self.rope_phases) + h = manual_cast(h, x.dtype) + h = F.layer_norm(h, h.shape[-1:]) + h = self.out_layer(h) + + h = h.permute(0, 2, 1).view(h.shape[0], h.shape[2], *[self.resolution] * 3).contiguous() + + return h diff --git a/trellis2/models/sparse_structure_vae.py b/trellis2/models/sparse_structure_vae.py new file mode 100644 index 0000000..c3e0913 --- /dev/null +++ b/trellis2/models/sparse_structure_vae.py @@ -0,0 +1,306 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +from ..modules.norm import GroupNorm32, ChannelLayerNorm32 +from ..modules.spatial import pixel_shuffle_3d +from ..modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32 + + +def norm_layer(norm_type: str, *args, **kwargs) -> nn.Module: + """ + Return a normalization layer. + """ + if norm_type == "group": + return GroupNorm32(32, *args, **kwargs) + elif norm_type == "layer": + return ChannelLayerNorm32(*args, **kwargs) + else: + raise ValueError(f"Invalid norm type {norm_type}") + + +class ResBlock3d(nn.Module): + def __init__( + self, + channels: int, + out_channels: Optional[int] = None, + norm_type: Literal["group", "layer"] = "layer", + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.norm1 = norm_layer(norm_type, channels) + self.norm2 = norm_layer(norm_type, self.out_channels) + self.conv1 = nn.Conv3d(channels, self.out_channels, 3, padding=1) + self.conv2 = zero_module(nn.Conv3d(self.out_channels, self.out_channels, 3, padding=1)) + self.skip_connection = nn.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = self.norm1(x) + h = F.silu(h) + h = self.conv1(h) + h = self.norm2(h) + h = F.silu(h) + h = self.conv2(h) + h = h + self.skip_connection(x) + return h + + +class DownsampleBlock3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + mode: Literal["conv", "avgpool"] = "conv", + ): + assert mode in ["conv", "avgpool"], f"Invalid mode {mode}" + + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + if mode == "conv": + self.conv = nn.Conv3d(in_channels, out_channels, 2, stride=2) + elif mode == "avgpool": + assert in_channels == out_channels, "Pooling mode requires in_channels to be equal to out_channels" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if hasattr(self, "conv"): + return self.conv(x) + else: + return F.avg_pool3d(x, 2) + + +class UpsampleBlock3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + mode: Literal["conv", "nearest"] = "conv", + ): + assert mode in ["conv", "nearest"], f"Invalid mode {mode}" + + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + if mode == "conv": + self.conv = nn.Conv3d(in_channels, out_channels*8, 3, padding=1) + elif mode == "nearest": + assert in_channels == out_channels, "Nearest mode requires in_channels to be equal to out_channels" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if hasattr(self, "conv"): + x = self.conv(x) + return pixel_shuffle_3d(x, 2) + else: + return F.interpolate(x, scale_factor=2, mode="nearest") + + +class SparseStructureEncoder(nn.Module): + """ + Encoder for Sparse Structure (\mathcal{E}_S in the paper Sec. 3.3). + + Args: + in_channels (int): Channels of the input. + latent_channels (int): Channels of the latent representation. + num_res_blocks (int): Number of residual blocks at each resolution. + channels (List[int]): Channels of the encoder blocks. + num_res_blocks_middle (int): Number of residual blocks in the middle. + norm_type (Literal["group", "layer"]): Type of normalization layer. + use_fp16 (bool): Whether to use FP16. + """ + def __init__( + self, + in_channels: int, + latent_channels: int, + num_res_blocks: int, + channels: List[int], + num_res_blocks_middle: int = 2, + norm_type: Literal["group", "layer"] = "layer", + use_fp16: bool = False, + ): + super().__init__() + self.in_channels = in_channels + self.latent_channels = latent_channels + self.num_res_blocks = num_res_blocks + self.channels = channels + self.num_res_blocks_middle = num_res_blocks_middle + self.norm_type = norm_type + self.use_fp16 = use_fp16 + self.dtype = torch.float16 if use_fp16 else torch.float32 + + self.input_layer = nn.Conv3d(in_channels, channels[0], 3, padding=1) + + self.blocks = nn.ModuleList([]) + for i, ch in enumerate(channels): + self.blocks.extend([ + ResBlock3d(ch, ch) + for _ in range(num_res_blocks) + ]) + if i < len(channels) - 1: + self.blocks.append( + DownsampleBlock3d(ch, channels[i+1]) + ) + + self.middle_block = nn.Sequential(*[ + ResBlock3d(channels[-1], channels[-1]) + for _ in range(num_res_blocks_middle) + ]) + + self.out_layer = nn.Sequential( + norm_layer(norm_type, channels[-1]), + nn.SiLU(), + nn.Conv3d(channels[-1], latent_channels*2, 3, padding=1) + ) + + if use_fp16: + self.convert_to_fp16() + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + self.use_fp16 = True + self.dtype = torch.float16 + self.blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + self.use_fp16 = False + self.dtype = torch.float32 + self.blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x: torch.Tensor, sample_posterior: bool = False, return_raw: bool = False) -> torch.Tensor: + h = self.input_layer(x) + h = h.type(self.dtype) + + for block in self.blocks: + h = block(h) + h = self.middle_block(h) + + h = h.type(x.dtype) + h = self.out_layer(h) + + mean, logvar = h.chunk(2, dim=1) + + if sample_posterior: + std = torch.exp(0.5 * logvar) + z = mean + std * torch.randn_like(std) + else: + z = mean + + if return_raw: + return z, mean, logvar + return z + + +class SparseStructureDecoder(nn.Module): + """ + Decoder for Sparse Structure (\mathcal{D}_S in the paper Sec. 3.3). + + Args: + out_channels (int): Channels of the output. + latent_channels (int): Channels of the latent representation. + num_res_blocks (int): Number of residual blocks at each resolution. + channels (List[int]): Channels of the decoder blocks. + num_res_blocks_middle (int): Number of residual blocks in the middle. + norm_type (Literal["group", "layer"]): Type of normalization layer. + use_fp16 (bool): Whether to use FP16. + """ + def __init__( + self, + out_channels: int, + latent_channels: int, + num_res_blocks: int, + channels: List[int], + num_res_blocks_middle: int = 2, + norm_type: Literal["group", "layer"] = "layer", + use_fp16: bool = False, + ): + super().__init__() + self.out_channels = out_channels + self.latent_channels = latent_channels + self.num_res_blocks = num_res_blocks + self.channels = channels + self.num_res_blocks_middle = num_res_blocks_middle + self.norm_type = norm_type + self.use_fp16 = use_fp16 + self.dtype = torch.float16 if use_fp16 else torch.float32 + + self.input_layer = nn.Conv3d(latent_channels, channels[0], 3, padding=1) + + self.middle_block = nn.Sequential(*[ + ResBlock3d(channels[0], channels[0]) + for _ in range(num_res_blocks_middle) + ]) + + self.blocks = nn.ModuleList([]) + for i, ch in enumerate(channels): + self.blocks.extend([ + ResBlock3d(ch, ch) + for _ in range(num_res_blocks) + ]) + if i < len(channels) - 1: + self.blocks.append( + UpsampleBlock3d(ch, channels[i+1]) + ) + + self.out_layer = nn.Sequential( + norm_layer(norm_type, channels[-1]), + nn.SiLU(), + nn.Conv3d(channels[-1], out_channels, 3, padding=1) + ) + + if use_fp16: + self.convert_to_fp16() + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + self.use_fp16 = True + self.dtype = torch.float16 + self.blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + self.use_fp16 = False + self.dtype = torch.float32 + self.blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = self.input_layer(x) + + h = h.type(self.dtype) + + h = self.middle_block(h) + for block in self.blocks: + h = block(h) + + h = h.type(x.dtype) + h = self.out_layer(h) + return h diff --git a/trellis2/models/structured_latent_flow.py b/trellis2/models/structured_latent_flow.py new file mode 100644 index 0000000..9378ff7 --- /dev/null +++ b/trellis2/models/structured_latent_flow.py @@ -0,0 +1,207 @@ +from typing import * +from functools import partial +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from ..modules.utils import convert_module_to, manual_cast, str_to_dtype +from ..modules.transformer import AbsolutePositionEmbedder +from ..modules import sparse as sp +from ..modules.sparse.transformer import ModulatedSparseTransformerCrossBlock +from .sparse_structure_flow import TimestepEmbedder +from .sparse_elastic_mixin import SparseTransformerElasticMixin + + +class SLatFlowModel(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + model_channels: int, + cond_channels: int, + out_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4, + pe_mode: Literal["ape", "rope"] = "ape", + rope_freq: Tuple[float, float] = (1.0, 10000.0), + dtype: str = 'float32', + use_checkpoint: bool = False, + share_mod: bool = False, + initialization: str = 'vanilla', + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + ): + super().__init__() + self.resolution = resolution + self.in_channels = in_channels + self.model_channels = model_channels + self.cond_channels = cond_channels + self.out_channels = out_channels + self.num_blocks = num_blocks + self.num_heads = num_heads or model_channels // num_head_channels + self.mlp_ratio = mlp_ratio + self.pe_mode = pe_mode + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.initialization = initialization + self.qk_rms_norm = qk_rms_norm + self.qk_rms_norm_cross = qk_rms_norm_cross + self.dtype = str_to_dtype(dtype) + + self.t_embedder = TimestepEmbedder(model_channels) + if share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(model_channels, 6 * model_channels, bias=True) + ) + + if pe_mode == "ape": + self.pos_embedder = AbsolutePositionEmbedder(model_channels) + + self.input_layer = sp.SparseLinear(in_channels, model_channels) + + self.blocks = nn.ModuleList([ + ModulatedSparseTransformerCrossBlock( + model_channels, + cond_channels, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + attn_mode='full', + use_checkpoint=self.use_checkpoint, + use_rope=(pe_mode == "rope"), + rope_freq=rope_freq, + share_mod=self.share_mod, + qk_rms_norm=self.qk_rms_norm, + qk_rms_norm_cross=self.qk_rms_norm_cross, + ) + for _ in range(num_blocks) + ]) + + self.out_layer = sp.SparseLinear(model_channels, out_channels) + + self.initialize_weights() + self.convert_to(self.dtype) + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to(self, dtype: torch.dtype) -> None: + """ + Convert the torso of the model to the specified dtype. + """ + self.dtype = dtype + self.blocks.apply(partial(convert_module_to, dtype=dtype)) + + def initialize_weights(self) -> None: + if self.initialization == 'vanilla': + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in DiT blocks: + if self.share_mod: + nn.init.constant_(self.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.adaLN_modulation[-1].bias, 0) + else: + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.out_layer.weight, 0) + nn.init.constant_(self.out_layer.bias, 0) + + elif self.initialization == 'scaled': + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, std=np.sqrt(2.0 / (5.0 * self.model_channels))) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + # Scaled init for to_out and ffn2 + def _scaled_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, std=1.0 / np.sqrt(5 * self.num_blocks * self.model_channels)) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + for block in self.blocks: + block.self_attn.to_out.apply(_scaled_init) + block.cross_attn.to_out.apply(_scaled_init) + block.mlp.mlp[2].apply(_scaled_init) + + # Initialize input layer to make the initial representation have variance 1 + nn.init.normal_(self.input_layer.weight, std=1.0 / np.sqrt(self.in_channels)) + nn.init.zeros_(self.input_layer.bias) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in DiT blocks: + if self.share_mod: + nn.init.constant_(self.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.adaLN_modulation[-1].bias, 0) + else: + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.out_layer.weight, 0) + nn.init.constant_(self.out_layer.bias, 0) + + def forward( + self, + x: sp.SparseTensor, + t: torch.Tensor, + cond: Union[torch.Tensor, List[torch.Tensor]], + concat_cond: Optional[sp.SparseTensor] = None, + **kwargs + ) -> sp.SparseTensor: + if concat_cond is not None: + x = sp.sparse_cat([x, concat_cond], dim=-1) + if isinstance(cond, list): + cond = sp.VarLenTensor.from_tensor_list(cond) + + h = self.input_layer(x) + h = manual_cast(h, self.dtype) + t_emb = self.t_embedder(t) + if self.share_mod: + t_emb = self.adaLN_modulation(t_emb) + t_emb = manual_cast(t_emb, self.dtype) + cond = manual_cast(cond, self.dtype) + + if self.pe_mode == "ape": + pe = self.pos_embedder(h.coords[:, 1:]) + h = h + manual_cast(pe, self.dtype) + for block in self.blocks: + h = block(h, t_emb, cond) + + h = manual_cast(h, x.dtype) + h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) + h = self.out_layer(h) + return h + + +class ElasticSLatFlowModel(SparseTransformerElasticMixin, SLatFlowModel): + """ + SLat Flow Model with elastic memory management. + Used for training with low VRAM. + """ + pass diff --git a/trellis2/modules/attention/__init__.py b/trellis2/modules/attention/__init__.py new file mode 100644 index 0000000..e90e901 --- /dev/null +++ b/trellis2/modules/attention/__init__.py @@ -0,0 +1,3 @@ +from .full_attn import * +from .modules import * +from .rope import * diff --git a/trellis2/modules/attention/config.py b/trellis2/modules/attention/config.py new file mode 100644 index 0000000..a6d5180 --- /dev/null +++ b/trellis2/modules/attention/config.py @@ -0,0 +1,32 @@ +from typing import * + +BACKEND = 'flash_attn' +DEBUG = False + +def __from_env(): + import os + + global BACKEND + global DEBUG + + env_attn_backend = os.environ.get('ATTN_BACKEND') + env_attn_debug = os.environ.get('ATTN_DEBUG') + + if env_attn_backend is not None and env_attn_backend in ['xformers', 'flash_attn', 'flash_attn_3', 'sdpa', 'naive']: + BACKEND = env_attn_backend + if env_attn_debug is not None: + DEBUG = env_attn_debug == '1' + + print(f"[ATTENTION] Using backend: {BACKEND}") + + +__from_env() + + +def set_backend(backend: Literal['xformers', 'flash_attn']): + global BACKEND + BACKEND = backend + +def set_debug(debug: bool): + global DEBUG + DEBUG = debug diff --git a/trellis2/modules/attention/full_attn.py b/trellis2/modules/attention/full_attn.py new file mode 100644 index 0000000..e2f9b2a --- /dev/null +++ b/trellis2/modules/attention/full_attn.py @@ -0,0 +1,145 @@ +from typing import * +import torch +import math +from . import config + + +__all__ = [ + 'scaled_dot_product_attention', +] + + +def _naive_sdpa(q, k, v): + """ + Naive implementation of scaled dot product attention. + """ + q = q.permute(0, 2, 1, 3) # [N, H, L, C] + k = k.permute(0, 2, 1, 3) # [N, H, L, C] + v = v.permute(0, 2, 1, 3) # [N, H, L, C] + scale_factor = 1 / math.sqrt(q.size(-1)) + attn_weight = q @ k.transpose(-2, -1) * scale_factor + attn_weight = torch.softmax(attn_weight, dim=-1) + out = attn_weight @ v + out = out.permute(0, 2, 1, 3) # [N, L, H, C] + return out + + +@overload +def scaled_dot_product_attention(qkv: torch.Tensor) -> torch.Tensor: + """ + Apply scaled dot product attention. + + Args: + qkv (torch.Tensor): A [N, L, 3, H, C] tensor containing Qs, Ks, and Vs. + """ + ... + +@overload +def scaled_dot_product_attention(q: torch.Tensor, kv: torch.Tensor) -> torch.Tensor: + """ + Apply scaled dot product attention. + + Args: + q (torch.Tensor): A [N, L, H, C] tensor containing Qs. + kv (torch.Tensor): A [N, L, 2, H, C] tensor containing Ks and Vs. + """ + ... + +@overload +def scaled_dot_product_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + """ + Apply scaled dot product attention. + + Args: + q (torch.Tensor): A [N, L, H, Ci] tensor containing Qs. + k (torch.Tensor): A [N, L, H, Ci] tensor containing Ks. + v (torch.Tensor): A [N, L, H, Co] tensor containing Vs. + + Note: + k and v are assumed to have the same coordinate map. + """ + ... + +def scaled_dot_product_attention(*args, **kwargs): + arg_names_dict = { + 1: ['qkv'], + 2: ['q', 'kv'], + 3: ['q', 'k', 'v'] + } + num_all_args = len(args) + len(kwargs) + assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3" + for key in arg_names_dict[num_all_args][len(args):]: + assert key in kwargs, f"Missing argument {key}" + + if num_all_args == 1: + qkv = args[0] if len(args) > 0 else kwargs['qkv'] + assert len(qkv.shape) == 5 and qkv.shape[2] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, L, 3, H, C]" + device = qkv.device + + elif num_all_args == 2: + q = args[0] if len(args) > 0 else kwargs['q'] + kv = args[1] if len(args) > 1 else kwargs['kv'] + assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}" + assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]" + assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]" + device = q.device + + elif num_all_args == 3: + q = args[0] if len(args) > 0 else kwargs['q'] + k = args[1] if len(args) > 1 else kwargs['k'] + v = args[2] if len(args) > 2 else kwargs['v'] + assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}" + assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]" + assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]" + assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]" + device = q.device + + if config.BACKEND == 'xformers': + if 'xops' not in globals(): + import xformers.ops as xops + if num_all_args == 1: + q, k, v = qkv.unbind(dim=2) + elif num_all_args == 2: + k, v = kv.unbind(dim=2) + out = xops.memory_efficient_attention(q, k, v) + elif config.BACKEND == 'flash_attn': + if 'flash_attn' not in globals(): + import flash_attn + if num_all_args == 1: + out = flash_attn.flash_attn_qkvpacked_func(qkv) + elif num_all_args == 2: + out = flash_attn.flash_attn_kvpacked_func(q, kv) + elif num_all_args == 3: + out = flash_attn.flash_attn_func(q, k, v) + elif config.BACKEND == 'flash_attn_3': + if 'flash_attn_3' not in globals(): + import flash_attn_interface as flash_attn_3 + if num_all_args == 1: + out = flash_attn_3.flash_attn_qkvpacked_func(qkv) + elif num_all_args == 2: + k, v = kv.unbind(dim=2) + out = flash_attn_3.flash_attn_func(q, k, v) + elif num_all_args == 3: + out = flash_attn_3.flash_attn_func(q, k, v) + elif config.BACKEND == 'sdpa': + if 'sdpa' not in globals(): + from torch.nn.functional import scaled_dot_product_attention as sdpa + if num_all_args == 1: + q, k, v = qkv.unbind(dim=2) + elif num_all_args == 2: + k, v = kv.unbind(dim=2) + q = q.permute(0, 2, 1, 3) # [N, H, L, C] + k = k.permute(0, 2, 1, 3) # [N, H, L, C] + v = v.permute(0, 2, 1, 3) # [N, H, L, C] + out = sdpa(q, k, v) # [N, H, L, C] + out = out.permute(0, 2, 1, 3) # [N, L, H, C] + elif config.BACKEND == 'naive': + if num_all_args == 1: + q, k, v = qkv.unbind(dim=2) + elif num_all_args == 2: + k, v = kv.unbind(dim=2) + out = _naive_sdpa(q, k, v) + else: + raise ValueError(f"Unknown attention module: {config.BACKEND}") + + return out diff --git a/trellis2/modules/attention/modules.py b/trellis2/modules/attention/modules.py new file mode 100644 index 0000000..492784c --- /dev/null +++ b/trellis2/modules/attention/modules.py @@ -0,0 +1,102 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +from .full_attn import scaled_dot_product_attention +from .rope import RotaryPositionEmbedder + + +class MultiHeadRMSNorm(nn.Module): + def __init__(self, dim: int, heads: int): + super().__init__() + self.scale = dim ** 0.5 + self.gamma = nn.Parameter(torch.ones(heads, dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype) + + +class MultiHeadAttention(nn.Module): + def __init__( + self, + channels: int, + num_heads: int, + ctx_channels: Optional[int]=None, + type: Literal["self", "cross"] = "self", + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + qkv_bias: bool = True, + use_rope: bool = False, + rope_freq: Tuple[float, float] = (1.0, 10000.0), + qk_rms_norm: bool = False, + ): + super().__init__() + assert channels % num_heads == 0 + assert type in ["self", "cross"], f"Invalid attention type: {type}" + assert attn_mode in ["full", "windowed"], f"Invalid attention mode: {attn_mode}" + assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention" + + if attn_mode == "windowed": + raise NotImplementedError("Windowed attention is not yet implemented") + + self.channels = channels + self.head_dim = channels // num_heads + self.ctx_channels = ctx_channels if ctx_channels is not None else channels + self.num_heads = num_heads + self._type = type + self.attn_mode = attn_mode + self.window_size = window_size + self.shift_window = shift_window + self.use_rope = use_rope + self.qk_rms_norm = qk_rms_norm + + if self._type == "self": + self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias) + else: + self.to_q = nn.Linear(channels, channels, bias=qkv_bias) + self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias) + + if self.qk_rms_norm: + self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads) + self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads) + + self.to_out = nn.Linear(channels, channels) + + def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, phases: Optional[torch.Tensor] = None) -> torch.Tensor: + B, L, C = x.shape + if self._type == "self": + qkv = self.to_qkv(x) + qkv = qkv.reshape(B, L, 3, self.num_heads, -1) + + if self.attn_mode == "full": + if self.qk_rms_norm or self.use_rope: + q, k, v = qkv.unbind(dim=2) + if self.qk_rms_norm: + q = self.q_rms_norm(q) + k = self.k_rms_norm(k) + if self.use_rope: + assert phases is not None, "Phases must be provided for RoPE" + q = RotaryPositionEmbedder.apply_rotary_embedding(q, phases) + k = RotaryPositionEmbedder.apply_rotary_embedding(k, phases) + h = scaled_dot_product_attention(q, k, v) + else: + h = scaled_dot_product_attention(qkv) + elif self.attn_mode == "windowed": + raise NotImplementedError("Windowed attention is not yet implemented") + else: + Lkv = context.shape[1] + q = self.to_q(x) + kv = self.to_kv(context) + q = q.reshape(B, L, self.num_heads, -1) + kv = kv.reshape(B, Lkv, 2, self.num_heads, -1) + if self.qk_rms_norm: + q = self.q_rms_norm(q) + k, v = kv.unbind(dim=2) + k = self.k_rms_norm(k) + h = scaled_dot_product_attention(q, k, v) + else: + h = scaled_dot_product_attention(q, kv) + h = h.reshape(B, L, -1) + h = self.to_out(h) + return h diff --git a/trellis2/modules/attention/rope.py b/trellis2/modules/attention/rope.py new file mode 100644 index 0000000..1cf6c5b --- /dev/null +++ b/trellis2/modules/attention/rope.py @@ -0,0 +1,48 @@ +from typing import * +import torch +import torch.nn as nn + + +class RotaryPositionEmbedder(nn.Module): + def __init__( + self, + head_dim: int, + dim: int = 3, + rope_freq: Tuple[float, float] = (1.0, 10000.0) + ): + super().__init__() + assert head_dim % 2 == 0, "Head dim must be divisible by 2" + self.head_dim = head_dim + self.dim = dim + self.rope_freq = rope_freq + self.freq_dim = head_dim // 2 // dim + self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim + self.freqs = rope_freq[0] / (rope_freq[1] ** (self.freqs)) + + def _get_phases(self, indices: torch.Tensor) -> torch.Tensor: + self.freqs = self.freqs.to(indices.device) + phases = torch.outer(indices, self.freqs) + phases = torch.polar(torch.ones_like(phases), phases) + return phases + + @staticmethod + def apply_rotary_embedding(x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor: + x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + x_rotated = x_complex * phases.unsqueeze(-2) + x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype) + return x_embed + + def forward(self, indices: torch.Tensor) -> torch.Tensor: + """ + Args: + indices (torch.Tensor): [..., N, C] tensor of spatial positions + """ + assert indices.shape[-1] == self.dim, f"Last dim of indices must be {self.dim}" + phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1) + if phases.shape[-1] < self.head_dim // 2: + padn = self.head_dim // 2 - phases.shape[-1] + phases = torch.cat([phases, torch.polar( + torch.ones(*phases.shape[:-1], padn, device=phases.device), + torch.zeros(*phases.shape[:-1], padn, device=phases.device) + )], dim=-1) + return phases \ No newline at end of file diff --git a/trellis2/modules/image_feature_extractor.py b/trellis2/modules/image_feature_extractor.py new file mode 100644 index 0000000..c3cb515 --- /dev/null +++ b/trellis2/modules/image_feature_extractor.py @@ -0,0 +1,118 @@ +from typing import * +import torch +import torch.nn.functional as F +from torchvision import transforms +from transformers import DINOv3ViTModel +import numpy as np +from PIL import Image + + +class DinoV2FeatureExtractor: + """ + Feature extractor for DINOv2 models. + """ + def __init__(self, model_name: str): + self.model_name = model_name + self.model = torch.hub.load('facebookresearch/dinov2', model_name, pretrained=True) + self.model.eval() + self.transform = transforms.Compose([ + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ]) + + def to(self, device): + self.model.to(device) + + def cuda(self): + self.model.cuda() + + def cpu(self): + self.model.cpu() + + @torch.no_grad() + def __call__(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor: + """ + Extract features from the image. + + Args: + image: A batch of images as a tensor of shape (B, C, H, W) or a list of PIL images. + + Returns: + A tensor of shape (B, N, D) where N is the number of patches and D is the feature dimension. + """ + if isinstance(image, torch.Tensor): + assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)" + elif isinstance(image, list): + assert all(isinstance(i, Image.Image) for i in image), "Image list should be list of PIL images" + image = [i.resize((518, 518), Image.LANCZOS) for i in image] + image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image] + image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image] + image = torch.stack(image).cuda() + else: + raise ValueError(f"Unsupported type of image: {type(image)}") + + image = self.transform(image).cuda() + features = self.model(image, is_training=True)['x_prenorm'] + patchtokens = F.layer_norm(features, features.shape[-1:]) + return patchtokens + + +class DinoV3FeatureExtractor: + """ + Feature extractor for DINOv3 models. + """ + def __init__(self, model_name: str, image_size=512): + self.model_name = model_name + self.model = DINOv3ViTModel.from_pretrained(model_name) + self.model.eval() + self.image_size = image_size + self.transform = transforms.Compose([ + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ]) + + def to(self, device): + self.model.to(device) + + def cuda(self): + self.model.cuda() + + def cpu(self): + self.model.cpu() + + def extract_features(self, image: torch.Tensor) -> torch.Tensor: + image = image.to(self.model.embeddings.patch_embeddings.weight.dtype) + hidden_states = self.model.embeddings(image, bool_masked_pos=None) + position_embeddings = self.model.rope_embeddings(image) + + for i, layer_module in enumerate(self.model.layer): + hidden_states = layer_module( + hidden_states, + position_embeddings=position_embeddings, + ) + + return F.layer_norm(hidden_states, hidden_states.shape[-1:]) + + @torch.no_grad() + def __call__(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor: + """ + Extract features from the image. + + Args: + image: A batch of images as a tensor of shape (B, C, H, W) or a list of PIL images. + + Returns: + A tensor of shape (B, N, D) where N is the number of patches and D is the feature dimension. + """ + if isinstance(image, torch.Tensor): + assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)" + elif isinstance(image, list): + assert all(isinstance(i, Image.Image) for i in image), "Image list should be list of PIL images" + image = [i.resize((self.image_size, self.image_size), Image.LANCZOS) for i in image] + image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image] + image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image] + image = torch.stack(image).cuda() + else: + raise ValueError(f"Unsupported type of image: {type(image)}") + + image = self.transform(image).cuda() + features = self.extract_features(image) + return features diff --git a/trellis2/modules/norm.py b/trellis2/modules/norm.py new file mode 100644 index 0000000..78675d0 --- /dev/null +++ b/trellis2/modules/norm.py @@ -0,0 +1,32 @@ +import torch +import torch.nn as nn +from .utils import manual_cast + + +class LayerNorm32(nn.LayerNorm): + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_dtype = x.dtype + x = manual_cast(x, torch.float32) + o = super().forward(x) + return manual_cast(o, x_dtype) + + +class GroupNorm32(nn.GroupNorm): + """ + A GroupNorm layer that converts to float32 before the forward pass. + """ + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_dtype = x.dtype + x = manual_cast(x, torch.float32) + o = super().forward(x) + return manual_cast(o, x_dtype) + + +class ChannelLayerNorm32(LayerNorm32): + def forward(self, x: torch.Tensor) -> torch.Tensor: + DIM = x.dim() + x = x.permute(0, *range(2, DIM), 1).contiguous() + x = super().forward(x) + x = x.permute(0, DIM-1, *range(1, DIM-1)).contiguous() + return x + \ No newline at end of file diff --git a/trellis2/modules/sparse/__init__.py b/trellis2/modules/sparse/__init__.py new file mode 100644 index 0000000..e73f232 --- /dev/null +++ b/trellis2/modules/sparse/__init__.py @@ -0,0 +1,69 @@ +from . import config +import importlib + +__attributes = { + 'VarLenTensor': 'basic', + 'varlen_cat': 'basic', + 'varlen_unbind': 'basic', + 'SparseTensor': 'basic', + 'sparse_cat': 'basic', + 'sparse_unbind': 'basic', + 'SparseGroupNorm': 'norm', + 'SparseLayerNorm': 'norm', + 'SparseGroupNorm32': 'norm', + 'SparseLayerNorm32': 'norm', + 'SparseReLU': 'nonlinearity', + 'SparseSiLU': 'nonlinearity', + 'SparseGELU': 'nonlinearity', + 'SparseActivation': 'nonlinearity', + 'SparseLinear': 'linear', + 'sparse_scaled_dot_product_attention': 'attention', + 'SerializeMode': 'attention', + 'sparse_serialized_scaled_dot_product_self_attention': 'attention', + 'sparse_windowed_scaled_dot_product_self_attention': 'attention', + 'sparse_windowed_scaled_dot_product_cross_attention': 'attention', + 'SparseRotaryPositionEmbedder': 'attention', + 'SparseMultiHeadAttention': 'attention', + 'SparseConv3d': 'conv', + 'SparseInverseConv3d': 'conv', + 'SparseDownsample': 'spatial', + 'SparseUpsample': 'spatial', + 'SparseSubdivide': 'spatial', + 'SparseSpatial2Channel': 'spatial', + 'SparseChannel2Spatial': 'spatial', + 'sparse_nearest_interpolate': 'spatial', + 'sparse_trilinear_interpolate': 'spatial', + 'encode_seq': 'serialize', + 'decode_seq': 'serialize', +} + +__submodules = ['transformer', 'conv'] + +__all__ = list(__attributes.keys()) + __submodules + +def __getattr__(name): + if name not in globals(): + if name in __attributes: + module_name = __attributes[name] + module = importlib.import_module(f".{module_name}", __name__) + globals()[name] = getattr(module, name) + elif name in __submodules: + module = importlib.import_module(f".{name}", __name__) + globals()[name] = module + else: + raise AttributeError(f"module {__name__} has no attribute {name}") + return globals()[name] + + +# For Pylance +if __name__ == '__main__': + from .basic import * + from .norm import * + from .nonlinearity import * + from .linear import * + from .attention import * + from .conv import * + from .spatial import * + from .serialize import * + import transformer + import conv diff --git a/trellis2/modules/sparse/attention/__init__.py b/trellis2/modules/sparse/attention/__init__.py new file mode 100644 index 0000000..18ab3cc --- /dev/null +++ b/trellis2/modules/sparse/attention/__init__.py @@ -0,0 +1,3 @@ +from .full_attn import * +from .windowed_attn import * +from .modules import * diff --git a/trellis2/modules/sparse/attention/full_attn.py b/trellis2/modules/sparse/attention/full_attn.py new file mode 100644 index 0000000..4eb74d2 --- /dev/null +++ b/trellis2/modules/sparse/attention/full_attn.py @@ -0,0 +1,220 @@ +from typing import * +import torch +from .. import VarLenTensor +from .. import config + + +__all__ = [ + 'sparse_scaled_dot_product_attention', +] + + +@overload +def sparse_scaled_dot_product_attention(qkv: VarLenTensor) -> VarLenTensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + qkv (VarLenTensor): A [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs. + """ + ... + +@overload +def sparse_scaled_dot_product_attention(q: VarLenTensor, kv: Union[VarLenTensor, torch.Tensor]) -> VarLenTensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (VarLenTensor): A [N, *, H, C] sparse tensor containing Qs. + kv (VarLenTensor or torch.Tensor): A [N, *, 2, H, C] sparse tensor or a [N, L, 2, H, C] dense tensor containing Ks and Vs. + """ + ... + +@overload +def sparse_scaled_dot_product_attention(q: torch.Tensor, kv: VarLenTensor) -> torch.Tensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (torch.Tensor): A [N, L, H, C] dense tensor containing Qs. + kv (VarLenTensor): A [N, *, 2, H, C] sparse tensor containing Ks and Vs. + """ + ... + +@overload +def sparse_scaled_dot_product_attention(q: VarLenTensor, k: VarLenTensor, v: VarLenTensor) -> VarLenTensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (VarLenTensor): A [N, *, H, Ci] sparse tensor containing Qs. + k (VarLenTensor): A [N, *, H, Ci] sparse tensor containing Ks. + v (VarLenTensor): A [N, *, H, Co] sparse tensor containing Vs. + + Note: + k and v are assumed to have the same coordinate map. + """ + ... + +@overload +def sparse_scaled_dot_product_attention(q: VarLenTensor, k: torch.Tensor, v: torch.Tensor) -> VarLenTensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (VarLenTensor): A [N, *, H, Ci] sparse tensor containing Qs. + k (torch.Tensor): A [N, L, H, Ci] dense tensor containing Ks. + v (torch.Tensor): A [N, L, H, Co] dense tensor containing Vs. + """ + ... + +@overload +def sparse_scaled_dot_product_attention(q: torch.Tensor, k: VarLenTensor, v: VarLenTensor) -> torch.Tensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (torch.Tensor): A [N, L, H, Ci] dense tensor containing Qs. + k (VarLenTensor): A [N, *, H, Ci] sparse tensor containing Ks. + v (VarLenTensor): A [N, *, H, Co] sparse tensor containing Vs. + """ + ... + +def sparse_scaled_dot_product_attention(*args, **kwargs): + arg_names_dict = { + 1: ['qkv'], + 2: ['q', 'kv'], + 3: ['q', 'k', 'v'] + } + num_all_args = len(args) + len(kwargs) + assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3" + for key in arg_names_dict[num_all_args][len(args):]: + assert key in kwargs, f"Missing argument {key}" + + if num_all_args == 1: + qkv = args[0] if len(args) > 0 else kwargs['qkv'] + assert isinstance(qkv, VarLenTensor), f"qkv must be a VarLenTensor, got {type(qkv)}" + assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]" + device = qkv.device + + s = qkv + q_seqlen = [qkv.layout[i].stop - qkv.layout[i].start for i in range(qkv.shape[0])] + kv_seqlen = q_seqlen + qkv = qkv.feats # [T, 3, H, C] + + elif num_all_args == 2: + q = args[0] if len(args) > 0 else kwargs['q'] + kv = args[1] if len(args) > 1 else kwargs['kv'] + assert isinstance(q, VarLenTensor) and isinstance(kv, (VarLenTensor, torch.Tensor)) or \ + isinstance(q, torch.Tensor) and isinstance(kv, VarLenTensor), \ + f"Invalid types, got {type(q)} and {type(kv)}" + assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}" + device = q.device + + if isinstance(q, VarLenTensor): + assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, C]" + s = q + q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])] + q = q.feats # [T_Q, H, C] + else: + assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]" + s = None + N, L, H, C = q.shape + q_seqlen = [L] * N + q = q.reshape(N * L, H, C) # [T_Q, H, C] + + if isinstance(kv, VarLenTensor): + assert len(kv.shape) == 4 and kv.shape[1] == 2, f"Invalid shape for kv, got {kv.shape}, expected [N, *, 2, H, C]" + kv_seqlen = [kv.layout[i].stop - kv.layout[i].start for i in range(kv.shape[0])] + kv = kv.feats # [T_KV, 2, H, C] + else: + assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]" + N, L, _, H, C = kv.shape + kv_seqlen = [L] * N + kv = kv.reshape(N * L, 2, H, C) # [T_KV, 2, H, C] + + elif num_all_args == 3: + q = args[0] if len(args) > 0 else kwargs['q'] + k = args[1] if len(args) > 1 else kwargs['k'] + v = args[2] if len(args) > 2 else kwargs['v'] + assert isinstance(q, VarLenTensor) and isinstance(k, (VarLenTensor, torch.Tensor)) and type(k) == type(v) or \ + isinstance(q, torch.Tensor) and isinstance(k, VarLenTensor) and isinstance(v, VarLenTensor), \ + f"Invalid types, got {type(q)}, {type(k)}, and {type(v)}" + assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}" + device = q.device + + if isinstance(q, VarLenTensor): + assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, Ci]" + s = q + q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])] + q = q.feats # [T_Q, H, Ci] + else: + assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]" + s = None + N, L, H, CI = q.shape + q_seqlen = [L] * N + q = q.reshape(N * L, H, CI) # [T_Q, H, Ci] + + if isinstance(k, VarLenTensor): + assert len(k.shape) == 3, f"Invalid shape for k, got {k.shape}, expected [N, *, H, Ci]" + assert len(v.shape) == 3, f"Invalid shape for v, got {v.shape}, expected [N, *, H, Co]" + kv_seqlen = [k.layout[i].stop - k.layout[i].start for i in range(k.shape[0])] + k = k.feats # [T_KV, H, Ci] + v = v.feats # [T_KV, H, Co] + else: + assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]" + assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]" + N, L, H, CI, CO = *k.shape, v.shape[-1] + kv_seqlen = [L] * N + k = k.reshape(N * L, H, CI) # [T_KV, H, Ci] + v = v.reshape(N * L, H, CO) # [T_KV, H, Co] + + if config.ATTN == 'xformers': + if 'xops' not in globals(): + import xformers.ops as xops + if num_all_args == 1: + q, k, v = qkv.unbind(dim=1) + elif num_all_args == 2: + k, v = kv.unbind(dim=1) + q = q.unsqueeze(0) + k = k.unsqueeze(0) + v = v.unsqueeze(0) + mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen) + out = xops.memory_efficient_attention(q, k, v, mask)[0] + elif config.ATTN == 'flash_attn': + if 'flash_attn' not in globals(): + import flash_attn + cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device) + if num_all_args in [2, 3]: + cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device) + if num_all_args == 1: + out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens_q, max(q_seqlen)) + elif num_all_args == 2: + out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)) + elif num_all_args == 3: + out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)) + elif config.ATTN == 'flash_attn_3': + if 'flash_attn_3' not in globals(): + import flash_attn_interface as flash_attn_3 + cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device) + if num_all_args == 1: + q, k, v = qkv.unbind(dim=1) + cu_seqlens_kv = cu_seqlens_q.clone() + max_q_seqlen = max_kv_seqlen = max(q_seqlen) + elif num_all_args == 2: + k, v = kv.unbind(dim=1) + cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device) + max_q_seqlen = max(q_seqlen) + max_kv_seqlen = max(kv_seqlen) + elif num_all_args == 3: + cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device) + max_q_seqlen = max(q_seqlen) + max_kv_seqlen = max(kv_seqlen) + out = flash_attn_3.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_q_seqlen, max_kv_seqlen) + else: + raise ValueError(f"Unknown attention module: {config.ATTN}") + + if s is not None: + return s.replace(out) + else: + return out.reshape(N, L, H, -1) diff --git a/trellis2/modules/sparse/attention/modules.py b/trellis2/modules/sparse/attention/modules.py new file mode 100644 index 0000000..d762b4b --- /dev/null +++ b/trellis2/modules/sparse/attention/modules.py @@ -0,0 +1,141 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +from .. import VarLenTensor, SparseTensor +from .full_attn import sparse_scaled_dot_product_attention +from .windowed_attn import sparse_windowed_scaled_dot_product_self_attention +from .rope import SparseRotaryPositionEmbedder + + +class SparseMultiHeadRMSNorm(nn.Module): + def __init__(self, dim: int, heads: int): + super().__init__() + self.scale = dim ** 0.5 + self.gamma = nn.Parameter(torch.ones(heads, dim)) + + def forward(self, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]: + x_type = x.dtype + x = x.float() + if isinstance(x, VarLenTensor): + x = x.replace(F.normalize(x.feats, dim=-1) * self.gamma * self.scale) + else: + x = F.normalize(x, dim=-1) * self.gamma * self.scale + return x.to(x_type) + + +class SparseMultiHeadAttention(nn.Module): + def __init__( + self, + channels: int, + num_heads: int, + ctx_channels: Optional[int] = None, + type: Literal["self", "cross"] = "self", + attn_mode: Literal["full", "windowed", "double_windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + qkv_bias: bool = True, + use_rope: bool = False, + rope_freq: Tuple[int, int] = (1.0, 10000.0), + qk_rms_norm: bool = False, + ): + super().__init__() + assert channels % num_heads == 0 + assert type in ["self", "cross"], f"Invalid attention type: {type}" + assert attn_mode in ["full", "windowed", "double_windowed"], f"Invalid attention mode: {attn_mode}" + assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention" + assert type == "self" or use_rope is False, "Rotary position embeddings only supported for self-attention" + if attn_mode == 'double_windowed': + assert window_size % 2 == 0, "Window size must be even for double windowed attention" + assert num_heads % 2 == 0, "Number of heads must be even for double windowed attention" + self.channels = channels + self.head_dim = channels // num_heads + self.ctx_channels = ctx_channels if ctx_channels is not None else channels + self.num_heads = num_heads + self._type = type + self.attn_mode = attn_mode + self.window_size = window_size + self.shift_window = shift_window + self.use_rope = use_rope + self.qk_rms_norm = qk_rms_norm + + if self._type == "self": + self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias) + else: + self.to_q = nn.Linear(channels, channels, bias=qkv_bias) + self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias) + + if self.qk_rms_norm: + self.q_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads) + self.k_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads) + + self.to_out = nn.Linear(channels, channels) + + if use_rope: + self.rope = SparseRotaryPositionEmbedder(self.head_dim, rope_freq=rope_freq) + + @staticmethod + def _linear(module: nn.Linear, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]: + if isinstance(x, VarLenTensor): + return x.replace(module(x.feats)) + else: + return module(x) + + @staticmethod + def _reshape_chs(x: Union[VarLenTensor, torch.Tensor], shape: Tuple[int, ...]) -> Union[VarLenTensor, torch.Tensor]: + if isinstance(x, VarLenTensor): + return x.reshape(*shape) + else: + return x.reshape(*x.shape[:2], *shape) + + def _fused_pre(self, x: Union[VarLenTensor, torch.Tensor], num_fused: int) -> Union[VarLenTensor, torch.Tensor]: + if isinstance(x, VarLenTensor): + x_feats = x.feats.unsqueeze(0) + else: + x_feats = x + x_feats = x_feats.reshape(*x_feats.shape[:2], num_fused, self.num_heads, -1) + return x.replace(x_feats.squeeze(0)) if isinstance(x, VarLenTensor) else x_feats + + def forward(self, x: SparseTensor, context: Optional[Union[VarLenTensor, torch.Tensor]] = None) -> SparseTensor: + if self._type == "self": + qkv = self._linear(self.to_qkv, x) + qkv = self._fused_pre(qkv, num_fused=3) + if self.qk_rms_norm or self.use_rope: + q, k, v = qkv.unbind(dim=-3) + if self.qk_rms_norm: + q = self.q_rms_norm(q) + k = self.k_rms_norm(k) + if self.use_rope: + q, k = self.rope(q, k) + qkv = qkv.replace(torch.stack([q.feats, k.feats, v.feats], dim=1)) + if self.attn_mode == "full": + h = sparse_scaled_dot_product_attention(qkv) + elif self.attn_mode == "windowed": + h = sparse_windowed_scaled_dot_product_self_attention( + qkv, self.window_size, shift_window=self.shift_window + ) + elif self.attn_mode == "double_windowed": + qkv0 = qkv.replace(qkv.feats[:, :, self.num_heads//2:]) + qkv1 = qkv.replace(qkv.feats[:, :, :self.num_heads//2]) + h0 = sparse_windowed_scaled_dot_product_self_attention( + qkv0, self.window_size, shift_window=(0, 0, 0) + ) + h1 = sparse_windowed_scaled_dot_product_self_attention( + qkv1, self.window_size, shift_window=tuple([self.window_size//2] * 3) + ) + h = qkv.replace(torch.cat([h0.feats, h1.feats], dim=1)) + else: + q = self._linear(self.to_q, x) + q = self._reshape_chs(q, (self.num_heads, -1)) + kv = self._linear(self.to_kv, context) + kv = self._fused_pre(kv, num_fused=2) + if self.qk_rms_norm: + q = self.q_rms_norm(q) + k, v = kv.unbind(dim=-3) + k = self.k_rms_norm(k) + h = sparse_scaled_dot_product_attention(q, k, v) + else: + h = sparse_scaled_dot_product_attention(q, kv) + h = self._reshape_chs(h, (-1,)) + h = self._linear(self.to_out, h) + return h diff --git a/trellis2/modules/sparse/attention/rope.py b/trellis2/modules/sparse/attention/rope.py new file mode 100644 index 0000000..fb87729 --- /dev/null +++ b/trellis2/modules/sparse/attention/rope.py @@ -0,0 +1,58 @@ +from typing import * +import torch +import torch.nn as nn +from ..basic import SparseTensor + + +class SparseRotaryPositionEmbedder(nn.Module): + def __init__( + self, + head_dim: int, + dim: int = 3, + rope_freq: Tuple[float, float] = (1.0, 10000.0) + ): + super().__init__() + assert head_dim % 2 == 0, "Head dim must be divisible by 2" + self.head_dim = head_dim + self.dim = dim + self.rope_freq = rope_freq + self.freq_dim = head_dim // 2 // dim + self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim + self.freqs = rope_freq[0] / (rope_freq[1] ** (self.freqs)) + + def _get_phases(self, indices: torch.Tensor) -> torch.Tensor: + self.freqs = self.freqs.to(indices.device) + phases = torch.outer(indices, self.freqs) + phases = torch.polar(torch.ones_like(phases), phases) + return phases + + def _rotary_embedding(self, x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor: + x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + x_rotated = x_complex * phases.unsqueeze(-2) + x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype) + return x_embed + + def forward(self, q: SparseTensor, k: Optional[SparseTensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + q (SparseTensor): [..., N, H, D] tensor of queries + k (SparseTensor): [..., N, H, D] tensor of keys + """ + assert q.coords.shape[-1] == self.dim + 1, "Last dimension of coords must be equal to dim+1" + phases_cache_name = f'rope_phase_{self.dim}d_freq{self.rope_freq[0]}-{self.rope_freq[1]}_hd{self.head_dim}' + phases = q.get_spatial_cache(phases_cache_name) + if phases is None: + coords = q.coords[..., 1:] + phases = self._get_phases(coords.reshape(-1)).reshape(*coords.shape[:-1], -1) + if phases.shape[-1] < self.head_dim // 2: + padn = self.head_dim // 2 - phases.shape[-1] + phases = torch.cat([phases, torch.polar( + torch.ones(*phases.shape[:-1], padn, device=phases.device), + torch.zeros(*phases.shape[:-1], padn, device=phases.device) + )], dim=-1) + q.register_spatial_cache(phases_cache_name, phases) + q_embed = q.replace(self._rotary_embedding(q.feats, phases)) + if k is None: + return q_embed + k_embed = k.replace(self._rotary_embedding(k.feats, phases)) + return q_embed, k_embed \ No newline at end of file diff --git a/trellis2/modules/sparse/attention/windowed_attn.py b/trellis2/modules/sparse/attention/windowed_attn.py new file mode 100644 index 0000000..0430788 --- /dev/null +++ b/trellis2/modules/sparse/attention/windowed_attn.py @@ -0,0 +1,190 @@ +from typing import * +import torch +import math +from .. import SparseTensor +from .. import config + + +__all__ = [ + 'sparse_windowed_scaled_dot_product_self_attention', + 'sparse_windowed_scaled_dot_product_cross_attention', +] + + +def calc_window_partition( + tensor: SparseTensor, + window_size: Union[int, Tuple[int, ...]], + shift_window: Union[int, Tuple[int, ...]] = 0, +) -> Tuple[torch.Tensor, torch.Tensor, List[int], List[int]]: + """ + Calculate serialization and partitioning for a set of coordinates. + + Args: + tensor (SparseTensor): The input tensor. + window_size (int): The window size to use. + shift_window (Tuple[int, ...]): The shift of serialized coordinates. + + Returns: + (torch.Tensor): Forwards indices. + (torch.Tensor): Backwards indices. + (torch.Tensor): Sequence lengths. + (dict): Attn func args. + """ + DIM = tensor.coords.shape[1] - 1 + shift_window = (shift_window,) * DIM if isinstance(shift_window, int) else shift_window + window_size = (window_size,) * DIM if isinstance(window_size, int) else window_size + shifted_coords = tensor.coords.clone().detach() + shifted_coords[:, 1:] += torch.tensor(shift_window, device=tensor.device, dtype=torch.int32).unsqueeze(0) + + MAX_COORDS = [i + j for i, j in zip(tensor.spatial_shape, shift_window)] + NUM_WINDOWS = [math.ceil((mc + 1) / ws) for mc, ws in zip(MAX_COORDS, window_size)] + OFFSET = torch.cumprod(torch.tensor([1] + NUM_WINDOWS[::-1]), dim=0).tolist()[::-1] + + shifted_coords[:, 1:] //= torch.tensor(window_size, device=tensor.device, dtype=torch.int32).unsqueeze(0) + shifted_indices = (shifted_coords * torch.tensor(OFFSET, device=tensor.device, dtype=torch.int32).unsqueeze(0)).sum(dim=1) + fwd_indices = torch.argsort(shifted_indices) + bwd_indices = torch.empty_like(fwd_indices) + bwd_indices[fwd_indices] = torch.arange(fwd_indices.shape[0], device=tensor.device) + seq_lens = torch.bincount(shifted_indices) + mask = seq_lens != 0 + seq_lens = seq_lens[mask] + + if config.ATTN == 'xformers': + if 'xops' not in globals(): + import xformers.ops as xops + attn_func_args = { + 'attn_bias': xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens) + } + elif config.ATTN == 'flash_attn': + attn_func_args = { + 'cu_seqlens': torch.cat([torch.tensor([0], device=tensor.device), torch.cumsum(seq_lens, dim=0)], dim=0).int(), + 'max_seqlen': torch.max(seq_lens) + } + + return fwd_indices, bwd_indices, seq_lens, attn_func_args + + +def sparse_windowed_scaled_dot_product_self_attention( + qkv: SparseTensor, + window_size: int, + shift_window: Tuple[int, int, int] = (0, 0, 0) +) -> SparseTensor: + """ + Apply windowed scaled dot product self attention to a sparse tensor. + + Args: + qkv (SparseTensor): [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs. + window_size (int): The window size to use. + shift_window (Tuple[int, int, int]): The shift of serialized coordinates. + + Returns: + (SparseTensor): [N, *, H, C] sparse tensor containing the output features. + """ + assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]" + + serialization_spatial_cache_name = f'windowed_attention_{window_size}_{shift_window}' + serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name) + if serialization_spatial_cache is None: + fwd_indices, bwd_indices, seq_lens, attn_func_args = calc_window_partition(qkv, window_size, shift_window) + qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, attn_func_args)) + else: + fwd_indices, bwd_indices, seq_lens, attn_func_args = serialization_spatial_cache + + qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C] + + if config.DEBUG: + start = 0 + qkv_coords = qkv.coords[fwd_indices] + for i in range(len(seq_lens)): + seq_coords = qkv_coords[start:start+seq_lens[i]] + assert (seq_coords[:, 1:].max(dim=0).values - seq_coords[:, 1:].min(dim=0).values < window_size).all(), \ + f"SparseWindowedScaledDotProductSelfAttention: window size exceeded" + start += seq_lens[i] + + if config.ATTN == 'xformers': + if 'xops' not in globals(): + import xformers.ops as xops + q, k, v = qkv_feats.unbind(dim=1) # [M, H, C] + q = q.unsqueeze(0) # [1, M, H, C] + k = k.unsqueeze(0) # [1, M, H, C] + v = v.unsqueeze(0) # [1, M, H, C] + out = xops.memory_efficient_attention(q, k, v, **attn_func_args)[0] # [M, H, C] + elif config.ATTN == 'flash_attn': + if 'flash_attn' not in globals(): + import flash_attn + out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, **attn_func_args) # [M, H, C] + + out = out[bwd_indices] # [T, H, C] + + if config.DEBUG: + qkv_coords = qkv_coords[bwd_indices] + assert torch.equal(qkv_coords, qkv.coords), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch" + + return qkv.replace(out) + + +def sparse_windowed_scaled_dot_product_cross_attention( + q: SparseTensor, + kv: SparseTensor, + q_window_size: int, + kv_window_size: int, + q_shift_window: Tuple[int, int, int] = (0, 0, 0), + kv_shift_window: Tuple[int, int, int] = (0, 0, 0), +) -> SparseTensor: + """ + Apply windowed scaled dot product cross attention to two sparse tensors. + + Args: + q (SparseTensor): [N, *, H, C] sparse tensor containing Qs. + kv (SparseTensor): [N, *, 2, H, C] sparse tensor containing Ks and Vs. + q_window_size (int): The window size to use for Qs. + kv_window_size (int): The window size to use for Ks and Vs. + q_shift_window (Tuple[int, int, int]): The shift of serialized coordinates for Qs. + kv_shift_window (Tuple[int, int, int]): The shift of serialized coordinates for Ks and Vs. + + Returns: + (SparseTensor): [N, *, H, C] sparse tensor containing the output features. + """ + assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, C]" + assert len(kv.shape) == 4 and kv.shape[1] == 2, f"Invalid shape for kv, got {kv.shape}, expected [N, *, 2, H, C]" + + q_serialization_spatial_cache_name = f'windowed_attention_{q_window_size}_{q_shift_window}' + q_serialization_spatial_cache = q.get_spatial_cache(q_serialization_spatial_cache_name) + if q_serialization_spatial_cache is None: + q_fwd_indices, q_bwd_indices, q_seq_lens, q_attn_func_args = calc_window_partition(q, q_window_size, q_shift_window) + q.register_spatial_cache(q_serialization_spatial_cache_name, (q_fwd_indices, q_bwd_indices, q_seq_lens, q_attn_func_args)) + else: + q_fwd_indices, q_bwd_indices, q_seq_lens, q_attn_func_args = q_serialization_spatial_cache + kv_serialization_spatial_cache_name = f'windowed_attention_{kv_window_size}_{kv_shift_window}' + kv_serialization_spatial_cache = kv.get_spatial_cache(kv_serialization_spatial_cache_name) + if kv_serialization_spatial_cache is None: + kv_fwd_indices, kv_bwd_indices, kv_seq_lens, kv_attn_func_args = calc_window_partition(kv, kv_window_size, kv_shift_window) + kv.register_spatial_cache(kv_serialization_spatial_cache_name, (kv_fwd_indices, kv_bwd_indices, kv_seq_lens, kv_attn_func_args)) + else: + kv_fwd_indices, kv_bwd_indices, kv_seq_lens, kv_attn_func_args = kv_serialization_spatial_cache + + assert len(q_seq_lens) == len(kv_seq_lens), "Number of sequences in q and kv must match" + + q_feats = q.feats[q_fwd_indices] # [M, H, C] + kv_feats = kv.feats[kv_fwd_indices] # [M, 2, H, C] + + if config.ATTN == 'xformers': + if 'xops' not in globals(): + import xformers.ops as xops + k, v = kv_feats.unbind(dim=1) # [M, H, C] + q = q.unsqueeze(0) # [1, M, H, C] + k = k.unsqueeze(0) # [1, M, H, C] + v = v.unsqueeze(0) # [1, M, H, C] + mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seq_lens, kv_seq_lens) + out = xops.memory_efficient_attention(q, k, v, attn_bias=mask)[0] # [M, H, C] + elif config.ATTN == 'flash_attn': + if 'flash_attn' not in globals(): + import flash_attn + out = flash_attn.flash_attn_varlen_kvpacked_func(q_feats, kv_feats, + cu_seqlens_q=q_attn_func_args['cu_seqlens'], cu_seqlens_k=kv_attn_func_args['cu_seqlens'], + max_seqlen_q=q_attn_func_args['max_seqlen'], max_seqlen_k=kv_attn_func_args['max_seqlen'], + ) # [M, H, C] + + out = out[q_bwd_indices] # [T, H, C] + + return q.replace(out) diff --git a/trellis2/modules/sparse/basic.py b/trellis2/modules/sparse/basic.py new file mode 100644 index 0000000..880973b --- /dev/null +++ b/trellis2/modules/sparse/basic.py @@ -0,0 +1,836 @@ +from typing import * +from fractions import Fraction +import torch +from . import config + + +__all__ = [ + 'VarLenTensor', + 'varlen_cat', + 'varlen_unbind', + 'SparseTensor', + 'sparse_cat', + 'sparse_unbind', +] + + +class VarLenTensor: + """ + Sequential tensor with variable length. + + Args: + feats (torch.Tensor): Features of the varlen tensor. + layout (List[slice]): Layout of the varlen tensor for each batch + """ + def __init__(self, feats: torch.Tensor, layout: List[slice]=None): + self.feats = feats + self.layout = layout if layout is not None else [slice(0, feats.shape[0])] + self._cache = {} + + @staticmethod + def layout_from_seqlen(seqlen: list) -> List[slice]: + """ + Create a layout from a tensor of sequence lengths. + """ + layout = [] + start = 0 + for l in seqlen: + layout.append(slice(start, start + l)) + start += l + return layout + + @staticmethod + def from_tensor_list(tensor_list: List[torch.Tensor]) -> 'VarLenTensor': + """ + Create a VarLenTensor from a list of tensors. + """ + feats = torch.cat(tensor_list, dim=0) + layout = [] + start = 0 + for tensor in tensor_list: + layout.append(slice(start, start + tensor.shape[0])) + start += tensor.shape[0] + return VarLenTensor(feats, layout) + + def to_tensor_list(self) -> List[torch.Tensor]: + """ + Convert a VarLenTensor to a list of tensors. + """ + tensor_list = [] + for s in self.layout: + tensor_list.append(self.feats[s]) + return tensor_list + + def __len__(self) -> int: + return len(self.layout) + + @property + def shape(self) -> torch.Size: + return torch.Size([len(self.layout), *self.feats.shape[1:]]) + + def dim(self) -> int: + return len(self.shape) + + @property + def ndim(self) -> int: + return self.dim() + + @property + def dtype(self): + return self.feats.dtype + + @property + def device(self): + return self.feats.device + + @property + def seqlen(self) -> torch.LongTensor: + if 'seqlen' not in self._cache: + self._cache['seqlen'] = torch.tensor([l.stop - l.start for l in self.layout], dtype=torch.long, device=self.device) + return self._cache['seqlen'] + + @property + def cum_seqlen(self) -> torch.LongTensor: + if 'cum_seqlen' not in self._cache: + self._cache['cum_seqlen'] = torch.cat([ + torch.tensor([0], dtype=torch.long, device=self.device), + self.seqlen.cumsum(dim=0) + ], dim=0) + return self._cache['cum_seqlen'] + + @property + def batch_boardcast_map(self) -> torch.LongTensor: + """ + Get the broadcast map for the varlen tensor. + """ + if 'batch_boardcast_map' not in self._cache: + self._cache['batch_boardcast_map'] = torch.repeat_interleave( + torch.arange(len(self.layout), device=self.device), + self.seqlen, + ) + return self._cache['batch_boardcast_map'] + + @overload + def to(self, dtype: torch.dtype, *, non_blocking: bool = False, copy: bool = False) -> 'VarLenTensor': ... + + @overload + def to(self, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None, *, non_blocking: bool = False, copy: bool = False) -> 'VarLenTensor': ... + + def to(self, *args, **kwargs) -> 'VarLenTensor': + device = None + dtype = None + if len(args) == 2: + device, dtype = args + elif len(args) == 1: + if isinstance(args[0], torch.dtype): + dtype = args[0] + else: + device = args[0] + if 'dtype' in kwargs: + assert dtype is None, "to() received multiple values for argument 'dtype'" + dtype = kwargs['dtype'] + if 'device' in kwargs: + assert device is None, "to() received multiple values for argument 'device'" + device = kwargs['device'] + non_blocking = kwargs.get('non_blocking', False) + copy = kwargs.get('copy', False) + + new_feats = self.feats.to(device=device, dtype=dtype, non_blocking=non_blocking, copy=copy) + return self.replace(new_feats) + + def type(self, dtype): + new_feats = self.feats.type(dtype) + return self.replace(new_feats) + + def cpu(self) -> 'VarLenTensor': + new_feats = self.feats.cpu() + return self.replace(new_feats) + + def cuda(self) -> 'VarLenTensor': + new_feats = self.feats.cuda() + return self.replace(new_feats) + + def half(self) -> 'VarLenTensor': + new_feats = self.feats.half() + return self.replace(new_feats) + + def float(self) -> 'VarLenTensor': + new_feats = self.feats.float() + return self.replace(new_feats) + + def detach(self) -> 'VarLenTensor': + new_feats = self.feats.detach() + return self.replace(new_feats) + + def reshape(self, *shape) -> 'VarLenTensor': + new_feats = self.feats.reshape(self.feats.shape[0], *shape) + return self.replace(new_feats) + + def unbind(self, dim: int) -> List['VarLenTensor']: + return varlen_unbind(self, dim) + + def replace(self, feats: torch.Tensor) -> 'VarLenTensor': + new_tensor = VarLenTensor( + feats=feats, + layout=self.layout, + ) + new_tensor._cache = self._cache + return new_tensor + + def to_dense(self, max_length=None) -> torch.Tensor: + """ + Convert a VarLenTensor to a dense representation without for-loop. + + Returns: + dense (torch.Tensor): (N, L, C) dense tensor + mask (torch.BoolTensor): (N, L) mask indicating valid positions + """ + N = len(self) + L = max_length or self.seqlen.max().item() + spatial = self.feats.shape[1:] + idx = torch.arange(L, device=self.device).unsqueeze(0).expand(N, L) + mask = (idx < self.seqlen.unsqueeze(1)) + mapping = mask.reshape(-1).cumsum(dim=0) - 1 + dense = self.feats[mapping] + dense = dense.reshape(N, L, *spatial) + return dense, mask + + def __neg__(self) -> 'VarLenTensor': + return self.replace(-self.feats) + + def __elemwise__(self, other: Union[torch.Tensor, 'VarLenTensor'], op: callable) -> 'VarLenTensor': + if isinstance(other, torch.Tensor): + try: + other = torch.broadcast_to(other, self.shape) + other = other[self.batch_boardcast_map] + except: + pass + if isinstance(other, VarLenTensor): + other = other.feats + new_feats = op(self.feats, other) + new_tensor = self.replace(new_feats) + return new_tensor + + def __add__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, torch.add) + + def __radd__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, torch.add) + + def __sub__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, torch.sub) + + def __rsub__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, lambda x, y: torch.sub(y, x)) + + def __mul__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, torch.mul) + + def __rmul__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, torch.mul) + + def __truediv__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, torch.div) + + def __rtruediv__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, lambda x, y: torch.div(y, x)) + + def __getitem__(self, idx): + if isinstance(idx, int): + idx = [idx] + elif isinstance(idx, slice): + idx = range(*idx.indices(self.shape[0])) + elif isinstance(idx, list): + assert all(isinstance(i, int) for i in idx), f"Only integer indices are supported: {idx}" + elif isinstance(idx, torch.Tensor): + if idx.dtype == torch.bool: + assert idx.shape == (self.shape[0],), f"Invalid index shape: {idx.shape}" + idx = idx.nonzero().squeeze(1) + elif idx.dtype in [torch.int32, torch.int64]: + assert len(idx.shape) == 1, f"Invalid index shape: {idx.shape}" + else: + raise ValueError(f"Unknown index type: {idx.dtype}") + else: + raise ValueError(f"Unknown index type: {type(idx)}") + + new_feats = [] + new_layout = [] + start = 0 + for new_idx, old_idx in enumerate(idx): + new_feats.append(self.feats[self.layout[old_idx]]) + new_layout.append(slice(start, start + len(new_feats[-1]))) + start += len(new_feats[-1]) + new_feats = torch.cat(new_feats, dim=0).contiguous() + new_tensor = VarLenTensor(feats=new_feats, layout=new_layout) + return new_tensor + + def reduce(self, op: str, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor: + if isinstance(dim, int): + dim = (dim,) + + if op =='mean': + red = self.feats.mean(dim=dim, keepdim=keepdim) + elif op =='sum': + red = self.feats.sum(dim=dim, keepdim=keepdim) + elif op == 'prod': + red = self.feats.prod(dim=dim, keepdim=keepdim) + else: + raise ValueError(f"Unsupported reduce operation: {op}") + + if dim is None or 0 in dim: + return red + + red = torch.segment_reduce(red, reduce=op, lengths=self.seqlen) + return red + + def mean(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor: + return self.reduce(op='mean', dim=dim, keepdim=keepdim) + + def sum(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor: + return self.reduce(op='sum', dim=dim, keepdim=keepdim) + + def prod(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor: + return self.reduce(op='prod', dim=dim, keepdim=keepdim) + + def std(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor: + mean = self.mean(dim=dim, keepdim=True) + mean2 = self.replace(self.feats ** 2).mean(dim=dim, keepdim=True) + std = (mean2 - mean ** 2).sqrt() + return std + + def __repr__(self) -> str: + return f"VarLenTensor(shape={self.shape}, dtype={self.dtype}, device={self.device})" + + +def varlen_cat(inputs: List[VarLenTensor], dim: int = 0) -> VarLenTensor: + """ + Concatenate a list of varlen tensors. + + Args: + inputs (List[VarLenTensor]): List of varlen tensors to concatenate. + """ + if dim == 0: + new_feats = torch.cat([input.feats for input in inputs], dim=0) + start = 0 + new_layout = [] + for input in inputs: + for l in input.layout: + new_layout.append(slice(start, start + l.stop - l.start)) + start += l.stop - l.start + output = VarLenTensor(feats=new_feats, layout=new_layout) + else: + feats = torch.cat([input.feats for input in inputs], dim=dim) + output = inputs[0].replace(feats) + + return output + + +def varlen_unbind(input: VarLenTensor, dim: int) -> Union[List[VarLenTensor]]: + """ + Unbind a varlen tensor along a dimension. + + Args: + input (VarLenTensor): Varlen tensor to unbind. + dim (int): Dimension to unbind. + """ + if dim == 0: + return [input[i] for i in range(len(input))] + else: + feats = input.feats.unbind(dim) + return [input.replace(f) for f in feats] + + +class SparseTensor(VarLenTensor): + """ + Sparse tensor with support for both torchsparse and spconv backends. + + Parameters: + - feats (torch.Tensor): Features of the sparse tensor. + - coords (torch.Tensor): Coordinates of the sparse tensor. + - shape (torch.Size): Shape of the sparse tensor. + - layout (List[slice]): Layout of the sparse tensor for each batch + - data (SparseTensorData): Sparse tensor data used for convolusion + + NOTE: + - Data corresponding to a same batch should be contiguous. + - Coords should be in [0, 1023] + """ + SparseTensorData = None + + @overload + def __init__(self, feats: torch.Tensor, coords: torch.Tensor, shape: Optional[torch.Size] = None, **kwargs): ... + + @overload + def __init__(self, data, shape: Optional[torch.Size] = None, **kwargs): ... + + def __init__(self, *args, **kwargs): + # Lazy import of sparse tensor backend + if self.SparseTensorData is None: + import importlib + if config.CONV == 'torchsparse': + self.SparseTensorData = importlib.import_module('torchsparse').SparseTensor + elif config.CONV == 'spconv': + self.SparseTensorData = importlib.import_module('spconv.pytorch').SparseConvTensor + + method_id = 0 + if len(args) != 0: + method_id = 0 if isinstance(args[0], torch.Tensor) else 1 + else: + method_id = 1 if 'data' in kwargs else 0 + + if method_id == 0: + feats, coords, shape = args + (None,) * (3 - len(args)) + if 'feats' in kwargs: + feats = kwargs['feats'] + del kwargs['feats'] + if 'coords' in kwargs: + coords = kwargs['coords'] + del kwargs['coords'] + if 'shape' in kwargs: + shape = kwargs['shape'] + del kwargs['shape'] + + if config.CONV == 'torchsparse': + self.data = self.SparseTensorData(feats, coords, **kwargs) + elif config.CONV == 'spconv': + spatial_shape = list(coords.max(0)[0] + 1) + self.data = self.SparseTensorData(feats.reshape(feats.shape[0], -1), coords, spatial_shape[1:], spatial_shape[0], **kwargs) + self.data._features = feats + else: + self.data = { + 'feats': feats, + 'coords': coords, + } + elif method_id == 1: + data, shape = args + (None,) * (2 - len(args)) + if 'data' in kwargs: + data = kwargs['data'] + del kwargs['data'] + if 'shape' in kwargs: + shape = kwargs['shape'] + del kwargs['shape'] + + self.data = data + + self._shape = shape + self._scale = kwargs.get('scale', (Fraction(1, 1), Fraction(1, 1), Fraction(1, 1))) + self._spatial_cache = kwargs.get('spatial_cache', {}) + + if config.DEBUG: + try: + assert self.feats.shape[0] == self.coords.shape[0], f"Invalid feats shape: {self.feats.shape}, coords shape: {self.coords.shape}" + assert self.shape == self.__cal_shape(self.feats, self.coords), f"Invalid shape: {self.shape}" + assert self.layout == self.__cal_layout(self.coords, self.shape[0]), f"Invalid layout: {self.layout}" + for i in range(self.shape[0]): + assert torch.all(self.coords[self.layout[i], 0] == i), f"The data of batch {i} is not contiguous" + except Exception as e: + print('Debugging information:') + print(f"- Shape: {self.shape}") + print(f"- Layout: {self.layout}") + print(f"- Scale: {self._scale}") + print(f"- Coords: {self.coords}") + raise e + + @staticmethod + def from_tensor_list(feats_list: List[torch.Tensor], coords_list: List[torch.Tensor]) -> 'SparseTensor': + """ + Create a SparseTensor from a list of tensors. + """ + feats = torch.cat(feats_list, dim=0) + coords = [] + for i, coord in enumerate(coords_list): + coord = torch.cat([torch.full_like(coord[:, :1], i), coord[:, 1:]], dim=1) + coords.append(coord) + coords = torch.cat(coords, dim=0) + return SparseTensor(feats, coords) + + def to_tensor_list(self) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + Convert a SparseTensor to list of tensors. + """ + feats_list = [] + coords_list = [] + for s in self.layout: + feats_list.append(self.feats[s]) + coords_list.append(self.coords[s]) + return feats_list, coords_list + + def __len__(self) -> int: + return len(self.layout) + + def __cal_shape(self, feats, coords): + shape = [] + shape.append(coords[:, 0].max().item() + 1) + shape.extend([*feats.shape[1:]]) + return torch.Size(shape) + + def __cal_layout(self, coords, batch_size): + seq_len = torch.bincount(coords[:, 0], minlength=batch_size) + offset = torch.cumsum(seq_len, dim=0) + layout = [slice((offset[i] - seq_len[i]).item(), offset[i].item()) for i in range(batch_size)] + return layout + + def __cal_spatial_shape(self, coords): + return torch.Size((coords[:, 1:].max(0)[0] + 1).tolist()) + + @property + def shape(self) -> torch.Size: + if self._shape is None: + self._shape = self.__cal_shape(self.feats, self.coords) + return self._shape + + @property + def layout(self) -> List[slice]: + layout = self.get_spatial_cache('layout') + if layout is None: + layout = self.__cal_layout(self.coords, self.shape[0]) + self.register_spatial_cache('layout', layout) + return layout + + @property + def spatial_shape(self) -> torch.Size: + spatial_shape = self.get_spatial_cache('shape') + if spatial_shape is None: + spatial_shape = self.__cal_spatial_shape(self.coords) + self.register_spatial_cache('shape', spatial_shape) + return spatial_shape + + @property + def feats(self) -> torch.Tensor: + if config.CONV == 'torchsparse': + return self.data.F + elif config.CONV == 'spconv': + return self.data.features + else: + return self.data['feats'] + + @feats.setter + def feats(self, value: torch.Tensor): + if config.CONV == 'torchsparse': + self.data.F = value + elif config.CONV == 'spconv': + self.data.features = value + else: + self.data['feats'] = value + + @property + def coords(self) -> torch.Tensor: + if config.CONV == 'torchsparse': + return self.data.C + elif config.CONV == 'spconv': + return self.data.indices + else: + return self.data['coords'] + + @coords.setter + def coords(self, value: torch.Tensor): + if config.CONV == 'torchsparse': + self.data.C = value + elif config.CONV == 'spconv': + self.data.indices = value + else: + self.data['coords'] = value + + @property + def dtype(self): + return self.feats.dtype + + @property + def device(self): + return self.feats.device + + @property + def seqlen(self) -> torch.LongTensor: + seqlen = self.get_spatial_cache('seqlen') + if seqlen is None: + seqlen = torch.tensor([l.stop - l.start for l in self.layout], dtype=torch.long, device=self.device) + self.register_spatial_cache('seqlen', seqlen) + return seqlen + + @property + def cum_seqlen(self) -> torch.LongTensor: + cum_seqlen = self.get_spatial_cache('cum_seqlen') + if cum_seqlen is None: + cum_seqlen = torch.cat([ + torch.tensor([0], dtype=torch.long, device=self.device), + self.seqlen.cumsum(dim=0) + ], dim=0) + self.register_spatial_cache('cum_seqlen', cum_seqlen) + return cum_seqlen + + @property + def batch_boardcast_map(self) -> torch.LongTensor: + """ + Get the broadcast map for the varlen tensor. + """ + batch_boardcast_map = self.get_spatial_cache('batch_boardcast_map') + if batch_boardcast_map is None: + batch_boardcast_map = torch.repeat_interleave( + torch.arange(len(self.layout), device=self.device), + self.seqlen, + ) + self.register_spatial_cache('batch_boardcast_map', batch_boardcast_map) + return batch_boardcast_map + + @overload + def to(self, dtype: torch.dtype, *, non_blocking: bool = False, copy: bool = False) -> 'SparseTensor': ... + + @overload + def to(self, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None, *, non_blocking: bool = False, copy: bool = False) -> 'SparseTensor': ... + + def to(self, *args, **kwargs) -> 'SparseTensor': + device = None + dtype = None + if len(args) == 2: + device, dtype = args + elif len(args) == 1: + if isinstance(args[0], torch.dtype): + dtype = args[0] + else: + device = args[0] + if 'dtype' in kwargs: + assert dtype is None, "to() received multiple values for argument 'dtype'" + dtype = kwargs['dtype'] + if 'device' in kwargs: + assert device is None, "to() received multiple values for argument 'device'" + device = kwargs['device'] + non_blocking = kwargs.get('non_blocking', False) + copy = kwargs.get('copy', False) + + new_feats = self.feats.to(device=device, dtype=dtype, non_blocking=non_blocking, copy=copy) + new_coords = self.coords.to(device=device, non_blocking=non_blocking, copy=copy) + return self.replace(new_feats, new_coords) + + def type(self, dtype): + new_feats = self.feats.type(dtype) + return self.replace(new_feats) + + def cpu(self) -> 'SparseTensor': + new_feats = self.feats.cpu() + new_coords = self.coords.cpu() + return self.replace(new_feats, new_coords) + + def cuda(self) -> 'SparseTensor': + new_feats = self.feats.cuda() + new_coords = self.coords.cuda() + return self.replace(new_feats, new_coords) + + def half(self) -> 'SparseTensor': + new_feats = self.feats.half() + return self.replace(new_feats) + + def float(self) -> 'SparseTensor': + new_feats = self.feats.float() + return self.replace(new_feats) + + def detach(self) -> 'SparseTensor': + new_coords = self.coords.detach() + new_feats = self.feats.detach() + return self.replace(new_feats, new_coords) + + def reshape(self, *shape) -> 'SparseTensor': + new_feats = self.feats.reshape(self.feats.shape[0], *shape) + return self.replace(new_feats) + + def unbind(self, dim: int) -> List['SparseTensor']: + return sparse_unbind(self, dim) + + def replace(self, feats: torch.Tensor, coords: Optional[torch.Tensor] = None) -> 'SparseTensor': + if config.CONV == 'torchsparse': + new_data = self.SparseTensorData( + feats=feats, + coords=self.data.coords if coords is None else coords, + stride=self.data.stride, + spatial_range=self.data.spatial_range, + ) + new_data._caches = self.data._caches + elif config.CONV == 'spconv': + new_data = self.SparseTensorData( + self.data.features.reshape(self.data.features.shape[0], -1), + self.data.indices, + self.data.spatial_shape, + self.data.batch_size, + self.data.grid, + self.data.voxel_num, + self.data.indice_dict + ) + new_data._features = feats + new_data.benchmark = self.data.benchmark + new_data.benchmark_record = self.data.benchmark_record + new_data.thrust_allocator = self.data.thrust_allocator + new_data._timer = self.data._timer + new_data.force_algo = self.data.force_algo + new_data.int8_scale = self.data.int8_scale + if coords is not None: + new_data.indices = coords + else: + new_data = { + 'feats': feats, + 'coords': self.data['coords'] if coords is None else coords, + } + new_tensor = SparseTensor( + new_data, + shape=torch.Size([self._shape[0]] + list(feats.shape[1:])) if self._shape is not None else None, + scale=self._scale, + spatial_cache=self._spatial_cache + ) + return new_tensor + + def to_dense(self) -> torch.Tensor: + if config.CONV == 'torchsparse': + return self.data.dense() + elif config.CONV == 'spconv': + return self.data.dense() + else: + spatial_shape = self.spatial_shape + ret = torch.zeros(*self.shape, *spatial_shape, dtype=self.dtype, device=self.device) + idx = [self.coords[:, 0], slice(None)] + self.coords[:, 1:].unbind(1) + ret[tuple(idx)] = self.feats + return ret + + @staticmethod + def full(aabb, dim, value, dtype=torch.float32, device=None) -> 'SparseTensor': + N, C = dim + x = torch.arange(aabb[0], aabb[3] + 1) + y = torch.arange(aabb[1], aabb[4] + 1) + z = torch.arange(aabb[2], aabb[5] + 1) + coords = torch.stack(torch.meshgrid(x, y, z, indexing='ij'), dim=-1).reshape(-1, 3) + coords = torch.cat([ + torch.arange(N).view(-1, 1).repeat(1, coords.shape[0]).view(-1, 1), + coords.repeat(N, 1), + ], dim=1).to(dtype=torch.int32, device=device) + feats = torch.full((coords.shape[0], C), value, dtype=dtype, device=device) + return SparseTensor(feats=feats, coords=coords) + + def __merge_sparse_cache(self, other: 'SparseTensor') -> dict: + new_cache = {} + for k in set(list(self._spatial_cache.keys()) + list(other._spatial_cache.keys())): + if k in self._spatial_cache: + new_cache[k] = self._spatial_cache[k] + if k in other._spatial_cache: + if k not in new_cache: + new_cache[k] = other._spatial_cache[k] + else: + new_cache[k].update(other._spatial_cache[k]) + return new_cache + + def __elemwise__(self, other: Union[torch.Tensor, VarLenTensor], op: callable) -> 'SparseTensor': + if isinstance(other, torch.Tensor): + try: + other = torch.broadcast_to(other, self.shape) + other = other[self.batch_boardcast_map] + except: + pass + if isinstance(other, VarLenTensor): + other = other.feats + new_feats = op(self.feats, other) + new_tensor = self.replace(new_feats) + if isinstance(other, SparseTensor): + new_tensor._spatial_cache = self.__merge_sparse_cache(other) + return new_tensor + + def __getitem__(self, idx): + if isinstance(idx, int): + idx = [idx] + elif isinstance(idx, slice): + idx = range(*idx.indices(self.shape[0])) + elif isinstance(idx, list): + assert all(isinstance(i, int) for i in idx), f"Only integer indices are supported: {idx}" + elif isinstance(idx, torch.Tensor): + if idx.dtype == torch.bool: + assert idx.shape == (self.shape[0],), f"Invalid index shape: {idx.shape}" + idx = idx.nonzero().squeeze(1) + elif idx.dtype in [torch.int32, torch.int64]: + assert len(idx.shape) == 1, f"Invalid index shape: {idx.shape}" + else: + raise ValueError(f"Unknown index type: {idx.dtype}") + else: + raise ValueError(f"Unknown index type: {type(idx)}") + + new_coords = [] + new_feats = [] + new_layout = [] + new_shape = torch.Size([len(idx)] + list(self.shape[1:])) + start = 0 + for new_idx, old_idx in enumerate(idx): + new_coords.append(self.coords[self.layout[old_idx]].clone()) + new_coords[-1][:, 0] = new_idx + new_feats.append(self.feats[self.layout[old_idx]]) + new_layout.append(slice(start, start + len(new_coords[-1]))) + start += len(new_coords[-1]) + new_coords = torch.cat(new_coords, dim=0).contiguous() + new_feats = torch.cat(new_feats, dim=0).contiguous() + new_tensor = SparseTensor(feats=new_feats, coords=new_coords, shape=new_shape) + new_tensor.register_spatial_cache('layout', new_layout) + return new_tensor + + def clear_spatial_cache(self) -> None: + """ + Clear all spatial caches. + """ + self._spatial_cache = {} + + def register_spatial_cache(self, key, value) -> None: + """ + Register a spatial cache. + The spatial cache can be any thing you want to cache. + The registery and retrieval of the cache is based on current scale. + """ + scale_key = str(self._scale) + if scale_key not in self._spatial_cache: + self._spatial_cache[scale_key] = {} + self._spatial_cache[scale_key][key] = value + + def get_spatial_cache(self, key=None): + """ + Get a spatial cache. + """ + scale_key = str(self._scale) + cur_scale_cache = self._spatial_cache.get(scale_key, {}) + if key is None: + return cur_scale_cache + return cur_scale_cache.get(key, None) + + def __repr__(self) -> str: + return f"SparseTensor(shape={self.shape}, dtype={self.dtype}, device={self.device})" + +def sparse_cat(inputs: List[SparseTensor], dim: int = 0) -> SparseTensor: + """ + Concatenate a list of sparse tensors. + + Args: + inputs (List[SparseTensor]): List of sparse tensors to concatenate. + """ + if dim == 0: + start = 0 + coords = [] + for input in inputs: + coords.append(input.coords.clone()) + coords[-1][:, 0] += start + start += input.shape[0] + coords = torch.cat(coords, dim=0) + feats = torch.cat([input.feats for input in inputs], dim=0) + output = SparseTensor( + coords=coords, + feats=feats, + ) + else: + feats = torch.cat([input.feats for input in inputs], dim=dim) + output = inputs[0].replace(feats) + + return output + + +def sparse_unbind(input: SparseTensor, dim: int) -> List[SparseTensor]: + """ + Unbind a sparse tensor along a dimension. + + Args: + input (SparseTensor): Sparse tensor to unbind. + dim (int): Dimension to unbind. + """ + if dim == 0: + return [input[i] for i in range(input.shape[0])] + else: + feats = input.feats.unbind(dim) + return [input.replace(f) for f in feats] diff --git a/trellis2/modules/sparse/config.py b/trellis2/modules/sparse/config.py new file mode 100644 index 0000000..a5f4d53 --- /dev/null +++ b/trellis2/modules/sparse/config.py @@ -0,0 +1,43 @@ +from typing import * + +CONV = 'flex_gemm' +DEBUG = False +ATTN = 'flash_attn' + +def __from_env(): + import os + + global CONV + global DEBUG + global ATTN + + env_sparse_conv_backend = os.environ.get('SPARSE_CONV_BACKEND') + env_sparse_debug = os.environ.get('SPARSE_DEBUG') + env_sparse_attn_backend = os.environ.get('SPARSE_ATTN_BACKEND') + if env_sparse_attn_backend is None: + env_sparse_attn_backend = os.environ.get('ATTN_BACKEND') + + if env_sparse_conv_backend is not None and env_sparse_conv_backend in ['none', 'spconv', 'torchsparse', 'flex_gemm']: + CONV = env_sparse_conv_backend + if env_sparse_debug is not None: + DEBUG = env_sparse_debug == '1' + if env_sparse_attn_backend is not None and env_sparse_attn_backend in ['xformers', 'flash_attn', 'flash_attn_3']: + ATTN = env_sparse_attn_backend + + print(f"[SPARSE] Conv backend: {CONV}; Attention backend: {ATTN}") + + +__from_env() + + +def set_conv_backend(backend: Literal['none', 'spconv', 'torchsparse', 'flex_gemm']): + global CONV + CONV = backend + +def set_debug(debug: bool): + global DEBUG + DEBUG = debug + +def set_attn_backend(backend: Literal['xformers', 'flash_attn']): + global ATTN + ATTN = backend diff --git a/trellis2/modules/sparse/conv/__init__.py b/trellis2/modules/sparse/conv/__init__.py new file mode 100644 index 0000000..a7f5911 --- /dev/null +++ b/trellis2/modules/sparse/conv/__init__.py @@ -0,0 +1,2 @@ +from .conv import SparseConv3d, SparseInverseConv3d +from . import config diff --git a/trellis2/modules/sparse/conv/config.py b/trellis2/modules/sparse/conv/config.py new file mode 100644 index 0000000..ac08489 --- /dev/null +++ b/trellis2/modules/sparse/conv/config.py @@ -0,0 +1,3 @@ +SPCONV_ALGO = 'auto' # 'auto', 'implicit_gemm', 'native' +FLEX_GEMM_ALGO = 'masked_implicit_gemm_splitk' # 'explicit_gemm', 'implicit_gemm', 'implicit_gemm_splitk', 'masked_implicit_gemm', 'masked_implicit_gemm_splitk' +FLEX_GEMM_HASHMAP_RATIO = 2.0 # Ratio of hashmap size to input size diff --git a/trellis2/modules/sparse/conv/conv.py b/trellis2/modules/sparse/conv/conv.py new file mode 100644 index 0000000..4c7d407 --- /dev/null +++ b/trellis2/modules/sparse/conv/conv.py @@ -0,0 +1,30 @@ +from .. import config +import importlib +import torch +import torch.nn as nn +from .. import SparseTensor + + +_backends = {} + + +class SparseConv3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None): + super(SparseConv3d, self).__init__() + if config.CONV not in _backends: + _backends[config.CONV] = importlib.import_module(f'..conv_{config.CONV}', __name__) + _backends[config.CONV].sparse_conv3d_init(self, in_channels, out_channels, kernel_size, stride, dilation, padding, bias, indice_key) + + def forward(self, x: SparseTensor) -> SparseTensor: + return _backends[config.CONV].sparse_conv3d_forward(self, x) + + +class SparseInverseConv3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None): + super(SparseInverseConv3d, self).__init__() + if config.CONV not in _backends: + _backends[config.CONV] = importlib.import_module(f'..conv_{config.CONV}', __name__) + _backends[config.CONV].sparse_inverse_conv3d_init(self, in_channels, out_channels, kernel_size, stride, dilation, bias, indice_key) + + def forward(self, x: SparseTensor) -> SparseTensor: + return _backends[config.CONV].sparse_inverse_conv3d_forward(self, x) diff --git a/trellis2/modules/sparse/conv/conv_flex_gemm.py b/trellis2/modules/sparse/conv/conv_flex_gemm.py new file mode 100644 index 0000000..d256194 --- /dev/null +++ b/trellis2/modules/sparse/conv/conv_flex_gemm.py @@ -0,0 +1,68 @@ +import math +import torch +import torch.nn as nn +from .. import SparseTensor +from . import config +import flex_gemm +from flex_gemm.ops.spconv import sparse_submanifold_conv3d + + +def sparse_conv3d_init(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None): + assert stride == 1 and (padding is None), 'Currently flex_gemm implementation only support submanifold sparse convolution (stride=1, padding=None)' + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = tuple(kernel_size) if isinstance(kernel_size, (list, tuple)) else (kernel_size, ) * 3 + self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, ) * 3 + self.dilation = tuple(dilation) if isinstance(dilation, (list, tuple)) else (dilation, ) * 3 + + self.weight = nn.Parameter(torch.empty((out_channels, in_channels, *self.kernel_size))) + if bias: + self.bias = nn.Parameter(torch.empty(out_channels)) + else: + self.register_parameter("bias", None) + + # initialize parameters + torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight) + if fan_in != 0: + bound = 1 / math.sqrt(fan_in) + torch.nn.init.uniform_(self.bias, -bound, bound) + + # Permute weight (Co, Ci, Kd, Kh, Kw) -> (Co, Kd, Kh, Kw, Ci) + self.weight = nn.Parameter(self.weight.permute(0, 2, 3, 4, 1).contiguous()) + + +def sparse_conv3d_forward(self, x: SparseTensor) -> SparseTensor: + flex_gemm.ops.spconv.set_algorithm(config.FLEX_GEMM_ALGO) + flex_gemm.ops.spconv.set_hashmap_ratio(config.FLEX_GEMM_HASHMAP_RATIO) + + # check if neighbor map is already computed + Co, Kd, Kh, Kw, Ci = self.weight.shape + neighbor_cache_key = f'SubMConv3d_neighbor_cache_{Kw}x{Kh}x{Kd}_dilation{self.dilation}' + neighbor_cache = x.get_spatial_cache(neighbor_cache_key) + + out, neighbor_cache_ = sparse_submanifold_conv3d( + x.feats, + x.coords, + torch.Size([*x.shape, *x.spatial_shape]), + self.weight, + self.bias, + neighbor_cache, + self.dilation + ) + + if neighbor_cache is None: + x.register_spatial_cache(neighbor_cache_key, neighbor_cache_) + + out = x.replace(out) + return out + + +def sparse_inverse_conv3d_init(self, *args, **kwargs): + raise NotImplementedError('SparseInverseConv3d with flex_gemm is not implemented yet') + + +def sparse_inverse_conv3d_forward(self, x: SparseTensor) -> SparseTensor: + raise NotImplementedError('SparseInverseConv3d with flex_gemm is not implemented yet') diff --git a/trellis2/modules/sparse/conv/conv_spconv.py b/trellis2/modules/sparse/conv/conv_spconv.py new file mode 100644 index 0000000..f709708 --- /dev/null +++ b/trellis2/modules/sparse/conv/conv_spconv.py @@ -0,0 +1,73 @@ +import torch +import torch.nn as nn +from .. import SparseTensor +from . import config +import spconv.pytorch as spconv + + +def sparse_conv3d_init(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None): + algo = None + if config.SPCONV_ALGO == 'native': + algo = spconv.ConvAlgo.Native + elif config.SPCONV_ALGO == 'implicit_gemm': + algo = spconv.ConvAlgo.MaskImplicitGemm + if stride == 1 and (padding is None): + self.conv = spconv.SubMConv3d(in_channels, out_channels, kernel_size, dilation=dilation, bias=bias, indice_key=indice_key, algo=algo) + else: + self.conv = spconv.SparseConv3d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, padding=padding, bias=bias, indice_key=indice_key, algo=algo) + self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride) + self.padding = padding + + +def sparse_conv3d_forward(self, x: SparseTensor) -> SparseTensor: + spatial_changed = any(s != 1 for s in self.stride) or (self.padding is not None) + new_data = self.conv(x.data) + new_shape = [x.shape[0], self.conv.out_channels] + new_layout = None if spatial_changed else x.layout + + if spatial_changed and (x.shape[0] != 1): + # spconv was non-1 stride will break the contiguous of the output tensor, sort by the coords + fwd = new_data.indices[:, 0].argsort() + bwd = torch.zeros_like(fwd).scatter_(0, fwd, torch.arange(fwd.shape[0], device=fwd.device)) + sorted_feats = new_data.features[fwd] + sorted_coords = new_data.indices[fwd] + unsorted_data = new_data + new_data = spconv.SparseConvTensor(sorted_feats, sorted_coords, unsorted_data.spatial_shape, unsorted_data.batch_size) # type: ignore + + out = SparseTensor( + new_data, shape=torch.Size(new_shape), layout=new_layout, + scale=tuple([s * stride for s, stride in zip(x._scale, self.stride)]), + spatial_cache=x._spatial_cache, + ) + + if spatial_changed and (x.shape[0] != 1): + out.register_spatial_cache(f'conv_{self.stride}_unsorted_data', unsorted_data) + out.register_spatial_cache(f'conv_{self.stride}_sort_bwd', bwd) + + return out + + +def sparse_inverse_conv3d_init(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None): + self.conv = spconv.SparseInverseConv3d(in_channels, out_channels, kernel_size, bias=bias, indice_key=indice_key) + self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride) + + +def sparse_inverse_conv3d_forward(self, x: SparseTensor) -> SparseTensor: + spatial_changed = any(s != 1 for s in self.stride) + if spatial_changed: + # recover the original spconv order + data = x.get_spatial_cache(f'conv_{self.stride}_unsorted_data') + bwd = x.get_spatial_cache(f'conv_{self.stride}_sort_bwd') + data = data.replace_feature(x.feats[bwd]) + else: + data = x.data + + new_data = self.conv(data) + new_shape = [x.shape[0], self.conv.out_channels] + new_layout = None if spatial_changed else x.layout + out = SparseTensor( + new_data, shape=torch.Size(new_shape), layout=new_layout, + scale=tuple([s // stride for s, stride in zip(x._scale, self.stride)]), + spatial_cache=x._spatial_cache, + ) + return out diff --git a/trellis2/modules/sparse/conv/conv_torchsparse.py b/trellis2/modules/sparse/conv/conv_torchsparse.py new file mode 100644 index 0000000..5234bd1 --- /dev/null +++ b/trellis2/modules/sparse/conv/conv_torchsparse.py @@ -0,0 +1,30 @@ +import torch +import torch.nn as nn +from .. import SparseTensor +import torchsparse + + +def sparse_conv3d_init(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None): + self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, 0, dilation, bias) + + +def sparse_conv3d_forward(self, x: SparseTensor) -> SparseTensor: + out = self.conv(x.data) + new_shape = [x.shape[0], self.conv.out_channels] + out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None) + out._spatial_cache = x._spatial_cache + out._scale = tuple([s * stride for s, stride in zip(x._scale, self.conv.stride)]) + return out + + +def sparse_inverse_conv3d_init(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None): + self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, 0, dilation, bias, transposed=True) + + +def sparse_inverse_conv3d_forward(self, x: SparseTensor) -> SparseTensor: + out = self.conv(x.data) + new_shape = [x.shape[0], self.conv.out_channels] + out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None) + out._spatial_cache = x._spatial_cache + out._scale = tuple([s / stride for s, stride in zip(x._scale, self.conv.stride)]) + return out diff --git a/trellis2/modules/sparse/linear.py b/trellis2/modules/sparse/linear.py new file mode 100644 index 0000000..4431770 --- /dev/null +++ b/trellis2/modules/sparse/linear.py @@ -0,0 +1,15 @@ +import torch +import torch.nn as nn +from . import VarLenTensor + +__all__ = [ + 'SparseLinear' +] + + +class SparseLinear(nn.Linear): + def __init__(self, in_features, out_features, bias=True): + super(SparseLinear, self).__init__(in_features, out_features, bias) + + def forward(self, input: VarLenTensor) -> VarLenTensor: + return input.replace(super().forward(input.feats)) diff --git a/trellis2/modules/sparse/nonlinearity.py b/trellis2/modules/sparse/nonlinearity.py new file mode 100644 index 0000000..950e5c0 --- /dev/null +++ b/trellis2/modules/sparse/nonlinearity.py @@ -0,0 +1,35 @@ +import torch +import torch.nn as nn +from . import VarLenTensor + +__all__ = [ + 'SparseReLU', + 'SparseSiLU', + 'SparseGELU', + 'SparseActivation' +] + + +class SparseReLU(nn.ReLU): + def forward(self, input: VarLenTensor) -> VarLenTensor: + return input.replace(super().forward(input.feats)) + + +class SparseSiLU(nn.SiLU): + def forward(self, input: VarLenTensor) -> VarLenTensor: + return input.replace(super().forward(input.feats)) + + +class SparseGELU(nn.GELU): + def forward(self, input: VarLenTensor) -> VarLenTensor: + return input.replace(super().forward(input.feats)) + + +class SparseActivation(nn.Module): + def __init__(self, activation: nn.Module): + super().__init__() + self.activation = activation + + def forward(self, input: VarLenTensor) -> VarLenTensor: + return input.replace(self.activation(input.feats)) + diff --git a/trellis2/modules/sparse/norm.py b/trellis2/modules/sparse/norm.py new file mode 100644 index 0000000..9571120 --- /dev/null +++ b/trellis2/modules/sparse/norm.py @@ -0,0 +1,64 @@ +import torch +import torch.nn as nn +from ..utils import manual_cast +from . import VarLenTensor +from . import config + +__all__ = [ + 'SparseGroupNorm', + 'SparseLayerNorm', + 'SparseGroupNorm32', + 'SparseLayerNorm32', +] + + +class SparseGroupNorm(nn.GroupNorm): + def __init__(self, num_groups, num_channels, eps=1e-5, affine=True): + super(SparseGroupNorm, self).__init__(num_groups, num_channels, eps, affine) + + def forward(self, input: VarLenTensor) -> VarLenTensor: + nfeats = torch.zeros_like(input.feats) + for k in range(input.shape[0]): + bfeats = input.feats[input.layout[k]] + bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1) + bfeats = super().forward(bfeats) + bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0) + nfeats[input.layout[k]] = bfeats + return input.replace(nfeats) + + +class SparseLayerNorm(nn.LayerNorm): + def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): + super(SparseLayerNorm, self).__init__(normalized_shape, eps, elementwise_affine) + + def forward(self, input: VarLenTensor) -> VarLenTensor: + nfeats = torch.zeros_like(input.feats) + for k in range(input.shape[0]): + bfeats = input.feats[input.layout[k]] + bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1) + bfeats = super().forward(bfeats) + bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0) + nfeats[input.layout[k]] = bfeats + return input.replace(nfeats) + + +class SparseGroupNorm32(SparseGroupNorm): + """ + A GroupNorm layer that converts to float32 before the forward pass. + """ + def forward(self, x: VarLenTensor) -> VarLenTensor: + x_dtype = x.dtype + x = manual_cast(x, torch.float32) + o = super().forward(x) + return manual_cast(o, x_dtype) + + +class SparseLayerNorm32(SparseLayerNorm): + """ + A LayerNorm layer that converts to float32 before the forward pass. + """ + def forward(self, x: VarLenTensor) -> VarLenTensor: + x_dtype = x.dtype + x = manual_cast(x, torch.float32) + o = super().forward(x) + return manual_cast(o, x_dtype) diff --git a/trellis2/modules/sparse/spatial/__init__.py b/trellis2/modules/sparse/spatial/__init__.py new file mode 100644 index 0000000..e27425f --- /dev/null +++ b/trellis2/modules/sparse/spatial/__init__.py @@ -0,0 +1,2 @@ +from .basic import * +from .spatial2channel import * diff --git a/trellis2/modules/sparse/spatial/basic.py b/trellis2/modules/sparse/spatial/basic.py new file mode 100644 index 0000000..eaeb8af --- /dev/null +++ b/trellis2/modules/sparse/spatial/basic.py @@ -0,0 +1,109 @@ +from typing import * +import torch +import torch.nn as nn +from .. import SparseTensor + +__all__ = [ + 'SparseDownsample', + 'SparseUpsample', +] + + +class SparseDownsample(nn.Module): + """ + Downsample a sparse tensor by a factor of `factor`. + Implemented as average pooling. + """ + def __init__(self, factor: int, mode: Literal['mean', 'max'] = 'mean'): + super(SparseDownsample, self).__init__() + self.factor = factor + self.mode = mode + assert self.mode in ['mean', 'max'], f'Invalid mode: {self.mode}' + + def forward(self, x: SparseTensor) -> SparseTensor: + cache = x.get_spatial_cache(f'downsample_{self.factor}') + if cache is None: + DIM = x.coords.shape[-1] - 1 + + coord = list(x.coords.unbind(dim=-1)) + for i in range(DIM): + coord[i+1] = coord[i+1] // self.factor + + MAX = [(s + self.factor - 1) // self.factor for s in x.spatial_shape] + OFFSET = torch.cumprod(torch.tensor(MAX[::-1]), 0).tolist()[::-1] + [1] + code = sum([c * o for c, o in zip(coord, OFFSET)]) + code, idx = code.unique(return_inverse=True) + + new_coords = torch.stack( + [code // OFFSET[0]] + + [(code // OFFSET[i+1]) % MAX[i] for i in range(DIM)], + dim=-1 + ) + else: + new_coords, idx = cache + + new_feats = torch.scatter_reduce( + torch.zeros(new_coords.shape[0], x.feats.shape[1], device=x.feats.device, dtype=x.feats.dtype), + dim=0, + index=idx.unsqueeze(1).expand(-1, x.feats.shape[1]), + src=x.feats, + reduce=self.mode, + include_self=False, + ) + out = SparseTensor(new_feats, new_coords, x._shape) + out._scale = tuple([s * self.factor for s in x._scale]) + out._spatial_cache = x._spatial_cache + + if cache is None: + x.register_spatial_cache(f'downsample_{self.factor}', (new_coords, idx)) + out.register_spatial_cache(f'upsample_{self.factor}', (x.coords, idx)) + out.register_spatial_cache(f'shape', torch.Size(MAX)) + if self.training: + subidx = x.coords[:, 1:] % self.factor + subidx = sum([subidx[..., i] * self.factor ** i for i in range(DIM)]) + subdivision = torch.zeros((new_coords.shape[0], self.factor ** DIM), device=x.device, dtype=torch.bool) + subdivision[idx, subidx] = True + out.register_spatial_cache(f'subdivision', subdivision) + + return out + + +class SparseUpsample(nn.Module): + """ + Upsample a sparse tensor by a factor of `factor`. + Implemented as nearest neighbor interpolation. + """ + def __init__( + self, factor: int + ): + super(SparseUpsample, self).__init__() + self.factor = factor + + def forward(self, x: SparseTensor, subdivision: Optional[SparseTensor] = None) -> SparseTensor: + DIM = x.coords.shape[-1] - 1 + + cache = x.get_spatial_cache(f'upsample_{self.factor}') + if cache is None: + if subdivision is None: + raise ValueError('Cache not found. Provide subdivision tensor or pair SparseUpsample with SparseDownsample.') + else: + sub = subdivision.feats + N_leaf = sub.sum(dim=-1) + subidx = sub.nonzero()[:, -1] + new_coords = x.coords.clone().detach() + new_coords[:, 1:] *= self.factor + new_coords = torch.repeat_interleave(new_coords, N_leaf, dim=0, output_size=subidx.shape[0]) + for i in range(DIM): + new_coords[:, i+1] += subidx // self.factor ** i % self.factor + idx = torch.repeat_interleave(torch.arange(x.coords.shape[0], device=x.device), N_leaf, dim=0, output_size=subidx.shape[0]) + else: + new_coords, idx = cache + + new_feats = x.feats[idx] + out = SparseTensor(new_feats, new_coords, x._shape) + out._scale = tuple([s / self.factor for s in x._scale]) + if cache is not None: # only keep cache when subdiv following it + out._spatial_cache = x._spatial_cache + + return out + \ No newline at end of file diff --git a/trellis2/modules/sparse/spatial/spatial2channel.py b/trellis2/modules/sparse/spatial/spatial2channel.py new file mode 100644 index 0000000..577f36d --- /dev/null +++ b/trellis2/modules/sparse/spatial/spatial2channel.py @@ -0,0 +1,93 @@ +from typing import * +import torch +import torch.nn as nn +from .. import SparseTensor + + +class SparseSpatial2Channel(nn.Module): + """ + Downsample a sparse tensor by a factor of `factor`. + Implemented as rearranging its features from spatial to channel. + """ + def __init__(self, factor: int = 2): + super(SparseSpatial2Channel, self).__init__() + self.factor = factor + + def forward(self, x: SparseTensor) -> SparseTensor: + DIM = x.coords.shape[-1] - 1 + cache = x.get_spatial_cache(f'spatial2channel_{self.factor}') + if cache is None: + coord = list(x.coords.unbind(dim=-1)) + for i in range(DIM): + coord[i+1] = coord[i+1] // self.factor + subidx = x.coords[:, 1:] % self.factor + subidx = sum([subidx[..., i] * self.factor ** i for i in range(DIM)]) + + MAX = [(s + self.factor - 1) // self.factor for s in x.spatial_shape] + OFFSET = torch.cumprod(torch.tensor(MAX[::-1]), 0).tolist()[::-1] + [1] + code = sum([c * o for c, o in zip(coord, OFFSET)]) + code, idx = code.unique(return_inverse=True) + + new_coords = torch.stack( + [code // OFFSET[0]] + + [(code // OFFSET[i+1]) % MAX[i] for i in range(DIM)], + dim=-1 + ) + else: + new_coords, idx, subidx = cache + + new_feats = torch.zeros(new_coords.shape[0] * self.factor ** DIM, x.feats.shape[1], device=x.feats.device, dtype=x.feats.dtype) + new_feats[idx * self.factor ** DIM + subidx] = x.feats + + out = SparseTensor(new_feats.reshape(new_coords.shape[0], -1), new_coords, None if x._shape is None else torch.Size([x._shape[0], x._shape[1] * self.factor ** DIM])) + out._scale = tuple([s * self.factor for s in x._scale]) + out._spatial_cache = x._spatial_cache + + if cache is None: + x.register_spatial_cache(f'spatial2channel_{self.factor}', (new_coords, idx, subidx)) + out.register_spatial_cache(f'channel2spatial_{self.factor}', (x.coords, idx, subidx)) + out.register_spatial_cache(f'shape', torch.Size(MAX)) + if self.training: + subdivision = torch.zeros((new_coords.shape[0], self.factor ** DIM), device=x.device, dtype=torch.bool) + subdivision[idx, subidx] = True + out.register_spatial_cache(f'subdivision', subdivision) + + return out + + +class SparseChannel2Spatial(nn.Module): + """ + Upsample a sparse tensor by a factor of `factor`. + Implemented as rearranging its features from channel to spatial. + """ + def __init__(self, factor: int = 2): + super(SparseChannel2Spatial, self).__init__() + self.factor = factor + + def forward(self, x: SparseTensor, subdivision: Optional[SparseTensor] = None) -> SparseTensor: + DIM = x.coords.shape[-1] - 1 + + cache = x.get_spatial_cache(f'channel2spatial_{self.factor}') + if cache is None: + if subdivision is None: + raise ValueError('Cache not found. Provide subdivision tensor or pair SparseChannel2Spatial with SparseSpatial2Channel.') + else: + sub = subdivision.feats # [N, self.factor ** DIM] + N_leaf = sub.sum(dim=-1) # [N] + subidx = sub.nonzero()[:, -1] + new_coords = x.coords.clone().detach() + new_coords[:, 1:] *= self.factor + new_coords = torch.repeat_interleave(new_coords, N_leaf, dim=0, output_size=subidx.shape[0]) + for i in range(DIM): + new_coords[:, i+1] += subidx // self.factor ** i % self.factor + idx = torch.repeat_interleave(torch.arange(x.coords.shape[0], device=x.device), N_leaf, dim=0, output_size=subidx.shape[0]) + else: + new_coords, idx, subidx = cache + + x_feats = x.feats.reshape(x.feats.shape[0] * self.factor ** DIM, -1) + new_feats = x_feats[idx * self.factor ** DIM + subidx] + out = SparseTensor(new_feats, new_coords, None if x._shape is None else torch.Size([x._shape[0], x._shape[1] // self.factor ** DIM])) + out._scale = tuple([s / self.factor for s in x._scale]) + if cache is not None: # only keep cache when subdiv following it + out._spatial_cache = x._spatial_cache + return out diff --git a/trellis2/modules/sparse/transformer/__init__.py b/trellis2/modules/sparse/transformer/__init__.py new file mode 100644 index 0000000..b08b0d4 --- /dev/null +++ b/trellis2/modules/sparse/transformer/__init__.py @@ -0,0 +1,2 @@ +from .blocks import * +from .modulated import * \ No newline at end of file diff --git a/trellis2/modules/sparse/transformer/blocks.py b/trellis2/modules/sparse/transformer/blocks.py new file mode 100644 index 0000000..9d1ec60 --- /dev/null +++ b/trellis2/modules/sparse/transformer/blocks.py @@ -0,0 +1,145 @@ +from typing import * +import torch +import torch.nn as nn +from ..basic import VarLenTensor, SparseTensor +from ..linear import SparseLinear +from ..nonlinearity import SparseGELU +from ..attention import SparseMultiHeadAttention +from ...norm import LayerNorm32 + + +class SparseFeedForwardNet(nn.Module): + def __init__(self, channels: int, mlp_ratio: float = 4.0): + super().__init__() + self.mlp = nn.Sequential( + SparseLinear(channels, int(channels * mlp_ratio)), + SparseGELU(approximate="tanh"), + SparseLinear(int(channels * mlp_ratio), channels), + ) + + def forward(self, x: VarLenTensor) -> VarLenTensor: + return self.mlp(x) + + +class SparseTransformerBlock(nn.Module): + """ + Sparse Transformer block (MSA + FFN). + """ + def __init__( + self, + channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "swin"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + rope_freq: Tuple[int, int] = (1.0, 10000.0), + qk_rms_norm: bool = False, + qkv_bias: bool = True, + ln_affine: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.attn = SparseMultiHeadAttention( + channels, + num_heads=num_heads, + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + rope_freq=rope_freq, + qk_rms_norm=qk_rms_norm, + ) + self.mlp = SparseFeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + + def _forward(self, x: SparseTensor) -> SparseTensor: + h = x.replace(self.norm1(x.feats)) + h = self.attn(h) + x = x + h + h = x.replace(self.norm2(x.feats)) + h = self.mlp(h) + x = x + h + return x + + def forward(self, x: SparseTensor) -> SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) + else: + return self._forward(x) + + +class SparseTransformerCrossBlock(nn.Module): + """ + Sparse Transformer cross-attention block (MSA + MCA + FFN). + """ + def __init__( + self, + channels: int, + ctx_channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "swin"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + qkv_bias: bool = True, + ln_affine: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.self_attn = SparseMultiHeadAttention( + channels, + num_heads=num_heads, + type="self", + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.cross_attn = SparseMultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + self.mlp = SparseFeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + + def _forward(self, x: SparseTensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor: + h = x.replace(self.norm1(x.feats)) + h = self.self_attn(h) + x = x + h + h = x.replace(self.norm2(x.feats)) + h = self.cross_attn(h, context) + x = x + h + h = x.replace(self.norm3(x.feats)) + h = self.mlp(h) + x = x + h + return x + + def forward(self, x: SparseTensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False) + else: + return self._forward(x, context) diff --git a/trellis2/modules/sparse/transformer/modulated.py b/trellis2/modules/sparse/transformer/modulated.py new file mode 100644 index 0000000..e616932 --- /dev/null +++ b/trellis2/modules/sparse/transformer/modulated.py @@ -0,0 +1,166 @@ +from typing import * +import torch +import torch.nn as nn +from ..basic import VarLenTensor, SparseTensor +from ..attention import SparseMultiHeadAttention +from ...norm import LayerNorm32 +from .blocks import SparseFeedForwardNet + + +class ModulatedSparseTransformerBlock(nn.Module): + """ + Sparse Transformer block (MSA + FFN) with adaptive layer norm conditioning. + """ + def __init__( + self, + channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "swin"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + rope_freq: Tuple[float, float] = (1.0, 10000.0), + qk_rms_norm: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.attn = SparseMultiHeadAttention( + channels, + num_heads=num_heads, + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + rope_freq=rope_freq, + qk_rms_norm=qk_rms_norm, + ) + self.mlp = SparseFeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(channels, 6 * channels, bias=True) + ) + else: + self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5) + + def _forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor: + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + h = x.replace(self.norm1(x.feats)) + h = h * (1 + scale_msa) + shift_msa + h = self.attn(h) + h = h * gate_msa + x = x + h + h = x.replace(self.norm2(x.feats)) + h = h * (1 + scale_mlp) + shift_mlp + h = self.mlp(h) + h = h * gate_mlp + x = x + h + return x + + def forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False) + else: + return self._forward(x, mod) + + +class ModulatedSparseTransformerCrossBlock(nn.Module): + """ + Sparse Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning. + """ + def __init__( + self, + channels: int, + ctx_channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "swin"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + rope_freq: Tuple[float, float] = (1.0, 10000.0), + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.self_attn = SparseMultiHeadAttention( + channels, + num_heads=num_heads, + type="self", + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + rope_freq=rope_freq, + qk_rms_norm=qk_rms_norm, + ) + self.cross_attn = SparseMultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + self.mlp = SparseFeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(channels, 6 * channels, bias=True) + ) + else: + self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5) + + def _forward(self, x: SparseTensor, mod: torch.Tensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor: + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + h = x.replace(self.norm1(x.feats)) + h = h * (1 + scale_msa) + shift_msa + h = self.self_attn(h) + h = h * gate_msa + x = x + h + h = x.replace(self.norm2(x.feats)) + h = self.cross_attn(h, context) + x = x + h + h = x.replace(self.norm3(x.feats)) + h = h * (1 + scale_mlp) + shift_mlp + h = self.mlp(h) + h = h * gate_mlp + x = x + h + return x + + def forward(self, x: SparseTensor, mod: torch.Tensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, use_reentrant=False) + else: + return self._forward(x, mod, context) diff --git a/trellis2/modules/spatial.py b/trellis2/modules/spatial.py new file mode 100644 index 0000000..79e268d --- /dev/null +++ b/trellis2/modules/spatial.py @@ -0,0 +1,48 @@ +import torch + + +def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor: + """ + 3D pixel shuffle. + """ + B, C, H, W, D = x.shape + C_ = C // scale_factor**3 + x = x.reshape(B, C_, scale_factor, scale_factor, scale_factor, H, W, D) + x = x.permute(0, 1, 5, 2, 6, 3, 7, 4) + x = x.reshape(B, C_, H*scale_factor, W*scale_factor, D*scale_factor) + return x + + +def patchify(x: torch.Tensor, patch_size: int): + """ + Patchify a tensor. + + Args: + x (torch.Tensor): (N, C, *spatial) tensor + patch_size (int): Patch size + """ + DIM = x.dim() - 2 + for d in range(2, DIM + 2): + assert x.shape[d] % patch_size == 0, f"Dimension {d} of input tensor must be divisible by patch size, got {x.shape[d]} and {patch_size}" + + x = x.reshape(*x.shape[:2], *sum([[x.shape[d] // patch_size, patch_size] for d in range(2, DIM + 2)], [])) + x = x.permute(0, 1, *([2 * i + 3 for i in range(DIM)] + [2 * i + 2 for i in range(DIM)])) + x = x.reshape(x.shape[0], x.shape[1] * (patch_size ** DIM), *(x.shape[-DIM:])) + return x + + +def unpatchify(x: torch.Tensor, patch_size: int): + """ + Unpatchify a tensor. + + Args: + x (torch.Tensor): (N, C, *spatial) tensor + patch_size (int): Patch size + """ + DIM = x.dim() - 2 + assert x.shape[1] % (patch_size ** DIM) == 0, f"Second dimension of input tensor must be divisible by patch size to unpatchify, got {x.shape[1]} and {patch_size ** DIM}" + + x = x.reshape(x.shape[0], x.shape[1] // (patch_size ** DIM), *([patch_size] * DIM), *(x.shape[-DIM:])) + x = x.permute(0, 1, *(sum([[2 + DIM + i, 2 + i] for i in range(DIM)], []))) + x = x.reshape(x.shape[0], x.shape[1], *[x.shape[2 + 2 * i] * patch_size for i in range(DIM)]) + return x diff --git a/trellis2/modules/transformer/__init__.py b/trellis2/modules/transformer/__init__.py new file mode 100644 index 0000000..b08b0d4 --- /dev/null +++ b/trellis2/modules/transformer/__init__.py @@ -0,0 +1,2 @@ +from .blocks import * +from .modulated import * \ No newline at end of file diff --git a/trellis2/modules/transformer/blocks.py b/trellis2/modules/transformer/blocks.py new file mode 100644 index 0000000..fb6f5eb --- /dev/null +++ b/trellis2/modules/transformer/blocks.py @@ -0,0 +1,186 @@ +from typing import * +import torch +import torch.nn as nn +from ..attention import MultiHeadAttention +from ..norm import LayerNorm32 + + +class AbsolutePositionEmbedder(nn.Module): + """ + Embeds spatial positions into vector representations. + """ + def __init__(self, channels: int, in_channels: int = 3): + super().__init__() + self.channels = channels + self.in_channels = in_channels + self.freq_dim = channels // in_channels // 2 + self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim + self.freqs = 1.0 / (10000 ** self.freqs) + + def _sin_cos_embedding(self, x: torch.Tensor) -> torch.Tensor: + """ + Create sinusoidal position embeddings. + + Args: + x: a 1-D Tensor of N indices + + Returns: + an (N, D) Tensor of positional embeddings. + """ + self.freqs = self.freqs.to(x.device) + out = torch.outer(x, self.freqs) + out = torch.cat([torch.sin(out), torch.cos(out)], dim=-1) + return out + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): (N, D) tensor of spatial positions + """ + N, D = x.shape + assert D == self.in_channels, "Input dimension must match number of input channels" + embed = self._sin_cos_embedding(x.reshape(-1)) + embed = embed.reshape(N, -1) + if embed.shape[1] < self.channels: + embed = torch.cat([embed, torch.zeros(N, self.channels - embed.shape[1], device=embed.device)], dim=-1) + return embed + + +class FeedForwardNet(nn.Module): + def __init__(self, channels: int, mlp_ratio: float = 4.0): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(channels, int(channels * mlp_ratio)), + nn.GELU(approximate="tanh"), + nn.Linear(int(channels * mlp_ratio), channels), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.mlp(x) + + +class TransformerBlock(nn.Module): + """ + Transformer block (MSA + FFN). + """ + def __init__( + self, + channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[int] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + rope_freq: Tuple[int, int] = (1.0, 10000.0), + qk_rms_norm: bool = False, + qkv_bias: bool = True, + ln_affine: bool = True, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.attn = MultiHeadAttention( + channels, + num_heads=num_heads, + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + rope_freq=rope_freq, + qk_rms_norm=qk_rms_norm, + ) + self.mlp = FeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + + def _forward(self, x: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor: + h = self.norm1(x) + h = self.attn(h, phases=phases) + x = x + h + h = self.norm2(x) + h = self.mlp(h) + x = x + h + return x + + def forward(self, x: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, phases, use_reentrant=False) + else: + return self._forward(x, phases) + + +class TransformerCrossBlock(nn.Module): + """ + Transformer cross-attention block (MSA + MCA + FFN). + """ + def __init__( + self, + channels: int, + ctx_channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + rope_freq: Tuple[int, int] = (1.0, 10000.0), + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + qkv_bias: bool = True, + ln_affine: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.self_attn = MultiHeadAttention( + channels, + num_heads=num_heads, + type="self", + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + rope_freq=rope_freq, + qk_rms_norm=qk_rms_norm, + ) + self.cross_attn = MultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + self.mlp = FeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + + def _forward(self, x: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor: + h = self.norm1(x) + h = self.self_attn(h, phases=phases) + x = x + h + h = self.norm2(x) + h = self.cross_attn(h, context) + x = x + h + h = self.norm3(x) + h = self.mlp(h) + x = x + h + return x + + def forward(self, x: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, context, phases, use_reentrant=False) + else: + return self._forward(x, context, phases) + \ No newline at end of file diff --git a/trellis2/modules/transformer/modulated.py b/trellis2/modules/transformer/modulated.py new file mode 100644 index 0000000..0d71e58 --- /dev/null +++ b/trellis2/modules/transformer/modulated.py @@ -0,0 +1,165 @@ +from typing import * +import torch +import torch.nn as nn +from ..attention import MultiHeadAttention +from ..norm import LayerNorm32 +from .blocks import FeedForwardNet + + +class ModulatedTransformerBlock(nn.Module): + """ + Transformer block (MSA + FFN) with adaptive layer norm conditioning. + """ + def __init__( + self, + channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + rope_freq: Tuple[int, int] = (1.0, 10000.0), + qk_rms_norm: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.attn = MultiHeadAttention( + channels, + num_heads=num_heads, + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + rope_freq=rope_freq, + qk_rms_norm=qk_rms_norm, + ) + self.mlp = FeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(channels, 6 * channels, bias=True) + ) + else: + self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5) + + def _forward(self, x: torch.Tensor, mod: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor: + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + h = self.norm1(x) + h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1) + h = self.attn(h, phases=phases) + h = h * gate_msa.unsqueeze(1) + x = x + h + h = self.norm2(x) + h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1) + h = self.mlp(h) + h = h * gate_mlp.unsqueeze(1) + x = x + h + return x + + def forward(self, x: torch.Tensor, mod: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, mod, phases, use_reentrant=False) + else: + return self._forward(x, mod, phases) + + +class ModulatedTransformerCrossBlock(nn.Module): + """ + Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning. + """ + def __init__( + self, + channels: int, + ctx_channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + rope_freq: Tuple[int, int] = (1.0, 10000.0), + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.self_attn = MultiHeadAttention( + channels, + num_heads=num_heads, + type="self", + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + rope_freq=rope_freq, + qk_rms_norm=qk_rms_norm, + ) + self.cross_attn = MultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + self.mlp = FeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(channels, 6 * channels, bias=True) + ) + else: + self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5) + + def _forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor: + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + h = self.norm1(x) + h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1) + h = self.self_attn(h, phases=phases) + h = h * gate_msa.unsqueeze(1) + x = x + h + h = self.norm2(x) + h = self.cross_attn(h, context) + x = x + h + h = self.norm3(x) + h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1) + h = self.mlp(h) + h = h * gate_mlp.unsqueeze(1) + x = x + h + return x + + def forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, phases, use_reentrant=False) + else: + return self._forward(x, mod, context, phases) + \ No newline at end of file diff --git a/trellis2/modules/utils.py b/trellis2/modules/utils.py new file mode 100644 index 0000000..5d92d7d --- /dev/null +++ b/trellis2/modules/utils.py @@ -0,0 +1,87 @@ +import torch +import torch.nn as nn +from ..modules import sparse as sp + +MIX_PRECISION_MODULES = ( + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nn.ConvTranspose1d, + nn.ConvTranspose2d, + nn.ConvTranspose3d, + nn.Linear, + sp.SparseConv3d, + sp.SparseInverseConv3d, + sp.SparseLinear, +) + + +def convert_module_to_f16(l): + """ + Convert primitive modules to float16. + """ + if isinstance(l, MIX_PRECISION_MODULES): + for p in l.parameters(): + p.data = p.data.half() + + +def convert_module_to_f32(l): + """ + Convert primitive modules to float32, undoing convert_module_to_f16(). + """ + if isinstance(l, MIX_PRECISION_MODULES): + for p in l.parameters(): + p.data = p.data.float() + + +def convert_module_to(l, dtype): + """ + Convert primitive modules to the given dtype. + """ + if isinstance(l, MIX_PRECISION_MODULES): + for p in l.parameters(): + p.data = p.data.to(dtype) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +def manual_cast(tensor, dtype): + """ + Cast if autocast is not enabled. + """ + if not torch.is_autocast_enabled(): + return tensor.type(dtype) + return tensor + + +def str_to_dtype(dtype_str: str): + return { + 'f16': torch.float16, + 'fp16': torch.float16, + 'float16': torch.float16, + 'bf16': torch.bfloat16, + 'bfloat16': torch.bfloat16, + 'f32': torch.float32, + 'fp32': torch.float32, + 'float32': torch.float32, + }[dtype_str] diff --git a/trellis2/pipelines/__init__.py b/trellis2/pipelines/__init__.py new file mode 100644 index 0000000..53d8917 --- /dev/null +++ b/trellis2/pipelines/__init__.py @@ -0,0 +1,55 @@ +import importlib + +__attributes = { + "Trellis2ImageTo3DPipeline": "trellis2_image_to_3d", + "Trellis2ImageTo3DCascadePipeline": "trellis2_image_to_3d_cascade", + "Trellis2ImageToTexturePipeline": "trellis2_image_to_tex", +} + +__submodules = ['samplers', 'rembg'] + +__all__ = list(__attributes.keys()) + __submodules + +def __getattr__(name): + if name not in globals(): + if name in __attributes: + module_name = __attributes[name] + module = importlib.import_module(f".{module_name}", __name__) + globals()[name] = getattr(module, name) + elif name in __submodules: + module = importlib.import_module(f".{name}", __name__) + globals()[name] = module + else: + raise AttributeError(f"module {__name__} has no attribute {name}") + return globals()[name] + + +def from_pretrained(path: str): + """ + Load a pipeline from a model folder or a Hugging Face model hub. + + Args: + path: The path to the model. Can be either local path or a Hugging Face model name. + """ + import os + import json + is_local = os.path.exists(f"{path}/pipeline.json") + + if is_local: + config_file = f"{path}/pipeline.json" + else: + from huggingface_hub import hf_hub_download + config_file = hf_hub_download(path, "pipeline.json") + + with open(config_file, 'r') as f: + config = json.load(f) + return globals()[config['name']].from_pretrained(path) + + +# For PyLance +if __name__ == '__main__': + from . import samplers, rembg + from .trellis_image_to_3d import TrellisImageTo3DPipeline + from .trellis2_image_to_3d import Trellis2ImageTo3DPipeline + from .trellis2_image_to_3d_cascade import Trellis2ImageTo3DCascadePipeline + from .trellis2_image_to_tex import Trellis2ImageToTexturePipeline diff --git a/trellis2/pipelines/base.py b/trellis2/pipelines/base.py new file mode 100644 index 0000000..d897825 --- /dev/null +++ b/trellis2/pipelines/base.py @@ -0,0 +1,70 @@ +from typing import * +import torch +import torch.nn as nn +from .. import models + + +class Pipeline: + """ + A base class for pipelines. + """ + def __init__( + self, + models: dict[str, nn.Module] = None, + ): + if models is None: + return + self.models = models + for model in self.models.values(): + model.eval() + + @staticmethod + def from_pretrained(path: str) -> "Pipeline": + """ + Load a pretrained model. + """ + import os + import json + is_local = os.path.exists(f"{path}/pipeline.json") + + if is_local: + config_file = f"{path}/pipeline.json" + else: + from huggingface_hub import hf_hub_download + config_file = hf_hub_download(path, "pipeline.json") + + with open(config_file, 'r') as f: + args = json.load(f)['args'] + + _models = {} + for k, v in args['models'].items(): + try: + _models[k] = models.from_pretrained(f"{path}/{v}") + except Exception as e: + _models[k] = models.from_pretrained(v) + + new_pipeline = Pipeline(_models) + new_pipeline._pretrained_args = args + return new_pipeline + + @property + def device(self) -> torch.device: + if hasattr(self, '_device'): + return self._device + for model in self.models.values(): + if hasattr(model, 'device'): + return model.device + for model in self.models.values(): + if hasattr(model, 'parameters'): + return next(model.parameters()).device + raise RuntimeError("No device found.") + + def to(self, device: torch.device) -> None: + for model in self.models.values(): + model.to(device) + + def cuda(self) -> None: + self.to(torch.device("cuda")) + + def cpu(self) -> None: + self.to(torch.device("cpu")) \ No newline at end of file diff --git a/trellis2/pipelines/rembg/BiRefNet.py b/trellis2/pipelines/rembg/BiRefNet.py new file mode 100644 index 0000000..c71a992 --- /dev/null +++ b/trellis2/pipelines/rembg/BiRefNet.py @@ -0,0 +1,42 @@ +from typing import * +from transformers import AutoModelForImageSegmentation +import torch +from torchvision import transforms +from PIL import Image + + +class BiRefNet: + def __init__(self, model_name: str = "ZhengPeng7/BiRefNet"): + self.model = AutoModelForImageSegmentation.from_pretrained( + model_name, trust_remote_code=True + ) + self.model.eval() + self.transform_image = transforms.Compose( + [ + transforms.Resize((1024, 1024)), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ] + ) + + def to(self, device: str): + self.model.to(device) + + def cuda(self): + self.model.cuda() + + def cpu(self): + self.model.cpu() + + def __call__(self, image: Image.Image) -> Image.Image: + image_size = image.size + input_images = self.transform_image(image).unsqueeze(0).to("cuda") + # Prediction + with torch.no_grad(): + preds = self.model(input_images)[-1].sigmoid().cpu() + pred = preds[0].squeeze() + pred_pil = transforms.ToPILImage()(pred) + mask = pred_pil.resize(image_size) + image.putalpha(mask) + return image + \ No newline at end of file diff --git a/trellis2/pipelines/rembg/__init__.py b/trellis2/pipelines/rembg/__init__.py new file mode 100644 index 0000000..fc1eed1 --- /dev/null +++ b/trellis2/pipelines/rembg/__init__.py @@ -0,0 +1 @@ +from .BiRefNet import * diff --git a/trellis2/pipelines/samplers/__init__.py b/trellis2/pipelines/samplers/__init__.py new file mode 100644 index 0000000..4a69b95 --- /dev/null +++ b/trellis2/pipelines/samplers/__init__.py @@ -0,0 +1,6 @@ +from .base import Sampler +from .flow_euler import ( + FlowEulerSampler, + FlowEulerCfgSampler, + FlowEulerGuidanceIntervalSampler, +) \ No newline at end of file diff --git a/trellis2/pipelines/samplers/base.py b/trellis2/pipelines/samplers/base.py new file mode 100644 index 0000000..1966ce7 --- /dev/null +++ b/trellis2/pipelines/samplers/base.py @@ -0,0 +1,20 @@ +from typing import * +from abc import ABC, abstractmethod + + +class Sampler(ABC): + """ + A base class for samplers. + """ + + @abstractmethod + def sample( + self, + model, + **kwargs + ): + """ + Sample from a model. + """ + pass + \ No newline at end of file diff --git a/trellis2/pipelines/samplers/classifier_free_guidance_mixin.py b/trellis2/pipelines/samplers/classifier_free_guidance_mixin.py new file mode 100644 index 0000000..8c7a4da --- /dev/null +++ b/trellis2/pipelines/samplers/classifier_free_guidance_mixin.py @@ -0,0 +1,29 @@ +from typing import * + + +class ClassifierFreeGuidanceSamplerMixin: + """ + A mixin class for samplers that apply classifier-free guidance. + """ + + def _inference_model(self, model, x_t, t, cond, neg_cond, guidance_strength, guidance_rescale=0.0, **kwargs): + if guidance_strength == 1: + return super()._inference_model(model, x_t, t, cond, **kwargs) + elif guidance_strength == 0: + return super()._inference_model(model, x_t, t, neg_cond, **kwargs) + else: + pred_pos = super()._inference_model(model, x_t, t, cond, **kwargs) + pred_neg = super()._inference_model(model, x_t, t, neg_cond, **kwargs) + pred = guidance_strength * pred_pos + (1 - guidance_strength) * pred_neg + + # CFG rescale + if guidance_rescale > 0: + x_0_pos = self._pred_to_xstart(x_t, t, pred_pos) + x_0_cfg = self._pred_to_xstart(x_t, t, pred) + std_pos = x_0_pos.std(dim=list(range(1, x_0_pos.ndim)), keepdim=True) + std_cfg = x_0_cfg.std(dim=list(range(1, x_0_cfg.ndim)), keepdim=True) + x_0_rescaled = x_0_cfg * (std_pos / std_cfg) + x_0 = guidance_rescale * x_0_rescaled + (1 - guidance_rescale) * x_0_cfg + pred = self._xstart_to_pred(x_t, t, x_0) + + return pred diff --git a/trellis2/pipelines/samplers/flow_euler.py b/trellis2/pipelines/samplers/flow_euler.py new file mode 100644 index 0000000..5ff72b8 --- /dev/null +++ b/trellis2/pipelines/samplers/flow_euler.py @@ -0,0 +1,208 @@ +from typing import * +import torch +import numpy as np +from tqdm import tqdm +from easydict import EasyDict as edict +from .base import Sampler +from .classifier_free_guidance_mixin import ClassifierFreeGuidanceSamplerMixin +from .guidance_interval_mixin import GuidanceIntervalSamplerMixin + + +class FlowEulerSampler(Sampler): + """ + Generate samples from a flow-matching model using Euler sampling. + + Args: + sigma_min: The minimum scale of noise in flow. + """ + def __init__( + self, + sigma_min: float, + ): + self.sigma_min = sigma_min + + def _eps_to_xstart(self, x_t, t, eps): + assert x_t.shape == eps.shape + return (x_t - (self.sigma_min + (1 - self.sigma_min) * t) * eps) / (1 - t) + + def _xstart_to_eps(self, x_t, t, x_0): + assert x_t.shape == x_0.shape + return (x_t - (1 - t) * x_0) / (self.sigma_min + (1 - self.sigma_min) * t) + + def _v_to_xstart_eps(self, x_t, t, v): + assert x_t.shape == v.shape + eps = (1 - t) * v + x_t + x_0 = (1 - self.sigma_min) * x_t - (self.sigma_min + (1 - self.sigma_min) * t) * v + return x_0, eps + + def _pred_to_xstart(self, x_t, t, pred): + return (1 - self.sigma_min) * x_t - (self.sigma_min + (1 - self.sigma_min) * t) * pred + + def _xstart_to_pred(self, x_t, t, x_0): + return ((1 - self.sigma_min) * x_t - x_0) / (self.sigma_min + (1 - self.sigma_min) * t) + + def _inference_model(self, model, x_t, t, cond=None, **kwargs): + t = torch.tensor([1000 * t] * x_t.shape[0], device=x_t.device, dtype=torch.float32) + return model(x_t, t, cond, **kwargs) + + def _get_model_prediction(self, model, x_t, t, cond=None, **kwargs): + pred_v = self._inference_model(model, x_t, t, cond, **kwargs) + pred_x_0, pred_eps = self._v_to_xstart_eps(x_t=x_t, t=t, v=pred_v) + return pred_x_0, pred_eps, pred_v + + @torch.no_grad() + def sample_once( + self, + model, + x_t, + t: float, + t_prev: float, + cond: Optional[Any] = None, + **kwargs + ): + """ + Sample x_{t-1} from the model using Euler method. + + Args: + model: The model to sample from. + x_t: The [N x C x ...] tensor of noisy inputs at time t. + t: The current timestep. + t_prev: The previous timestep. + cond: conditional information. + **kwargs: Additional arguments for model inference. + + Returns: + a dict containing the following + - 'pred_x_prev': x_{t-1}. + - 'pred_x_0': a prediction of x_0. + """ + pred_x_0, pred_eps, pred_v = self._get_model_prediction(model, x_t, t, cond, **kwargs) + pred_x_prev = x_t - (t - t_prev) * pred_v + return edict({"pred_x_prev": pred_x_prev, "pred_x_0": pred_x_0}) + + @torch.no_grad() + def sample( + self, + model, + noise, + cond: Optional[Any] = None, + steps: int = 50, + rescale_t: float = 1.0, + verbose: bool = True, + tqdm_desc: str = "Sampling", + **kwargs + ): + """ + Generate samples from the model using Euler method. + + Args: + model: The model to sample from. + noise: The initial noise tensor. + cond: conditional information. + steps: The number of steps to sample. + rescale_t: The rescale factor for t. + verbose: If True, show a progress bar. + tqdm_desc: A customized tqdm desc. + **kwargs: Additional arguments for model_inference. + + Returns: + a dict containing the following + - 'samples': the model samples. + - 'pred_x_t': a list of prediction of x_t. + - 'pred_x_0': a list of prediction of x_0. + """ + sample = noise + t_seq = np.linspace(1, 0, steps + 1) + t_seq = rescale_t * t_seq / (1 + (rescale_t - 1) * t_seq) + t_seq = t_seq.tolist() + t_pairs = list((t_seq[i], t_seq[i + 1]) for i in range(steps)) + ret = edict({"samples": None, "pred_x_t": [], "pred_x_0": []}) + for t, t_prev in tqdm(t_pairs, desc=tqdm_desc, disable=not verbose): + out = self.sample_once(model, sample, t, t_prev, cond, **kwargs) + sample = out.pred_x_prev + ret.pred_x_t.append(out.pred_x_prev) + ret.pred_x_0.append(out.pred_x_0) + ret.samples = sample + return ret + + +class FlowEulerCfgSampler(ClassifierFreeGuidanceSamplerMixin, FlowEulerSampler): + """ + Generate samples from a flow-matching model using Euler sampling with classifier-free guidance. + """ + @torch.no_grad() + def sample( + self, + model, + noise, + cond, + neg_cond, + steps: int = 50, + rescale_t: float = 1.0, + guidance_strength: float = 3.0, + verbose: bool = True, + **kwargs + ): + """ + Generate samples from the model using Euler method. + + Args: + model: The model to sample from. + noise: The initial noise tensor. + cond: conditional information. + neg_cond: negative conditional information. + steps: The number of steps to sample. + rescale_t: The rescale factor for t. + guidance_strength: The strength of classifier-free guidance. + verbose: If True, show a progress bar. + **kwargs: Additional arguments for model_inference. + + Returns: + a dict containing the following + - 'samples': the model samples. + - 'pred_x_t': a list of prediction of x_t. + - 'pred_x_0': a list of prediction of x_0. + """ + return super().sample(model, noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, guidance_strength=guidance_strength, **kwargs) + + +class FlowEulerGuidanceIntervalSampler(GuidanceIntervalSamplerMixin, ClassifierFreeGuidanceSamplerMixin, FlowEulerSampler): + """ + Generate samples from a flow-matching model using Euler sampling with classifier-free guidance and interval. + """ + @torch.no_grad() + def sample( + self, + model, + noise, + cond, + neg_cond, + steps: int = 50, + rescale_t: float = 1.0, + guidance_strength: float = 3.0, + guidance_interval: Tuple[float, float] = (0.0, 1.0), + verbose: bool = True, + **kwargs + ): + """ + Generate samples from the model using Euler method. + + Args: + model: The model to sample from. + noise: The initial noise tensor. + cond: conditional information. + neg_cond: negative conditional information. + steps: The number of steps to sample. + rescale_t: The rescale factor for t. + guidance_strength: The strength of classifier-free guidance. + guidance_interval: The interval for classifier-free guidance. + verbose: If True, show a progress bar. + **kwargs: Additional arguments for model_inference. + + Returns: + a dict containing the following + - 'samples': the model samples. + - 'pred_x_t': a list of prediction of x_t. + - 'pred_x_0': a list of prediction of x_0. + """ + return super().sample(model, noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, guidance_strength=guidance_strength, guidance_interval=guidance_interval, **kwargs) diff --git a/trellis2/pipelines/samplers/guidance_interval_mixin.py b/trellis2/pipelines/samplers/guidance_interval_mixin.py new file mode 100644 index 0000000..3f57869 --- /dev/null +++ b/trellis2/pipelines/samplers/guidance_interval_mixin.py @@ -0,0 +1,13 @@ +from typing import * + + +class GuidanceIntervalSamplerMixin: + """ + A mixin class for samplers that apply classifier-free guidance with interval. + """ + + def _inference_model(self, model, x_t, t, cond, guidance_strength, guidance_interval, **kwargs): + if guidance_interval[0] <= t <= guidance_interval[1]: + return super()._inference_model(model, x_t, t, cond, guidance_strength=guidance_strength, **kwargs) + else: + return super()._inference_model(model, x_t, t, cond, guidance_strength=1, **kwargs) diff --git a/trellis2/pipelines/trellis2_image_to_3d.py b/trellis2/pipelines/trellis2_image_to_3d.py new file mode 100644 index 0000000..8d7afd5 --- /dev/null +++ b/trellis2/pipelines/trellis2_image_to_3d.py @@ -0,0 +1,588 @@ +from typing import * +import torch +import torch.nn as nn +import numpy as np +from PIL import Image +from .base import Pipeline +from . import samplers, rembg +from ..modules.sparse import SparseTensor +from ..modules import image_feature_extractor +from ..representations import Mesh, MeshWithVoxel + + +class Trellis2ImageTo3DPipeline(Pipeline): + """ + Pipeline for inferring Trellis2 image-to-3D models. + + Args: + models (dict[str, nn.Module]): The models to use in the pipeline. + sparse_structure_sampler (samplers.Sampler): The sampler for the sparse structure. + shape_slat_sampler (samplers.Sampler): The sampler for the structured latent. + tex_slat_sampler (samplers.Sampler): The sampler for the texture latent. + sparse_structure_sampler_params (dict): The parameters for the sparse structure sampler. + shape_slat_sampler_params (dict): The parameters for the structured latent sampler. + tex_slat_sampler_params (dict): The parameters for the texture latent sampler. + shape_slat_normalization (dict): The normalization parameters for the structured latent. + tex_slat_normalization (dict): The normalization parameters for the texture latent. + image_cond_model (Callable): The image conditioning model. + rembg_model (Callable): The model for removing background. + low_vram (bool): Whether to use low-VRAM mode. + """ + def __init__( + self, + models: dict[str, nn.Module] = None, + sparse_structure_sampler: samplers.Sampler = None, + shape_slat_sampler: samplers.Sampler = None, + tex_slat_sampler: samplers.Sampler = None, + sparse_structure_sampler_params: dict = None, + shape_slat_sampler_params: dict = None, + tex_slat_sampler_params: dict = None, + shape_slat_normalization: dict = None, + tex_slat_normalization: dict = None, + image_cond_model: Callable = None, + rembg_model: Callable = None, + low_vram: bool = True, + default_pipeline_type: str = '1024_cascade', + ): + if models is None: + return + super().__init__(models) + self.sparse_structure_sampler = sparse_structure_sampler + self.shape_slat_sampler = shape_slat_sampler + self.tex_slat_sampler = tex_slat_sampler + self.sparse_structure_sampler_params = sparse_structure_sampler_params + self.shape_slat_sampler_params = shape_slat_sampler_params + self.tex_slat_sampler_params = tex_slat_sampler_params + self.shape_slat_normalization = shape_slat_normalization + self.tex_slat_normalization = tex_slat_normalization + self.image_cond_model = image_cond_model + self.rembg_model = rembg_model + self.low_vram = low_vram + self.default_pipeline_type = default_pipeline_type + self.pbr_attr_layout = { + 'base_color': slice(0, 3), + 'metallic': slice(3, 4), + 'roughness': slice(4, 5), + 'alpha': slice(5, 6), + } + self._device = 'cpu' + + @staticmethod + def from_pretrained(path: str) -> "Trellis2ImageTo3DPipeline": + """ + Load a pretrained model. + + Args: + path (str): The path to the model. Can be either local path or a Hugging Face repository. + """ + pipeline = super(Trellis2ImageTo3DPipeline, Trellis2ImageTo3DPipeline).from_pretrained(path) + new_pipeline = Trellis2ImageTo3DPipeline() + new_pipeline.__dict__ = pipeline.__dict__ + args = pipeline._pretrained_args + + new_pipeline.sparse_structure_sampler = getattr(samplers, args['sparse_structure_sampler']['name'])(**args['sparse_structure_sampler']['args']) + new_pipeline.sparse_structure_sampler_params = args['sparse_structure_sampler']['params'] + + new_pipeline.shape_slat_sampler = getattr(samplers, args['shape_slat_sampler']['name'])(**args['shape_slat_sampler']['args']) + new_pipeline.shape_slat_sampler_params = args['shape_slat_sampler']['params'] + + new_pipeline.tex_slat_sampler = getattr(samplers, args['tex_slat_sampler']['name'])(**args['tex_slat_sampler']['args']) + new_pipeline.tex_slat_sampler_params = args['tex_slat_sampler']['params'] + + new_pipeline.shape_slat_normalization = args['shape_slat_normalization'] + new_pipeline.tex_slat_normalization = args['tex_slat_normalization'] + + new_pipeline.image_cond_model = getattr(image_feature_extractor, args['image_cond_model']['name'])(**args['image_cond_model']['args']) + new_pipeline.rembg_model = getattr(rembg, args['rembg_model']['name'])(**args['rembg_model']['args']) + + new_pipeline.low_vram = args.get('low_vram', True) + new_pipeline.default_pipeline_type = args.get('default_pipeline_type', '1024_cascade') + new_pipeline.pbr_attr_layout = { + 'base_color': slice(0, 3), + 'metallic': slice(3, 4), + 'roughness': slice(4, 5), + 'alpha': slice(5, 6), + } + new_pipeline._device = 'cpu' + + return new_pipeline + + def to(self, device: torch.device) -> None: + self._device = device + if not self.low_vram: + super().to(device) + self.image_cond_model.to(device) + if self.rembg_model is not None: + self.rembg_model.to(device) + + def preprocess_image(self, input: Image.Image) -> Image.Image: + """ + Preprocess the input image. + """ + # if has alpha channel, use it directly; otherwise, remove background + has_alpha = False + if input.mode == 'RGBA': + alpha = np.array(input)[:, :, 3] + if not np.all(alpha == 255): + has_alpha = True + max_size = max(input.size) + scale = min(1, 1024 / max_size) + if scale < 1: + input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS) + if has_alpha: + output = input + else: + input = input.convert('RGB') + if self.low_vram: + self.rembg_model.to(self.device) + output = self.rembg_model(input) + if self.low_vram: + self.rembg_model.cpu() + output_np = np.array(output) + alpha = output_np[:, :, 3] + bbox = np.argwhere(alpha > 0.8 * 255) + bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0]) + center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2 + size = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) + size = int(size * 1) + bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2 + output = output.crop(bbox) # type: ignore + output = np.array(output).astype(np.float32) / 255 + output = output[:, :, :3] * output[:, :, 3:4] + output = Image.fromarray((output * 255).astype(np.uint8)) + return output + + def get_cond(self, image: Union[torch.Tensor, list[Image.Image]], resolution: int, include_neg_cond: bool = True) -> dict: + """ + Get the conditioning information for the model. + + Args: + image (Union[torch.Tensor, list[Image.Image]]): The image prompts. + + Returns: + dict: The conditioning information + """ + self.image_cond_model.image_size = resolution + if self.low_vram: + self.image_cond_model.to(self.device) + cond = self.image_cond_model(image) + if self.low_vram: + self.image_cond_model.cpu() + if not include_neg_cond: + return {'cond': cond} + neg_cond = torch.zeros_like(cond) + return { + 'cond': cond, + 'neg_cond': neg_cond, + } + + def sample_sparse_structure( + self, + cond: dict, + resolution: int, + num_samples: int = 1, + sampler_params: dict = {}, + ) -> torch.Tensor: + """ + Sample sparse structures with the given conditioning. + + Args: + cond (dict): The conditioning information. + resolution (int): The resolution of the sparse structure. + num_samples (int): The number of samples to generate. + sampler_params (dict): Additional parameters for the sampler. + """ + # Sample sparse structure latent + flow_model = self.models['sparse_structure_flow_model'] + reso = flow_model.resolution + in_channels = flow_model.in_channels + noise = torch.randn(num_samples, in_channels, reso, reso, reso).to(self.device) + sampler_params = {**self.sparse_structure_sampler_params, **sampler_params} + if self.low_vram: + flow_model.to(self.device) + z_s = self.sparse_structure_sampler.sample( + flow_model, + noise, + **cond, + **sampler_params, + verbose=True, + tqdm_desc="Sampling sparse structure", + ).samples + if self.low_vram: + flow_model.cpu() + + # Decode sparse structure latent + decoder = self.models['sparse_structure_decoder'] + if self.low_vram: + decoder.to(self.device) + decoded = decoder(z_s)>0 + if self.low_vram: + decoder.cpu() + if resolution != decoded.shape[2]: + ratio = decoded.shape[2] // resolution + decoded = torch.nn.functional.max_pool3d(decoded.float(), ratio, ratio, 0) > 0.5 + coords = torch.argwhere(decoded)[:, [0, 2, 3, 4]].int() + + return coords + + def sample_shape_slat( + self, + cond: dict, + flow_model, + coords: torch.Tensor, + sampler_params: dict = {}, + ) -> SparseTensor: + """ + Sample structured latent with the given conditioning. + + Args: + cond (dict): The conditioning information. + coords (torch.Tensor): The coordinates of the sparse structure. + sampler_params (dict): Additional parameters for the sampler. + """ + # Sample structured latent + noise = SparseTensor( + feats=torch.randn(coords.shape[0], flow_model.in_channels).to(self.device), + coords=coords, + ) + sampler_params = {**self.shape_slat_sampler_params, **sampler_params} + if self.low_vram: + flow_model.to(self.device) + slat = self.shape_slat_sampler.sample( + flow_model, + noise, + **cond, + **sampler_params, + verbose=True, + tqdm_desc="Sampling shape SLat", + ).samples + if self.low_vram: + flow_model.cpu() + + std = torch.tensor(self.shape_slat_normalization['std'])[None].to(slat.device) + mean = torch.tensor(self.shape_slat_normalization['mean'])[None].to(slat.device) + slat = slat * std + mean + + return slat + + def sample_shape_slat_cascade( + self, + lr_cond: dict, + cond: dict, + flow_model_lr, + flow_model, + lr_resolution: int, + resolution: int, + coords: torch.Tensor, + sampler_params: dict = {}, + max_num_tokens: int = 49152, + ) -> SparseTensor: + """ + Sample structured latent with the given conditioning. + + Args: + cond (dict): The conditioning information. + coords (torch.Tensor): The coordinates of the sparse structure. + sampler_params (dict): Additional parameters for the sampler. + """ + # LR + noise = SparseTensor( + feats=torch.randn(coords.shape[0], flow_model_lr.in_channels).to(self.device), + coords=coords, + ) + sampler_params = {**self.shape_slat_sampler_params, **sampler_params} + if self.low_vram: + flow_model_lr.to(self.device) + slat = self.shape_slat_sampler.sample( + flow_model_lr, + noise, + **lr_cond, + **sampler_params, + verbose=True, + tqdm_desc="Sampling shape SLat", + ).samples + if self.low_vram: + flow_model_lr.cpu() + std = torch.tensor(self.shape_slat_normalization['std'])[None].to(slat.device) + mean = torch.tensor(self.shape_slat_normalization['mean'])[None].to(slat.device) + slat = slat * std + mean + + # Upsample + if self.low_vram: + self.models['shape_slat_decoder'].to(self.device) + self.models['shape_slat_decoder'].low_vram = True + hr_coords = self.models['shape_slat_decoder'].upsample(slat, upsample_times=4) + if self.low_vram: + self.models['shape_slat_decoder'].cpu() + self.models['shape_slat_decoder'].low_vram = False + hr_resolution = resolution + while True: + quant_coords = torch.cat([ + hr_coords[:, :1], + ((hr_coords[:, 1:] + 0.5) / lr_resolution * (hr_resolution // 16)).int(), + ], dim=1) + coords = quant_coords.unique(dim=0) + num_tokens = coords.shape[0] + if num_tokens < max_num_tokens or hr_resolution == 1024: + if hr_resolution != resolution: + print(f"Due to the limited number of tokens, the resolution is reduced to {hr_resolution}.") + break + hr_resolution -= 128 + + # Sample structured latent + noise = SparseTensor( + feats=torch.randn(coords.shape[0], flow_model.in_channels).to(self.device), + coords=coords, + ) + sampler_params = {**self.shape_slat_sampler_params, **sampler_params} + if self.low_vram: + flow_model.to(self.device) + slat = self.shape_slat_sampler.sample( + flow_model, + noise, + **cond, + **sampler_params, + verbose=True, + tqdm_desc="Sampling shape SLat", + ).samples + if self.low_vram: + flow_model.cpu() + + std = torch.tensor(self.shape_slat_normalization['std'])[None].to(slat.device) + mean = torch.tensor(self.shape_slat_normalization['mean'])[None].to(slat.device) + slat = slat * std + mean + + return slat, hr_resolution + + def decode_shape_slat( + self, + slat: SparseTensor, + resolution: int, + ) -> Tuple[List[Mesh], List[SparseTensor]]: + """ + Decode the structured latent. + + Args: + slat (SparseTensor): The structured latent. + formats (List[str]): The formats to decode the structured latent to. + + Returns: + List[Mesh]: The decoded meshes. + List[SparseTensor]: The decoded substructures. + """ + self.models['shape_slat_decoder'].set_resolution(resolution) + if self.low_vram: + self.models['shape_slat_decoder'].to(self.device) + self.models['shape_slat_decoder'].low_vram = True + ret = self.models['shape_slat_decoder'](slat, return_subs=True) + if self.low_vram: + self.models['shape_slat_decoder'].cpu() + self.models['shape_slat_decoder'].low_vram = False + return ret + + def sample_tex_slat( + self, + cond: dict, + flow_model, + shape_slat: SparseTensor, + sampler_params: dict = {}, + ) -> SparseTensor: + """ + Sample structured latent with the given conditioning. + + Args: + cond (dict): The conditioning information. + shape_slat (SparseTensor): The structured latent for shape + sampler_params (dict): Additional parameters for the sampler. + """ + # Sample structured latent + std = torch.tensor(self.shape_slat_normalization['std'])[None].to(shape_slat.device) + mean = torch.tensor(self.shape_slat_normalization['mean'])[None].to(shape_slat.device) + shape_slat = (shape_slat - mean) / std + + in_channels = flow_model.in_channels if isinstance(flow_model, nn.Module) else flow_model[0].in_channels + noise = shape_slat.replace(feats=torch.randn(shape_slat.coords.shape[0], in_channels - shape_slat.feats.shape[1]).to(self.device)) + sampler_params = {**self.tex_slat_sampler_params, **sampler_params} + if self.low_vram: + flow_model.to(self.device) + slat = self.tex_slat_sampler.sample( + flow_model, + noise, + concat_cond=shape_slat, + **cond, + **sampler_params, + verbose=True, + tqdm_desc="Sampling texture SLat", + ).samples + if self.low_vram: + flow_model.cpu() + + std = torch.tensor(self.tex_slat_normalization['std'])[None].to(slat.device) + mean = torch.tensor(self.tex_slat_normalization['mean'])[None].to(slat.device) + slat = slat * std + mean + + return slat + + def decode_tex_slat( + self, + slat: SparseTensor, + subs: List[SparseTensor], + ) -> SparseTensor: + """ + Decode the structured latent. + + Args: + slat (SparseTensor): The structured latent. + formats (List[str]): The formats to decode the structured latent to. + + Returns: + List[SparseTensor]: The decoded texture voxels + """ + if self.low_vram: + self.models['tex_slat_decoder'].to(self.device) + ret = self.models['tex_slat_decoder'](slat, guide_subs=subs) * 0.5 + 0.5 + if self.low_vram: + self.models['tex_slat_decoder'].cpu() + return ret + + @torch.no_grad() + def decode_latent( + self, + shape_slat: SparseTensor, + tex_slat: SparseTensor, + resolution: int, + ) -> List[MeshWithVoxel]: + """ + Decode the latent codes. + + Args: + shape_slat (SparseTensor): The structured latent for shape. + tex_slat (SparseTensor): The structured latent for texture. + resolution (int): The resolution of the output. + """ + meshes, subs = self.decode_shape_slat(shape_slat, resolution) + tex_voxels = self.decode_tex_slat(tex_slat, subs) + out_mesh = [] + for m, v in zip(meshes, tex_voxels): + m.fill_holes() + out_mesh.append( + MeshWithVoxel( + m.vertices, m.faces, + origin = [-0.5, -0.5, -0.5], + voxel_size = 1 / resolution, + coords = v.coords[:, 1:], + attrs = v.feats, + voxel_shape = torch.Size([*v.shape, *v.spatial_shape]), + layout=self.pbr_attr_layout + ) + ) + return out_mesh + + @torch.no_grad() + def run( + self, + image: Image.Image, + num_samples: int = 1, + seed: int = 42, + sparse_structure_sampler_params: dict = {}, + shape_slat_sampler_params: dict = {}, + tex_slat_sampler_params: dict = {}, + preprocess_image: bool = True, + return_latent: bool = False, + pipeline_type: Optional[str] = None, + max_num_tokens: int = 49152, + ) -> List[MeshWithVoxel]: + """ + Run the pipeline. + + Args: + image (Image.Image): The image prompt. + num_samples (int): The number of samples to generate. + seed (int): The random seed. + sparse_structure_sampler_params (dict): Additional parameters for the sparse structure sampler. + shape_slat_sampler_params (dict): Additional parameters for the shape SLat sampler. + tex_slat_sampler_params (dict): Additional parameters for the texture SLat sampler. + preprocess_image (bool): Whether to preprocess the image. + return_latent (bool): Whether to return the latent codes. + pipeline_type (str): The type of the pipeline. Options: '512', '1024', '1024_cascade', '1536_cascade'. + max_num_tokens (int): The maximum number of tokens to use. + """ + # Check pipeline type + pipeline_type = pipeline_type or self.default_pipeline_type + if pipeline_type == '512': + assert 'shape_slat_flow_model_512' in self.models, "No 512 resolution shape SLat flow model found." + assert 'tex_slat_flow_model_512' in self.models, "No 512 resolution texture SLat flow model found." + elif pipeline_type == '1024': + assert 'shape_slat_flow_model_1024' in self.models, "No 1024 resolution shape SLat flow model found." + assert 'tex_slat_flow_model_1024' in self.models, "No 1024 resolution texture SLat flow model found." + elif pipeline_type == '1024_cascade': + assert 'shape_slat_flow_model_512' in self.models, "No 512 resolution shape SLat flow model found." + assert 'shape_slat_flow_model_1024' in self.models, "No 1024 resolution shape SLat flow model found." + assert 'tex_slat_flow_model_1024' in self.models, "No 1024 resolution texture SLat flow model found." + elif pipeline_type == '1536_cascade': + assert 'shape_slat_flow_model_512' in self.models, "No 512 resolution shape SLat flow model found." + assert 'shape_slat_flow_model_1024' in self.models, "No 1024 resolution shape SLat flow model found." + assert 'tex_slat_flow_model_1024' in self.models, "No 1024 resolution texture SLat flow model found." + else: + raise ValueError(f"Invalid pipeline type: {pipeline_type}") + + if preprocess_image: + image = self.preprocess_image(image) + torch.manual_seed(seed) + cond_512 = self.get_cond([image], 512) + cond_1024 = self.get_cond([image], 1024) if pipeline_type != '512' else None + ss_res = {'512': 32, '1024': 64, '1024_cascade': 32, '1536_cascade': 32}[pipeline_type] + coords = self.sample_sparse_structure( + cond_512, ss_res, + num_samples, sparse_structure_sampler_params + ) + if pipeline_type == '512': + shape_slat = self.sample_shape_slat( + cond_512, self.models['shape_slat_flow_model_512'], + coords, shape_slat_sampler_params + ) + tex_slat = self.sample_tex_slat( + cond_512, self.models['tex_slat_flow_model_512'], + shape_slat, tex_slat_sampler_params + ) + res = 512 + elif pipeline_type == '1024': + shape_slat = self.sample_shape_slat( + cond_1024, self.models['shape_slat_flow_model_1024'], + coords, shape_slat_sampler_params + ) + tex_slat = self.sample_tex_slat( + cond_1024, self.models['tex_slat_flow_model_1024'], + shape_slat, tex_slat_sampler_params + ) + res = 1024 + elif pipeline_type == '1024_cascade': + shape_slat, res = self.sample_shape_slat_cascade( + cond_512, cond_1024, + self.models['shape_slat_flow_model_512'], self.models['shape_slat_flow_model_1024'], + 512, 1024, + coords, shape_slat_sampler_params, + max_num_tokens + ) + tex_slat = self.sample_tex_slat( + cond_1024, self.models['tex_slat_flow_model_1024'], + shape_slat, tex_slat_sampler_params + ) + elif pipeline_type == '1536_cascade': + shape_slat, res = self.sample_shape_slat_cascade( + cond_512, cond_1024, + self.models['shape_slat_flow_model_512'], self.models['shape_slat_flow_model_1024'], + 512, 1536, + coords, shape_slat_sampler_params, + max_num_tokens + ) + tex_slat = self.sample_tex_slat( + cond_1024, self.models['tex_slat_flow_model_1024'], + shape_slat, tex_slat_sampler_params + ) + torch.cuda.empty_cache() + out_mesh = self.decode_latent(shape_slat, tex_slat, res) + if return_latent: + return out_mesh, (shape_slat, tex_slat, res) + else: + return out_mesh diff --git a/trellis2/renderers/__init__.py b/trellis2/renderers/__init__.py new file mode 100644 index 0000000..de3203d --- /dev/null +++ b/trellis2/renderers/__init__.py @@ -0,0 +1,33 @@ +import importlib + +__attributes = { + 'MeshRenderer': 'mesh_renderer', + 'VoxelRenderer': 'voxel_renderer', + 'PbrMeshRenderer': 'pbr_mesh_renderer', + 'EnvMap': 'pbr_mesh_renderer', +} + +__submodules = [] + +__all__ = list(__attributes.keys()) + __submodules + +def __getattr__(name): + if name not in globals(): + if name in __attributes: + module_name = __attributes[name] + module = importlib.import_module(f".{module_name}", __name__) + globals()[name] = getattr(module, name) + elif name in __submodules: + module = importlib.import_module(f".{name}", __name__) + globals()[name] = module + else: + raise AttributeError(f"module {__name__} has no attribute {name}") + return globals()[name] + + +# For Pylance +if __name__ == '__main__': + from .mesh_renderer import MeshRenderer + from .voxel_renderer import VoxelRenderer + from .pbr_mesh_renderer import PbrMeshRenderer, EnvMap + \ No newline at end of file diff --git a/trellis2/renderers/mesh_renderer.py b/trellis2/renderers/mesh_renderer.py new file mode 100644 index 0000000..e20efc5 --- /dev/null +++ b/trellis2/renderers/mesh_renderer.py @@ -0,0 +1,414 @@ +from typing import * +import torch +from easydict import EasyDict as edict +from ..representations.mesh import Mesh, MeshWithVoxel, MeshWithPbrMaterial, TextureFilterMode, AlphaMode, TextureWrapMode +import torch.nn.functional as F + + +def intrinsics_to_projection( + intrinsics: torch.Tensor, + near: float, + far: float, + ) -> torch.Tensor: + """ + OpenCV intrinsics to OpenGL perspective matrix + + Args: + intrinsics (torch.Tensor): [3, 3] OpenCV intrinsics matrix + near (float): near plane to clip + far (float): far plane to clip + Returns: + (torch.Tensor): [4, 4] OpenGL perspective matrix + """ + fx, fy = intrinsics[0, 0], intrinsics[1, 1] + cx, cy = intrinsics[0, 2], intrinsics[1, 2] + ret = torch.zeros((4, 4), dtype=intrinsics.dtype, device=intrinsics.device) + ret[0, 0] = 2 * fx + ret[1, 1] = 2 * fy + ret[0, 2] = 2 * cx - 1 + ret[1, 2] = - 2 * cy + 1 + ret[2, 2] = (far + near) / (far - near) + ret[2, 3] = 2 * near * far / (near - far) + ret[3, 2] = 1. + return ret + + +class MeshRenderer: + """ + Renderer for the Mesh representation. + + Args: + rendering_options (dict): Rendering options. + """ + def __init__(self, rendering_options={}, device='cuda'): + if 'dr' not in globals(): + import nvdiffrast.torch as dr + + self.rendering_options = edict({ + "resolution": None, + "near": None, + "far": None, + "ssaa": 1, + "chunk_size": None, + "antialias": True, + "clamp_barycentric_coords": False, + }) + self.rendering_options.update(rendering_options) + self.glctx = dr.RasterizeCudaContext(device=device) + self.device=device + + def render( + self, + mesh : Mesh, + extrinsics: torch.Tensor, + intrinsics: torch.Tensor, + return_types = ["mask", "normal", "depth"], + transformation : Optional[torch.Tensor] = None + ) -> edict: + """ + Render the mesh. + + Args: + mesh : meshmodel + extrinsics (torch.Tensor): (4, 4) camera extrinsics + intrinsics (torch.Tensor): (3, 3) camera intrinsics + return_types (list): list of return types, can be "attr", "mask", "depth", "coord", "normal" + + Returns: + edict based on return_types containing: + attr (torch.Tensor): [C, H, W] rendered attr image + depth (torch.Tensor): [H, W] rendered depth image + normal (torch.Tensor): [3, H, W] rendered normal image + mask (torch.Tensor): [H, W] rendered mask image + """ + if 'dr' not in globals(): + import nvdiffrast.torch as dr + + resolution = self.rendering_options["resolution"] + near = self.rendering_options["near"] + far = self.rendering_options["far"] + ssaa = self.rendering_options["ssaa"] + chunk_size = self.rendering_options["chunk_size"] + antialias = self.rendering_options["antialias"] + clamp_barycentric_coords = self.rendering_options["clamp_barycentric_coords"] + + if mesh.vertices.shape[0] == 0 or mesh.faces.shape[0] == 0: + ret_dict = edict() + for type in return_types: + if type == "mask" : + ret_dict[type] = torch.zeros((resolution, resolution), dtype=torch.float32, device=self.device) + elif type == "depth": + ret_dict[type] = torch.zeros((resolution, resolution), dtype=torch.float32, device=self.device) + elif type == "normal": + ret_dict[type] = torch.full((3, resolution, resolution), 0.5, dtype=torch.float32, device=self.device) + elif type == "coord": + ret_dict[type] = torch.zeros((3, resolution, resolution), dtype=torch.float32, device=self.device) + elif type == "attr": + if isinstance(mesh, MeshWithVoxel): + ret_dict[type] = torch.zeros((mesh.attrs.shape[-1], resolution, resolution), dtype=torch.float32, device=self.device) + else: + ret_dict[type] = torch.zeros((mesh.vertex_attrs.shape[-1], resolution, resolution), dtype=torch.float32, device=self.device) + return ret_dict + + perspective = intrinsics_to_projection(intrinsics, near, far) + + full_proj = (perspective @ extrinsics).unsqueeze(0) + extrinsics = extrinsics.unsqueeze(0) + + vertices = mesh.vertices.unsqueeze(0) + vertices_homo = torch.cat([vertices, torch.ones_like(vertices[..., :1])], dim=-1) + if transformation is not None: + vertices_homo = torch.bmm(vertices_homo, transformation.unsqueeze(0).transpose(-1, -2)) + vertices = vertices_homo[..., :3].contiguous() + vertices_camera = torch.bmm(vertices_homo, extrinsics.transpose(-1, -2)) + vertices_clip = torch.bmm(vertices_homo, full_proj.transpose(-1, -2)) + faces = mesh.faces + + if 'normal' in return_types: + v0 = vertices_camera[0, mesh.faces[:, 0], :3] + v1 = vertices_camera[0, mesh.faces[:, 1], :3] + v2 = vertices_camera[0, mesh.faces[:, 2], :3] + e0 = v1 - v0 + e1 = v2 - v0 + face_normal = torch.cross(e0, e1, dim=1) + face_normal = F.normalize(face_normal, dim=1) + face_normal = torch.where(torch.sum(face_normal * v0, dim=1, keepdim=True) > 0, face_normal, -face_normal) + + out_dict = edict() + if chunk_size is None: + rast, rast_db = dr.rasterize( + self.glctx, vertices_clip, faces, (resolution * ssaa, resolution * ssaa) + ) + if clamp_barycentric_coords: + rast[..., :2] = torch.clamp(rast[..., :2], 0, 1) + rast[..., :2] /= torch.where(rast[..., :2].sum(dim=-1, keepdim=True) > 1, rast[..., :2].sum(dim=-1, keepdim=True), torch.ones_like(rast[..., :2])) + for type in return_types: + img = None + if type == "mask" : + img = (rast[..., -1:] > 0).float() + if antialias: img = dr.antialias(img, rast, vertices_clip, faces) + elif type == "depth": + img = dr.interpolate(vertices_camera[..., 2:3].contiguous(), rast, faces)[0] + if antialias: img = dr.antialias(img, rast, vertices_clip, faces) + elif type == "normal" : + img = dr.interpolate(face_normal.unsqueeze(0), rast, torch.arange(face_normal.shape[0], dtype=torch.int, device=self.device).unsqueeze(1).repeat(1, 3).contiguous())[0] + if antialias: img = dr.antialias(img, rast, vertices_clip, faces) + img = (img + 1) / 2 + elif type == "coord": + img = dr.interpolate(vertices, rast, faces)[0] + if antialias: img = dr.antialias(img, rast, vertices_clip, faces) + elif type == "attr": + if isinstance(mesh, MeshWithVoxel): + if 'grid_sample_3d' not in globals(): + from flex_gemm.ops.grid_sample import grid_sample_3d + mask = rast[..., -1:] > 0 + xyz = dr.interpolate(vertices, rast, faces)[0] + xyz = ((xyz - mesh.origin) / mesh.voxel_size).reshape(1, -1, 3) + img = grid_sample_3d( + mesh.attrs, + torch.cat([torch.zeros_like(mesh.coords[..., :1]), mesh.coords], dim=-1), + mesh.voxel_shape, + xyz, + mode='trilinear' + ) + img = img.reshape(1, resolution * ssaa, resolution * ssaa, mesh.attrs.shape[-1]) * mask + elif isinstance(mesh, MeshWithPbrMaterial): + tri_id = rast[0, :, :, -1:] + mask = tri_id > 0 + uv_coords = mesh.uv_coords.reshape(1, -1, 2) + texc, texd = dr.interpolate( + uv_coords, + rast, + torch.arange(mesh.uv_coords.shape[0] * 3, dtype=torch.int, device=self.device).reshape(-1, 3), + rast_db=rast_db, + diff_attrs='all' + ) + # Fix problematic texture coordinates + texc = torch.nan_to_num(texc, nan=0.0, posinf=1e3, neginf=-1e3) + texc = torch.clamp(texc, min=-1e3, max=1e3) + texd = torch.nan_to_num(texd, nan=0.0, posinf=1e3, neginf=-1e3) + texd = torch.clamp(texd, min=-1e3, max=1e3) + mid = mesh.material_ids[(tri_id - 1).long()] + imgs = { + 'base_color': torch.zeros((resolution * ssaa, resolution * ssaa, 3), dtype=torch.float32, device=self.device), + 'metallic': torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device), + 'roughness': torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device), + 'alpha': torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device) + } + for id, mat in enumerate(mesh.materials): + mat_mask = (mid == id).float() * mask.float() + mat_texc = texc * mat_mask + mat_texd = texd * mat_mask + + if mat.base_color_texture is not None: + base_color = dr.texture( + mat.base_color_texture.image.unsqueeze(0), + mat_texc, + mat_texd, + filter_mode='linear-mipmap-linear' if mat.base_color_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest', + boundary_mode='clamp' if mat.base_color_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap' + )[0] + imgs['base_color'] += base_color * mat.base_color_factor * mat_mask + else: + imgs['base_color'] += mat.base_color_factor * mat_mask + + if mat.metallic_texture is not None: + metallic = dr.texture( + mat.metallic_texture.image.unsqueeze(0), + mat_texc, + mat_texd, + filter_mode='linear-mipmap-linear' if mat.metallic_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest', + boundary_mode='clamp' if mat.metallic_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap' + )[0] + imgs['metallic'] += metallic * mat.metallic_factor * mat_mask + else: + imgs['metallic'] += mat.metallic_factor * mat_mask + + if mat.roughness_texture is not None: + roughness = dr.texture( + mat.roughness_texture.image.unsqueeze(0), + mat_texc, + mat_texd, + filter_mode='linear-mipmap-linear' if mat.roughness_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest', + boundary_mode='clamp' if mat.roughness_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap' + )[0] + imgs['roughness'] += roughness * mat.roughness_factor * mat_mask + else: + imgs['roughness'] += mat.roughness_factor * mat_mask + + if mat.alpha_mode == AlphaMode.OPAQUE: + imgs['alpha'] += 1.0 * mat_mask + else: + if mat.alpha_texture is not None: + alpha = dr.texture( + mat.alpha_texture.image.unsqueeze(0), + mat_texc, + mat_texd, + filter_mode='linear-mipmap-linear' if mat.alpha_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest', + boundary_mode='clamp' if mat.alpha_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap' + )[0] + if mat.alpha_mode == AlphaMode.MASK: + imgs['alpha'] += (alpha * mat.alpha_factor > mat.alpha_cutoff).float() * mat_mask + elif mat.alpha_mode == AlphaMode.BLEND: + imgs['alpha'] += alpha * mat.alpha_factor * mat_mask + else: + if mat.alpha_mode == AlphaMode.MASK: + imgs['alpha'] += (mat.alpha_factor > mat.alpha_cutoff).float() * mat_mask + elif mat.alpha_mode == AlphaMode.BLEND: + imgs['alpha'] += mat.alpha_factor * mat_mask + + img = torch.cat([imgs[name] for name in imgs.keys()], dim=-1).unsqueeze(0) + else: + img = dr.interpolate(mesh.vertex_attrs.unsqueeze(0), rast, faces)[0] + if antialias: img = dr.antialias(img, rast, vertices_clip, faces) + + out_dict[type] = img + else: + z_buffer = torch.full((1, resolution * ssaa, resolution * ssaa), torch.inf, device=self.device, dtype=torch.float32) + for i in range(0, faces.shape[0], chunk_size): + faces_chunk = faces[i:i+chunk_size] + rast, rast_db = dr.rasterize( + self.glctx, vertices_clip, faces_chunk, (resolution * ssaa, resolution * ssaa) + ) + z_filter = torch.logical_and( + rast[..., 3] != 0, + rast[..., 2] < z_buffer + ) + z_buffer[z_filter] = rast[z_filter][..., 2] + + for type in return_types: + img = None + if type == "mask" : + img = (rast[..., -1:] > 0).float() + elif type == "depth": + img = dr.interpolate(vertices_camera[..., 2:3].contiguous(), rast, faces_chunk)[0] + elif type == "normal" : + face_normal_chunk = face_normal[i:i+chunk_size] + img = dr.interpolate(face_normal_chunk.unsqueeze(0), rast, torch.arange(face_normal_chunk.shape[0], dtype=torch.int, device=self.device).unsqueeze(1).repeat(1, 3).contiguous())[0] + img = (img + 1) / 2 + elif type == "coord": + img = dr.interpolate(vertices, rast, faces_chunk)[0] + elif type == "attr": + if isinstance(mesh, MeshWithVoxel): + if 'grid_sample_3d' not in globals(): + from flex_gemm.ops.grid_sample import grid_sample_3d + mask = rast[..., -1:] > 0 + xyz = dr.interpolate(vertices, rast, faces_chunk)[0] + xyz = ((xyz - mesh.origin) / mesh.voxel_size).reshape(1, -1, 3) + img = grid_sample_3d( + mesh.attrs, + torch.cat([torch.zeros_like(mesh.coords[..., :1]), mesh.coords], dim=-1), + mesh.voxel_shape, + xyz, + mode='trilinear' + ) + img = img.reshape(1, resolution * ssaa, resolution * ssaa, mesh.attrs.shape[-1]) * mask + elif isinstance(mesh, MeshWithPbrMaterial): + tri_id = rast[0, :, :, -1:] + mask = tri_id > 0 + uv_coords = mesh.uv_coords.reshape(1, -1, 2) + texc, texd = dr.interpolate( + uv_coords, + rast, + torch.arange(mesh.uv_coords.shape[0] * 3, dtype=torch.int, device=self.device).reshape(-1, 3), + rast_db=rast_db, + diff_attrs='all' + ) + # Fix problematic texture coordinates + texc = torch.nan_to_num(texc, nan=0.0, posinf=1e3, neginf=-1e3) + texc = torch.clamp(texc, min=-1e3, max=1e3) + texd = torch.nan_to_num(texd, nan=0.0, posinf=1e3, neginf=-1e3) + texd = torch.clamp(texd, min=-1e3, max=1e3) + mid = mesh.material_ids[(tri_id - 1).long()] + imgs = { + 'base_color': torch.zeros((resolution * ssaa, resolution * ssaa, 3), dtype=torch.float32, device=self.device), + 'metallic': torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device), + 'roughness': torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device), + 'alpha': torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device) + } + for id, mat in enumerate(mesh.materials): + mat_mask = (mid == id).float() * mask.float() + mat_texc = texc * mat_mask + mat_texd = texd * mat_mask + + if mat.base_color_texture is not None: + base_color = dr.texture( + mat.base_color_texture.image.unsqueeze(0), + mat_texc, + mat_texd, + filter_mode='linear-mipmap-linear' if mat.base_color_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest', + boundary_mode='clamp' if mat.base_color_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap' + )[0] + imgs['base_color'] += base_color * mat.base_color_factor * mat_mask + else: + imgs['base_color'] += mat.base_color_factor * mat_mask + + if mat.metallic_texture is not None: + metallic = dr.texture( + mat.metallic_texture.image.unsqueeze(0), + mat_texc, + mat_texd, + filter_mode='linear-mipmap-linear' if mat.metallic_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest', + boundary_mode='clamp' if mat.metallic_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap' + )[0] + imgs['metallic'] += metallic * mat.metallic_factor * mat_mask + else: + imgs['metallic'] += mat.metallic_factor * mat_mask + + if mat.roughness_texture is not None: + roughness = dr.texture( + mat.roughness_texture.image.unsqueeze(0), + mat_texc, + mat_texd, + filter_mode='linear-mipmap-linear' if mat.roughness_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest', + boundary_mode='clamp' if mat.roughness_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap' + )[0] + imgs['roughness'] += roughness * mat.roughness_factor * mat_mask + else: + imgs['roughness'] += mat.roughness_factor * mat_mask + + if mat.alpha_mode == AlphaMode.OPAQUE: + imgs['alpha'] += 1.0 * mat_mask + else: + if mat.alpha_texture is not None: + alpha = dr.texture( + mat.alpha_texture.image.unsqueeze(0), + mat_texc, + mat_texd, + filter_mode='linear-mipmap-linear' if mat.alpha_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest', + boundary_mode='clamp' if mat.alpha_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap' + )[0] + if mat.alpha_mode == AlphaMode.MASK: + imgs['alpha'] += (alpha * mat.alpha_factor > mat.alpha_cutoff).float() * mat_mask + elif mat.alpha_mode == AlphaMode.BLEND: + imgs['alpha'] += alpha * mat.alpha_factor * mat_mask + else: + if mat.alpha_mode == AlphaMode.MASK: + imgs['alpha'] += (mat.alpha_factor > mat.alpha_cutoff).float() * mat_mask + elif mat.alpha_mode == AlphaMode.BLEND: + imgs['alpha'] += mat.alpha_factor * mat_mask + + img = torch.cat([imgs[name] for name in imgs.keys()], dim=-1).unsqueeze(0) + else: + img = dr.interpolate(mesh.vertex_attrs.unsqueeze(0), rast, faces_chunk)[0] + + if type not in out_dict: + out_dict[type] = img + else: + out_dict[type][z_filter] = img[z_filter] + + for type in return_types: + img = out_dict[type] + if ssaa > 1: + img = F.interpolate(img.permute(0, 3, 1, 2), (resolution, resolution), mode='bilinear', align_corners=False, antialias=True) + img = img.squeeze() + else: + img = img.permute(0, 3, 1, 2).squeeze() + out_dict[type] = img + + if isinstance(mesh, (MeshWithVoxel, MeshWithPbrMaterial)) and 'attr' in return_types: + for k, s in mesh.layout.items(): + out_dict[k] = out_dict['attr'][s] + del out_dict['attr'] + + return out_dict diff --git a/trellis2/renderers/pbr_mesh_renderer.py b/trellis2/renderers/pbr_mesh_renderer.py new file mode 100644 index 0000000..876378f --- /dev/null +++ b/trellis2/renderers/pbr_mesh_renderer.py @@ -0,0 +1,480 @@ +from typing import * +import torch +from easydict import EasyDict as edict +import numpy as np +import utils3d +from ..representations.mesh import Mesh, MeshWithVoxel, MeshWithPbrMaterial, TextureFilterMode, AlphaMode, TextureWrapMode +import torch.nn.functional as F + + +def cube_to_dir(s, x, y): + if s == 0: rx, ry, rz = torch.ones_like(x), -x, -y + elif s == 1: rx, ry, rz = -torch.ones_like(x), x, -y + elif s == 2: rx, ry, rz = x, y, torch.ones_like(x) + elif s == 3: rx, ry, rz = x, -y, -torch.ones_like(x) + elif s == 4: rx, ry, rz = x, torch.ones_like(x), -y + elif s == 5: rx, ry, rz = -x, -torch.ones_like(x), -y + return torch.stack((rx, ry, rz), dim=-1) + + +def latlong_to_cubemap(latlong_map, res): + if 'dr' not in globals(): + import nvdiffrast.torch as dr + cubemap = torch.zeros(6, res[0], res[1], latlong_map.shape[-1], dtype=torch.float32, device='cuda') + for s in range(6): + gy, gx = torch.meshgrid(torch.linspace(-1.0 + 1.0 / res[0], 1.0 - 1.0 / res[0], res[0], device='cuda'), + torch.linspace(-1.0 + 1.0 / res[1], 1.0 - 1.0 / res[1], res[1], device='cuda'), + indexing='ij') + v = F.normalize(cube_to_dir(s, gx, gy), dim=-1) + + tu = torch.atan2(v[..., 0:1], -v[..., 2:3]) / (2 * np.pi) + 0.5 + tv = torch.acos(torch.clamp(v[..., 1:2], min=-1, max=1)) / np.pi + texcoord = torch.cat((tu, tv), dim=-1) + + cubemap[s, ...] = dr.texture(latlong_map[None, ...], texcoord[None, ...], filter_mode='linear')[0] + return cubemap + + +class EnvMap: + def __init__(self, image: torch.Tensor): + self.image = image + + @property + def _backend(self): + if not hasattr(self, '_nvdiffrec_envlight'): + if 'EnvironmentLight' not in globals(): + from nvdiffrec_render.light import EnvironmentLight + cubemap = latlong_to_cubemap(self.image, [512, 512]) + self._nvdiffrec_envlight = EnvironmentLight(cubemap) + self._nvdiffrec_envlight.build_mips() + return self._nvdiffrec_envlight + + def shade(self, gb_pos, gb_normal, kd, ks, view_pos, specular=True): + return self._backend.shade(gb_pos, gb_normal, kd, ks, view_pos, specular) + + def sample(self, directions: torch.Tensor): + if 'dr' not in globals(): + import nvdiffrast.torch as dr + return dr.texture( + self._backend.base.unsqueeze(0), + directions.unsqueeze(0), + boundary_mode='cube', + )[0] + + +def intrinsics_to_projection( + intrinsics: torch.Tensor, + near: float, + far: float, + ) -> torch.Tensor: + """ + OpenCV intrinsics to OpenGL perspective matrix + + Args: + intrinsics (torch.Tensor): [3, 3] OpenCV intrinsics matrix + near (float): near plane to clip + far (float): far plane to clip + Returns: + (torch.Tensor): [4, 4] OpenGL perspective matrix + """ + fx, fy = intrinsics[0, 0], intrinsics[1, 1] + cx, cy = intrinsics[0, 2], intrinsics[1, 2] + ret = torch.zeros((4, 4), dtype=intrinsics.dtype, device=intrinsics.device) + ret[0, 0] = 2 * fx + ret[1, 1] = 2 * fy + ret[0, 2] = 2 * cx - 1 + ret[1, 2] = - 2 * cy + 1 + ret[2, 2] = (far + near) / (far - near) + ret[2, 3] = 2 * near * far / (near - far) + ret[3, 2] = 1. + return ret + + +def screen_space_ambient_occlusion( + depth: torch.Tensor, + normal: torch.Tensor, + perspective: torch.Tensor, + radius: float = 0.1, + bias: float = 1e-6, + samples: int = 64, + intensity: float = 1.0, +) -> torch.Tensor: + """ + Screen space ambient occlusion (SSAO) + + Args: + depth (torch.Tensor): [H, W, 1] depth image + normal (torch.Tensor): [H, W, 3] normal image + perspective (torch.Tensor): [4, 4] camera projection matrix + radius (float): radius of the SSAO kernel + bias (float): bias to avoid self-occlusion + samples (int): number of samples to use for the SSAO kernel + intensity (float): intensity of the SSAO effect + Returns: + (torch.Tensor): [H, W, 1] SSAO image + """ + device = depth.device + H, W, _ = depth.shape + + fx = perspective[0, 0] + fy = perspective[1, 1] + cx = perspective[0, 2] + cy = perspective[1, 2] + + y_grid, x_grid = torch.meshgrid( + (torch.arange(H, device=device) + 0.5) / H * 2 - 1, + (torch.arange(W, device=device) + 0.5) / W * 2 - 1, + indexing='ij' + ) + x_view = (x_grid.float() - cx) * depth[..., 0] / fx + y_view = (y_grid.float() - cy) * depth[..., 0] / fy + view_pos = torch.stack([x_view, y_view, depth[..., 0]], dim=-1) # [H, W, 3] + + depth_feat = depth.permute(2, 0, 1).unsqueeze(0) + occlusion = torch.zeros((H, W), device=device) + + # start sampling + for _ in range(samples): + # sample normal distribution, if inside, flip the sign + rnd_vec = torch.randn(H, W, 3, device=device) + rnd_vec = F.normalize(rnd_vec, p=2, dim=-1) + dot_val = torch.sum(rnd_vec * normal, dim=-1, keepdim=True) + sample_dir = torch.sign(dot_val) * rnd_vec + scale = torch.rand(H, W, 1, device=device) + scale = scale * scale + sample_pos = view_pos + sample_dir * radius * scale + sample_z = sample_pos[..., 2] + + # project to screen space + z_safe = torch.clamp(sample_pos[..., 2], min=1e-5) + proj_u = (sample_pos[..., 0] * fx / z_safe) + cx + proj_v = (sample_pos[..., 1] * fy / z_safe) + cy + grid = torch.stack([proj_u, proj_v], dim=-1).unsqueeze(0) + geo_z = F.grid_sample(depth_feat, grid, mode='nearest', padding_mode='border').squeeze() + range_check = torch.abs(geo_z - sample_z) < radius + is_occluded = (geo_z <= sample_z - bias) & range_check + occlusion += is_occluded.float() + + f_occ = occlusion / samples * intensity + f_occ = torch.clamp(f_occ, 0.0, 1.0) + + return f_occ.unsqueeze(-1) + + +def aces_tonemapping(x: torch.Tensor) -> torch.Tensor: + """ + Applies ACES tone mapping curve to an HDR image tensor. + Input: x - HDR tensor, shape (..., 3), range [0, +inf) + Output: LDR tensor, same shape, range [0, 1] + """ + a = 2.51 + b = 0.03 + c = 2.43 + d = 0.59 + e = 0.14 + + # Apply the ACES fitted curve + mapped = (x * (a * x + b)) / (x * (c * x + d) + e) + + # Clamp to [0, 1] for display or saving + return torch.clamp(mapped, 0.0, 1.0) + + +def gamma_correction(x: torch.Tensor, gamma: float = 2.2) -> torch.Tensor: + """ + Applies gamma correction to an HDR image tensor. + """ + return torch.clamp(x ** (1.0 / gamma), 0.0, 1.0) + + +class PbrMeshRenderer: + """ + Renderer for the PBR mesh. + + Args: + rendering_options (dict): Rendering options. + """ + def __init__(self, rendering_options={}, device='cuda'): + if 'dr' not in globals(): + import nvdiffrast.torch as dr + + self.rendering_options = edict({ + "resolution": None, + "near": None, + "far": None, + "ssaa": 1, + "peel_layers": 8, + }) + self.rendering_options.update(rendering_options) + self.glctx = dr.RasterizeCudaContext(device=device) + self.device=device + + def render( + self, + mesh : Mesh, + extrinsics: torch.Tensor, + intrinsics: torch.Tensor, + envmap : Union[EnvMap, Dict[str, EnvMap]], + use_envmap_bg : bool = False, + transformation : Optional[torch.Tensor] = None + ) -> edict: + """ + Render the mesh. + + Args: + mesh : meshmodel + extrinsics (torch.Tensor): (4, 4) camera extrinsics + intrinsics (torch.Tensor): (3, 3) camera intrinsics + envmap (Union[EnvMap, Dict[str, EnvMap]]): environment map or a dictionary of environment maps + use_envmap_bg (bool): whether to use envmap as background + transformation (torch.Tensor): (4, 4) transformation matrix + + Returns: + edict based on return_types containing: + shaded (torch.Tensor): [3, H, W] shaded color image + normal (torch.Tensor): [3, H, W] normal image + base_color (torch.Tensor): [3, H, W] base color image + metallic (torch.Tensor): [H, W] metallic image + roughness (torch.Tensor): [H, W] roughness image + """ + if 'dr' not in globals(): + import nvdiffrast.torch as dr + + if not isinstance(envmap, dict): + envmap = {'' : envmap} + num_envmaps = len(envmap) + + resolution = self.rendering_options["resolution"] + near = self.rendering_options["near"] + far = self.rendering_options["far"] + ssaa = self.rendering_options["ssaa"] + + if mesh.vertices.shape[0] == 0 or mesh.faces.shape[0] == 0: + return edict( + shaded=torch.full((4, resolution, resolution), 0.5, dtype=torch.float32, device=self.device), + ) + + rays_o, rays_d = utils3d.torch.get_image_rays( + extrinsics, intrinsics, resolution * ssaa, resolution * ssaa + ) + + perspective = intrinsics_to_projection(intrinsics, near, far) + + full_proj = (perspective @ extrinsics).unsqueeze(0) + extrinsics = extrinsics.unsqueeze(0) + + vertices = mesh.vertices.unsqueeze(0) + vertices_orig = vertices.clone() + vertices_homo = torch.cat([vertices, torch.ones_like(vertices[..., :1])], dim=-1) + if transformation is not None: + vertices_homo = torch.bmm(vertices_homo, transformation.unsqueeze(0).transpose(-1, -2)) + vertices = vertices_homo[..., :3].contiguous() + vertices_camera = torch.bmm(vertices_homo, extrinsics.transpose(-1, -2)) + vertices_clip = torch.bmm(vertices_homo, full_proj.transpose(-1, -2)) + faces = mesh.faces + + v0 = vertices[0, mesh.faces[:, 0], :3] + v1 = vertices[0, mesh.faces[:, 1], :3] + v2 = vertices[0, mesh.faces[:, 2], :3] + e0 = v1 - v0 + e1 = v2 - v0 + face_normal = torch.cross(e0, e1, dim=1) + face_normal = F.normalize(face_normal, dim=1) + + out_dict = edict() + shaded = torch.zeros((num_envmaps, resolution * ssaa, resolution * ssaa, 3), dtype=torch.float32, device=self.device) + depth = torch.full((resolution * ssaa, resolution * ssaa, 1), 1e10, dtype=torch.float32, device=self.device) + normal = torch.zeros((resolution * ssaa, resolution * ssaa, 3), dtype=torch.float32, device=self.device) + max_w = torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device) + alpha = torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device) + with dr.DepthPeeler(self.glctx, vertices_clip, faces, (resolution * ssaa, resolution * ssaa)) as peeler: + for _ in range(self.rendering_options["peel_layers"]): + rast, rast_db = peeler.rasterize_next_layer() + + # Pos + pos = dr.interpolate(vertices, rast, faces)[0][0] + + # Depth + gb_depth = dr.interpolate(vertices_camera[..., 2:3].contiguous(), rast, faces)[0][0] + + # Normal + gb_normal = dr.interpolate(face_normal.unsqueeze(0), rast, torch.arange(face_normal.shape[0], dtype=torch.int, device=self.device).unsqueeze(1).repeat(1, 3).contiguous())[0][0] + gb_normal = torch.where( + torch.sum(gb_normal * (pos - rays_o), dim=-1, keepdim=True) > 0, + -gb_normal, + gb_normal + ) + gb_cam_normal = (extrinsics[..., :3, :3].reshape(1, 1, 3, 3) @ gb_normal.unsqueeze(-1)).squeeze(-1) + if _ == 0: + out_dict.normal = -gb_cam_normal * 0.5 + 0.5 + mask = (rast[0, ..., -1:] > 0).float() + out_dict.mask = mask + + # PBR attributes + if isinstance(mesh, MeshWithVoxel): + if 'grid_sample_3d' not in globals(): + from flex_gemm.ops.grid_sample import grid_sample_3d + mask = rast[..., -1:] > 0 + xyz = dr.interpolate(vertices_orig, rast, faces)[0] + xyz = ((xyz - mesh.origin) / mesh.voxel_size).reshape(1, -1, 3) + img = grid_sample_3d( + mesh.attrs, + torch.cat([torch.zeros_like(mesh.coords[..., :1]), mesh.coords], dim=-1), + mesh.voxel_shape, + xyz, + mode='trilinear' + ) + img = img.reshape(1, resolution * ssaa, resolution * ssaa, mesh.attrs.shape[-1]) * mask + gb_basecolor = img[0, ..., mesh.layout['base_color']] + gb_metallic = img[0, ..., mesh.layout['metallic']] + gb_roughness = img[0, ..., mesh.layout['roughness']] + gb_alpha = img[0, ..., mesh.layout['alpha']] + elif isinstance(mesh, MeshWithPbrMaterial): + tri_id = rast[0, :, :, -1:] + mask = tri_id > 0 + uv_coords = mesh.uv_coords.reshape(1, -1, 2) + texc, texd = dr.interpolate( + uv_coords, + rast, + torch.arange(mesh.uv_coords.shape[0] * 3, dtype=torch.int, device=self.device).reshape(-1, 3), + rast_db=rast_db, + diff_attrs='all' + ) + # Fix problematic texture coordinates + texc = torch.nan_to_num(texc, nan=0.0, posinf=1e3, neginf=-1e3) + texc = torch.clamp(texc, min=-1e3, max=1e3) + texd = torch.nan_to_num(texd, nan=0.0, posinf=1e3, neginf=-1e3) + texd = torch.clamp(texd, min=-1e3, max=1e3) + mid = mesh.material_ids[(tri_id - 1).long()] + gb_basecolor = torch.zeros((resolution * ssaa, resolution * ssaa, 3), dtype=torch.float32, device=self.device) + gb_metallic = torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device) + gb_roughness = torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device) + gb_alpha = torch.zeros((resolution * ssaa, resolution * ssaa, 1), dtype=torch.float32, device=self.device) + for id, mat in enumerate(mesh.materials): + mat_mask = (mid == id).float() * mask.float() + mat_texc = texc * mat_mask + mat_texd = texd * mat_mask + + if mat.base_color_texture is not None: + bc = dr.texture( + mat.base_color_texture.image.unsqueeze(0), + mat_texc, + mat_texd, + filter_mode='linear-mipmap-linear' if mat.base_color_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest', + boundary_mode='clamp' if mat.base_color_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap' + )[0] + gb_basecolor += bc * mat.base_color_factor * mat_mask + else: + gb_basecolor += mat.base_color_factor * mat_mask + + if mat.metallic_texture is not None: + m = dr.texture( + mat.metallic_texture.image.unsqueeze(0), + mat_texc, + mat_texd, + filter_mode='linear-mipmap-linear' if mat.metallic_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest', + boundary_mode='clamp' if mat.metallic_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap' + )[0] + gb_metallic += m * mat.metallic_factor * mat_mask + else: + gb_metallic += mat.metallic_factor * mat_mask + + if mat.roughness_texture is not None: + r = dr.texture( + mat.roughness_texture.image.unsqueeze(0), + mat_texc, + mat_texd, + filter_mode='linear-mipmap-linear' if mat.roughness_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest', + boundary_mode='clamp' if mat.roughness_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap' + )[0] + gb_roughness += r * mat.roughness_factor * mat_mask + else: + gb_roughness += mat.roughness_factor * mat_mask + + if mat.alpha_mode == AlphaMode.OPAQUE: + gb_alpha += 1.0 * mat_mask + else: + if mat.alpha_texture is not None: + a = dr.texture( + mat.alpha_texture.image.unsqueeze(0), + mat_texc, + mat_texd, + filter_mode='linear-mipmap-linear' if mat.alpha_texture.filter_mode == TextureFilterMode.LINEAR else 'nearest', + boundary_mode='clamp' if mat.alpha_texture.wrap_mode == TextureWrapMode.CLAMP_TO_EDGE else 'wrap' + )[0] + if mat.alpha_mode == AlphaMode.MASK: + gb_alpha += (a * mat.alpha_factor > mat.alpha_cutoff).float() * mat_mask + elif mat.alpha_mode == AlphaMode.BLEND: + gb_alpha += a * mat.alpha_factor * mat_mask + else: + if mat.alpha_mode == AlphaMode.MASK: + gb_alpha += (mat.alpha_factor > mat.alpha_cutoff).float() * mat_mask + elif mat.alpha_mode == AlphaMode.BLEND: + gb_alpha += mat.alpha_factor * mat_mask + if _ == 0: + out_dict.base_color = gb_basecolor + out_dict.metallic = gb_metallic + out_dict.roughness = gb_roughness + out_dict.alpha = gb_alpha + + # Shading + gb_basecolor = torch.clamp(gb_basecolor, 0.0, 1.0) ** 2.2 + gb_metallic = torch.clamp(gb_metallic, 0.0, 1.0) + gb_roughness = torch.clamp(gb_roughness, 0.0, 1.0) + gb_alpha = torch.clamp(gb_alpha, 0.0, 1.0) + gb_orm = torch.cat([ + torch.zeros_like(gb_metallic), + gb_roughness, + gb_metallic, + ], dim=-1) + gb_shaded = torch.stack([ + e.shade( + pos.unsqueeze(0), + gb_normal.unsqueeze(0), + gb_basecolor.unsqueeze(0), + gb_orm.unsqueeze(0), + rays_o, + specular=True, + )[0] + for e in envmap.values() + ], dim=0) + + # Compositing + w = (1 - alpha) * gb_alpha + depth = torch.where(w > max_w, gb_depth, depth) + normal = torch.where(w > max_w, gb_cam_normal, normal) + max_w = torch.maximum(max_w, w) + shaded += w * gb_shaded + alpha += w + + # Ambient occulusion + f_occ = screen_space_ambient_occlusion( + depth, normal, perspective, intensity=1.5 + ) + shaded *= (1 - f_occ) + out_dict.clay = (1 - f_occ) + + # Background + if use_envmap_bg: + bg = torch.stack([e.sample(rays_d) for e in envmap.values()], dim=0) + shaded += (1 - alpha) * bg + + for i, k in enumerate(envmap.keys()): + shaded_key = f"shaded_{k}" if k != '' else "shaded" + out_dict[shaded_key] = shaded[i] + + # SSAA + for k in out_dict.keys(): + if ssaa > 1: + out_dict[k] = F.interpolate(out_dict[k].unsqueeze(0).permute(0, 3, 1, 2), (resolution, resolution), mode='bilinear', align_corners=False, antialias=True) + else: + out_dict[k] = out_dict[k].permute(2, 0, 1) + out_dict[k] = out_dict[k].squeeze() + + # Post processing + for k in envmap.keys(): + shaded_key = f"shaded_{k}" if k != '' else "shaded" + out_dict[shaded_key] = aces_tonemapping(out_dict[shaded_key]) + out_dict[shaded_key] = gamma_correction(out_dict[shaded_key]) + + return out_dict diff --git a/trellis2/renderers/voxel_renderer.py b/trellis2/renderers/voxel_renderer.py new file mode 100644 index 0000000..dfe28ad --- /dev/null +++ b/trellis2/renderers/voxel_renderer.py @@ -0,0 +1,68 @@ +import torch +from easydict import EasyDict as edict +from ..representations import Voxel +from easydict import EasyDict as edict + + +class VoxelRenderer: + """ + Renderer for the Voxel representation. + + Args: + rendering_options (dict): Rendering options. + """ + + def __init__(self, rendering_options={}) -> None: + self.rendering_options = edict({ + "resolution": None, + "near": 0.1, + "far": 10.0, + "ssaa": 1, + }) + self.rendering_options.update(rendering_options) + + def render( + self, + voxel: Voxel, + extrinsics: torch.Tensor, + intrinsics: torch.Tensor, + colors_overwrite: torch.Tensor = None + ) -> edict: + """ + Render the gausssian. + + Args: + voxel (Voxel): Voxel representation. + extrinsics (torch.Tensor): (4, 4) camera extrinsics + intrinsics (torch.Tensor): (3, 3) camera intrinsics + colors_overwrite (torch.Tensor): (N, 3) override color + + Returns: + edict containing: + color (torch.Tensor): (3, H, W) rendered color image + depth (torch.Tensor): (H, W) rendered depth + alpha (torch.Tensor): (H, W) rendered alpha + ... + """ + # lazy import + if 'o_voxel' not in globals(): + import o_voxel + renderer = o_voxel.rasterize.VoxelRenderer(self.rendering_options) + positions = voxel.position + attrs = voxel.attrs if colors_overwrite is None else colors_overwrite + voxel_size = voxel.voxel_size + + # Render + render_ret = renderer.render(positions, attrs, voxel_size, extrinsics, intrinsics) + + ret = { + 'depth': render_ret['depth'], + 'alpha': render_ret['alpha'], + } + if colors_overwrite is not None: + ret['color'] = render_ret['attr'] + else: + for k, s in voxel.layout.items(): + ret[k] = render_ret['attr'][s] + + return ret diff --git a/trellis2/representations/__init__.py b/trellis2/representations/__init__.py new file mode 100644 index 0000000..0e7d929 --- /dev/null +++ b/trellis2/representations/__init__.py @@ -0,0 +1,31 @@ +import importlib + +__attributes = { + 'Mesh': 'mesh', + 'Voxel': 'voxel', + 'MeshWithVoxel': 'mesh', + 'MeshWithPbrMaterial': 'mesh', +} + +__submodules = [] + +__all__ = list(__attributes.keys()) + __submodules + +def __getattr__(name): + if name not in globals(): + if name in __attributes: + module_name = __attributes[name] + module = importlib.import_module(f".{module_name}", __name__) + globals()[name] = getattr(module, name) + elif name in __submodules: + module = importlib.import_module(f".{name}", __name__) + globals()[name] = module + else: + raise AttributeError(f"module {__name__} has no attribute {name}") + return globals()[name] + + +# For Pylance +if __name__ == '__main__': + from .mesh import Mesh, MeshWithVoxel, MeshWithPbrMaterial + from .voxel import Voxel diff --git a/trellis2/representations/mesh/__init__.py b/trellis2/representations/mesh/__init__.py new file mode 100644 index 0000000..aff4c99 --- /dev/null +++ b/trellis2/representations/mesh/__init__.py @@ -0,0 +1 @@ +from .base import Mesh, MeshWithVoxel, MeshWithPbrMaterial, TextureFilterMode, TextureWrapMode, AlphaMode, PbrMaterial, Texture diff --git a/trellis2/representations/mesh/base.py b/trellis2/representations/mesh/base.py new file mode 100644 index 0000000..b70e4cc --- /dev/null +++ b/trellis2/representations/mesh/base.py @@ -0,0 +1,234 @@ +from typing import * +import torch +from ..voxel import Voxel +import cumesh +from flex_gemm.ops.grid_sample import grid_sample_3d + + +class Mesh: + def __init__(self, + vertices, + faces, + vertex_attrs=None + ): + self.vertices = vertices.float() + self.faces = faces.int() + self.vertex_attrs = vertex_attrs + + @property + def device(self): + return self.vertices.device + + def to(self, device, non_blocking=False): + return Mesh( + self.vertices.to(device, non_blocking=non_blocking), + self.faces.to(device, non_blocking=non_blocking), + self.vertex_attrs.to(device, non_blocking=non_blocking) if self.vertex_attrs is not None else None, + ) + + def cuda(self, non_blocking=False): + return self.to('cuda', non_blocking=non_blocking) + + def cpu(self): + return self.to('cpu') + + def fill_holes(self, max_hole_perimeter=3e-2): + vertices = self.vertices.cuda() + faces = self.faces.cuda() + + mesh = cumesh.CuMesh() + mesh.init(vertices, faces) + mesh.get_edges() + mesh.get_boundary_info() + if mesh.num_boundaries == 0: + return + mesh.get_vertex_edge_adjacency() + mesh.get_vertex_boundary_adjacency() + mesh.get_manifold_boundary_adjacency() + mesh.read_manifold_boundary_adjacency() + mesh.get_boundary_connected_components() + mesh.get_boundary_loops() + if mesh.num_boundary_loops == 0: + return + mesh.fill_holes(max_hole_perimeter=max_hole_perimeter) + new_vertices, new_faces = mesh.read() + + self.vertices = new_vertices.to(self.device) + self.faces = new_faces.to(self.device) + + def remove_faces(self, face_mask: torch.Tensor): + vertices = self.vertices.cuda() + faces = self.faces.cuda() + + mesh = cumesh.CuMesh() + mesh.init(vertices, faces) + mesh.remove_faces(face_mask) + new_vertices, new_faces = mesh.read() + + self.vertices = new_vertices.to(self.device) + self.faces = new_faces.to(self.device) + + def simplify(self, target=1000000, verbose: bool=False, options: dict={}): + vertices = self.vertices.cuda() + faces = self.faces.cuda() + + mesh = cumesh.CuMesh() + mesh.init(vertices, faces) + mesh.simplify(target, verbose=verbose, options=options) + new_vertices, new_faces = mesh.read() + + self.vertices = new_vertices.to(self.device) + self.faces = new_faces.to(self.device) + + +class TextureFilterMode: + CLOSEST = 0 + LINEAR = 1 + + +class TextureWrapMode: + CLAMP_TO_EDGE = 0 + REPEAT = 1 + MIRRORED_REPEAT = 2 + + +class AlphaMode: + OPAQUE = 0 + MASK = 1 + BLEND = 2 + + +class Texture: + def __init__( + self, + image: torch.Tensor, + filter_mode: TextureFilterMode = TextureFilterMode.LINEAR, + wrap_mode: TextureWrapMode = TextureWrapMode.REPEAT + ): + self.image = image + self.filter_mode = filter_mode + self.wrap_mode = wrap_mode + + def to(self, device, non_blocking=False): + return Texture( + self.image.to(device, non_blocking=non_blocking), + self.filter_mode, + self.wrap_mode, + ) + + +class PbrMaterial: + def __init__( + self, + base_color_texture: Optional[Texture] = None, + base_color_factor: Union[torch.Tensor, List[float]] = [1.0, 1.0, 1.0], + metallic_texture: Optional[Texture] = None, + metallic_factor: float = 1.0, + roughness_texture: Optional[Texture] = None, + roughness_factor: float = 1.0, + alpha_texture: Optional[Texture] = None, + alpha_factor: float = 1.0, + alpha_mode: AlphaMode = AlphaMode.OPAQUE, + alpha_cutoff: float = 0.5, + ): + self.base_color_texture = base_color_texture + self.base_color_factor = torch.tensor(base_color_factor, dtype=torch.float32)[:3] + self.metallic_texture = metallic_texture + self.metallic_factor = metallic_factor + self.roughness_texture = roughness_texture + self.roughness_factor = roughness_factor + self.alpha_texture = alpha_texture + self.alpha_factor = alpha_factor + self.alpha_mode = alpha_mode + self.alpha_cutoff = alpha_cutoff + + def to(self, device, non_blocking=False): + return PbrMaterial( + base_color_texture=self.base_color_texture.to(device, non_blocking=non_blocking) if self.base_color_texture is not None else None, + base_color_factor=self.base_color_factor.to(device, non_blocking=non_blocking), + metallic_texture=self.metallic_texture.to(device, non_blocking=non_blocking) if self.metallic_texture is not None else None, + metallic_factor=self.metallic_factor, + roughness_texture=self.roughness_texture.to(device, non_blocking=non_blocking) if self.roughness_texture is not None else None, + roughness_factor=self.roughness_factor, + alpha_texture=self.alpha_texture.to(device, non_blocking=non_blocking) if self.alpha_texture is not None else None, + alpha_factor=self.alpha_factor, + alpha_mode=self.alpha_mode, + alpha_cutoff=self.alpha_cutoff, + ) + + +class MeshWithPbrMaterial(Mesh): + def __init__(self, + vertices, + faces, + material_ids, + uv_coords, + materials: List[PbrMaterial], + ): + self.vertices = vertices.float() + self.faces = faces.int() + self.material_ids = material_ids # [M] + self.uv_coords = uv_coords # [M, 3, 2] + self.materials = materials + self.layout = { + 'base_color': slice(0, 3), + 'metallic': slice(3, 4), + 'roughness': slice(4, 5), + 'alpha': slice(5, 6), + } + + def to(self, device, non_blocking=False): + return MeshWithPbrMaterial( + self.vertices.to(device, non_blocking=non_blocking), + self.faces.to(device, non_blocking=non_blocking), + self.material_ids.to(device, non_blocking=non_blocking), + self.uv_coords.to(device, non_blocking=non_blocking), + [material.to(device, non_blocking=non_blocking) for material in self.materials], + ) + + +class MeshWithVoxel(Mesh, Voxel): + def __init__(self, + vertices: torch.Tensor, + faces: torch.Tensor, + origin: list, + voxel_size: float, + coords: torch.Tensor, + attrs: torch.Tensor, + voxel_shape: torch.Size, + layout: Dict = {}, + ): + self.vertices = vertices.float() + self.faces = faces.int() + self.origin = torch.tensor(origin, dtype=torch.float32, device=self.device) + self.voxel_size = voxel_size + self.coords = coords + self.attrs = attrs + self.voxel_shape = voxel_shape + self.layout = layout + + def to(self, device, non_blocking=False): + return MeshWithVoxel( + self.vertices.to(device, non_blocking=non_blocking), + self.faces.to(device, non_blocking=non_blocking), + self.origin.tolist(), + self.voxel_size, + self.coords.to(device, non_blocking=non_blocking), + self.attrs.to(device, non_blocking=non_blocking), + self.voxel_shape, + self.layout, + ) + + def query_attrs(self, xyz): + grid = ((xyz - self.origin) / self.voxel_size).reshape(1, -1, 3) + vertex_attrs = grid_sample_3d( + self.attrs, + torch.cat([torch.zeros_like(self.coords[..., :1]), self.coords], dim=-1), + self.voxel_shape, + grid, + mode='trilinear' + )[0] + return vertex_attrs + + def query_vertex_attrs(self): + return self.query_attrs(self.vertices) diff --git a/trellis2/representations/voxel/__init__.py b/trellis2/representations/voxel/__init__.py new file mode 100644 index 0000000..b5792ea --- /dev/null +++ b/trellis2/representations/voxel/__init__.py @@ -0,0 +1 @@ +from .voxel_model import Voxel \ No newline at end of file diff --git a/trellis2/representations/voxel/voxel_model.py b/trellis2/representations/voxel/voxel_model.py new file mode 100644 index 0000000..9317ab2 --- /dev/null +++ b/trellis2/representations/voxel/voxel_model.py @@ -0,0 +1,54 @@ +from typing import Dict +import torch + + +class Voxel: + def __init__( + self, + origin: list, + voxel_size: float, + coords: torch.Tensor = None, + attrs: torch.Tensor = None, + layout: Dict = {}, + device: torch.device = 'cuda' + ): + self.origin = torch.tensor(origin, dtype=torch.float32, device=device) + self.voxel_size = voxel_size + self.coords = coords + self.attrs = attrs + self.layout = layout + self.device = device + + @property + def position(self): + return (self.coords + 0.5) * self.voxel_size + self.origin[None, :] + + def split_attrs(self): + return { + k: self.attrs[:, self.layout[k]] + for k in self.layout + } + + def save(self, path): + # lazy import + if 'o_voxel' not in globals(): + import o_voxel + o_voxel.io.write( + path, + self.coords, + self.split_attrs(), + ) + + def load(self, path): + # lazy import + if 'o_voxel' not in globals(): + import o_voxel + coord, attrs = o_voxel.io.read(path) + self.coords = coord.int().to(self.device) + self.attrs = torch.cat([attrs[k] for k in attrs], dim=1).to(self.device) + # build layout + start = 0 + self.layout = {} + for k in attrs: + self.layout[k] = slice(start, start + attrs[k].shape[1]) + start += attrs[k].shape[1] diff --git a/trellis2/trainers/__init__.py b/trellis2/trainers/__init__.py new file mode 100644 index 0000000..8f25130 --- /dev/null +++ b/trellis2/trainers/__init__.py @@ -0,0 +1,68 @@ +import importlib + +__attributes = { + 'BasicTrainer': 'basic', + + 'SparseStructureVaeTrainer': 'vae.sparse_structure_vae', + 'ShapeVaeTrainer': 'vae.shape_vae', + 'PbrVaeTrainer': 'vae.pbr_vae', + + 'FlowMatchingTrainer': 'flow_matching.flow_matching', + 'FlowMatchingCFGTrainer': 'flow_matching.flow_matching', + 'TextConditionedFlowMatchingCFGTrainer': 'flow_matching.flow_matching', + 'ImageConditionedFlowMatchingCFGTrainer': 'flow_matching.flow_matching', + + 'SparseFlowMatchingTrainer': 'flow_matching.sparse_flow_matching', + 'SparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching', + 'TextConditionedSparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching', + 'ImageConditionedSparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching', + 'MultiImageConditionedSparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching', + + 'DinoV2FeatureExtractor': 'flow_matching.mixins.image_conditioned', + 'DinoV3FeatureExtractor': 'flow_matching.mixins.image_conditioned', +} + +__submodules = [] + +__all__ = list(__attributes.keys()) + __submodules + +def __getattr__(name): + if name not in globals(): + if name in __attributes: + module_name = __attributes[name] + module = importlib.import_module(f".{module_name}", __name__) + globals()[name] = getattr(module, name) + elif name in __submodules: + module = importlib.import_module(f".{name}", __name__) + globals()[name] = module + else: + raise AttributeError(f"module {__name__} has no attribute {name}") + return globals()[name] + + +# For Pylance +if __name__ == '__main__': + from .basic import BasicTrainer + + from .vae.sparse_structure_vae import SparseStructureVaeTrainer + from .vae.shape_vae import ShapeVaeTrainer + from .vae.pbr_vae import PbrVaeTrainer + + from .flow_matching.flow_matching import ( + FlowMatchingTrainer, + FlowMatchingCFGTrainer, + TextConditionedFlowMatchingCFGTrainer, + ImageConditionedFlowMatchingCFGTrainer, + ) + + from .flow_matching.sparse_flow_matching import ( + SparseFlowMatchingTrainer, + SparseFlowMatchingCFGTrainer, + TextConditionedSparseFlowMatchingCFGTrainer, + ImageConditionedSparseFlowMatchingCFGTrainer, + ) + + from .flow_matching.mixins.image_conditioned import ( + DinoV2FeatureExtractor, + DinoV3FeatureExtractor, + ) diff --git a/trellis2/trainers/basic.py b/trellis2/trainers/basic.py new file mode 100644 index 0000000..c8e4b4c --- /dev/null +++ b/trellis2/trainers/basic.py @@ -0,0 +1,910 @@ +from abc import abstractmethod +import os +import time +import json +import copy +import threading +from functools import partial +from contextlib import nullcontext + +import torch +import torch.distributed as dist +from torch.utils.data import DataLoader +from torch.nn.parallel import DistributedDataParallel as DDP +import numpy as np + +from torchvision import utils +from torch.utils.tensorboard import SummaryWriter + +from .utils import * +from ..utils.general_utils import * +from ..utils.data_utils import recursive_to_device, cycle, ResumableSampler +from ..utils.dist_utils import * +from ..utils import grad_clip_utils, elastic_utils + + +class BasicTrainer: + """ + Trainer for basic training loop. + + Args: + models (dict[str, nn.Module]): Models to train. + dataset (torch.utils.data.Dataset): Dataset. + output_dir (str): Output directory. + load_dir (str): Load directory. + step (int): Step to load. + batch_size (int): Batch size. + batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. + batch_split (int): Split batch with gradient accumulation. + max_steps (int): Max steps. + optimizer (dict): Optimizer config. + lr_scheduler (dict): Learning rate scheduler config. + elastic (dict): Elastic memory management config. + grad_clip (float or dict): Gradient clip config. + ema_rate (float or list): Exponential moving average rates. + mix_precision_mode (str): + - None: No mixed precision. + - 'inflat_all': Hold a inflated fp32 master param for all params. + - 'amp': Automatic mixed precision. + mix_precision_dtype (str): Mixed precision dtype. + fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. + parallel_mode (str): Parallel mode. Options are 'ddp'. + finetune_ckpt (dict): Finetune checkpoint. + log_param_stats (bool): Log parameter stats. + i_print (int): Print interval. + i_log (int): Log interval. + i_sample (int): Sample interval. + i_save (int): Save interval. + i_ddpcheck (int): DDP check interval. + """ + def __init__(self, + models, + dataset, + *, + output_dir, + load_dir, + step, + max_steps, + batch_size=None, + batch_size_per_gpu=None, + batch_split=None, + optimizer={}, + lr_scheduler=None, + elastic=None, + grad_clip=None, + ema_rate=0.9999, + fp16_mode=None, + mix_precision_mode='inflat_all', + mix_precision_dtype='float16', + fp16_scale_growth=1e-3, + parallel_mode='ddp', + finetune_ckpt=None, + log_param_stats=False, + prefetch_data=True, + snapshot_batch_size=4, + i_print=1000, + i_log=500, + i_sample=10000, + i_save=10000, + i_ddpcheck=10000, + **kwargs + ): + assert batch_size is not None or batch_size_per_gpu is not None, 'Either batch_size or batch_size_per_gpu must be specified.' + + self.models = models + self.dataset = dataset + self.batch_split = batch_split if batch_split is not None else 1 + self.max_steps = max_steps + self.optimizer_config = optimizer + self.lr_scheduler_config = lr_scheduler + self.elastic_controller_config = elastic + self.grad_clip = grad_clip + self.ema_rate = [ema_rate] if isinstance(ema_rate, float) else ema_rate + if fp16_mode is not None: + mix_precision_dtype = 'float16' + mix_precision_mode = fp16_mode + self.mix_precision_mode = mix_precision_mode + self.mix_precision_dtype = str_to_dtype(mix_precision_dtype) + self.fp16_scale_growth = fp16_scale_growth + self.parallel_mode = parallel_mode + self.log_param_stats = log_param_stats + self.prefetch_data = prefetch_data + self.snapshot_batch_size = snapshot_batch_size + self.log = [] + if self.prefetch_data: + self._data_prefetched = None + + self.output_dir = output_dir + self.i_print = i_print + self.i_log = i_log + self.i_sample = i_sample + self.i_save = i_save + self.i_ddpcheck = i_ddpcheck + + if dist.is_initialized(): + # Multi-GPU params + self.world_size = dist.get_world_size() + self.rank = dist.get_rank() + self.local_rank = dist.get_rank() % torch.cuda.device_count() + self.is_master = self.rank == 0 + else: + # Single-GPU params + self.world_size = 1 + self.rank = 0 + self.local_rank = 0 + self.is_master = True + + self.batch_size = batch_size if batch_size_per_gpu is None else batch_size_per_gpu * self.world_size + self.batch_size_per_gpu = batch_size_per_gpu if batch_size_per_gpu is not None else batch_size // self.world_size + assert self.batch_size % self.world_size == 0, 'Batch size must be divisible by the number of GPUs.' + assert self.batch_size_per_gpu % self.batch_split == 0, 'Batch size per GPU must be divisible by batch split.' + + self.init_models_and_more(**kwargs) + self.prepare_dataloader(**kwargs) + + # Load checkpoint + self.step = 0 + if load_dir is not None and step is not None: + self.load(load_dir, step) + elif finetune_ckpt is not None: + self.finetune_from(finetune_ckpt) + + if self.is_master: + os.makedirs(os.path.join(self.output_dir, 'ckpts'), exist_ok=True) + os.makedirs(os.path.join(self.output_dir, 'samples'), exist_ok=True) + self.writer = SummaryWriter(os.path.join(self.output_dir, 'tb_logs')) + + if self.parallel_mode == 'ddp' and self.world_size > 1: + self.check_ddp() + + if self.is_master: + print('\n\nTrainer initialized.') + print(self) + + def __str__(self): + lines = [] + lines.append(self.__class__.__name__) + lines.append(f' - Models:') + for name, model in self.models.items(): + lines.append(f' - {name}: {model.__class__.__name__}') + lines.append(f' - Dataset: {indent(str(self.dataset), 2)}') + lines.append(f' - Dataloader:') + lines.append(f' - Sampler: {self.dataloader.sampler.__class__.__name__}') + lines.append(f' - Num workers: {self.dataloader.num_workers}') + lines.append(f' - Number of steps: {self.max_steps}') + lines.append(f' - Number of GPUs: {self.world_size}') + lines.append(f' - Batch size: {self.batch_size}') + lines.append(f' - Batch size per GPU: {self.batch_size_per_gpu}') + lines.append(f' - Batch split: {self.batch_split}') + lines.append(f' - Optimizer: {self.optimizer.__class__.__name__}') + lines.append(f' - Learning rate: {self.optimizer.param_groups[0]["lr"]}') + if self.lr_scheduler_config is not None: + lines.append(f' - LR scheduler: {self.lr_scheduler.__class__.__name__}') + if self.elastic_controller_config is not None: + lines.append(f' - Elastic memory: {indent(str(self.elastic_controller), 2)}') + if self.grad_clip is not None: + lines.append(f' - Gradient clip: {indent(str(self.grad_clip), 2)}') + lines.append(f' - EMA rate: {self.ema_rate}') + lines.append(f' - Mixed precision dtype: {self.mix_precision_dtype}') + lines.append(f' - Mixed precision mode: {self.mix_precision_mode}') + if self.mix_precision_mode == 'amp' and self.mix_precision_dtype == torch.float16: + lines.append(f' - FP16 scale growth: {self.fp16_scale_growth}') + lines.append(f' - Parallel mode: {self.parallel_mode}') + return '\n'.join(lines) + + @property + def device(self): + for _, model in self.models.items(): + if hasattr(model, 'device'): + return model.device + return next(list(self.models.values())[0].parameters()).device + + def init_models_and_more(self, **kwargs): + """ + Initialize models and more. + """ + if self.world_size > 1: + # Prepare distributed data parallel + self.training_models = { + name: DDP( + model, + device_ids=[self.local_rank], + output_device=self.local_rank, + bucket_cap_mb=128, + find_unused_parameters=False + ) + for name, model in self.models.items() + } + else: + self.training_models = self.models + + # Build master params + self.model_params = sum( + [[p for p in model.parameters() if p.requires_grad] for model in self.models.values()] + , []) + if self.mix_precision_mode == 'amp': + self.master_params = self.model_params + if self.mix_precision_dtype == torch.float16: + self.scaler = torch.GradScaler() + elif self.mix_precision_mode == 'inflat_all': + self.master_params = make_master_params(self.model_params) + if self.mix_precision_dtype == torch.float16: + self.log_scale = 20.0 + elif self.mix_precision_mode is None: + self.master_params = self.model_params + else: + raise NotImplementedError(f'Mix precision mode {self.mix_precision_mode} is not implemented.') + + # Build EMA params + if self.is_master: + self.ema_params = [copy.deepcopy(self.master_params) for _ in self.ema_rate] + + # Initialize optimizer + if hasattr(torch.optim, self.optimizer_config['name']): + self.optimizer = getattr(torch.optim, self.optimizer_config['name'])(self.master_params, **self.optimizer_config['args']) + else: + self.optimizer = globals()[self.optimizer_config['name']](self.master_params, **self.optimizer_config['args']) + + # Initalize learning rate scheduler + if self.lr_scheduler_config is not None: + if hasattr(torch.optim.lr_scheduler, self.lr_scheduler_config['name']): + self.lr_scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_config['name'])(self.optimizer, **self.lr_scheduler_config['args']) + else: + self.lr_scheduler = globals()[self.lr_scheduler_config['name']](self.optimizer, **self.lr_scheduler_config['args']) + + # Initialize elastic memory controller + if self.elastic_controller_config is not None: + assert any([isinstance(model, (elastic_utils.ElasticModule, elastic_utils.ElasticModuleMixin)) for model in self.models.values()]), \ + 'No elastic module found in models, please inherit from ElasticModule or ElasticModuleMixin' + self.elastic_controller = getattr(elastic_utils, self.elastic_controller_config['name'])(**self.elastic_controller_config['args']) + for model in self.models.values(): + if isinstance(model, (elastic_utils.ElasticModule, elastic_utils.ElasticModuleMixin)): + model.register_memory_controller(self.elastic_controller) + + # Initialize gradient clipper + if self.grad_clip is not None: + if isinstance(self.grad_clip, (float, int)): + self.grad_clip = float(self.grad_clip) + else: + self.grad_clip = getattr(grad_clip_utils, self.grad_clip['name'])(**self.grad_clip['args']) + + def prepare_dataloader(self, **kwargs): + """ + Prepare dataloader. + """ + self.data_sampler = ResumableSampler( + self.dataset, + shuffle=True, + ) + self.dataloader = DataLoader( + self.dataset, + batch_size=self.batch_size_per_gpu, + num_workers=int(np.ceil(os.cpu_count() / torch.cuda.device_count())), + pin_memory=True, + drop_last=True, + persistent_workers=True, + collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None, + sampler=self.data_sampler, + ) + self.data_iterator = cycle(self.dataloader) + + def _master_params_to_state_dicts(self, master_params): + """ + Convert master params to dict of state_dicts. + """ + if self.mix_precision_mode == 'inflat_all': + master_params = unflatten_master_params(self.model_params, master_params) + state_dicts = {name: model.state_dict() for name, model in self.models.items()} + master_params_names = sum( + [[(name, n) for n, p in model.named_parameters() if p.requires_grad] for name, model in self.models.items()] + , []) + for i, (model_name, param_name) in enumerate(master_params_names): + state_dicts[model_name][param_name] = master_params[i] + return state_dicts + + def _state_dicts_to_master_params(self, master_params, state_dicts): + """ + Convert a state_dict to master params. + """ + master_params_names = sum( + [[(name, n) for n, p in model.named_parameters() if p.requires_grad] for name, model in self.models.items()] + , []) + params = [state_dicts[name][param_name] for name, param_name in master_params_names] + if self.mix_precision_mode == 'inflat_all': + model_params_to_master_params(params, master_params) + else: + for i, param in enumerate(params): + master_params[i].data.copy_(param.data) + + def load(self, load_dir, step=0): + """ + Load a checkpoint. + Should be called by all processes. + """ + if self.is_master: + print(f'\nLoading checkpoint from step {step}...', end='') + + model_ckpts = {} + for name, model in self.models.items(): + model_ckpt = torch.load(read_file_dist(os.path.join(load_dir, 'ckpts', f'{name}_step{step:07d}.pt')), map_location=self.device, weights_only=True) + model_ckpts[name] = model_ckpt + model.load_state_dict(model_ckpt) + self._state_dicts_to_master_params(self.master_params, model_ckpts) + del model_ckpts + + if self.is_master: + for i, ema_rate in enumerate(self.ema_rate): + ema_ckpts = {} + for name, model in self.models.items(): + ema_ckpt = torch.load(os.path.join(load_dir, 'ckpts', f'{name}_ema{ema_rate}_step{step:07d}.pt'), map_location=self.device, weights_only=True) + ema_ckpts[name] = ema_ckpt + self._state_dicts_to_master_params(self.ema_params[i], ema_ckpts) + del ema_ckpts + + misc_ckpt = torch.load(read_file_dist(os.path.join(load_dir, 'ckpts', f'misc_step{step:07d}.pt')), map_location=torch.device('cpu'), weights_only=False) + self.optimizer.load_state_dict(misc_ckpt['optimizer']) + self.step = misc_ckpt['step'] + self.data_sampler.load_state_dict(misc_ckpt['data_sampler']) + if self.mix_precision_mode == 'amp' and self.mix_precision_dtype == torch.float16: + self.scaler.load_state_dict(misc_ckpt['scaler']) + elif self.mix_precision_mode == 'inflat_all' and self.mix_precision_dtype == torch.float16: + self.log_scale = misc_ckpt['log_scale'] + if self.lr_scheduler_config is not None: + self.lr_scheduler.load_state_dict(misc_ckpt['lr_scheduler']) + if self.elastic_controller_config is not None: + self.elastic_controller.load_state_dict(misc_ckpt['elastic_controller']) + if self.grad_clip is not None and not isinstance(self.grad_clip, float): + self.grad_clip.load_state_dict(misc_ckpt['grad_clip']) + del misc_ckpt + + if self.world_size > 1: + dist.barrier() + if self.is_master: + print(' Done.') + + if self.world_size > 1: + self.check_ddp() + + def save(self, non_blocking=True): + """ + Save a checkpoint. + Should be called only by the rank 0 process. + """ + assert self.is_master, 'save() should be called only by the rank 0 process.' + print(f'\nSaving checkpoint at step {self.step}...', end='') + + model_ckpts = self._master_params_to_state_dicts(self.master_params) + for name, model_ckpt in model_ckpts.items(): + model_ckpt = {k: v.cpu() for k, v in model_ckpt.items()} # Move to CPU for saving + if non_blocking: + threading.Thread( + target=torch.save, + args=(model_ckpt, os.path.join(self.output_dir, 'ckpts', f'{name}_step{self.step:07d}.pt')), + ).start() + else: + torch.save(model_ckpt, os.path.join(self.output_dir, 'ckpts', f'{name}_step{self.step:07d}.pt')) + + for i, ema_rate in enumerate(self.ema_rate): + ema_ckpts = self._master_params_to_state_dicts(self.ema_params[i]) + for name, ema_ckpt in ema_ckpts.items(): + ema_ckpt = {k: v.cpu() for k, v in ema_ckpt.items()} # Move to CPU for saving + if non_blocking: + threading.Thread( + target=torch.save, + args=(ema_ckpt, os.path.join(self.output_dir, 'ckpts', f'{name}_ema{ema_rate}_step{self.step:07d}.pt')), + ).start() + else: + torch.save(ema_ckpt, os.path.join(self.output_dir, 'ckpts', f'{name}_ema{ema_rate}_step{self.step:07d}.pt')) + + misc_ckpt = { + 'optimizer': self.optimizer.state_dict(), + 'step': self.step, + 'data_sampler': self.data_sampler.state_dict(), + } + if self.mix_precision_mode == 'amp' and self.mix_precision_dtype == torch.float16: + misc_ckpt['scaler'] = self.scaler.state_dict() + elif self.mix_precision_mode == 'inflat_all' and self.mix_precision_dtype == torch.float16: + misc_ckpt['log_scale'] = self.log_scale + if self.lr_scheduler_config is not None: + misc_ckpt['lr_scheduler'] = self.lr_scheduler.state_dict() + if self.elastic_controller_config is not None: + misc_ckpt['elastic_controller'] = self.elastic_controller.state_dict() + if self.grad_clip is not None and not isinstance(self.grad_clip, float): + misc_ckpt['grad_clip'] = self.grad_clip.state_dict() + if non_blocking: + threading.Thread( + target=torch.save, + args=(misc_ckpt, os.path.join(self.output_dir, 'ckpts', f'misc_step{self.step:07d}.pt')), + ).start() + else: + torch.save(misc_ckpt, os.path.join(self.output_dir, 'ckpts', f'misc_step{self.step:07d}.pt')) + print(' Done.') + + def finetune_from(self, finetune_ckpt): + """ + Finetune from a checkpoint. + Should be called by all processes. + """ + if self.is_master: + print('\nFinetuning from:') + for name, path in finetune_ckpt.items(): + print(f' - {name}: {path}') + + model_ckpts = {} + for name, model in self.models.items(): + model_state_dict = model.state_dict() + if name in finetune_ckpt: + model_ckpt = torch.load(read_file_dist(finetune_ckpt[name]), map_location=self.device, weights_only=True) + for k, v in model_ckpt.items(): + if k not in model_state_dict: + if self.is_master: + print(f'Warning: {k} not found in model_state_dict, skipped.') + model_ckpt[k] = None + elif model_ckpt[k].shape != model_state_dict[k].shape: + if self.is_master: + print(f'Warning: {k} shape mismatch, {model_ckpt[k].shape} vs {model_state_dict[k].shape}, skipped.') + model_ckpt[k] = model_state_dict[k] + model_ckpt = {k: v for k, v in model_ckpt.items() if v is not None} + model_ckpts[name] = model_ckpt + model.load_state_dict(model_ckpt) + else: + if self.is_master: + print(f'Warning: {name} not found in finetune_ckpt, skipped.') + model_ckpts[name] = model_state_dict + self._state_dicts_to_master_params(self.master_params, model_ckpts) + if self.is_master: + for i, ema_rate in enumerate(self.ema_rate): + self._state_dicts_to_master_params(self.ema_params[i], model_ckpts) + del model_ckpts + + if self.world_size > 1: + dist.barrier() + if self.is_master: + print('Done.') + + if self.world_size > 1: + self.check_ddp() + + @abstractmethod + def run_snapshot(self, num_samples, batch_size=4, verbose=False, **kwargs): + """ + Run a snapshot of the model. + """ + pass + + @torch.no_grad() + def visualize_sample(self, sample): + """ + Convert a sample to an image. + """ + if hasattr(self.dataset, 'visualize_sample'): + return self.dataset.visualize_sample(sample) + else: + return sample + + @torch.no_grad() + def snapshot_dataset(self, num_samples=100, batch_size=4): + """ + Sample images from the dataset. + """ + dataloader = torch.utils.data.DataLoader( + self.dataset, + batch_size=batch_size, + num_workers=1, + shuffle=True, + collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None, + ) + save_cfg = {} + for i in range(0, num_samples, batch_size): + data = next(iter(dataloader)) + data = {k: v[:min(num_samples - i, batch_size)] for k, v in data.items()} + data = recursive_to_device(data, self.device) + vis = self.visualize_sample(data) + if isinstance(vis, dict): + for k, v in vis.items(): + if f'dataset_{k}' not in save_cfg: + save_cfg[f'dataset_{k}'] = [] + save_cfg[f'dataset_{k}'].append(v) + else: + if 'dataset' not in save_cfg: + save_cfg['dataset'] = [] + save_cfg['dataset'].append(vis) + for name, image in save_cfg.items(): + utils.save_image( + torch.cat(image, dim=0), + os.path.join(self.output_dir, 'samples', f'{name}.jpg'), + nrow=int(np.sqrt(num_samples)), + normalize=True, + value_range=self.dataset.value_range, + ) + + @torch.no_grad() + def snapshot(self, suffix=None, num_samples=64, batch_size=4, verbose=False): + """ + Sample images from the model. + NOTE: This function should be called by all processes. + """ + if self.is_master: + print(f'\nSampling {num_samples} images...', end='') + + if suffix is None: + suffix = f'step{self.step:07d}' + + # Assign tasks + num_samples_per_process = int(np.ceil(num_samples / self.world_size)) + amp_context = partial(torch.autocast, device_type='cuda', dtype=self.mix_precision_dtype) if self.mix_precision_mode == 'amp' else nullcontext + with amp_context(): + samples = self.run_snapshot(num_samples_per_process, batch_size=batch_size, verbose=verbose) + + # Preprocess images + for key in list(samples.keys()): + if samples[key]['type'] == 'sample': + vis = self.visualize_sample(samples[key]['value']) + if isinstance(vis, dict): + for k, v in vis.items(): + samples[f'{key}_{k}'] = {'value': v, 'type': 'image'} + del samples[key] + else: + samples[key] = {'value': vis, 'type': 'image'} + + # Gather results + if self.world_size > 1: + for key in samples.keys(): + samples[key]['value'] = samples[key]['value'].contiguous() + if self.is_master: + all_images = [torch.empty_like(samples[key]['value']) for _ in range(self.world_size)] + else: + all_images = [] + dist.gather(samples[key]['value'], all_images, dst=0) + if self.is_master: + samples[key]['value'] = torch.cat(all_images, dim=0)[:num_samples] + + # Save images + if self.is_master: + os.makedirs(os.path.join(self.output_dir, 'samples', suffix), exist_ok=True) + for key in samples.keys(): + if samples[key]['type'] == 'image': + utils.save_image( + samples[key]['value'], + os.path.join(self.output_dir, 'samples', suffix, f'{key}_{suffix}.jpg'), + nrow=int(np.sqrt(num_samples)), + normalize=True, + value_range=self.dataset.value_range, + ) + elif samples[key]['type'] == 'number': + min = samples[key]['value'].min() + max = samples[key]['value'].max() + images = (samples[key]['value'] - min) / (max - min) + images = utils.make_grid( + images, + nrow=int(np.sqrt(num_samples)), + normalize=False, + ) + save_image_with_notes( + images, + os.path.join(self.output_dir, 'samples', suffix, f'{key}_{suffix}.jpg'), + notes=f'{key} min: {min}, max: {max}', + ) + + if self.is_master: + print(' Done.') + + def update_ema(self): + """ + Update exponential moving average. + Should only be called by the rank 0 process. + """ + assert self.is_master, 'update_ema() should be called only by the rank 0 process.' + for i, ema_rate in enumerate(self.ema_rate): + for master_param, ema_param in zip(self.master_params, self.ema_params[i]): + ema_param.detach().mul_(ema_rate).add_(master_param, alpha=1.0 - ema_rate) + + def check_ddp(self): + """ + Check if DDP is working properly. + Should be called by all process. + """ + if self.is_master: + print('\nPerforming DDP check...') + + if self.is_master: + print('Checking if parameters are consistent across processes...') + dist.barrier() + try: + for p in self.master_params: + # split to avoid OOM + for i in range(0, p.numel(), 10000000): + sub_size = min(10000000, p.numel() - i) + sub_p = p.detach().view(-1)[i:i+sub_size] + # gather from all processes + sub_p_gather = [torch.empty_like(sub_p) for _ in range(self.world_size)] + dist.all_gather(sub_p_gather, sub_p) + # check if equal + assert all([torch.equal(sub_p, sub_p_gather[i]) for i in range(self.world_size)]), 'parameters are not consistent across processes' + except AssertionError as e: + if self.is_master: + print(f'\n\033[91mError: {e}\033[0m') + print('DDP check failed.') + raise e + + dist.barrier() + if self.is_master: + print('Done.') + + @abstractmethod + def training_losses(**mb_data): + """ + Compute training losses. + """ + pass + + def load_data(self): + """ + Load data. + """ + if self.prefetch_data: + if self._data_prefetched is None: + self._data_prefetched = recursive_to_device(next(self.data_iterator), self.device, non_blocking=True) + data = self._data_prefetched + self._data_prefetched = recursive_to_device(next(self.data_iterator), self.device, non_blocking=True) + else: + data = recursive_to_device(next(self.data_iterator), self.device, non_blocking=True) + + # if the data is a dict, we need to split it into multiple dicts with batch_size_per_gpu + if isinstance(data, dict): + if self.batch_split == 1: + data_list = [data] + else: + batch_size = list(data.values())[0].shape[0] + data_list = [ + {k: v[i * batch_size // self.batch_split:(i + 1) * batch_size // self.batch_split] for k, v in data.items()} + for i in range(self.batch_split) + ] + elif isinstance(data, list): + data_list = data + else: + raise ValueError('Data must be a dict or a list of dicts.') + + return data_list + + def run_step(self, data_list): + """ + Run a training step. + """ + step_log = {'loss': {}, 'status': {}} + amp_context = partial(torch.autocast, device_type='cuda', dtype=self.mix_precision_dtype) if self.mix_precision_mode == 'amp' else nullcontext + elastic_controller_context = self.elastic_controller.record if self.elastic_controller_config is not None else nullcontext + + # Train + losses = [] + statuses = [] + elastic_controller_logs = [] + zero_grad(self.model_params) + for i, mb_data in enumerate(data_list): + ## sync at the end of each batch split + sync_contexts = [self.training_models[name].no_sync for name in self.training_models] if i != len(data_list) - 1 and self.world_size > 1 else [nullcontext] + with nested_contexts(*sync_contexts), elastic_controller_context(): + with amp_context(): + loss, status = self.training_losses(**mb_data) + l = loss['loss'] / len(data_list) + ## backward + if self.mix_precision_mode == 'amp' and self.mix_precision_dtype == torch.float16: + self.scaler.scale(l).backward() + elif self.mix_precision_mode == 'inflat_all' and self.mix_precision_dtype == torch.float16: + scaled_l = l * (2 ** self.log_scale) + scaled_l.backward() + else: + l.backward() + ## log + losses.append(dict_foreach(loss, lambda x: x.item() if isinstance(x, torch.Tensor) else x)) + statuses.append(dict_foreach(status, lambda x: x.item() if isinstance(x, torch.Tensor) else x)) + if self.elastic_controller_config is not None: + elastic_controller_logs.append(self.elastic_controller.log()) + ## gradient clip + if self.grad_clip is not None: + if self.mix_precision_mode == 'amp' and self.mix_precision_dtype == torch.float16: + self.scaler.unscale_(self.optimizer) + elif self.mix_precision_mode == 'inflat_all': + model_grads_to_master_grads(self.model_params, self.master_params) + if self.mix_precision_dtype == torch.float16: + self.master_params[0].grad.mul_(1.0 / (2 ** self.log_scale)) + if isinstance(self.grad_clip, float): + grad_norm = torch.nn.utils.clip_grad_norm_(self.master_params, self.grad_clip) + else: + grad_norm = self.grad_clip(self.master_params) + if torch.isfinite(grad_norm): + statuses[-1]['grad_norm'] = grad_norm.item() + ## step + if self.mix_precision_mode == 'amp' and self.mix_precision_dtype == torch.float16: + prev_scale = self.scaler.get_scale() + self.scaler.step(self.optimizer) + self.scaler.update() + elif self.mix_precision_mode == 'inflat_all': + if self.mix_precision_dtype == torch.float16: + prev_scale = 2 ** self.log_scale + if not any(not p.grad.isfinite().all() for p in self.model_params): + if self.grad_clip is None: + model_grads_to_master_grads(self.model_params, self.master_params) + self.master_params[0].grad.mul_(1.0 / (2 ** self.log_scale)) + self.optimizer.step() + master_params_to_model_params(self.model_params, self.master_params) + self.log_scale += self.fp16_scale_growth + else: + self.log_scale -= 1 + else: + prev_scale = 1.0 + if self.grad_clip is None: + model_grads_to_master_grads(self.model_params, self.master_params) + if not any(not p.grad.isfinite().all() for p in self.master_params): + self.optimizer.step() + master_params_to_model_params(self.model_params, self.master_params) + else: + print('\n\033[93mWarning: NaN detected in gradients. Skipping update.\033[0m') + else: + prev_scale = 1.0 + if not any(not p.grad.isfinite().all() for p in self.model_params): + self.optimizer.step() + else: + print('\n\033[93mWarning: NaN detected in gradients. Skipping update.\033[0m') + ## adjust learning rate + if self.lr_scheduler_config is not None: + statuses[-1]['lr'] = self.lr_scheduler.get_last_lr()[0] + self.lr_scheduler.step() + + # Logs + step_log['loss'] = dict_reduce(losses, lambda x: np.mean(x)) + step_log['status'] = dict_reduce(statuses, lambda x: np.mean(x), special_func={'min': lambda x: np.min(x), 'max': lambda x: np.max(x)}) + if self.elastic_controller_config is not None: + step_log['elastic'] = dict_reduce(elastic_controller_logs, lambda x: np.mean(x)) + if self.grad_clip is not None: + step_log['grad_clip'] = self.grad_clip if isinstance(self.grad_clip, float) else self.grad_clip.log() + + # Check grad and norm of each param + if self.log_param_stats: + param_norms = {} + param_grads = {} + for model_name, model in self.models.items(): + for name, param in model.named_parameters(): + if param.requires_grad: + param_norms[f'{model_name}.{name}'] = param.norm().item() + if param.grad is not None and torch.isfinite(param.grad).all(): + param_grads[f'{model_name}.{name}'] = param.grad.norm().item() / prev_scale + step_log['param_norms'] = param_norms + step_log['param_grads'] = param_grads + + # Update exponential moving average + if self.is_master: + self.update_ema() + + return step_log + + def save_logs(self): + log_str = '\n'.join([ + f'{step}: {json.dumps(dict_foreach(log, lambda x: float(x)))}' for step, log in self.log + ]) + with open(os.path.join(self.output_dir, 'log.txt'), 'a') as log_file: + log_file.write(log_str + '\n') + + # show with mlflow + log_show = [l for _, l in self.log if not dict_any(l, lambda x: np.isnan(x))] + log_show = dict_reduce(log_show, lambda x: np.mean(x)) + log_show = dict_flatten(log_show, sep='/') + for key, value in log_show.items(): + self.writer.add_scalar(key, value, self.step) + self.log = [] + + def check_abort(self): + """ + Check if training should be aborted due to certain conditions. + """ + # 1. If log_scale in inflat_all mode is less than 0 + if self.mix_precision_dtype == torch.float16 and \ + self.mix_precision_mode == 'inflat_all' and \ + self.log_scale < 0: + if self.is_master: + print ('\n\n\033[91m') + print (f'ABORT: log_scale in inflat_all mode is less than 0 at step {self.step}.') + print ('This indicates that the model is diverging. You should look into the model and the data.') + print ('\033[0m') + self.save(non_blocking=False) + self.save_logs() + if self.world_size > 1: + dist.barrier() + raise ValueError('ABORT: log_scale in inflat_all mode is less than 0.') + + def run(self): + """ + Run training. + """ + if self.is_master: + print('\nStarting training...') + self.snapshot_dataset(batch_size=self.snapshot_batch_size) + if self.step == 0: + self.snapshot(suffix='init', batch_size=self.snapshot_batch_size) + else: # resume + self.snapshot(suffix=f'resume_step{self.step:07d}', batch_size=self.snapshot_batch_size) + + time_last_print = 0.0 + time_elapsed = 0.0 + while self.step < self.max_steps: + time_start = time.time() + + data_list = self.load_data() + step_log = self.run_step(data_list) + + time_end = time.time() + time_elapsed += time_end - time_start + + self.step += 1 + + # Print progress + if self.is_master and self.step % self.i_print == 0: + speed = self.i_print / (time_elapsed - time_last_print) * 3600 + columns = [ + f'Step: {self.step}/{self.max_steps} ({self.step / self.max_steps * 100:.2f}%)', + f'Elapsed: {time_elapsed / 3600:.2f} h', + f'Speed: {speed:.2f} steps/h', + f'ETA: {(self.max_steps - self.step) / speed:.2f} h', + ] + print(' | '.join([c.ljust(25) for c in columns]), flush=True) + time_last_print = time_elapsed + + # Check ddp + if self.parallel_mode == 'ddp' and self.world_size > 1 and self.i_ddpcheck is not None and self.step % self.i_ddpcheck == 0: + self.check_ddp() + + # Sample images + if self.step % self.i_sample == 0: + self.snapshot() + + if self.is_master: + self.log.append((self.step, {})) + + # Log time + self.log[-1][1]['time'] = { + 'step': time_end - time_start, + 'elapsed': time_elapsed, + } + + # Log losses + if step_log is not None: + self.log[-1][1].update(step_log) + + # Log scale + if self.mix_precision_dtype == torch.float16: + if self.mix_precision_mode == 'amp': + self.log[-1][1]['scale'] = self.scaler.get_scale() + elif self.mix_precision_mode == 'inflat_all': + self.log[-1][1]['log_scale'] = self.log_scale + + # Save log + if self.step % self.i_log == 0: + self.save_logs() + + # Save checkpoint + if self.step % self.i_save == 0: + self.save() + + # Check abort + self.check_abort() + + self.snapshot(suffix='final', batch_size=self.snapshot_batch_size) + if self.world_size > 1: + dist.barrier() + if self.is_master: + self.writer.close() + print('Training finished.') + + def profile(self, wait=2, warmup=3, active=5): + """ + Profile the training loop. + """ + with torch.profiler.profile( + schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=1), + on_trace_ready=torch.profiler.tensorboard_trace_handler(os.path.join(self.output_dir, 'profile')), + profile_memory=True, + with_stack=True, + ) as prof: + for _ in range(wait + warmup + active): + self.run_step() + prof.step() diff --git a/trellis2/trainers/flow_matching/flow_matching.py b/trellis2/trainers/flow_matching/flow_matching.py new file mode 100644 index 0000000..7b4390e --- /dev/null +++ b/trellis2/trainers/flow_matching/flow_matching.py @@ -0,0 +1,353 @@ +from typing import * +import copy +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader +import numpy as np +from easydict import EasyDict as edict + +from ..basic import BasicTrainer +from ...pipelines import samplers +from ...utils.general_utils import dict_reduce +from .mixins.classifier_free_guidance import ClassifierFreeGuidanceMixin +from .mixins.text_conditioned import TextConditionedMixin +from .mixins.image_conditioned import ImageConditionedMixin + + +class FlowMatchingTrainer(BasicTrainer): + """ + Trainer for diffusion model with flow matching objective. + + Args: + models (dict[str, nn.Module]): Models to train. + dataset (torch.utils.data.Dataset): Dataset. + output_dir (str): Output directory. + load_dir (str): Load directory. + step (int): Step to load. + batch_size (int): Batch size. + batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. + batch_split (int): Split batch with gradient accumulation. + max_steps (int): Max steps. + optimizer (dict): Optimizer config. + lr_scheduler (dict): Learning rate scheduler config. + elastic (dict): Elastic memory management config. + grad_clip (float or dict): Gradient clip config. + ema_rate (float or list): Exponential moving average rates. + fp16_mode (str): FP16 mode. + - None: No FP16. + - 'inflat_all': Hold a inflated fp32 master param for all params. + - 'amp': Automatic mixed precision. + fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. + finetune_ckpt (dict): Finetune checkpoint. + log_param_stats (bool): Log parameter stats. + i_print (int): Print interval. + i_log (int): Log interval. + i_sample (int): Sample interval. + i_save (int): Save interval. + i_ddpcheck (int): DDP check interval. + + t_schedule (dict): Time schedule for flow matching. + sigma_min (float): Minimum noise level. + """ + def __init__( + self, + *args, + t_schedule: dict = { + 'name': 'logitNormal', + 'args': { + 'mean': 0.0, + 'std': 1.0, + } + }, + sigma_min: float = 1e-5, + **kwargs + ): + super().__init__(*args, **kwargs) + self.t_schedule = t_schedule + self.sigma_min = sigma_min + + def diffuse(self, x_0: torch.Tensor, t: torch.Tensor, noise: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Diffuse the data for a given number of diffusion steps. + In other words, sample from q(x_t | x_0). + + Args: + x_0: The [N x C x ...] tensor of noiseless inputs. + t: The [N] tensor of diffusion steps [0-1]. + noise: If specified, use this noise instead of generating new noise. + + Returns: + x_t, the noisy version of x_0 under timestep t. + """ + if noise is None: + noise = torch.randn_like(x_0) + assert noise.shape == x_0.shape, "noise must have same shape as x_0" + + t = t.view(-1, *[1 for _ in range(len(x_0.shape) - 1)]) + x_t = (1 - t) * x_0 + (self.sigma_min + (1 - self.sigma_min) * t) * noise + + return x_t + + def reverse_diffuse(self, x_t: torch.Tensor, t: torch.Tensor, noise: torch.Tensor) -> torch.Tensor: + """ + Get original image from noisy version under timestep t. + """ + assert noise.shape == x_t.shape, "noise must have same shape as x_t" + t = t.view(-1, *[1 for _ in range(len(x_t.shape) - 1)]) + x_0 = (x_t - (self.sigma_min + (1 - self.sigma_min) * t) * noise) / (1 - t) + return x_0 + + def get_v(self, x_0: torch.Tensor, noise: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + """ + Compute the velocity of the diffusion process at time t. + """ + return (1 - self.sigma_min) * noise - x_0 + + def get_cond(self, cond, **kwargs): + """ + Get the conditioning data. + """ + return cond + + def get_inference_cond(self, cond, **kwargs): + """ + Get the conditioning data for inference. + """ + return {'cond': cond, **kwargs} + + def get_sampler(self, **kwargs) -> samplers.FlowEulerSampler: + """ + Get the sampler for the diffusion process. + """ + return samplers.FlowEulerSampler(self.sigma_min) + + def vis_cond(self, **kwargs): + """ + Visualize the conditioning data. + """ + return {} + + def sample_t(self, batch_size: int) -> torch.Tensor: + """ + Sample timesteps. + """ + if self.t_schedule['name'] == 'uniform': + t = torch.rand(batch_size) + elif self.t_schedule['name'] == 'logitNormal': + mean = self.t_schedule['args']['mean'] + std = self.t_schedule['args']['std'] + t = torch.sigmoid(torch.randn(batch_size) * std + mean) + else: + raise ValueError(f"Unknown t_schedule: {self.t_schedule['name']}") + return t + + def training_losses( + self, + x_0: torch.Tensor, + cond=None, + **kwargs + ) -> Tuple[Dict, Dict]: + """ + Compute training losses for a single timestep. + + Args: + x_0: The [N x C x ...] tensor of noiseless inputs. + cond: The [N x ...] tensor of additional conditions. + kwargs: Additional arguments to pass to the backbone. + + Returns: + a dict with the key "loss" containing a tensor of shape [N]. + may also contain other keys for different terms. + """ + noise = torch.randn_like(x_0) + t = self.sample_t(x_0.shape[0]).to(x_0.device).float() + x_t = self.diffuse(x_0, t, noise=noise) + cond = self.get_cond(cond, **kwargs) + + pred = self.training_models['denoiser'](x_t, t * 1000, cond, **kwargs) + assert pred.shape == noise.shape == x_0.shape + target = self.get_v(x_0, noise, t) + terms = edict() + terms["mse"] = F.mse_loss(pred, target) + terms["loss"] = terms["mse"] + + # log loss with time bins + mse_per_instance = np.array([ + F.mse_loss(pred[i], target[i]).item() + for i in range(x_0.shape[0]) + ]) + time_bin = np.digitize(t.cpu().numpy(), np.linspace(0, 1, 11)) - 1 + for i in range(10): + if (time_bin == i).sum() != 0: + terms[f"bin_{i}"] = {"mse": mse_per_instance[time_bin == i].mean()} + + return terms, {} + + @torch.no_grad() + def run_snapshot( + self, + num_samples: int, + batch_size: int, + verbose: bool = False, + ) -> Dict: + dataloader = DataLoader( + copy.deepcopy(self.dataset), + batch_size=batch_size, + shuffle=True, + num_workers=0, + collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None, + ) + + # inference + sampler = self.get_sampler() + sample_gt = [] + sample = [] + cond_vis = [] + for i in range(0, num_samples, batch_size): + batch = min(batch_size, num_samples - i) + data = next(iter(dataloader)) + data = {k: v[:batch].cuda() if isinstance(v, torch.Tensor) else v[:batch] for k, v in data.items()} + noise = torch.randn_like(data['x_0']) + sample_gt.append(data['x_0']) + cond_vis.append(self.vis_cond(**data)) + del data['x_0'] + args = self.get_inference_cond(**data) + res = sampler.sample( + self.models['denoiser'], + noise=noise, + **args, + steps=50, guidance_strength=3.0, verbose=verbose, + ) + sample.append(res.samples) + + sample_gt = torch.cat(sample_gt, dim=0) + sample = torch.cat(sample, dim=0) + sample_dict = { + 'sample_gt': {'value': sample_gt, 'type': 'sample'}, + 'sample': {'value': sample, 'type': 'sample'}, + } + sample_dict.update(dict_reduce(cond_vis, None, { + 'value': lambda x: torch.cat(x, dim=0), + 'type': lambda x: x[0], + })) + + return sample_dict + + +class FlowMatchingCFGTrainer(ClassifierFreeGuidanceMixin, FlowMatchingTrainer): + """ + Trainer for diffusion model with flow matching objective and classifier-free guidance. + + Args: + models (dict[str, nn.Module]): Models to train. + dataset (torch.utils.data.Dataset): Dataset. + output_dir (str): Output directory. + load_dir (str): Load directory. + step (int): Step to load. + batch_size (int): Batch size. + batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. + batch_split (int): Split batch with gradient accumulation. + max_steps (int): Max steps. + optimizer (dict): Optimizer config. + lr_scheduler (dict): Learning rate scheduler config. + elastic (dict): Elastic memory management config. + grad_clip (float or dict): Gradient clip config. + ema_rate (float or list): Exponential moving average rates. + fp16_mode (str): FP16 mode. + - None: No FP16. + - 'inflat_all': Hold a inflated fp32 master param for all params. + - 'amp': Automatic mixed precision. + fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. + finetune_ckpt (dict): Finetune checkpoint. + log_param_stats (bool): Log parameter stats. + i_print (int): Print interval. + i_log (int): Log interval. + i_sample (int): Sample interval. + i_save (int): Save interval. + i_ddpcheck (int): DDP check interval. + + t_schedule (dict): Time schedule for flow matching. + sigma_min (float): Minimum noise level. + p_uncond (float): Probability of dropping conditions. + """ + pass + + +class TextConditionedFlowMatchingCFGTrainer(TextConditionedMixin, FlowMatchingCFGTrainer): + """ + Trainer for text-conditioned diffusion model with flow matching objective and classifier-free guidance. + + Args: + models (dict[str, nn.Module]): Models to train. + dataset (torch.utils.data.Dataset): Dataset. + output_dir (str): Output directory. + load_dir (str): Load directory. + step (int): Step to load. + batch_size (int): Batch size. + batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. + batch_split (int): Split batch with gradient accumulation. + max_steps (int): Max steps. + optimizer (dict): Optimizer config. + lr_scheduler (dict): Learning rate scheduler config. + elastic (dict): Elastic memory management config. + grad_clip (float or dict): Gradient clip config. + ema_rate (float or list): Exponential moving average rates. + fp16_mode (str): FP16 mode. + - None: No FP16. + - 'inflat_all': Hold a inflated fp32 master param for all params. + - 'amp': Automatic mixed precision. + fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. + finetune_ckpt (dict): Finetune checkpoint. + log_param_stats (bool): Log parameter stats. + i_print (int): Print interval. + i_log (int): Log interval. + i_sample (int): Sample interval. + i_save (int): Save interval. + i_ddpcheck (int): DDP check interval. + + t_schedule (dict): Time schedule for flow matching. + sigma_min (float): Minimum noise level. + p_uncond (float): Probability of dropping conditions. + text_cond_model(str): Text conditioning model. + """ + pass + + +class ImageConditionedFlowMatchingCFGTrainer(ImageConditionedMixin, FlowMatchingCFGTrainer): + """ + Trainer for image-conditioned diffusion model with flow matching objective and classifier-free guidance. + + Args: + models (dict[str, nn.Module]): Models to train. + dataset (torch.utils.data.Dataset): Dataset. + output_dir (str): Output directory. + load_dir (str): Load directory. + step (int): Step to load. + batch_size (int): Batch size. + batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. + batch_split (int): Split batch with gradient accumulation. + max_steps (int): Max steps. + optimizer (dict): Optimizer config. + lr_scheduler (dict): Learning rate scheduler config. + elastic (dict): Elastic memory management config. + grad_clip (float or dict): Gradient clip config. + ema_rate (float or list): Exponential moving average rates. + fp16_mode (str): FP16 mode. + - None: No FP16. + - 'inflat_all': Hold a inflated fp32 master param for all params. + - 'amp': Automatic mixed precision. + fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. + finetune_ckpt (dict): Finetune checkpoint. + log_param_stats (bool): Log parameter stats. + i_print (int): Print interval. + i_log (int): Log interval. + i_sample (int): Sample interval. + i_save (int): Save interval. + i_ddpcheck (int): DDP check interval. + + t_schedule (dict): Time schedule for flow matching. + sigma_min (float): Minimum noise level. + p_uncond (float): Probability of dropping conditions. + image_cond_model (str): Image conditioning model. + """ + pass diff --git a/trellis2/trainers/flow_matching/mixins/classifier_free_guidance.py b/trellis2/trainers/flow_matching/mixins/classifier_free_guidance.py new file mode 100644 index 0000000..548e007 --- /dev/null +++ b/trellis2/trainers/flow_matching/mixins/classifier_free_guidance.py @@ -0,0 +1,59 @@ +import torch +import numpy as np +from ....utils.general_utils import dict_foreach +from ....pipelines import samplers + + +class ClassifierFreeGuidanceMixin: + def __init__(self, *args, p_uncond: float = 0.1, **kwargs): + super().__init__(*args, **kwargs) + self.p_uncond = p_uncond + + def get_cond(self, cond, neg_cond=None, **kwargs): + """ + Get the conditioning data. + """ + assert neg_cond is not None, "neg_cond must be provided for classifier-free guidance" + + if self.p_uncond > 0: + # randomly drop the class label + def get_batch_size(cond): + if isinstance(cond, torch.Tensor): + return cond.shape[0] + elif isinstance(cond, list): + return len(cond) + else: + raise ValueError(f"Unsupported type of cond: {type(cond)}") + + ref_cond = cond if not isinstance(cond, dict) else cond[list(cond.keys())[0]] + B = get_batch_size(ref_cond) + + def select(cond, neg_cond, mask): + if isinstance(cond, torch.Tensor): + mask = torch.tensor(mask, device=cond.device).reshape(-1, *[1] * (cond.ndim - 1)) + return torch.where(mask, neg_cond, cond) + elif isinstance(cond, list): + return [nc if m else c for c, nc, m in zip(cond, neg_cond, mask)] + else: + raise ValueError(f"Unsupported type of cond: {type(cond)}") + + mask = list(np.random.rand(B) < self.p_uncond) + if not isinstance(cond, dict): + cond = select(cond, neg_cond, mask) + else: + cond = dict_foreach([cond, neg_cond], lambda x: select(x[0], x[1], mask)) + + return cond + + def get_inference_cond(self, cond, neg_cond=None, **kwargs): + """ + Get the conditioning data for inference. + """ + assert neg_cond is not None, "neg_cond must be provided for classifier-free guidance" + return {'cond': cond, 'neg_cond': neg_cond, **kwargs} + + def get_sampler(self, **kwargs) -> samplers.FlowEulerCfgSampler: + """ + Get the sampler for the diffusion process. + """ + return samplers.FlowEulerCfgSampler(self.sigma_min) diff --git a/trellis2/trainers/flow_matching/mixins/image_conditioned.py b/trellis2/trainers/flow_matching/mixins/image_conditioned.py new file mode 100644 index 0000000..ab8da40 --- /dev/null +++ b/trellis2/trainers/flow_matching/mixins/image_conditioned.py @@ -0,0 +1,248 @@ +from typing import * +import torch +import torch.nn.functional as F +from torchvision import transforms +from transformers import DINOv3ViTModel +import numpy as np +from PIL import Image + +from ....utils import dist_utils + + +class DinoV2FeatureExtractor: + """ + Feature extractor for DINOv2 models. + """ + def __init__(self, model_name: str): + self.model_name = model_name + self.model = torch.hub.load('facebookresearch/dinov2', model_name, pretrained=True) + self.model.eval() + self.transform = transforms.Compose([ + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ]) + + def to(self, device): + self.model.to(device) + + def cuda(self): + self.model.cuda() + + def cpu(self): + self.model.cpu() + + @torch.no_grad() + def __call__(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor: + """ + Extract features from the image. + + Args: + image: A batch of images as a tensor of shape (B, C, H, W) or a list of PIL images. + + Returns: + A tensor of shape (B, N, D) where N is the number of patches and D is the feature dimension. + """ + if isinstance(image, torch.Tensor): + assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)" + elif isinstance(image, list): + assert all(isinstance(i, Image.Image) for i in image), "Image list should be list of PIL images" + image = [i.resize((518, 518), Image.LANCZOS) for i in image] + image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image] + image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image] + image = torch.stack(image).cuda() + else: + raise ValueError(f"Unsupported type of image: {type(image)}") + + image = self.transform(image).cuda() + features = self.model(image, is_training=True)['x_prenorm'] + patchtokens = F.layer_norm(features, features.shape[-1:]) + return patchtokens + + +class DinoV3FeatureExtractor: + """ + Feature extractor for DINOv3 models. + """ + def __init__(self, model_name: str, image_size=512): + self.model_name = model_name + self.model = DINOv3ViTModel.from_pretrained(model_name) + self.model.eval() + self.image_size = image_size + self.transform = transforms.Compose([ + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ]) + + def to(self, device): + self.model.to(device) + + def cuda(self): + self.model.cuda() + + def cpu(self): + self.model.cpu() + + def extract_features(self, image: torch.Tensor) -> torch.Tensor: + image = image.to(self.model.embeddings.patch_embeddings.weight.dtype) + hidden_states = self.model.embeddings(image, bool_masked_pos=None) + position_embeddings = self.model.rope_embeddings(image) + + for i, layer_module in enumerate(self.model.layer): + hidden_states = layer_module( + hidden_states, + position_embeddings=position_embeddings, + ) + + return F.layer_norm(hidden_states, hidden_states.shape[-1:]) + + @torch.no_grad() + def __call__(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor: + """ + Extract features from the image. + + Args: + image: A batch of images as a tensor of shape (B, C, H, W) or a list of PIL images. + + Returns: + A tensor of shape (B, N, D) where N is the number of patches and D is the feature dimension. + """ + if isinstance(image, torch.Tensor): + assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)" + elif isinstance(image, list): + assert all(isinstance(i, Image.Image) for i in image), "Image list should be list of PIL images" + image = [i.resize((self.image_size, self.image_size), Image.LANCZOS) for i in image] + image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image] + image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image] + image = torch.stack(image).cuda() + else: + raise ValueError(f"Unsupported type of image: {type(image)}") + + image = self.transform(image).cuda() + features = self.extract_features(image) + return features + + +class ImageConditionedMixin: + """ + Mixin for image-conditioned models. + + Args: + image_cond_model: The image conditioning model. + """ + def __init__(self, *args, image_cond_model: dict, **kwargs): + super().__init__(*args, **kwargs) + self.image_cond_model_config = image_cond_model + self.image_cond_model = None # the model is init lazily + + def _init_image_cond_model(self): + """ + Initialize the image conditioning model. + """ + with dist_utils.local_master_first(): + self.image_cond_model = globals()[self.image_cond_model_config['name']](**self.image_cond_model_config.get('args', {})) + + @torch.no_grad() + def encode_image(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor: + """ + Encode the image. + """ + if self.image_cond_model is None: + self._init_image_cond_model() + features = self.image_cond_model(image) + return features + + def get_cond(self, cond, **kwargs): + """ + Get the conditioning data. + """ + cond = self.encode_image(cond) + kwargs['neg_cond'] = torch.zeros_like(cond) + cond = super().get_cond(cond, **kwargs) + return cond + + def get_inference_cond(self, cond, **kwargs): + """ + Get the conditioning data for inference. + """ + cond = self.encode_image(cond) + kwargs['neg_cond'] = torch.zeros_like(cond) + cond = super().get_inference_cond(cond, **kwargs) + return cond + + def vis_cond(self, cond, **kwargs): + """ + Visualize the conditioning data. + """ + return {'image': {'value': cond, 'type': 'image'}} + + +class MultiImageConditionedMixin: + """ + Mixin for multiple-image-conditioned models. + + Args: + image_cond_model: The image conditioning model. + """ + def __init__(self, *args, image_cond_model: dict, **kwargs): + super().__init__(*args, **kwargs) + self.image_cond_model_config = image_cond_model + self.image_cond_model = None # the model is init lazily + + def _init_image_cond_model(self): + """ + Initialize the image conditioning model. + """ + with dist_utils.local_master_first(): + self.image_cond_model = globals()[self.image_cond_model_config['name']](**self.image_cond_model_config.get('args', {})) + + @torch.no_grad() + def encode_images(self, images: Union[List[torch.Tensor], List[List[Image.Image]]]) -> List[torch.Tensor]: + """ + Encode the image. + """ + if self.image_cond_model is None: + self._init_image_cond_model() + seqlen = [len(i) for i in images] + images = torch.cat(images, dim=0) if isinstance(images[0], torch.Tensor) else sum(images, []) + features = self.image_cond_model(images) + features = torch.split(features, seqlen) + features = [feature.reshape(-1, feature.shape[-1]) for feature in features] + return features + + def get_cond(self, cond, **kwargs): + """ + Get the conditioning data. + """ + cond = self.encode_images(cond) + kwargs['neg_cond'] = [ + torch.zeros_like(cond[0][:1, :]) for _ in range(len(cond)) + ] + cond = super().get_cond(cond, **kwargs) + return cond + + def get_inference_cond(self, cond, **kwargs): + """ + Get the conditioning data for inference. + """ + cond = self.encode_images(cond) + kwargs['neg_cond'] = [ + torch.zeros_like(cond[0][:1, :]) for _ in range(len(cond)) + ] + cond = super().get_inference_cond(cond, **kwargs) + return cond + + def vis_cond(self, cond, **kwargs): + """ + Visualize the conditioning data. + """ + H, W = cond[0].shape[-2:] + vis = [] + for images in cond: + canvas = torch.zeros(3, H * 2, W * 2, device=images.device, dtype=images.dtype) + for i, image in enumerate(images): + if i == 4: + break + kh = i // 2 + kw = i % 2 + canvas[:, kh*H:(kh+1)*H, kw*W:(kw+1)*W] = image + vis.append(canvas) + vis = torch.stack(vis) + return {'image': {'value': vis, 'type': 'image'}} diff --git a/trellis2/trainers/flow_matching/mixins/text_conditioned.py b/trellis2/trainers/flow_matching/mixins/text_conditioned.py new file mode 100644 index 0000000..85f1dcf --- /dev/null +++ b/trellis2/trainers/flow_matching/mixins/text_conditioned.py @@ -0,0 +1,68 @@ +from typing import * +import os +os.environ['TOKENIZERS_PARALLELISM'] = 'true' +import torch +from transformers import AutoTokenizer, CLIPTextModel + +from ....utils import dist_utils + + +class TextConditionedMixin: + """ + Mixin for text-conditioned models. + + Args: + text_cond_model: The text conditioning model. + """ + def __init__(self, *args, text_cond_model: str = 'openai/clip-vit-large-patch14', **kwargs): + super().__init__(*args, **kwargs) + self.text_cond_model_name = text_cond_model + self.text_cond_model = None # the model is init lazily + + def _init_text_cond_model(self): + """ + Initialize the text conditioning model. + """ + # load model + with dist_utils.local_master_first(): + model = CLIPTextModel.from_pretrained(self.text_cond_model_name) + tokenizer = AutoTokenizer.from_pretrained(self.text_cond_model_name) + model.eval() + model = model.cuda() + self.text_cond_model = { + 'model': model, + 'tokenizer': tokenizer, + } + self.text_cond_model['null_cond'] = self.encode_text(['']) + + @torch.no_grad() + def encode_text(self, text: List[str]) -> torch.Tensor: + """ + Encode the text. + """ + assert isinstance(text, list) and isinstance(text[0], str), "TextConditionedMixin only supports list of strings as cond" + if self.text_cond_model is None: + self._init_text_cond_model() + encoding = self.text_cond_model['tokenizer'](text, max_length=77, padding='max_length', truncation=True, return_tensors='pt') + tokens = encoding['input_ids'].cuda() + embeddings = self.text_cond_model['model'](input_ids=tokens).last_hidden_state + + return embeddings + + def get_cond(self, cond, **kwargs): + """ + Get the conditioning data. + """ + cond = self.encode_text(cond) + kwargs['neg_cond'] = self.text_cond_model['null_cond'].repeat(cond.shape[0], 1, 1) + cond = super().get_cond(cond, **kwargs) + return cond + + def get_inference_cond(self, cond, **kwargs): + """ + Get the conditioning data for inference. + """ + cond = self.encode_text(cond) + kwargs['neg_cond'] = self.text_cond_model['null_cond'].repeat(cond.shape[0], 1, 1) + cond = super().get_inference_cond(cond, **kwargs) + return cond diff --git a/trellis2/trainers/flow_matching/sparse_flow_matching.py b/trellis2/trainers/flow_matching/sparse_flow_matching.py new file mode 100644 index 0000000..c6735d0 --- /dev/null +++ b/trellis2/trainers/flow_matching/sparse_flow_matching.py @@ -0,0 +1,325 @@ +from typing import * +import os +import copy +import functools +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader +import numpy as np +from easydict import EasyDict as edict + +from ...modules import sparse as sp +from ...utils.general_utils import dict_reduce +from ...utils.data_utils import recursive_to_device, cycle, BalancedResumableSampler +from .flow_matching import FlowMatchingTrainer +from .mixins.classifier_free_guidance import ClassifierFreeGuidanceMixin +from .mixins.text_conditioned import TextConditionedMixin +from .mixins.image_conditioned import ImageConditionedMixin, MultiImageConditionedMixin + + +class SparseFlowMatchingTrainer(FlowMatchingTrainer): + """ + Trainer for sparse diffusion model with flow matching objective. + + Args: + models (dict[str, nn.Module]): Models to train. + dataset (torch.utils.data.Dataset): Dataset. + output_dir (str): Output directory. + load_dir (str): Load directory. + step (int): Step to load. + batch_size (int): Batch size. + batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. + batch_split (int): Split batch with gradient accumulation. + max_steps (int): Max steps. + optimizer (dict): Optimizer config. + lr_scheduler (dict): Learning rate scheduler config. + elastic (dict): Elastic memory management config. + grad_clip (float or dict): Gradient clip config. + ema_rate (float or list): Exponential moving average rates. + fp16_mode (str): FP16 mode. + - None: No FP16. + - 'inflat_all': Hold a inflated fp32 master param for all params. + - 'amp': Automatic mixed precision. + fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. + finetune_ckpt (dict): Finetune checkpoint. + log_param_stats (bool): Log parameter stats. + i_print (int): Print interval. + i_log (int): Log interval. + i_sample (int): Sample interval. + i_save (int): Save interval. + i_ddpcheck (int): DDP check interval. + + t_schedule (dict): Time schedule for flow matching. + sigma_min (float): Minimum noise level. + """ + + def prepare_dataloader(self, **kwargs): + """ + Prepare dataloader. + """ + self.data_sampler = BalancedResumableSampler( + self.dataset, + shuffle=True, + batch_size=self.batch_size_per_gpu, + ) + self.dataloader = DataLoader( + self.dataset, + batch_size=self.batch_size_per_gpu, + num_workers=int(np.ceil(os.cpu_count() / torch.cuda.device_count())), + pin_memory=True, + drop_last=True, + persistent_workers=True, + collate_fn=functools.partial(self.dataset.collate_fn, split_size=self.batch_split), + sampler=self.data_sampler, + ) + self.data_iterator = cycle(self.dataloader) + + def training_losses( + self, + x_0: sp.SparseTensor, + cond=None, + **kwargs + ) -> Tuple[Dict, Dict]: + """ + Compute training losses for a single timestep. + + Args: + x_0: The [N x ... x C] sparse tensor of the inputs. + cond: The [N x ...] tensor of additional conditions. + kwargs: Additional arguments to pass to the backbone. + + Returns: + a dict with the key "loss" containing a tensor of shape [N]. + may also contain other keys for different terms. + """ + noise = x_0.replace(torch.randn_like(x_0.feats)) + t = self.sample_t(x_0.shape[0]).to(x_0.device).float() + x_t = self.diffuse(x_0, t, noise=noise) + cond = self.get_cond(cond, **kwargs) + + pred = self.training_models['denoiser'](x_t, t * 1000, cond, **kwargs) + assert pred.shape == noise.shape == x_0.shape + target = self.get_v(x_0, noise, t) + terms = edict() + terms["mse"] = F.mse_loss(pred.feats, target.feats) + terms["loss"] = terms["mse"] + + # log loss with time bins + mse_per_instance = np.array([ + F.mse_loss(pred.feats[x_0.layout[i]], target.feats[x_0.layout[i]]).item() + for i in range(x_0.shape[0]) + ]) + time_bin = np.digitize(t.cpu().numpy(), np.linspace(0, 1, 11)) - 1 + for i in range(10): + if (time_bin == i).sum() != 0: + terms[f"bin_{i}"] = {"mse": mse_per_instance[time_bin == i].mean()} + + return terms, {} + + @torch.no_grad() + def run_snapshot( + self, + num_samples: int, + batch_size: int, + verbose: bool = False, + ) -> Dict: + dataloader = DataLoader( + copy.deepcopy(self.dataset), + batch_size=num_samples, + shuffle=True, + num_workers=0, + collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None, + ) + data = next(iter(dataloader)) + + # inference + sampler = self.get_sampler() + sample = [] + cond_vis = [] + for i in range(0, num_samples, batch_size): + batch_data = {k: v[i:i+batch_size] for k, v in data.items()} + batch_data = recursive_to_device(batch_data, 'cuda') + noise = batch_data['x_0'].replace(torch.randn_like(batch_data['x_0'].feats)) + cond_vis.append(self.vis_cond(**batch_data)) + del batch_data['x_0'] + args = self.get_inference_cond(**batch_data) + res = sampler.sample( + self.models['denoiser'], + noise=noise, + **args, + steps=12, guidance_strength=3.0, verbose=verbose, + ) + sample.append(res.samples) + sample = sp.sparse_cat(sample) + + sample_gt = {k: v for k, v in data.items()} + sample = {k: v if k != 'x_0' else sample for k, v in data.items()} + sample_dict = { + 'sample_gt': {'value': sample_gt, 'type': 'sample'}, + 'sample': {'value': sample, 'type': 'sample'}, + } + sample_dict.update(dict_reduce(cond_vis, None, { + 'value': lambda x: torch.cat(x, dim=0), + 'type': lambda x: x[0], + })) + + return sample_dict + + +class SparseFlowMatchingCFGTrainer(ClassifierFreeGuidanceMixin, SparseFlowMatchingTrainer): + """ + Trainer for sparse diffusion model with flow matching objective and classifier-free guidance. + + Args: + models (dict[str, nn.Module]): Models to train. + dataset (torch.utils.data.Dataset): Dataset. + output_dir (str): Output directory. + load_dir (str): Load directory. + step (int): Step to load. + batch_size (int): Batch size. + batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. + batch_split (int): Split batch with gradient accumulation. + max_steps (int): Max steps. + optimizer (dict): Optimizer config. + lr_scheduler (dict): Learning rate scheduler config. + elastic (dict): Elastic memory management config. + grad_clip (float or dict): Gradient clip config. + ema_rate (float or list): Exponential moving average rates. + fp16_mode (str): FP16 mode. + - None: No FP16. + - 'inflat_all': Hold a inflated fp32 master param for all params. + - 'amp': Automatic mixed precision. + fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. + finetune_ckpt (dict): Finetune checkpoint. + log_param_stats (bool): Log parameter stats. + i_print (int): Print interval. + i_log (int): Log interval. + i_sample (int): Sample interval. + i_save (int): Save interval. + i_ddpcheck (int): DDP check interval. + + t_schedule (dict): Time schedule for flow matching. + sigma_min (float): Minimum noise level. + p_uncond (float): Probability of dropping conditions. + """ + pass + + +class TextConditionedSparseFlowMatchingCFGTrainer(TextConditionedMixin, SparseFlowMatchingCFGTrainer): + """ + Trainer for sparse text-conditioned diffusion model with flow matching objective and classifier-free guidance. + + Args: + models (dict[str, nn.Module]): Models to train. + dataset (torch.utils.data.Dataset): Dataset. + output_dir (str): Output directory. + load_dir (str): Load directory. + step (int): Step to load. + batch_size (int): Batch size. + batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. + batch_split (int): Split batch with gradient accumulation. + max_steps (int): Max steps. + optimizer (dict): Optimizer config. + lr_scheduler (dict): Learning rate scheduler config. + elastic (dict): Elastic memory management config. + grad_clip (float or dict): Gradient clip config. + ema_rate (float or list): Exponential moving average rates. + fp16_mode (str): FP16 mode. + - None: No FP16. + - 'inflat_all': Hold a inflated fp32 master param for all params. + - 'amp': Automatic mixed precision. + fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. + finetune_ckpt (dict): Finetune checkpoint. + log_param_stats (bool): Log parameter stats. + i_print (int): Print interval. + i_log (int): Log interval. + i_sample (int): Sample interval. + i_save (int): Save interval. + i_ddpcheck (int): DDP check interval. + + t_schedule (dict): Time schedule for flow matching. + sigma_min (float): Minimum noise level. + p_uncond (float): Probability of dropping conditions. + text_cond_model(str): Text conditioning model. + """ + pass + + +class ImageConditionedSparseFlowMatchingCFGTrainer(ImageConditionedMixin, SparseFlowMatchingCFGTrainer): + """ + Trainer for sparse image-conditioned diffusion model with flow matching objective and classifier-free guidance. + + Args: + models (dict[str, nn.Module]): Models to train. + dataset (torch.utils.data.Dataset): Dataset. + output_dir (str): Output directory. + load_dir (str): Load directory. + step (int): Step to load. + batch_size (int): Batch size. + batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. + batch_split (int): Split batch with gradient accumulation. + max_steps (int): Max steps. + optimizer (dict): Optimizer config. + lr_scheduler (dict): Learning rate scheduler config. + elastic (dict): Elastic memory management config. + grad_clip (float or dict): Gradient clip config. + ema_rate (float or list): Exponential moving average rates. + fp16_mode (str): FP16 mode. + - None: No FP16. + - 'inflat_all': Hold a inflated fp32 master param for all params. + - 'amp': Automatic mixed precision. + fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. + finetune_ckpt (dict): Finetune checkpoint. + log_param_stats (bool): Log parameter stats. + i_print (int): Print interval. + i_log (int): Log interval. + i_sample (int): Sample interval. + i_save (int): Save interval. + i_ddpcheck (int): DDP check interval. + + t_schedule (dict): Time schedule for flow matching. + sigma_min (float): Minimum noise level. + p_uncond (float): Probability of dropping conditions. + image_cond_model (str): Image conditioning model. + """ + pass + + +class MultiImageConditionedSparseFlowMatchingCFGTrainer(MultiImageConditionedMixin, SparseFlowMatchingCFGTrainer): + """ + Trainer for sparse image-conditioned diffusion model with flow matching objective and classifier-free guidance. + + Args: + models (dict[str, nn.Module]): Models to train. + dataset (torch.utils.data.Dataset): Dataset. + output_dir (str): Output directory. + load_dir (str): Load directory. + step (int): Step to load. + batch_size (int): Batch size. + batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. + batch_split (int): Split batch with gradient accumulation. + max_steps (int): Max steps. + optimizer (dict): Optimizer config. + lr_scheduler (dict): Learning rate scheduler config. + elastic (dict): Elastic memory management config. + grad_clip (float or dict): Gradient clip config. + ema_rate (float or list): Exponential moving average rates. + fp16_mode (str): FP16 mode. + - None: No FP16. + - 'inflat_all': Hold a inflated fp32 master param for all params. + - 'amp': Automatic mixed precision. + fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. + finetune_ckpt (dict): Finetune checkpoint. + log_param_stats (bool): Log parameter stats. + i_print (int): Print interval. + i_log (int): Log interval. + i_sample (int): Sample interval. + i_save (int): Save interval. + i_ddpcheck (int): DDP check interval. + + t_schedule (dict): Time schedule for flow matching. + sigma_min (float): Minimum noise level. + p_uncond (float): Probability of dropping conditions. + image_cond_model (str): Image conditioning model. + """ + pass diff --git a/trellis2/trainers/utils.py b/trellis2/trainers/utils.py new file mode 100644 index 0000000..23e4286 --- /dev/null +++ b/trellis2/trainers/utils.py @@ -0,0 +1,91 @@ +import torch +import torch.nn as nn + +# FP16 utils +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + + +def str_to_dtype(dtype_str: str): + return { + 'f16': torch.float16, + 'fp16': torch.float16, + 'float16': torch.float16, + 'bf16': torch.bfloat16, + 'bfloat16': torch.bfloat16, + 'f32': torch.float32, + 'fp32': torch.float32, + 'float32': torch.float32, + }[dtype_str] + + +def make_master_params(model_params): + """ + Copy model parameters into a inflated tensor of full-precision parameters. + """ + master_params = _flatten_dense_tensors( + [param.detach().float() for param in model_params] + ) + master_params = nn.Parameter(master_params) + master_params.requires_grad = True + return [master_params] + + +def unflatten_master_params(model_params, master_params): + """ + Unflatten the master parameters to look like model_params. + """ + return _unflatten_dense_tensors(master_params[0].detach(), model_params) + + +def model_params_to_master_params(model_params, master_params): + """ + Copy the model parameter data into the master parameters. + """ + master_params[0].detach().copy_( + _flatten_dense_tensors([param.detach().float() for param in model_params]) + ) + + +def master_params_to_model_params(model_params, master_params): + """ + Copy the master parameter data back into the model parameters. + """ + for param, master_param in zip( + model_params, _unflatten_dense_tensors(master_params[0].detach(), model_params) + ): + param.detach().copy_(master_param) + + +def model_grads_to_master_grads(model_params, master_params): + """ + Copy the gradients from the model parameters into the master parameters + from make_master_params(). + """ + master_params[0].grad = _flatten_dense_tensors( + [param.grad.data.detach().float() for param in model_params] + ) + + +def zero_grad(model_params): + for param in model_params: + if param.grad is not None: + if param.grad.grad_fn is not None: + param.grad.detach_() + else: + param.grad.requires_grad_(False) + param.grad.zero_() + + +# LR Schedulers +from torch.optim.lr_scheduler import LambdaLR + +class LinearWarmupLRScheduler(LambdaLR): + def __init__(self, optimizer, warmup_steps, last_epoch=-1): + self.warmup_steps = warmup_steps + super(LinearWarmupLRScheduler, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) + + def lr_lambda(self, current_step): + if current_step < self.warmup_steps: + return float(current_step + 1) / self.warmup_steps + return 1.0 + \ No newline at end of file diff --git a/trellis2/trainers/vae/pbr_vae.py b/trellis2/trainers/vae/pbr_vae.py new file mode 100644 index 0000000..5527236 --- /dev/null +++ b/trellis2/trainers/vae/pbr_vae.py @@ -0,0 +1,281 @@ +from typing import * +import os +import copy +import functools +import numpy as np +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader +import utils3d +from easydict import EasyDict as edict + +from ..basic import BasicTrainer +from ...modules import sparse as sp +from ...renderers import MeshRenderer +from ...representations import Mesh, MeshWithPbrMaterial, MeshWithVoxel +from ...utils.data_utils import recursive_to_device, cycle, BalancedResumableSampler +from ...utils.loss_utils import l1_loss, l2_loss, ssim, lpips + + +class PbrVaeTrainer(BasicTrainer): + """ + Trainer for PBR attributes VAE + + Args: + models (dict[str, nn.Module]): Models to train. + dataset (torch.utils.data.Dataset): Dataset. + output_dir (str): Output directory. + load_dir (str): Load directory. + step (int): Step to load. + batch_size (int): Batch size. + batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. + batch_split (int): Split batch with gradient accumulation. + max_steps (int): Max steps. + optimizer (dict): Optimizer config. + lr_scheduler (dict): Learning rate scheduler config. + elastic (dict): Elastic memory management config. + grad_clip (float or dict): Gradient clip config. + ema_rate (float or list): Exponential moving average rates. + fp16_mode (str): FP16 mode. + - None: No FP16. + - 'inflat_all': Hold a inflated fp32 master param for all params. + - 'amp': Automatic mixed precision. + fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. + finetune_ckpt (dict): Finetune checkpoint. + log_param_stats (bool): Log parameter stats. + i_print (int): Print interval. + i_log (int): Log interval. + i_sample (int): Sample interval. + i_save (int): Save interval. + i_ddpcheck (int): DDP check interval. + + loss_type (str): Loss type. + lambda_kl (float): KL loss weight. + lambda_ssim (float): SSIM loss weight. + lambda_lpips (float): LPIPS loss weight. + """ + + def __init__( + self, + *args, + loss_type: str = 'l1', + lambda_kl: float = 1e-6, + lambda_ssim: float = 0.2, + lambda_lpips: float = 0.2, + lambda_render: float = 1.0, + render_resolution: float = 1024, + camera_randomization_config: dict = { + 'radius_range': [2, 100], + }, + **kwargs + ): + super().__init__(*args, **kwargs) + self.loss_type = loss_type + self.lambda_kl = lambda_kl + self.lambda_ssim = lambda_ssim + self.lambda_lpips = lambda_lpips + self.lambda_render = lambda_render + self.camera_randomization_config = camera_randomization_config + + self.renderer = MeshRenderer({'near': 1, 'far': 3, 'resolution': render_resolution}, device=self.device) + + def prepare_dataloader(self, **kwargs): + """ + Prepare dataloader. + """ + self.data_sampler = BalancedResumableSampler( + self.dataset, + shuffle=True, + batch_size=self.batch_size_per_gpu, + ) + self.dataloader = DataLoader( + self.dataset, + batch_size=self.batch_size_per_gpu, + num_workers=int(np.ceil(os.cpu_count() / torch.cuda.device_count())), + pin_memory=True, + drop_last=True, + persistent_workers=True, + collate_fn=functools.partial(self.dataset.collate_fn, split_size=self.batch_split), + sampler=self.data_sampler, + ) + self.data_iterator = cycle(self.dataloader) + + def _randomize_camera(self, num_samples: int): + # sample radius and fov + r_min, r_max = self.camera_randomization_config['radius_range'] + k_min = 1 / r_max**2 + k_max = 1 / r_min**2 + ks = torch.rand(num_samples, device=self.device) * (k_max - k_min) + k_min + radius = 1 / torch.sqrt(ks) + fov = 2 * torch.arcsin(0.5 / radius) + origin = radius.unsqueeze(-1) * F.normalize(torch.randn(num_samples, 3, device=self.device), dim=-1) + + # build camera + extrinsics = utils3d.torch.extrinsics_look_at(origin, torch.zeros_like(origin), torch.tensor([0, 0, 1], dtype=torch.float32, device=self.device)) + intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov) + near = [np.random.uniform(r - 1, r) for r in radius.tolist()] + + return { + 'extrinsics': extrinsics, + 'intrinsics': intrinsics, + 'near': near, + } + + def _render_batch(self, reps: List[Mesh], extrinsics: torch.Tensor, intrinsics: torch.Tensor, near: List, + ) -> Dict[str, torch.Tensor]: + """ + Render a batch of representations. + + Args: + reps: The dictionary of lists of representations. + extrinsics: The [N x 4 x 4] tensor of extrinsics. + intrinsics: The [N x 3 x 3] tensor of intrinsics. + + Returns: + a dict with + base_color : [N x 3 x H x W] tensor of base color. + metallic : [N x 1 x H x W] tensor of metallic. + roughness : [N x 1 x H x W] tensor of roughness. + alpha : [N x 1 x H x W] tensor of alpha. + """ + ret = {k : [] for k in ['base_color', 'metallic', 'roughness', 'alpha']} + for i, rep in enumerate(reps): + self.renderer.rendering_options['near'] = near[i] + self.renderer.rendering_options['far'] = near[i] + 2 + out_dict = self.renderer.render(rep, extrinsics[i], intrinsics[i], return_types=['attr']) + for k in out_dict: + ret[k].append(out_dict[k]) + for k in ret: + ret[k] = torch.stack(ret[k]) + return ret + + def training_losses( + self, + x: sp.SparseTensor, + mesh: List[MeshWithPbrMaterial] = None, + **kwargs + ) -> Tuple[Dict, Dict]: + """ + Compute training losses. + + Args: + x (SparseTensor): Input sparse tensor for pbr materials. + mesh (List[MeshWithPbrMaterial]): The list of meshes with PBR materials. + + Returns: + a dict with the key "loss" containing a scalar tensor. + may also contain other keys for different terms. + + """ + z, mean, logvar = self.training_models['encoder'](x, sample_posterior=True, return_raw=True) + y = self.training_models['decoder'](z) + + terms = edict(loss = 0.0) + + # direct regression + if self.loss_type == 'l1': + terms["l1"] = l1_loss(x.feats, y.feats) + terms["loss"] = terms["loss"] + terms["l1"] + elif self.loss_type == 'l2': + terms["l2"] = l2_loss(x.feats, y.feats) + terms["loss"] = terms["loss"] + terms["l2"] + else: + raise ValueError(f'Invalid loss type {self.loss_type}') + + # rendering loss + if self.lambda_render != 0.0: + recon = [MeshWithVoxel( + m.vertices, + m.faces, + [-0.5, -0.5, -0.5], + 1 / self.dataset.resolution, + v.coords[:, 1:], + v.feats * 0.5 + 0.5, + torch.Size([*v.shape, *v.spatial_shape]), + layout={ + 'base_color': slice(0, 3), + 'metallic': slice(3, 4), + 'roughness': slice(4, 5), + 'alpha': slice(5, 6), + } + ) for m, v in zip(mesh, y)] + cameras = self._randomize_camera(len(mesh)) + gt_renders = self._render_batch(mesh, **cameras) + pred_renders = self._render_batch(recon, **cameras) + gt_base_color = gt_renders['base_color'] + pred_base_color = pred_renders['base_color'] + gt_mra = torch.cat([gt_renders['metallic'], gt_renders['roughness'], gt_renders['alpha']], dim=1) + pred_mra = torch.cat([pred_renders['metallic'], pred_renders['roughness'], pred_renders['alpha']], dim=1) + terms['render/base_color/ssim'] = 1 - ssim(pred_base_color, gt_base_color) + terms['render/base_color/lpips'] = lpips(pred_base_color, gt_base_color) + terms['render/mra/ssim'] = 1 - ssim(pred_mra, gt_mra) + terms['render/mra/lpips'] = lpips(pred_mra, gt_mra) + terms['loss'] = terms['loss'] + \ + self.lambda_render * (self.lambda_ssim * terms['render/base_color/ssim'] + self.lambda_lpips * terms['render/base_color/lpips'] + \ + self.lambda_ssim * terms['render/mra/ssim'] + self.lambda_lpips * terms['render/mra/lpips']) + + # KL regularization + terms["kl"] = 0.5 * torch.mean(mean.pow(2) + logvar.exp() - logvar - 1) + terms["loss"] = terms["loss"] + self.lambda_kl * terms["kl"] + + return terms, {} + + @torch.no_grad() + def run_snapshot( + self, + num_samples: int, + batch_size: int, + verbose: bool = False, + ) -> Dict: + dataloader = DataLoader( + copy.deepcopy(self.dataset), + batch_size=batch_size, + shuffle=True, + num_workers=1, + collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None, + ) + dataloader.dataset.with_mesh = True + + # inference + gts = [] + recons = [] + self.models['encoder'].eval() + self.models['decoder'].eval() + for i in range(0, num_samples, batch_size): + batch = min(batch_size, num_samples - i) + data = next(iter(dataloader)) + args = {k: v[:batch] for k, v in data.items()} + args = recursive_to_device(args, self.device) + z = self.models['encoder'](args['x']) + y = self.models['decoder'](z) + gts.extend(args['mesh']) + recons.extend([MeshWithVoxel( + m.vertices, + m.faces, + [-0.5, -0.5, -0.5], + 1 / self.dataset.resolution, + v.coords[:, 1:], + v.feats * 0.5 + 0.5, + torch.Size([*v.shape, *v.spatial_shape]), + layout={ + 'base_color': slice(0, 3), + 'metallic': slice(3, 4), + 'roughness': slice(4, 5), + 'alpha': slice(5, 6), + } + ) for m, v in zip(args['mesh'], y)]) + self.models['encoder'].train() + self.models['decoder'].train() + + cameras = self._randomize_camera(num_samples) + gt_renders = self._render_batch(gts, **cameras) + pred_renders = self._render_batch(recons, **cameras) + + sample_dict = { + 'gt_base_color': {'value': gt_renders['base_color'] * 2 - 1, 'type': 'image'}, + 'pred_base_color': {'value': pred_renders['base_color'] * 2 - 1, 'type': 'image'}, + 'gt_mra': {'value': torch.cat([gt_renders['metallic'], gt_renders['roughness'], gt_renders['alpha']], dim=1) * 2 - 1, 'type': 'image'}, + 'pred_mra': {'value': torch.cat([pred_renders['metallic'], pred_renders['roughness'], pred_renders['alpha']], dim=1) * 2 - 1, 'type': 'image'}, + } + + return sample_dict diff --git a/trellis2/trainers/vae/shape_vae.py b/trellis2/trainers/vae/shape_vae.py new file mode 100644 index 0000000..f9441fa --- /dev/null +++ b/trellis2/trainers/vae/shape_vae.py @@ -0,0 +1,266 @@ +from typing import * +import os +import copy +import functools +import numpy as np +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader +import utils3d +from easydict import EasyDict as edict + +from ..basic import BasicTrainer +from ...modules import sparse as sp +from ...renderers import MeshRenderer +from ...representations import Mesh +from ...utils.data_utils import recursive_to_device, cycle, BalancedResumableSampler +from ...utils.loss_utils import l1_loss, ssim, lpips + + +class ShapeVaeTrainer(BasicTrainer): + """ + Trainer for Shape VAE + + Args: + models (dict[str, nn.Module]): Models to train. + dataset (torch.utils.data.Dataset): Dataset. + output_dir (str): Output directory. + load_dir (str): Load directory. + step (int): Step to load. + batch_size (int): Batch size. + batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. + batch_split (int): Split batch with gradient accumulation. + max_steps (int): Max steps. + optimizer (dict): Optimizer config. + lr_scheduler (dict): Learning rate scheduler config. + elastic (dict): Elastic memory management config. + grad_clip (float or dict): Gradient clip config. + ema_rate (float or list): Exponential moving average rates. + fp16_mode (str): FP16 mode. + - None: No FP16. + - 'inflat_all': Hold a inflated fp32 master param for all params. + - 'amp': Automatic mixed precision. + fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. + finetune_ckpt (dict): Finetune checkpoint. + log_param_stats (bool): Log parameter stats. + i_print (int): Print interval. + i_log (int): Log interval. + i_sample (int): Sample interval. + i_save (int): Save interval. + i_ddpcheck (int): DDP check interval. + + lambda_subdiv (float): Subdivision loss weight. + lambda_intersected (float): Intersected loss weight. + lambda_vertice (float): Vertice loss weight. + lambda_kl (float): KL loss weight. + lambda_ssim (float): SSIM loss weight. + lambda_lpips (float): LPIPS loss weight. + """ + + def __init__( + self, + *args, + lambda_subdiv: float = 0.1, + lambda_intersected: float = 0.1, + lambda_vertice: float = 1e-2, + lambda_mask: float = 1, + lambda_depth: float = 10, + lambda_normal: float = 1, + lambda_kl: float = 1e-6, + lambda_ssim: float = 0.2, + lambda_lpips: float = 0.2, + render_resolution: float = 1024, + camera_randomization_config: dict = { + 'radius_range': [2, 100], + }, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lambda_subdiv = lambda_subdiv + self.lambda_intersected = lambda_intersected + self.lambda_mask = lambda_mask + self.lambda_vertice = lambda_vertice + self.lambda_depth = lambda_depth + self.lambda_normal = lambda_normal + self.lambda_kl = lambda_kl + self.lambda_ssim = lambda_ssim + self.lambda_lpips = lambda_lpips + self.camera_randomization_config = camera_randomization_config + + self.renderer = MeshRenderer({'near': 1, 'far': 3, 'resolution': render_resolution}, device=self.device) + + def prepare_dataloader(self, **kwargs): + """ + Prepare dataloader. + """ + self.data_sampler = BalancedResumableSampler( + self.dataset, + shuffle=True, + batch_size=self.batch_size_per_gpu, + ) + self.dataloader = DataLoader( + self.dataset, + batch_size=self.batch_size_per_gpu, + num_workers=int(np.ceil(os.cpu_count() / torch.cuda.device_count())), + pin_memory=True, + drop_last=True, + persistent_workers=True, + collate_fn=functools.partial(self.dataset.collate_fn, split_size=self.batch_split), + sampler=self.data_sampler, + ) + self.data_iterator = cycle(self.dataloader) + + def _randomize_camera(self, num_samples: int): + # sample radius and fov + r_min, r_max = self.camera_randomization_config['radius_range'] + k_min = 1 / r_max**2 + k_max = 1 / r_min**2 + ks = torch.rand(num_samples, device=self.device) * (k_max - k_min) + k_min + radius = 1 / torch.sqrt(ks) + fov = 2 * torch.arcsin(0.5 / radius) + origin = radius.unsqueeze(-1) * F.normalize(torch.randn(num_samples, 3, device=self.device), dim=-1) + + # build camera + extrinsics = utils3d.torch.extrinsics_look_at(origin, torch.zeros_like(origin), torch.tensor([0, 0, 1], dtype=torch.float32, device=self.device)) + intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov) + near = [np.random.uniform(r - 1, r) for r in radius.tolist()] + + return { + 'extrinsics': extrinsics, + 'intrinsics': intrinsics, + 'near': near, + } + + def _render_batch(self, reps: List[Mesh], extrinsics: torch.Tensor, intrinsics: torch.Tensor, near: List, + return_types=['mask', 'normal', 'depth']) -> Dict[str, torch.Tensor]: + """ + Render a batch of representations. + + Args: + reps: The dictionary of lists of representations. + extrinsics: The [N x 4 x 4] tensor of extrinsics. + intrinsics: The [N x 3 x 3] tensor of intrinsics. + return_types: vary in ['mask', 'normal', 'depth', 'normal_map', 'color'] + + Returns: + a dict with + mask : [N x 1 x H x W] tensor of rendered masks + normal : [N x 3 x H x W] tensor of rendered normals + depth : [N x 1 x H x W] tensor of rendered depths + """ + ret = {k : [] for k in return_types} + for i, rep in enumerate(reps): + self.renderer.rendering_options['near'] = near[i] + self.renderer.rendering_options['far'] = near[i] + 2 + out_dict = self.renderer.render(rep, extrinsics[i], intrinsics[i], return_types=return_types) + for k in out_dict: + ret[k].append(out_dict[k][None] if k in ['mask', 'depth'] else out_dict[k]) + for k in ret: + ret[k] = torch.stack(ret[k]) + return ret + + def training_losses( + self, + vertices: sp.SparseTensor, + intersected: sp.SparseTensor, + mesh: List[Mesh], + ) -> Tuple[Dict, Dict]: + """ + Compute training losses. + + Args: + vertices (SparseTensor): vertices of each active voxel + intersected (SparseTensor): intersected flag of each active voxel + mesh (List[Mesh]): the list of meshes to render + + Returns: + a dict with the key "loss" containing a scalar tensor. + may also contain other keys for different terms. + """ + z, mean, logvar = self.training_models['encoder'](vertices, intersected, sample_posterior=True, return_raw=True) + recon, pred_vertice, pred_intersected, subs_gt, subs = self.training_models['decoder'](z, intersected) + + terms = edict(loss = 0.0) + + # direct regression + if self.lambda_intersected > 0: + terms["direct/intersected"] = F.binary_cross_entropy_with_logits(pred_intersected.feats.flatten(), intersected.feats.flatten().float()) + terms["loss"] = terms["loss"] + self.lambda_intersected * terms["direct/intersected"] + if self.lambda_vertice > 0: + terms["direct/vertice"] = F.mse_loss(pred_vertice.feats, vertices.feats) + terms["loss"] = terms["loss"] + self.lambda_vertice * terms["direct/vertice"] + + # subdivision prediction loss + for i, (sub_gt, sub) in enumerate(zip(subs_gt, subs)): + terms[f"bce_sub{i}"] = F.binary_cross_entropy_with_logits(sub.feats, sub_gt.float()) + terms["loss"] = terms["loss"] + self.lambda_subdiv * terms[f"bce_sub{i}"] + + # rendering loss + cameras = self._randomize_camera(len(mesh)) + gt_renders = self._render_batch(mesh, **cameras, return_types=['mask', 'normal', 'depth']) + pred_renders = self._render_batch(recon, **cameras, return_types=['mask', 'normal', 'depth']) + terms['render/mask'] = l1_loss(pred_renders['mask'], gt_renders['mask']) + terms['render/depth'] = l1_loss(pred_renders['depth'], gt_renders['depth']) + terms['render/normal/l1'] = l1_loss(pred_renders['normal'], gt_renders['normal']) + terms['render/normal/ssim'] = 1 - ssim(pred_renders['normal'], gt_renders['normal']) + terms['render/normal/lpips'] = lpips(pred_renders['normal'], gt_renders['normal']) + terms['loss'] = terms['loss'] + \ + self.lambda_mask * terms['render/mask'] + \ + self.lambda_depth * terms['render/depth'] + \ + self.lambda_normal * (terms['render/normal/l1'] + self.lambda_ssim * terms['render/normal/ssim'] + self.lambda_lpips * terms['render/normal/lpips']) + + # KL regularization + terms["kl"] = 0.5 * torch.mean(mean.pow(2) + logvar.exp() - logvar - 1) + terms["loss"] = terms["loss"] + self.lambda_kl * terms["kl"] + + return terms, {} + + @torch.no_grad() + def run_snapshot( + self, + num_samples: int, + batch_size: int, + verbose: bool = False, + ) -> Dict: + dataloader = DataLoader( + copy.deepcopy(self.dataset), + batch_size=batch_size, + shuffle=True, + num_workers=1, + collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None, + ) + + # inference + gts = [] + recons = [] + recons2 = [] + self.models['encoder'].eval() + for i in range(0, num_samples, batch_size): + batch = min(batch_size, num_samples - i) + data = next(iter(dataloader)) + args = {k: v[:batch] for k, v in data.items()} + args = recursive_to_device(args, self.device) + z = self.models['encoder'](args['vertices'], args['intersected']) + self.models['decoder'].train() + y = self.models['decoder'](z, args['intersected'])[0] + z.clear_spatial_cache() + self.models['decoder'].eval() + y2 = self.models['decoder'](z) + gts.extend(args['mesh']) + recons.extend(y) + recons2.extend(y2) + self.models['encoder'].train() + self.models['decoder'].train() + + cameras = self._randomize_camera(num_samples) + gt_renders = self._render_batch(gts, **cameras, return_types=['normal']) + recons_renders = self._render_batch(recons, **cameras, return_types=['normal']) + recons2_renders = self._render_batch(recons2, **cameras, return_types=['normal']) + + sample_dict = { + 'gt': {'value': gt_renders['normal'], 'type': 'image'}, + 'rec': {'value': recons_renders['normal'], 'type': 'image'}, + 'rec2': {'value': recons2_renders['normal'], 'type': 'image'}, + } + + return sample_dict diff --git a/trellis2/trainers/vae/sparse_structure_vae.py b/trellis2/trainers/vae/sparse_structure_vae.py new file mode 100644 index 0000000..5563641 --- /dev/null +++ b/trellis2/trainers/vae/sparse_structure_vae.py @@ -0,0 +1,130 @@ +from typing import * +import copy +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader +from easydict import EasyDict as edict + +from ..basic import BasicTrainer + + +class SparseStructureVaeTrainer(BasicTrainer): + """ + Trainer for Sparse Structure VAE. + + Args: + models (dict[str, nn.Module]): Models to train. + dataset (torch.utils.data.Dataset): Dataset. + output_dir (str): Output directory. + load_dir (str): Load directory. + step (int): Step to load. + batch_size (int): Batch size. + batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. + batch_split (int): Split batch with gradient accumulation. + max_steps (int): Max steps. + optimizer (dict): Optimizer config. + lr_scheduler (dict): Learning rate scheduler config. + elastic (dict): Elastic memory management config. + grad_clip (float or dict): Gradient clip config. + ema_rate (float or list): Exponential moving average rates. + fp16_mode (str): FP16 mode. + - None: No FP16. + - 'inflat_all': Hold a inflated fp32 master param for all params. + - 'amp': Automatic mixed precision. + fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. + finetune_ckpt (dict): Finetune checkpoint. + log_param_stats (bool): Log parameter stats. + i_print (int): Print interval. + i_log (int): Log interval. + i_sample (int): Sample interval. + i_save (int): Save interval. + i_ddpcheck (int): DDP check interval. + + loss_type (str): Loss type. 'bce' for binary cross entropy, 'l1' for L1 loss, 'dice' for Dice loss. + lambda_kl (float): KL divergence loss weight. + """ + + def __init__( + self, + *args, + loss_type='bce', + lambda_kl=1e-6, + **kwargs + ): + super().__init__(*args, **kwargs) + self.loss_type = loss_type + self.lambda_kl = lambda_kl + + def training_losses( + self, + ss: torch.Tensor, + **kwargs + ) -> Tuple[Dict, Dict]: + """ + Compute training losses. + + Args: + ss: The [N x 1 x H x W x D] tensor of binary sparse structure. + + Returns: + a dict with the key "loss" containing a scalar tensor. + may also contain other keys for different terms. + """ + z, mean, logvar = self.training_models['encoder'](ss.float(), sample_posterior=True, return_raw=True) + logits = self.training_models['decoder'](z) + + terms = edict(loss = 0.0) + if self.loss_type == 'bce': + terms["bce"] = F.binary_cross_entropy_with_logits(logits, ss.float(), reduction='mean') + terms["loss"] = terms["loss"] + terms["bce"] + elif self.loss_type == 'l1': + terms["l1"] = F.l1_loss(F.sigmoid(logits), ss.float(), reduction='mean') + terms["loss"] = terms["loss"] + terms["l1"] + elif self.loss_type == 'dice': + logits = F.sigmoid(logits) + terms["dice"] = 1 - (2 * (logits * ss.float()).sum() + 1) / (logits.sum() + ss.float().sum() + 1) + terms["loss"] = terms["loss"] + terms["dice"] + else: + raise ValueError(f'Invalid loss type {self.loss_type}') + terms["kl"] = 0.5 * torch.mean(mean.pow(2) + logvar.exp() - logvar - 1) + terms["loss"] = terms["loss"] + self.lamda_kl * terms["kl"] + + return terms, {} + + @torch.no_grad() + def snapshot(self, suffix=None, num_samples=64, batch_size=1, verbose=False): + super().snapshot(suffix=suffix, num_samples=num_samples, batch_size=batch_size, verbose=verbose) + + @torch.no_grad() + def run_snapshot( + self, + num_samples: int, + batch_size: int, + verbose: bool = False, + ) -> Dict: + dataloader = DataLoader( + copy.deepcopy(self.dataset), + batch_size=batch_size, + shuffle=True, + num_workers=0, + collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None, + ) + + # inference + gts = [] + recons = [] + for i in range(0, num_samples, batch_size): + batch = min(batch_size, num_samples - i) + data = next(iter(dataloader)) + args = {k: v[:batch].cuda() if isinstance(v, torch.Tensor) else v[:batch] for k, v in data.items()} + z = self.models['encoder'](args['ss'].float(), sample_posterior=False) + logits = self.models['decoder'](z) + recon = (logits > 0).long() + gts.append(args['ss']) + recons.append(recon) + + sample_dict = { + 'gt': {'value': torch.cat(gts, dim=0), 'type': 'sample'}, + 'recon': {'value': torch.cat(recons, dim=0), 'type': 'sample'}, + } + return sample_dict diff --git a/trellis2/utils/__init__.py b/trellis2/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/trellis2/utils/data_utils.py b/trellis2/utils/data_utils.py new file mode 100644 index 0000000..805b6cc --- /dev/null +++ b/trellis2/utils/data_utils.py @@ -0,0 +1,226 @@ +from typing import * +import math +import torch +import numpy as np +from torch.utils.data import Sampler, Dataset, DataLoader, DistributedSampler +import torch.distributed as dist + + +def recursive_to_device( + data: Any, + device: torch.device, + non_blocking: bool = False, +) -> Any: + """ + Recursively move all tensors in a data structure to a device. + """ + if hasattr(data, "to"): + return data.to(device, non_blocking=non_blocking) + elif isinstance(data, (list, tuple)): + return type(data)(recursive_to_device(d, device, non_blocking) for d in data) + elif isinstance(data, dict): + return {k: recursive_to_device(v, device, non_blocking) for k, v in data.items()} + else: + return data + + +def load_balanced_group_indices( + load: List[int], + num_groups: int, + equal_size: bool = False, +) -> List[List[int]]: + """ + Split indices into groups with balanced load. + """ + if equal_size: + group_size = len(load) // num_groups + indices = np.argsort(load)[::-1] + groups = [[] for _ in range(num_groups)] + group_load = np.zeros(num_groups) + for idx in indices: + min_group_idx = np.argmin(group_load) + groups[min_group_idx].append(idx) + if equal_size and len(groups[min_group_idx]) == group_size: + group_load[min_group_idx] = float('inf') + else: + group_load[min_group_idx] += load[idx] + return groups + + +def cycle(data_loader: DataLoader) -> Iterator: + while True: + for data in data_loader: + if isinstance(data_loader.sampler, ResumableSampler): + data_loader.sampler.idx += data_loader.batch_size # type: ignore[attr-defined] + yield data + if isinstance(data_loader.sampler, DistributedSampler): + data_loader.sampler.epoch += 1 + if isinstance(data_loader.sampler, ResumableSampler): + data_loader.sampler.epoch += 1 + data_loader.sampler.idx = 0 + + +class ResumableSampler(Sampler): + """ + Distributed sampler that is resumable. + + Args: + dataset: Dataset used for sampling. + rank (int, optional): Rank of the current process within :attr:`num_replicas`. + By default, :attr:`rank` is retrieved from the current distributed + group. + shuffle (bool, optional): If ``True`` (default), sampler will shuffle the + indices. + seed (int, optional): random seed used to shuffle the sampler if + :attr:`shuffle=True`. This number should be identical across all + processes in the distributed group. Default: ``0``. + drop_last (bool, optional): if ``True``, then the sampler will drop the + tail of the data to make it evenly divisible across the number of + replicas. If ``False``, the sampler will add extra indices to make + the data evenly divisible across the replicas. Default: ``False``. + """ + + def __init__( + self, + dataset: Dataset, + shuffle: bool = True, + seed: int = 0, + drop_last: bool = False, + ) -> None: + self.dataset = dataset + self.epoch = 0 + self.idx = 0 + self.drop_last = drop_last + self.world_size = dist.get_world_size() if dist.is_initialized() else 1 + self.rank = dist.get_rank() if dist.is_initialized() else 0 + # If the dataset length is evenly divisible by # of replicas, then there + # is no need to drop any data, since the dataset will be split equally. + if self.drop_last and len(self.dataset) % self.world_size != 0: # type: ignore[arg-type] + # Split to nearest available length that is evenly divisible. + # This is to ensure each rank receives the same amount of data when + # using this Sampler. + self.num_samples = math.ceil( + (len(self.dataset) - self.world_size) / self.world_size # type: ignore[arg-type] + ) + else: + self.num_samples = math.ceil(len(self.dataset) / self.world_size) # type: ignore[arg-type] + self.total_size = self.num_samples * self.world_size + self.shuffle = shuffle + self.seed = seed + + def __iter__(self) -> Iterator: + if self.shuffle: + # deterministically shuffle based on epoch and seed + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type] + else: + indices = list(range(len(self.dataset))) # type: ignore[arg-type] + + if not self.drop_last: + # add extra samples to make it evenly divisible + padding_size = self.total_size - len(indices) + if padding_size <= len(indices): + indices += indices[:padding_size] + else: + indices += (indices * math.ceil(padding_size / len(indices)))[ + :padding_size + ] + else: + # remove tail of data to make it evenly divisible. + indices = indices[: self.total_size] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank : self.total_size : self.world_size] + + # resume from previous state + indices = indices[self.idx:] + + return iter(indices) + + def __len__(self) -> int: + return self.num_samples + + def state_dict(self) -> dict[str, int]: + return { + 'epoch': self.epoch, + 'idx': self.idx, + } + + def load_state_dict(self, state_dict): + self.epoch = state_dict['epoch'] + self.idx = state_dict['idx'] + + +class BalancedResumableSampler(ResumableSampler): + """ + Distributed sampler that is resumable and balances the load among the processes. + + Args: + dataset: Dataset used for sampling. + rank (int, optional): Rank of the current process within :attr:`num_replicas`. + By default, :attr:`rank` is retrieved from the current distributed + group. + shuffle (bool, optional): If ``True`` (default), sampler will shuffle the + indices. + seed (int, optional): random seed used to shuffle the sampler if + :attr:`shuffle=True`. This number should be identical across all + processes in the distributed group. Default: ``0``. + drop_last (bool, optional): if ``True``, then the sampler will drop the + tail of the data to make it evenly divisible across the number of + replicas. If ``False``, the sampler will add extra indices to make + the data evenly divisible across the replicas. Default: ``False``. + """ + + def __init__( + self, + dataset: Dataset, + shuffle: bool = True, + seed: int = 0, + drop_last: bool = False, + batch_size: int = 1, + ) -> None: + assert hasattr(dataset, 'loads'), 'Dataset must have "loads" attribute to use BalancedResumableSampler' + super().__init__(dataset, shuffle, seed, drop_last) + self.batch_size = batch_size + self.loads = dataset.loads + + def __iter__(self) -> Iterator: + if self.shuffle: + # deterministically shuffle based on epoch and seed + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type] + else: + indices = list(range(len(self.dataset))) # type: ignore[arg-type] + + if not self.drop_last: + # add extra samples to make it evenly divisible + padding_size = self.total_size - len(indices) + if padding_size <= len(indices): + indices += indices[:padding_size] + else: + indices += (indices * math.ceil(padding_size / len(indices)))[ + :padding_size + ] + else: + # remove tail of data to make it evenly divisible. + indices = indices[: self.total_size] + assert len(indices) == self.total_size + + # balance load among processes + num_batches = len(indices) // (self.batch_size * self.world_size) + balanced_indices = [] + for i in range(num_batches): + start_idx = i * self.batch_size * self.world_size + end_idx = (i + 1) * self.batch_size * self.world_size + batch_indices = indices[start_idx:end_idx] + batch_loads = [self.loads[idx] for idx in batch_indices] + groups = load_balanced_group_indices(batch_loads, self.world_size, equal_size=True) + balanced_indices.extend([batch_indices[j] for j in groups[self.rank]]) + + # resume from previous state + indices = balanced_indices[self.idx:] + + return iter(indices) diff --git a/trellis2/utils/dist_utils.py b/trellis2/utils/dist_utils.py new file mode 100644 index 0000000..348799c --- /dev/null +++ b/trellis2/utils/dist_utils.py @@ -0,0 +1,93 @@ +import os +import io +from contextlib import contextmanager +import torch +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP + + +def setup_dist(rank, local_rank, world_size, master_addr, master_port): + os.environ['MASTER_ADDR'] = master_addr + os.environ['MASTER_PORT'] = master_port + os.environ['WORLD_SIZE'] = str(world_size) + os.environ['RANK'] = str(rank) + os.environ['LOCAL_RANK'] = str(local_rank) + torch.cuda.set_device(local_rank) + dist.init_process_group('nccl', rank=rank, world_size=world_size) + + +def read_file_dist(path): + """ + Read the binary file distributedly. + File is only read once by the rank 0 process and broadcasted to other processes. + + Returns: + data (io.BytesIO): The binary data read from the file. + """ + if dist.is_initialized() and dist.get_world_size() > 1: + # read file + size = torch.LongTensor(1).cuda() + if dist.get_rank() == 0: + with open(path, 'rb') as f: + data = f.read() + data = torch.ByteTensor( + torch.UntypedStorage.from_buffer(data, dtype=torch.uint8) + ).cuda() + size[0] = data.shape[0] + # broadcast size + dist.broadcast(size, src=0) + if dist.get_rank() != 0: + data = torch.ByteTensor(size[0].item()).cuda() + # broadcast data + dist.broadcast(data, src=0) + # convert to io.BytesIO + data = data.cpu().numpy().tobytes() + data = io.BytesIO(data) + return data + else: + with open(path, 'rb') as f: + data = f.read() + data = io.BytesIO(data) + return data + + +def unwrap_dist(model): + """ + Unwrap the model from distributed training. + """ + if isinstance(model, DDP): + return model.module + return model + + +@contextmanager +def master_first(): + """ + A context manager that ensures master process executes first. + """ + if not dist.is_initialized(): + yield + else: + if dist.get_rank() == 0: + yield + dist.barrier() + else: + dist.barrier() + yield + + +@contextmanager +def local_master_first(): + """ + A context manager that ensures local master process executes first. + """ + if not dist.is_initialized(): + yield + else: + if dist.get_rank() % torch.cuda.device_count() == 0: + yield + dist.barrier() + else: + dist.barrier() + yield + \ No newline at end of file diff --git a/trellis2/utils/elastic_utils.py b/trellis2/utils/elastic_utils.py new file mode 100644 index 0000000..cba3cf8 --- /dev/null +++ b/trellis2/utils/elastic_utils.py @@ -0,0 +1,228 @@ +from abc import abstractmethod +from contextlib import contextmanager +from typing import Tuple +import torch +import torch.nn as nn +import numpy as np + + +class MemoryController: + """ + Base class for memory management during training. + """ + + _last_input_size = None + _last_mem_ratio = [] + + @contextmanager + def record(self): + pass + + def update_run_states(self, input_size=None, mem_ratio=None): + if self._last_input_size is None: + self._last_input_size = input_size + elif self._last_input_size!= input_size: + raise ValueError(f'Input size should not change for different ElasticModules.') + self._last_mem_ratio.append(mem_ratio) + + @abstractmethod + def get_mem_ratio(self, input_size): + pass + + @abstractmethod + def state_dict(self): + pass + + @abstractmethod + def log(self): + pass + + +class LinearMemoryController(MemoryController): + """ + A simple controller for memory management during training. + The memory usage is modeled as a linear function of: + - the number of input parameters + - the ratio of memory the model use compared to the maximum usage (with no checkpointing) + memory_usage = k * input_size * mem_ratio + b + The controller keeps track of the memory usage and gives the + expected memory ratio to keep the memory usage under a target + """ + def __init__( + self, + buffer_size=1000, + update_every=500, + target_ratio=0.8, + available_memory=None, + max_mem_ratio_start=0.1, + params=None, + device=None + ): + self.buffer_size = buffer_size + self.update_every = update_every + self.target_ratio = target_ratio + self.device = device or torch.cuda.current_device() + self.available_memory = available_memory or torch.cuda.get_device_properties(self.device).total_memory / 1024**3 + + self._memory = np.zeros(buffer_size, dtype=np.float32) + self._input_size = np.zeros(buffer_size, dtype=np.float32) + self._mem_ratio = np.zeros(buffer_size, dtype=np.float32) + self._buffer_ptr = 0 + self._buffer_length = 0 + self._params = tuple(params) if params is not None else (0.0, 0.0) + self._max_mem_ratio = max_mem_ratio_start + self.step = 0 + + def __repr__(self): + return f'LinearMemoryController(target_ratio={self.target_ratio}, available_memory={self.available_memory})' + + def _add_sample(self, memory, input_size, mem_ratio): + self._memory[self._buffer_ptr] = memory + self._input_size[self._buffer_ptr] = input_size + self._mem_ratio[self._buffer_ptr] = mem_ratio + self._buffer_ptr = (self._buffer_ptr + 1) % self.buffer_size + self._buffer_length = min(self._buffer_length + 1, self.buffer_size) + + @contextmanager + def record(self): + torch.cuda.reset_peak_memory_stats(self.device) + self._last_input_size = None + self._last_mem_ratio = [] + yield + self._last_memory = torch.cuda.max_memory_allocated(self.device) / 1024**3 + self._last_mem_ratio = sum(self._last_mem_ratio) / len(self._last_mem_ratio) + self._add_sample(self._last_memory, self._last_input_size, self._last_mem_ratio) + self.step += 1 + if self.step % self.update_every == 0: + self._max_mem_ratio = min(1.0, self._max_mem_ratio + 0.1) + self._fit_params() + + def _fit_params(self): + memory_usage = self._memory[:self._buffer_length] + input_size = self._input_size[:self._buffer_length] + mem_ratio = self._mem_ratio[:self._buffer_length] + + x = input_size * mem_ratio + y = memory_usage + k, b = np.polyfit(x, y, 1) + self._params = (k, b) + # self._visualize() + + def _visualize(self): + import matplotlib.pyplot as plt + memory_usage = self._memory[:self._buffer_length] + input_size = self._input_size[:self._buffer_length] + mem_ratio = self._mem_ratio[:self._buffer_length] + k, b = self._params + + plt.scatter(input_size * mem_ratio, memory_usage, c=mem_ratio, cmap='viridis') + x = np.array([0.0, 20000.0]) + plt.plot(x, k * x + b, c='r') + plt.savefig(f'linear_memory_controller_{self.step}.png') + plt.cla() + + def get_mem_ratio(self, input_size): + k, b = self._params + if k == 0: return np.random.rand() * self._max_mem_ratio + pred = (self.available_memory * self.target_ratio - b) / (k * input_size) + return min(self._max_mem_ratio, max(0.0, pred)) + + def state_dict(self): + return { + 'params': self._params, + } + + def load_state_dict(self, state_dict): + self._params = tuple(state_dict['params']) + + def log(self): + return { + 'params/k': self._params[0], + 'params/b': self._params[1], + 'memory': self._last_memory, + 'input_size': self._last_input_size, + 'mem_ratio': self._last_mem_ratio, + } + + +class ElasticModule(nn.Module): + """ + Module for training with elastic memory management. + """ + def __init__(self): + super().__init__() + self._memory_controller: MemoryController = None + + @abstractmethod + def _get_input_size(self, *args, **kwargs) -> int: + """ + Get the size of the input data. + + Returns: + int: The size of the input data. + """ + pass + + @abstractmethod + def _forward_with_mem_ratio(self, *args, mem_ratio=0.0, **kwargs) -> Tuple[float, Tuple]: + """ + Forward with a given memory ratio. + """ + pass + + def register_memory_controller(self, memory_controller: MemoryController): + self._memory_controller = memory_controller + + def forward(self, *args, **kwargs): + if self._memory_controller is None or not torch.is_grad_enabled() or not self.training: + _, ret = self._forward_with_mem_ratio(*args, **kwargs) + else: + input_size = self._get_input_size(*args, **kwargs) + mem_ratio = self._memory_controller.get_mem_ratio(input_size) + mem_ratio, ret = self._forward_with_mem_ratio(*args, mem_ratio=mem_ratio, **kwargs) + self._memory_controller.update_run_states(input_size, mem_ratio) + return ret + + +class ElasticModuleMixin: + """ + Mixin for training with elastic memory management. + """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._memory_controller: MemoryController = None + + @abstractmethod + def _get_input_size(self, *args, **kwargs) -> int: + """ + Get the size of the input data. + + Returns: + int: The size of the input data. + """ + pass + + @abstractmethod + @contextmanager + def with_mem_ratio(self, mem_ratio=1.0) -> float: + """ + Context manager for training with a reduced memory ratio compared to the full memory usage. + + Returns: + float: The exact memory ratio used during the forward pass. + """ + pass + + def register_memory_controller(self, memory_controller: MemoryController): + self._memory_controller = memory_controller + + def forward(self, *args, **kwargs): + if self._memory_controller is None or not torch.is_grad_enabled() or not self.training: + ret = super().forward(*args, **kwargs) + else: + input_size = self._get_input_size(*args, **kwargs) + mem_ratio = self._memory_controller.get_mem_ratio(input_size) + with self.with_mem_ratio(mem_ratio) as exact_mem_ratio: + ret = super().forward(*args, **kwargs) + self._memory_controller.update_run_states(input_size, exact_mem_ratio) + return ret diff --git a/trellis2/utils/general_utils.py b/trellis2/utils/general_utils.py new file mode 100644 index 0000000..589c103 --- /dev/null +++ b/trellis2/utils/general_utils.py @@ -0,0 +1,373 @@ +import re +import numpy as np +import cv2 +import torch +import contextlib + + +# Dictionary utils +def _dict_merge(dicta, dictb, prefix=''): + """ + Merge two dictionaries. + """ + assert isinstance(dicta, dict), 'input must be a dictionary' + assert isinstance(dictb, dict), 'input must be a dictionary' + dict_ = {} + all_keys = set(dicta.keys()).union(set(dictb.keys())) + for key in all_keys: + if key in dicta.keys() and key in dictb.keys(): + if isinstance(dicta[key], dict) and isinstance(dictb[key], dict): + dict_[key] = _dict_merge(dicta[key], dictb[key], prefix=f'{prefix}.{key}') + else: + raise ValueError(f'Duplicate key {prefix}.{key} found in both dictionaries. Types: {type(dicta[key])}, {type(dictb[key])}') + elif key in dicta.keys(): + dict_[key] = dicta[key] + else: + dict_[key] = dictb[key] + return dict_ + + +def dict_merge(dicta, dictb): + """ + Merge two dictionaries. + """ + return _dict_merge(dicta, dictb, prefix='') + + +def dict_foreach(dic, func, special_func={}): + """ + Recursively apply a function to all non-dictionary leaf values in a dictionary. + """ + assert isinstance(dic, dict), 'input must be a dictionary' + for key in dic.keys(): + if isinstance(dic[key], dict): + dic[key] = dict_foreach(dic[key], func) + else: + if key in special_func.keys(): + dic[key] = special_func[key](dic[key]) + else: + dic[key] = func(dic[key]) + return dic + + +def dict_reduce(dicts, func, special_func={}): + """ + Reduce a list of dictionaries. Leaf values must be scalars. + """ + assert isinstance(dicts, list), 'input must be a list of dictionaries' + assert all([isinstance(d, dict) for d in dicts]), 'input must be a list of dictionaries' + assert len(dicts) > 0, 'input must be a non-empty list of dictionaries' + all_keys = set([key for dict_ in dicts for key in dict_.keys()]) + reduced_dict = {} + for key in all_keys: + vlist = [dict_[key] for dict_ in dicts if key in dict_.keys()] + if isinstance(vlist[0], dict): + reduced_dict[key] = dict_reduce(vlist, func, special_func) + else: + if key in special_func.keys(): + reduced_dict[key] = special_func[key](vlist) + else: + reduced_dict[key] = func(vlist) + return reduced_dict + + +def dict_any(dic, func): + """ + Recursively apply a function to all non-dictionary leaf values in a dictionary. + """ + assert isinstance(dic, dict), 'input must be a dictionary' + for key in dic.keys(): + if isinstance(dic[key], dict): + if dict_any(dic[key], func): + return True + else: + if func(dic[key]): + return True + return False + + +def dict_all(dic, func): + """ + Recursively apply a function to all non-dictionary leaf values in a dictionary. + """ + assert isinstance(dic, dict), 'input must be a dictionary' + for key in dic.keys(): + if isinstance(dic[key], dict): + if not dict_all(dic[key], func): + return False + else: + if not func(dic[key]): + return False + return True + + +def dict_flatten(dic, sep='.'): + """ + Flatten a nested dictionary into a dictionary with no nested dictionaries. + """ + assert isinstance(dic, dict), 'input must be a dictionary' + flat_dict = {} + for key in dic.keys(): + if isinstance(dic[key], dict): + sub_dict = dict_flatten(dic[key], sep=sep) + for sub_key in sub_dict.keys(): + flat_dict[str(key) + sep + str(sub_key)] = sub_dict[sub_key] + else: + flat_dict[key] = dic[key] + return flat_dict + + +# Context utils +@contextlib.contextmanager +def nested_contexts(*contexts): + with contextlib.ExitStack() as stack: + for ctx in contexts: + stack.enter_context(ctx()) + yield + + +# Image utils +def make_grid(images, nrow=None, ncol=None, aspect_ratio=None): + num_images = len(images) + if nrow is None and ncol is None: + if aspect_ratio is not None: + nrow = int(np.round(np.sqrt(num_images / aspect_ratio))) + else: + nrow = int(np.sqrt(num_images)) + ncol = (num_images + nrow - 1) // nrow + elif nrow is None and ncol is not None: + nrow = (num_images + ncol - 1) // ncol + elif nrow is not None and ncol is None: + ncol = (num_images + nrow - 1) // nrow + else: + assert nrow * ncol >= num_images, 'nrow * ncol must be greater than or equal to the number of images' + + if images[0].ndim == 2: + grid = np.zeros((nrow * images[0].shape[0], ncol * images[0].shape[1]), dtype=images[0].dtype) + else: + grid = np.zeros((nrow * images[0].shape[0], ncol * images[0].shape[1], images[0].shape[2]), dtype=images[0].dtype) + for i, img in enumerate(images): + row = i // ncol + col = i % ncol + grid[row * img.shape[0]:(row + 1) * img.shape[0], col * img.shape[1]:(col + 1) * img.shape[1]] = img + return grid + + +def notes_on_image(img, notes=None): + img = np.pad(img, ((0, 32), (0, 0), (0, 0)), 'constant', constant_values=0) + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + if notes is not None: + img = cv2.putText(img, notes, (0, img.shape[0] - 4), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 1) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + return img + + + +def text_image(text, resolution=(512, 512), max_size=0.5, h_align="left", v_align="center"): + """ + Draw text on an image of the given resolution. The text is automatically wrapped + and scaled so that it fits completely within the image while preserving any explicit + line breaks and original spacing. Horizontal and vertical alignment can be controlled + via flags. + + Parameters: + text (str): The input text. Newline characters and spacing are preserved. + resolution (tuple): The image resolution as (width, height). + max_size (float): The maximum font size. + h_align (str): Horizontal alignment. Options: "left", "center", "right". + v_align (str): Vertical alignment. Options: "top", "center", "bottom". + + Returns: + numpy.ndarray: The resulting image (BGR format) with the text drawn. + """ + width, height = resolution + # Create a white background image + img = np.full((height, width, 3), 255, dtype=np.uint8) + + # Set margins and compute available drawing area + margin = 10 + avail_width = width - 2 * margin + avail_height = height - 2 * margin + + # Choose OpenCV font and text thickness + font = cv2.FONT_HERSHEY_SIMPLEX + thickness = 1 + # Ratio for additional spacing between lines (relative to the height of "A") + line_spacing_ratio = 0.5 + + def wrap_line(line, max_width, font, thickness, scale): + """ + Wrap a single line of text into multiple lines such that each line's + width (measured at the given scale) does not exceed max_width. + This function preserves the original spacing by splitting the line into tokens + (words and whitespace) using a regular expression. + + Parameters: + line (str): The input text line. + max_width (int): Maximum allowed width in pixels. + font (int): OpenCV font identifier. + thickness (int): Text thickness. + scale (float): The current font scale. + + Returns: + List[str]: A list of wrapped lines. + """ + # Split the line into tokens (words and whitespace), preserving spacing + tokens = re.split(r'(\s+)', line) + if not tokens: + return [''] + + wrapped_lines = [] + current_line = "" + for token in tokens: + candidate = current_line + token + candidate_width = cv2.getTextSize(candidate, font, scale, thickness)[0][0] + if candidate_width <= max_width: + current_line = candidate + else: + # If current_line is empty, the token itself is too wide; + # break the token character by character. + if current_line == "": + sub_token = "" + for char in token: + candidate_char = sub_token + char + if cv2.getTextSize(candidate_char, font, scale, thickness)[0][0] <= max_width: + sub_token = candidate_char + else: + if sub_token: + wrapped_lines.append(sub_token) + sub_token = char + current_line = sub_token + else: + wrapped_lines.append(current_line) + current_line = token + if current_line: + wrapped_lines.append(current_line) + return wrapped_lines + + def compute_text_block(scale): + """ + Wrap the entire text (splitting at explicit newline characters) using the + provided scale, and then compute the overall width and height of the text block. + + Returns: + wrapped_lines (List[str]): The list of wrapped lines. + block_width (int): Maximum width among the wrapped lines. + block_height (int): Total height of the text block including spacing. + sizes (List[tuple]): A list of (width, height) for each wrapped line. + spacing (int): The spacing between lines (computed from the scaled "A" height). + """ + # Split text by explicit newlines + input_lines = text.splitlines() if text else [''] + wrapped_lines = [] + for line in input_lines: + wrapped = wrap_line(line, avail_width, font, thickness, scale) + wrapped_lines.extend(wrapped) + + sizes = [] + for line in wrapped_lines: + (text_size, _) = cv2.getTextSize(line, font, scale, thickness) + sizes.append(text_size) # (width, height) + + block_width = max((w for w, h in sizes), default=0) + # Use the height of "A" (at the current scale) to compute line spacing + base_height = cv2.getTextSize("A", font, scale, thickness)[0][1] + spacing = int(line_spacing_ratio * base_height) + block_height = sum(h for w, h in sizes) + spacing * (len(sizes) - 1) if sizes else 0 + + return wrapped_lines, block_width, block_height, sizes, spacing + + # Use binary search to find the maximum scale that allows the text block to fit + lo = 0.001 + hi = max_size + eps = 0.001 # convergence threshold + best_scale = lo + best_result = None + + while hi - lo > eps: + mid = (lo + hi) / 2 + wrapped_lines, block_width, block_height, sizes, spacing = compute_text_block(mid) + # Ensure that both width and height constraints are met + if block_width <= avail_width and block_height <= avail_height: + best_scale = mid + best_result = (wrapped_lines, block_width, block_height, sizes, spacing) + lo = mid # try a larger scale + else: + hi = mid # reduce the scale + + if best_result is None: + best_scale = 0.5 + best_result = compute_text_block(best_scale) + + wrapped_lines, block_width, block_height, sizes, spacing = best_result + + # Compute starting y-coordinate based on vertical alignment flag + if v_align == "top": + y_top = margin + elif v_align == "center": + y_top = margin + (avail_height - block_height) // 2 + elif v_align == "bottom": + y_top = margin + (avail_height - block_height) + else: + y_top = margin + (avail_height - block_height) // 2 # default to center if invalid flag + + # For cv2.putText, the y coordinate represents the text baseline; + # so for the first line add its height. + y = y_top + (sizes[0][1] if sizes else 0) + + # Draw each line with horizontal alignment based on the flag + for i, line in enumerate(wrapped_lines): + line_width, line_height = sizes[i] + if h_align == "left": + x = margin + elif h_align == "center": + x = margin + (avail_width - line_width) // 2 + elif h_align == "right": + x = margin + (avail_width - line_width) + else: + x = margin # default to left if invalid flag + + cv2.putText(img, line, (x, y), font, best_scale, (0, 0, 0), thickness, cv2.LINE_AA) + y += line_height + spacing + + return img + + +def save_image_with_notes(img, path, notes=None): + """ + Save an image with notes. + """ + if isinstance(img, torch.Tensor): + img = img.cpu().numpy().transpose(1, 2, 0) + if img.dtype == np.float32 or img.dtype == np.float64: + img = np.clip(img * 255, 0, 255).astype(np.uint8) + img = notes_on_image(img, notes) + cv2.imwrite(path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) + + +# debug utils + +def atol(x, y): + """ + Absolute tolerance. + """ + return torch.abs(x - y) + + +def rtol(x, y): + """ + Relative tolerance. + """ + return torch.abs(x - y) / torch.clamp_min(torch.maximum(torch.abs(x), torch.abs(y)), 1e-12) + + +# print utils +def indent(s, n=4): + """ + Indent a string. + """ + lines = s.split('\n') + for i in range(1, len(lines)): + lines[i] = ' ' * n + lines[i] + return '\n'.join(lines) + diff --git a/trellis2/utils/grad_clip_utils.py b/trellis2/utils/grad_clip_utils.py new file mode 100644 index 0000000..990a435 --- /dev/null +++ b/trellis2/utils/grad_clip_utils.py @@ -0,0 +1,81 @@ +from typing import * +import torch +import numpy as np +import torch.utils + + +class AdaptiveGradClipper: + """ + Adaptive gradient clipping for training. + """ + def __init__( + self, + max_norm=None, + clip_percentile=95.0, + buffer_size=1000, + ): + self.max_norm = max_norm + self.clip_percentile = clip_percentile + self.buffer_size = buffer_size + + self._grad_norm = np.zeros(buffer_size, dtype=np.float32) + self._max_norm = max_norm + self._buffer_ptr = 0 + self._buffer_length = 0 + + def __repr__(self): + return f'AdaptiveGradClipper(max_norm={self.max_norm}, clip_percentile={self.clip_percentile})' + + def state_dict(self): + return { + 'grad_norm': self._grad_norm, + 'max_norm': self._max_norm, + 'buffer_ptr': self._buffer_ptr, + 'buffer_length': self._buffer_length, + } + + def load_state_dict(self, state_dict): + self._grad_norm = state_dict['grad_norm'] + self._max_norm = state_dict['max_norm'] + self._buffer_ptr = state_dict['buffer_ptr'] + self._buffer_length = state_dict['buffer_length'] + + def log(self): + return { + 'max_norm': self._max_norm, + } + + def __call__(self, parameters, norm_type=2.0, error_if_nonfinite=False, foreach=None): + """Clip the gradient norm of an iterable of parameters. + + The norm is computed over all gradients together, as if they were + concatenated into a single vector. Gradients are modified in-place. + + Args: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + norm_type (float): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + error_if_nonfinite (bool): if True, an error is thrown if the total + norm of the gradients from :attr:`parameters` is ``nan``, + ``inf``, or ``-inf``. Default: False (will switch to True in the future) + foreach (bool): use the faster foreach-based implementation. + If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently + fall back to the slow implementation for other device types. + Default: ``None`` + + Returns: + Total norm of the parameter gradients (viewed as a single vector). + """ + max_norm = self._max_norm if self._max_norm is not None else float('inf') + grad_norm = torch.nn.utils.clip_grad_norm_(parameters, max_norm=max_norm, norm_type=norm_type, error_if_nonfinite=error_if_nonfinite, foreach=foreach) + + if torch.isfinite(grad_norm): + self._grad_norm[self._buffer_ptr] = grad_norm + self._buffer_ptr = (self._buffer_ptr + 1) % self.buffer_size + self._buffer_length = min(self._buffer_length + 1, self.buffer_size) + if self._buffer_length == self.buffer_size: + self._max_norm = np.percentile(self._grad_norm, self.clip_percentile) + self._max_norm = min(self._max_norm, self.max_norm) if self.max_norm is not None else self._max_norm + + return grad_norm \ No newline at end of file diff --git a/trellis2/utils/loss_utils.py b/trellis2/utils/loss_utils.py new file mode 100644 index 0000000..52049f6 --- /dev/null +++ b/trellis2/utils/loss_utils.py @@ -0,0 +1,92 @@ +import torch +import torch.nn.functional as F +from torch.autograd import Variable +from math import exp +from lpips import LPIPS + + +def smooth_l1_loss(pred, target, beta=1.0): + diff = torch.abs(pred - target) + loss = torch.where(diff < beta, 0.5 * diff ** 2 / beta, diff - 0.5 * beta) + return loss.mean() + + +def l1_loss(network_output, gt): + return torch.abs((network_output - gt)).mean() + + +def l2_loss(network_output, gt): + return ((network_output - gt) ** 2).mean() + + +def gaussian(window_size, sigma): + gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) + return gauss / gauss.sum() + + +def create_window(window_size, channel): + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) + window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) + return window + + +def psnr(img1, img2, max_val=1.0): + mse = F.mse_loss(img1, img2) + return 20 * torch.log10(max_val / torch.sqrt(mse)) + + +def ssim(img1, img2, window_size=11, size_average=True): + channel = img1.size(-3) + window = create_window(window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + return _ssim(img1, img2, window, window_size, channel, size_average) + +def _ssim(img1, img2, window, window_size, channel, size_average=True): + mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) + mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq + sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq + sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 + + C1 = 0.01 ** 2 + C2 = 0.03 ** 2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) + + if size_average: + return ssim_map.mean() + else: + return ssim_map.mean(1).mean(1).mean(1) + + +loss_fn_vgg = None +def lpips(img1, img2, value_range=(0, 1)): + global loss_fn_vgg + if loss_fn_vgg is None: + loss_fn_vgg = LPIPS(net='vgg').cuda().eval() + # normalize to [-1, 1] + img1 = (img1 - value_range[0]) / (value_range[1] - value_range[0]) * 2 - 1 + img2 = (img2 - value_range[0]) / (value_range[1] - value_range[0]) * 2 - 1 + return loss_fn_vgg(img1, img2).mean() + + +def normal_angle(pred, gt): + pred = pred * 2.0 - 1.0 + gt = gt * 2.0 - 1.0 + norms = pred.norm(dim=-1) * gt.norm(dim=-1) + cos_sim = (pred * gt).sum(-1) / (norms + 1e-9) + cos_sim = torch.clamp(cos_sim, -1.0, 1.0) + ang = torch.rad2deg(torch.acos(cos_sim[norms > 1e-9])).mean() + if ang.isnan(): + return -1 + return ang diff --git a/trellis2/utils/mesh_utils.py b/trellis2/utils/mesh_utils.py new file mode 100644 index 0000000..a9f1451 --- /dev/null +++ b/trellis2/utils/mesh_utils.py @@ -0,0 +1,268 @@ +from typing import Tuple, Dict +import numpy as np +from trimesh import grouping, util, remesh +import struct +import re +from plyfile import PlyData, PlyElement + + +def read_ply(filename): + """ + Read a PLY file and return vertices, triangle faces, and quad faces. + + Args: + filename (str): The file path to read from. + + Returns: + vertices (np.ndarray): Array of shape [N, 3] containing vertex positions. + tris (np.ndarray): Array of shape [M, 3] containing triangle face indices (empty if none). + quads (np.ndarray): Array of shape [K, 4] containing quad face indices (empty if none). + """ + with open(filename, 'rb') as f: + # Read the header until 'end_header' is encountered + header_bytes = b"" + while True: + line = f.readline() + if not line: + raise ValueError("PLY header not found") + header_bytes += line + if b"end_header" in line: + break + header = header_bytes.decode('utf-8') + + # Determine if the file is in ASCII or binary format + is_ascii = "ascii" in header + + # Extract the number of vertices and faces from the header using regex + vertex_match = re.search(r'element vertex (\d+)', header) + if vertex_match: + num_vertices = int(vertex_match.group(1)) + else: + raise ValueError("Vertex count not found in header") + + face_match = re.search(r'element face (\d+)', header) + if face_match: + num_faces = int(face_match.group(1)) + else: + raise ValueError("Face count not found in header") + + vertices = [] + tris = [] + quads = [] + + if is_ascii: + # For ASCII format, read each line of vertex data (each line contains 3 floats) + for _ in range(num_vertices): + line = f.readline().decode('utf-8').strip() + if not line: + continue + parts = line.split() + vertices.append([float(parts[0]), float(parts[1]), float(parts[2])]) + + # Read face data, where the first number indicates the number of vertices for the face + for _ in range(num_faces): + line = f.readline().decode('utf-8').strip() + if not line: + continue + parts = line.split() + count = int(parts[0]) + indices = list(map(int, parts[1:])) + if count == 3: + tris.append(indices) + elif count == 4: + quads.append(indices) + else: + # Skip faces with other numbers of vertices (can be extended as needed) + pass + else: + # For binary format: read directly from the binary stream + # Each vertex consists of 3 floats (12 bytes per vertex) + for _ in range(num_vertices): + data = f.read(12) + if len(data) < 12: + raise ValueError("Insufficient vertex data") + v = struct.unpack(' 0 else np.empty((0, 3), dtype=np.int32) + quads = np.array(quads, dtype=np.int32) if len(quads) > 0 else np.empty((0, 4), dtype=np.int32) + + return vertices, tris, quads + + +def write_ply( + filename: str, + vertices: np.ndarray, + tris: np.ndarray, + quads: np.ndarray, + vertex_colors: np.ndarray = None, + ascii: bool = False +): + """ + Write a mesh to a PLY file, with the option to save in ASCII or binary format, + and optional per-vertex colors. + + Args: + filename (str): The filename to write to. + vertices (np.ndarray): [N, 3] The vertex positions. + tris (np.ndarray): [M, 3] The triangle indices. + quads (np.ndarray): [K, 4] The quad indices. + vertex_colors (np.ndarray, optional): [N, 3] or [N, 4] UInt8 colors for each vertex (RGB or RGBA). + ascii (bool): If True, write in ASCII format; otherwise binary little-endian. + """ + import struct + + num_vertices = len(vertices) + num_faces = len(tris) + len(quads) + + # Build header + header_lines = [ + "ply", + f"format {'ascii 1.0' if ascii else 'binary_little_endian 1.0'}", + f"element vertex {num_vertices}", + "property float x", + "property float y", + "property float z", + ] + + # Add vertex color properties if provided + has_color = vertex_colors is not None + if has_color: + # Expect uint8 values 0-255 + header_lines += [ + "property uchar red", + "property uchar green", + "property uchar blue", + ] + # Include alpha if RGBA + if vertex_colors.shape[1] == 4: + header_lines.append("property uchar alpha") + + header_lines += [ + f"element face {num_faces}", + "property list uchar int vertex_index", + "end_header", + "" + ] + header = "\n".join(header_lines) + + mode = 'w' if ascii else 'wb' + with open(filename, mode) as f: + # Write header + if ascii: + f.write(header) + else: + f.write(header.encode('utf-8')) + + # Write vertex data + for i, v in enumerate(vertices): + if ascii: + line = f"{v[0]} {v[1]} {v[2]}" + if has_color: + col = vertex_colors[i] + line += ' ' + ' '.join(str(int(c)) for c in col) + f.write(line + '\n') + else: + # pack xyz as floats + f.write(struct.pack(' 0: + digit = n % base + val += digit * inv_base_n + n //= base + inv_base_n *= inv_base + return val + +def halton_sequence(dim, n): + return [radical_inverse(PRIMES[dim], n) for dim in range(dim)] + +def hammersley_sequence(dim, n, num_samples): + return [n / num_samples] + halton_sequence(dim - 1, n) + +def sphere_hammersley_sequence(n, num_samples, offset=(0, 0), remap=False): + u, v = hammersley_sequence(2, n, num_samples) + u += offset[0] / num_samples + v += offset[1] + if remap: + u = 2 * u if u < 0.25 else 2 / 3 * u + 1 / 3 + theta = np.arccos(1 - 2 * u) - np.pi / 2 + phi = v * 2 * np.pi + return [phi, theta] \ No newline at end of file diff --git a/trellis2/utils/render_utils.py b/trellis2/utils/render_utils.py new file mode 100644 index 0000000..d6fb3e9 --- /dev/null +++ b/trellis2/utils/render_utils.py @@ -0,0 +1,129 @@ +import torch +import numpy as np +from tqdm import tqdm +import utils3d +from PIL import Image + +from ..renderers import MeshRenderer, VoxelRenderer, PbrMeshRenderer +from ..representations import Mesh, Voxel, MeshWithPbrMaterial, MeshWithVoxel +from .random_utils import sphere_hammersley_sequence + + +def yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitchs, rs, fovs): + is_list = isinstance(yaws, list) + if not is_list: + yaws = [yaws] + pitchs = [pitchs] + if not isinstance(rs, list): + rs = [rs] * len(yaws) + if not isinstance(fovs, list): + fovs = [fovs] * len(yaws) + extrinsics = [] + intrinsics = [] + for yaw, pitch, r, fov in zip(yaws, pitchs, rs, fovs): + fov = torch.deg2rad(torch.tensor(float(fov))).cuda() + yaw = torch.tensor(float(yaw)).cuda() + pitch = torch.tensor(float(pitch)).cuda() + orig = torch.tensor([ + torch.sin(yaw) * torch.cos(pitch), + torch.cos(yaw) * torch.cos(pitch), + torch.sin(pitch), + ]).cuda() * r + extr = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda()) + intr = utils3d.torch.intrinsics_from_fov_xy(fov, fov) + extrinsics.append(extr) + intrinsics.append(intr) + if not is_list: + extrinsics = extrinsics[0] + intrinsics = intrinsics[0] + return extrinsics, intrinsics + + +def get_renderer(sample, **kwargs): + if isinstance(sample, (MeshWithPbrMaterial, MeshWithVoxel)): + renderer = PbrMeshRenderer() + renderer.rendering_options.resolution = kwargs.get('resolution', 512) + renderer.rendering_options.near = kwargs.get('near', 1) + renderer.rendering_options.far = kwargs.get('far', 100) + renderer.rendering_options.ssaa = kwargs.get('ssaa', 2) + renderer.rendering_options.peel_layers = kwargs.get('peel_layers', 8) + elif isinstance(sample, Mesh): + renderer = MeshRenderer() + renderer.rendering_options.resolution = kwargs.get('resolution', 512) + renderer.rendering_options.near = kwargs.get('near', 1) + renderer.rendering_options.far = kwargs.get('far', 100) + renderer.rendering_options.ssaa = kwargs.get('ssaa', 2) + renderer.rendering_options.chunk_size = kwargs.get('chunk_size', None) + elif isinstance(sample, Voxel): + renderer = VoxelRenderer() + renderer.rendering_options.resolution = kwargs.get('resolution', 512) + renderer.rendering_options.near = kwargs.get('near', 0.1) + renderer.rendering_options.far = kwargs.get('far', 10.0) + renderer.rendering_options.ssaa = kwargs.get('ssaa', 2) + else: + raise ValueError(f'Unsupported sample type: {type(sample)}') + return renderer + + +def render_frames(sample, extrinsics, intrinsics, options={}, verbose=True, **kwargs): + renderer = get_renderer(sample, **options) + rets = {} + for j, (extr, intr) in tqdm(enumerate(zip(extrinsics, intrinsics)), total=len(extrinsics), desc='Rendering', disable=not verbose): + res = renderer.render(sample, extr, intr, **kwargs) + for k, v in res.items(): + if k not in rets: rets[k] = [] + if v.dim() == 2: v = v[None].repeat(3, 1, 1) + rets[k].append(np.clip(v.detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255).astype(np.uint8)) + return rets + + +def render_video(sample, resolution=1024, bg_color=(0, 0, 0), num_frames=120, r=2, fov=40, **kwargs): + yaws = -torch.linspace(0, 2 * 3.1415, num_frames) + np.pi/2 + pitch = 0.25 + 0.5 * torch.sin(torch.linspace(0, 2 * 3.1415, num_frames)) + yaws = yaws.tolist() + pitch = pitch.tolist() + extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitch, r, fov) + return render_frames(sample, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': bg_color}, **kwargs) + + +def render_multiview(sample, resolution=512, nviews=30): + r = 2 + fov = 40 + cams = [sphere_hammersley_sequence(i, nviews) for i in range(nviews)] + yaws = [cam[0] for cam in cams] + pitchs = [cam[1] for cam in cams] + extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitchs, r, fov) + res = render_frames(sample, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': (0, 0, 0)}) + return res['color'], extrinsics, intrinsics + + +def render_snapshot(samples, resolution=512, bg_color=(0, 0, 0), offset=(-16 / 180 * np.pi, 20 / 180 * np.pi), r=10, fov=8, nviews=4, **kwargs): + yaw = np.linspace(0, 2 * np.pi, nviews, endpoint=False) + yaw_offset = offset[0] + yaw = [y + yaw_offset for y in yaw] + pitch = [offset[1] for _ in range(nviews)] + extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaw, pitch, r, fov) + return render_frames(samples, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': bg_color}, **kwargs) + + +def make_pbr_vis_frames(result, resolution=1024): + num_frames = len(result['shaded']) + frames = [] + for i in range(num_frames): + shaded = Image.fromarray(result['shaded'][i]) + normal = Image.fromarray(result['normal'][i]) + base_color = Image.fromarray(result['base_color'][i]) + metallic = Image.fromarray(result['metallic'][i]) + roughness = Image.fromarray(result['roughness'][i]) + alpha = Image.fromarray(result['alpha'][i]) + shaded = shaded.resize((resolution, resolution)) + normal = normal.resize((resolution, resolution)) + base_color = base_color.resize((resolution//2, resolution//2)) + metallic = metallic.resize((resolution//2, resolution//2)) + roughness = roughness.resize((resolution//2, resolution//2)) + alpha = alpha.resize((resolution//2, resolution//2)) + row1 = np.concatenate([shaded, normal], axis=1) + row2 = np.concatenate([base_color, metallic, roughness, alpha], axis=1) + frame = np.concatenate([row1, row2], axis=0) + frames.append(frame) + return frames diff --git a/trellis2/utils/vis_utils.py b/trellis2/utils/vis_utils.py new file mode 100644 index 0000000..0e5f58e --- /dev/null +++ b/trellis2/utils/vis_utils.py @@ -0,0 +1,44 @@ +from typing import * +import numpy as np +import torch +from ..modules import sparse as sp +from ..representations import Voxel +from .render_utils import render_video + + +def pca_color(feats: torch.Tensor, channels: Tuple[int, int, int] = (0, 1, 2)) -> torch.Tensor: + """ + Apply PCA to the features and return the first three principal components. + """ + feats = feats.detach() + u, s, v = torch.svd(feats) + color = u[:, channels] + color = (color - color.min(dim=0, keepdim=True)[0]) / (color.max(dim=0, keepdim=True)[0] - color.min(dim=0, keepdim=True)[0]) + return color + + +def vis_sparse_tensor( + x: sp.SparseTensor, + num_frames: int = 300, +): + assert x.shape[0] == 1, "Only support batch size 1" + assert x.coords.shape[1] == 4, "Only support 3D coordinates" + + coords = x.coords.cuda().detach()[:, 1:] + feats = x.feats.cuda().detach() + color = pca_color(feats) + + resolution = max(list(x.spatial_shape)) + resolution = int(2**np.ceil(np.log2(resolution))) + + rep = Voxel( + origin=[-0.5, -0.5, -0.5], + voxel_size=1/resolution, + coords=coords, + attrs=color, + layout={ + 'color': slice(0, 3), + } + ) + + return render_video(rep, colors_overwrite=color, num_frames=num_frames)['color']