# Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved.
在此教學課程中,我們將學習使用可微渲染,根據參考影像找出相機的 [x、y、z] 位置。
我們首先使用相機的起始位置初始化一個渲染器。然後,我們將使用這項功能產生影像、計算參考影像的損失,最後將反向傳播至整個工作流程,以更新相機的位置。
本教學課程將示範如何
.obj
檔案載入網格Camera
、Shader
和 Renderer
確定已安裝 torch
和 torchvision
。如果未安裝 pytorch3d
,請使用以下儲存格進行安裝
import os
import sys
import torch
need_pytorch3d=False
try:
import pytorch3d
except ModuleNotFoundError:
need_pytorch3d=True
if need_pytorch3d:
if torch.__version__.startswith("2.2.") and sys.platform.startswith("linux"):
# We try to install PyTorch3D via a released wheel.
pyt_version_str=torch.__version__.split("+")[0].replace(".", "")
version_str="".join([
f"py3{sys.version_info.minor}_cu",
torch.version.cuda.replace(".",""),
f"_pyt{pyt_version_str}"
])
!pip install fvcore iopath
!pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html
else:
# We try to install PyTorch3D from source.
!pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'
import os
import torch
import numpy as np
from tqdm.notebook import tqdm
import imageio
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from skimage import img_as_ubyte
# io utils
from pytorch3d.io import load_obj
# datastructures
from pytorch3d.structures import Meshes
# 3D transformations functions
from pytorch3d.transforms import Rotate, Translate
# rendering components
from pytorch3d.renderer import (
FoVPerspectiveCameras, look_at_view_transform, look_at_rotation,
RasterizationSettings, MeshRenderer, MeshRasterizer, BlendParams,
SoftSilhouetteShader, HardPhongShader, PointLights, TexturesVertex,
)
我們將載入一個 obj 檔案並建立一個 Meshes 物件。Meshes 是 PyTorch3D 中提供的獨特資料結構,用於處理不同大小網格的批次。它有幾個有用的類別方法,用於渲染工作流程中。
如果您在複製 PyTorch3D 儲存庫之後在本機執行此筆記本,網格將已可用。如果使用 Google Colab,請擷取網格並儲存在路徑 data/
中
!mkdir -p data
!wget -P data https://dl.fbaipublicfiles.com/pytorch3d/data/teapot/teapot.obj
# Set the cuda device
if torch.cuda.is_available():
device = torch.device("cuda:0")
torch.cuda.set_device(device)
else:
device = torch.device("cpu")
# Load the obj and ignore the textures and materials.
verts, faces_idx, _ = load_obj("./data/teapot.obj")
faces = faces_idx.verts_idx
# Initialize each vertex to be white in color.
verts_rgb = torch.ones_like(verts)[None] # (1, V, 3)
textures = TexturesVertex(verts_features=verts_rgb.to(device))
# Create a Meshes object for the teapot. Here we have only one mesh in the batch.
teapot_mesh = Meshes(
verts=[verts.to(device)],
faces=[faces.to(device)],
textures=textures
)
PyTorch3D 中的渲染器由光柵器和著色器組成,每個組成部分都有許多子組件,例如相機(正交/透視)。在此,我們會初始化其中的部分組件,並為其他組件使用預設值。
針對相機位置最佳化,我們將使用只產生物件輪廓的渲染器,並且不套用任何光照或陰影。我們還會初始化另一個套用完整Phong 陰影的渲染器,並將其用於視覺化輸出。
# Initialize a perspective camera.
cameras = FoVPerspectiveCameras(device=device)
# To blend the 100 faces we set a few parameters which control the opacity and the sharpness of
# edges. Refer to blending.py for more details.
blend_params = BlendParams(sigma=1e-4, gamma=1e-4)
# Define the settings for rasterization and shading. Here we set the output image to be of size
# 256x256. To form the blended image we use 100 faces for each pixel. We also set bin_size and max_faces_per_bin to None which ensure that
# the faster coarse-to-fine rasterization method is used. Refer to rasterize_meshes.py for
# explanations of these parameters. Refer to docs/notes/renderer.md for an explanation of
# the difference between naive and coarse-to-fine rasterization.
raster_settings = RasterizationSettings(
image_size=256,
blur_radius=np.log(1. / 1e-4 - 1.) * blend_params.sigma,
faces_per_pixel=100,
)
# Create a silhouette mesh renderer by composing a rasterizer and a shader.
silhouette_renderer = MeshRenderer(
rasterizer=MeshRasterizer(
cameras=cameras,
raster_settings=raster_settings
),
shader=SoftSilhouetteShader(blend_params=blend_params)
)
# We will also create a Phong renderer. This is simpler and only needs to render one face per pixel.
raster_settings = RasterizationSettings(
image_size=256,
blur_radius=0.0,
faces_per_pixel=1,
)
# We can add a point light in front of the object.
lights = PointLights(device=device, location=((2.0, 2.0, -2.0),))
phong_renderer = MeshRenderer(
rasterizer=MeshRasterizer(
cameras=cameras,
raster_settings=raster_settings
),
shader=HardPhongShader(device=device, cameras=cameras, lights=lights)
)
我們將先定位茶壺並產生影像。我們使用輔助函數將茶壺旋轉至所需視點。然後,我們可以使用渲染器產生影像。在此,我們將使用兩個渲染器,並視覺化輪廓和完整陰影影像。
世界座標系統定義為 +Y 向上、+X 向左和 +Z 向內。以世界座標表示的茶壺,其壺嘴朝向左側。
我們定義了一個位於正 Z 軸上的相機,因此會看到壺嘴在右側。
# Select the viewpoint using spherical angles
distance = 3 # distance from camera to the object
elevation = 50.0 # angle of elevation in degrees
azimuth = 0.0 # No rotation so the camera is positioned on the +Z axis.
# Get the position of the camera based on the spherical angles
R, T = look_at_view_transform(distance, elevation, azimuth, device=device)
# Render the teapot providing the values of R and T.
silhouette = silhouette_renderer(meshes_world=teapot_mesh, R=R, T=T)
image_ref = phong_renderer(meshes_world=teapot_mesh, R=R, T=T)
silhouette = silhouette.cpu().numpy()
image_ref = image_ref.cpu().numpy()
plt.figure(figsize=(10, 10))
plt.subplot(1, 2, 1)
plt.imshow(silhouette.squeeze()[..., 3]) # only plot the alpha channel of the RGBA image
plt.grid(False)
plt.subplot(1, 2, 2)
plt.imshow(image_ref.squeeze())
plt.grid(False)
在此,我們建立一個簡單的模型類別,並初始化相機位置的參數。
class Model(nn.Module):
def __init__(self, meshes, renderer, image_ref):
super().__init__()
self.meshes = meshes
self.device = meshes.device
self.renderer = renderer
# Get the silhouette of the reference RGB image by finding all non-white pixel values.
image_ref = torch.from_numpy((image_ref[..., :3].max(-1) != 1).astype(np.float32))
self.register_buffer('image_ref', image_ref)
# Create an optimizable parameter for the x, y, z position of the camera.
self.camera_position = nn.Parameter(
torch.from_numpy(np.array([3.0, 6.9, +2.5], dtype=np.float32)).to(meshes.device))
def forward(self):
# Render the image using the updated camera position. Based on the new position of the
# camera we calculate the rotation and translation matrices
R = look_at_rotation(self.camera_position[None, :], device=self.device) # (1, 3, 3)
T = -torch.bmm(R.transpose(1, 2), self.camera_position[None, :, None])[:, :, 0] # (1, 3)
image = self.renderer(meshes_world=self.meshes.clone(), R=R, T=T)
# Calculate the silhouette loss
loss = torch.sum((image[..., 3] - self.image_ref) ** 2)
return loss, image
現在,我們可以建立上述模型的執行個體,並設定相機位置參數的最佳化器。
# We will save images periodically and compose them into a GIF.
filename_output = "./teapot_optimization_demo.gif"
writer = imageio.get_writer(filename_output, mode='I', duration=0.3)
# Initialize a model using the renderer, mesh and reference image
model = Model(meshes=teapot_mesh, renderer=silhouette_renderer, image_ref=image_ref).to(device)
# Create an optimizer. Here we are using Adam and we pass in the parameters of the model
optimizer = torch.optim.Adam(model.parameters(), lr=0.05)
plt.figure(figsize=(10, 10))
_, image_init = model()
plt.subplot(1, 2, 1)
plt.imshow(image_init.detach().squeeze().cpu().numpy()[..., 3])
plt.grid(False)
plt.title("Starting position")
plt.subplot(1, 2, 2)
plt.imshow(model.image_ref.cpu().numpy().squeeze())
plt.grid(False)
plt.title("Reference silhouette");
我們會執行前進和後向通過的數個反覆運算,並每 10 個反覆運算就儲存一次輸出。執行完畢後,請查看./teapot_optimization_demo.gif
,欣賞最佳化程序的精采 GIF!
loop = tqdm(range(200))
for i in loop:
optimizer.zero_grad()
loss, _ = model()
loss.backward()
optimizer.step()
loop.set_description('Optimizing (loss %.4f)' % loss.data)
if loss.item() < 200:
break
# Save outputs to create a GIF.
if i % 10 == 0:
R = look_at_rotation(model.camera_position[None, :], device=model.device)
T = -torch.bmm(R.transpose(1, 2), model.camera_position[None, :, None])[:, :, 0] # (1, 3)
image = phong_renderer(meshes_world=model.meshes.clone(), R=R, T=T)
image = image[0, ..., :3].detach().squeeze().cpu().numpy()
image = img_as_ubyte(image)
writer.append_data(image)
plt.figure()
plt.imshow(image[..., :3])
plt.title("iter: %d, loss: %0.2f" % (i, loss.data))
plt.axis("off")
writer.close()
在本教學課程中,我們學習如何從 obj 檔案載入網格、初始化稱為Meshes 的 PyTorch3D 資料結構、設定包含Rasterizer 和Shader 的渲染器、設定包含模型和損失函數的最佳化迴圈,以及執行最佳化。