# Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved.
本教學課程示範如何使用微分體積渲染來調整場景一系列檢視下得到的體積。
更具體來說,本教學課程將說明如何執行以下步驟:
Volumes
類別)。確認已安裝 torch
和 torchvision
。如果尚未安裝 pytorch3d
,請使用以下儲存格安裝它
import os
import sys
import torch
need_pytorch3d=False
try:
import pytorch3d
except ModuleNotFoundError:
need_pytorch3d=True
if need_pytorch3d:
if torch.__version__.startswith("2.2.") and sys.platform.startswith("linux"):
# We try to install PyTorch3D via a released wheel.
pyt_version_str=torch.__version__.split("+")[0].replace(".", "")
version_str="".join([
f"py3{sys.version_info.minor}_cu",
torch.version.cuda.replace(".",""),
f"_pyt{pyt_version_str}"
])
!pip install fvcore iopath
!pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html
else:
# We try to install PyTorch3D from source.
!pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'
import os
import sys
import time
import json
import glob
import torch
import math
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from IPython import display
# Data structures and functions for rendering
from pytorch3d.structures import Volumes
from pytorch3d.renderer import (
FoVPerspectiveCameras,
VolumeRenderer,
NDCMultinomialRaysampler,
EmissionAbsorptionRaymarcher
)
from pytorch3d.transforms import so3_exp_map
# obtain the utilized device
if torch.cuda.is_available():
device = torch.device("cuda:0")
torch.cuda.set_device(device)
else:
device = torch.device("cpu")
!wget https://raw.githubusercontent.com/facebookresearch/pytorch3d/main/docs/tutorials/utils/plot_image_grid.py
!wget https://raw.githubusercontent.com/facebookresearch/pytorch3d/main/docs/tutorials/utils/generate_cow_renders.py
from plot_image_grid import image_grid
from generate_cow_renders import generate_cow_renders
如果要在本機執行,請取消註解並執行以下儲存格
# from utils.generate_cow_renders import generate_cow_renders
# from utils import image_grid
以下儲存格會產生我們的訓練資料。它會從多個視角渲染 fit_textured_mesh.ipynb
教學課程中的牛隻網格,並傳回
注意:本教學課程的目的是說明體積渲染的詳細資訊,在 generate_cow_renders
函式中實作的網格渲染方式不在討論範圍內。如需有關網格渲染的詳細說明,請參閱 fit_textured_mesh.ipynb
。
target_cameras, target_images, target_silhouettes = generate_cow_renders(num_views=40)
print(f'Generated {len(target_images)} images/silhouettes/cameras.')
接下來會初始化體積渲染器,由目標影像的每個畫素發射一條射線,並沿射線均勻間隔取樣一組點。在每個 ray-point 處,透過查詢場景體積模型中相應的位置,取得對應的密度和顏色值(稍後會說明和實例化模型)。
渲染器包含一個 射線追蹤器 和一個 射線採樣器。
NDCMultinomialRaysampler
,它遵循標準的 PyTorch3D 座標格線慣例(+X 由右至左、+Y 由下至上、+Z 由使用者遠離)。EmissionAbsorptionRaymarcher
。# render_size describes the size of both sides of the
# rendered images in pixels. We set this to the same size
# as the target images. I.e. we render at the same
# size as the ground truth images.
render_size = target_images.shape[1]
# Our rendered scene is centered around (0,0,0)
# and is enclosed inside a bounding box
# whose side is roughly equal to 3.0 (world units).
volume_extent_world = 3.0
# 1) Instantiate the raysampler.
# Here, NDCMultinomialRaysampler generates a rectangular image
# grid of rays whose coordinates follow the PyTorch3D
# coordinate conventions.
# Since we use a volume of size 128^3, we sample n_pts_per_ray=150,
# which roughly corresponds to a one ray-point per voxel.
# We further set the min_depth=0.1 since there is no surface within
# 0.1 units of any camera plane.
raysampler = NDCMultinomialRaysampler(
image_width=render_size,
image_height=render_size,
n_pts_per_ray=150,
min_depth=0.1,
max_depth=volume_extent_world,
)
# 2) Instantiate the raymarcher.
# Here, we use the standard EmissionAbsorptionRaymarcher
# which marches along each ray in order to render
# each ray into a single 3D color vector
# and an opacity scalar.
raymarcher = EmissionAbsorptionRaymarcher()
# Finally, instantiate the volumetric render
# with the raysampler and raymarcher objects.
renderer = VolumeRenderer(
raysampler=raysampler, raymarcher=raymarcher,
)
接下來,我們會實例化場景的體積模型。此模型將 3D 空間量化成立方體素,其中每個體素都用 3D 向量來描述體素的 RGB 顏色,以及用描述體素不透明度的密度純量(介於 [0-1] 之間,數值越高不透明度越高)。
為確保密度和色彩的範圍介於 [0-1] 之間,我們在對數空間中表示體積色彩和密度。 在模型的前向函數期間,對數空間值會通過 sigmoid 函數傳遞,以將對數空間值帶到正確的範圍。
此外, VolumeModel
會包含渲染器物件。在優化過程中,此物件保持不變。
在此儲存格中,我們也會定義 huber
損失函數,用來計算渲染色彩和遮罩之間的差異。
class VolumeModel(torch.nn.Module):
def __init__(self, renderer, volume_size=[64] * 3, voxel_size=0.1):
super().__init__()
# After evaluating torch.sigmoid(self.log_colors), we get
# densities close to zero.
self.log_densities = torch.nn.Parameter(-4.0 * torch.ones(1, *volume_size))
# After evaluating torch.sigmoid(self.log_colors), we get
# a neutral gray color everywhere.
self.log_colors = torch.nn.Parameter(torch.zeros(3, *volume_size))
self._voxel_size = voxel_size
# Store the renderer module as well.
self._renderer = renderer
def forward(self, cameras):
batch_size = cameras.R.shape[0]
# Convert the log-space values to the densities/colors
densities = torch.sigmoid(self.log_densities)
colors = torch.sigmoid(self.log_colors)
# Instantiate the Volumes object, making sure
# the densities and colors are correctly
# expanded batch_size-times.
volumes = Volumes(
densities = densities[None].expand(
batch_size, *self.log_densities.shape),
features = colors[None].expand(
batch_size, *self.log_colors.shape),
voxel_size=self._voxel_size,
)
# Given cameras and volumes, run the renderer
# and return only the first output value
# (the 2nd output is a representation of the sampled
# rays which can be omitted for our purpose).
return self._renderer(cameras=cameras, volumes=volumes)[0]
# A helper function for evaluating the smooth L1 (huber) loss
# between the rendered silhouettes and colors.
def huber(x, y, scaling=0.1):
diff_sq = (x - y) ** 2
loss = ((1 + diff_sq / (scaling**2)).clamp(1e-4).sqrt() - 1) * float(scaling)
return loss
在此我們執行可微分渲染的體積調整。
為了調整體積,我們從 target_cameras
視角渲染並將結果的渲染與觀察到的 target_images
和 target_silhouettes
進行比較。
比較時會對應的 target_images
/rendered_images
和 target_silhouettes
/rendered_silhouettes
組合評估平均 Huber(平滑 L1)誤差。
# First move all relevant variables to the correct device.
target_cameras = target_cameras.to(device)
target_images = target_images.to(device)
target_silhouettes = target_silhouettes.to(device)
# Instantiate the volumetric model.
# We use a cubical volume with the size of
# one side = 128. The size of each voxel of the volume
# is set to volume_extent_world / volume_size s.t. the
# volume represents the space enclosed in a 3D bounding box
# centered at (0, 0, 0) with the size of each side equal to 3.
volume_size = 128
volume_model = VolumeModel(
renderer,
volume_size=[volume_size] * 3,
voxel_size = volume_extent_world / volume_size,
).to(device)
# Instantiate the Adam optimizer. We set its master learning rate to 0.1.
lr = 0.1
optimizer = torch.optim.Adam(volume_model.parameters(), lr=lr)
# We do 300 Adam iterations and sample 10 random images in each minibatch.
batch_size = 10
n_iter = 300
for iteration in range(n_iter):
# In case we reached the last 75% of iterations,
# decrease the learning rate of the optimizer 10-fold.
if iteration == round(n_iter * 0.75):
print('Decreasing LR 10-fold ...')
optimizer = torch.optim.Adam(
volume_model.parameters(), lr=lr * 0.1
)
# Zero the optimizer gradient.
optimizer.zero_grad()
# Sample random batch indices.
batch_idx = torch.randperm(len(target_cameras))[:batch_size]
# Sample the minibatch of cameras.
batch_cameras = FoVPerspectiveCameras(
R = target_cameras.R[batch_idx],
T = target_cameras.T[batch_idx],
znear = target_cameras.znear[batch_idx],
zfar = target_cameras.zfar[batch_idx],
aspect_ratio = target_cameras.aspect_ratio[batch_idx],
fov = target_cameras.fov[batch_idx],
device = device,
)
# Evaluate the volumetric model.
rendered_images, rendered_silhouettes = volume_model(
batch_cameras
).split([3, 1], dim=-1)
# Compute the silhouette error as the mean huber
# loss between the predicted masks and the
# target silhouettes.
sil_err = huber(
rendered_silhouettes[..., 0], target_silhouettes[batch_idx],
).abs().mean()
# Compute the color error as the mean huber
# loss between the rendered colors and the
# target ground truth images.
color_err = huber(
rendered_images, target_images[batch_idx],
).abs().mean()
# The optimization loss is a simple
# sum of the color and silhouette errors.
loss = color_err + sil_err
# Print the current values of the losses.
if iteration % 10 == 0:
print(
f'Iteration {iteration:05d}:'
+ f' color_err = {float(color_err):1.2e}'
+ f' mask_err = {float(sil_err):1.2e}'
)
# Take the optimization step.
loss.backward()
optimizer.step()
# Visualize the renders every 40 iterations.
if iteration % 40 == 0:
# Visualize only a single randomly selected element of the batch.
im_show_idx = int(torch.randint(low=0, high=batch_size, size=(1,)))
fig, ax = plt.subplots(2, 2, figsize=(10, 10))
ax = ax.ravel()
clamp_and_detach = lambda x: x.clamp(0.0, 1.0).cpu().detach().numpy()
ax[0].imshow(clamp_and_detach(rendered_images[im_show_idx]))
ax[1].imshow(clamp_and_detach(target_images[batch_idx[im_show_idx], ..., :3]))
ax[2].imshow(clamp_and_detach(rendered_silhouettes[im_show_idx, ..., 0]))
ax[3].imshow(clamp_and_detach(target_silhouettes[batch_idx[im_show_idx]]))
for ax_, title_ in zip(
ax,
("rendered image", "target image", "rendered silhouette", "target silhouette")
):
ax_.grid("off")
ax_.axis("off")
ax_.set_title(title_)
fig.canvas.draw(); fig.show()
display.clear_output(wait=True)
display.display(fig)
最後,我們透過從環繞體積 y 軸旋轉的多個視角渲染,將最佳化的體積視覺化。
def generate_rotating_volume(volume_model, n_frames = 50):
logRs = torch.zeros(n_frames, 3, device=device)
logRs[:, 1] = torch.linspace(0.0, 2.0 * 3.14, n_frames, device=device)
Rs = so3_exp_map(logRs)
Ts = torch.zeros(n_frames, 3, device=device)
Ts[:, 2] = 2.7
frames = []
print('Generating rotating volume ...')
for R, T in zip(tqdm(Rs), Ts):
camera = FoVPerspectiveCameras(
R=R[None],
T=T[None],
znear = target_cameras.znear[0],
zfar = target_cameras.zfar[0],
aspect_ratio = target_cameras.aspect_ratio[0],
fov = target_cameras.fov[0],
device=device,
)
frames.append(volume_model(camera)[..., :3].clamp(0.0, 1.0))
return torch.cat(frames)
with torch.no_grad():
rotating_volume_frames = generate_rotating_volume(volume_model, n_frames=7*4)
image_grid(rotating_volume_frames.clamp(0., 1.).cpu().numpy(), rows=4, cols=7, rgb=True, fill=True)
plt.show()
在本教學課程中,我們展示了如何最佳化場景的 3D 體積表示,讓體積在已知視角的渲染與每個視角的觀察影像相符。渲染使用 PyTorch3D 的體積渲染器執行,該渲染器由 NDCMultinomialRaysampler
和 EmissionAbsorptionRaymarcher
組成。