# Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved.
在此範例中,我們使用 PyTorch3D 中的 VolumeRenderer 做為 Implicitron 中的客製化隱含函數。我們將會看到
請確定已安裝 torch
和 torchvision
。如果尚未安裝 pytorch3d
,請使用下列儲存格安裝
import os
import sys
import torch
need_pytorch3d=False
try:
import pytorch3d
except ModuleNotFoundError:
need_pytorch3d=True
if need_pytorch3d:
if torch.__version__.startswith("2.2.") and sys.platform.startswith("linux"):
# We try to install PyTorch3D via a released wheel.
pyt_version_str=torch.__version__.split("+")[0].replace(".", "")
version_str="".join([
f"py3{sys.version_info.minor}_cu",
torch.version.cuda.replace(".",""),
f"_pyt{pyt_version_str}"
])
!pip install fvcore iopath
!pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html
else:
# We try to install PyTorch3D from source.
!pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'
請確定已安裝 omegaconf 和 visdom。如果尚未安裝,請執行此儲存格(應無需重新啟動執行時期)。
!pip install omegaconf visdom
import logging
from typing import Tuple
import matplotlib.animation as animation
import matplotlib.pyplot as plt
import numpy as np
import torch
import tqdm
from IPython.display import HTML
from omegaconf import OmegaConf
from PIL import Image
from pytorch3d.implicitron.dataset.dataset_base import FrameData
from pytorch3d.implicitron.dataset.rendered_mesh_dataset_map_provider import RenderedMeshDatasetMapProvider
from pytorch3d.implicitron.models.generic_model import GenericModel
from pytorch3d.implicitron.models.implicit_function.base import ImplicitFunctionBase, ImplicitronRayBundle
from pytorch3d.implicitron.models.renderer.base import EvaluationMode
from pytorch3d.implicitron.tools.config import get_default_args, registry, remove_unused_components
from pytorch3d.renderer.implicit.renderer import VolumeSampler
from pytorch3d.structures import Volumes
from pytorch3d.vis.plotly_vis import plot_batch_individually, plot_scene
output_resolution = 80
torch.set_printoptions(sci_mode=False)
在 Implicitron 中,資料集的訓練、驗證和測試部分以 dataset_map
表示,並由 DatasetMapProvider
的實作提供。RenderedMeshDatasetMapProvider
是其中一種,可透過擷取網格並執行渲染來產生僅有訓練元件的單一場景資料集。我們在乳牛網格中使用此提供者。
如果使用 Google Colab 執行此筆記本,請執行下列儲存格以擷取網格 obj 和紋理檔案,並將其儲存在路徑 data/cow_mesh 中。如果在本地執行,則資料已在正確的路徑中提供。
!mkdir -p data/cow_mesh
!wget -P data/cow_mesh https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow.obj
!wget -P data/cow_mesh https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow.mtl
!wget -P data/cow_mesh https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow_texture.png
cow_provider = RenderedMeshDatasetMapProvider(
data_file="data/cow_mesh/cow.obj",
use_point_light=False,
resolution=output_resolution,
)
dataset_map = cow_provider.get_dataset_map()
tr_cameras = [training_frame.camera for training_frame in dataset_map.train]
# The cameras are all in the XZ plane, in a circle about 2.7 from the origin
centers = torch.cat([i.get_camera_center() for i in tr_cameras])
print(centers.min(0).values)
print(centers.max(0).values)
# visualization of the cameras
plot = plot_scene({"k": {i: camera for i, camera in enumerate(tr_cameras)}}, camera_scale=0.25)
plot.layout.scene.aspectmode = "data"
plot
神經渲染方法的核心為空間座標函數,稱為隱含函數,於渲染過程中使用。(這些函數通常也可使用其他資料,例如檢視方向。)常見的渲染程序為對隱含函數提供的密度和顏色進行射線追蹤。就本例而言,從 3D 體積格線中擷取樣本是很簡單的空間座標函數。
在此,我們定義自己的隱含函數,這將使用 PyTorch3D 現有的功能來從體積格線中擷取樣本。我們透過建立 ImplicitFunctionBase
的子類別來執行此操作。我們需要使用特殊裝飾器註冊我們的子類別。我們使用 Python 的 dataclass 標記來設定模組。
@registry.register
class MyVolumes(ImplicitFunctionBase, torch.nn.Module):
grid_resolution: int = 50 # common HWD of volumes, the number of voxels in each direction
extent: float = 1.0 # In world coordinates, the volume occupies is [-extent, extent] along each axis
def __post_init__(self):
# We have to call this explicitly if there are other base classes like Module
super().__init__()
# We define parameters like other torch.nn.Module objects.
# In this case, both our parameter tensors are trainable; they govern the contents of the volume grid.
density = torch.full((self.grid_resolution, self.grid_resolution, self.grid_resolution), -2.0)
self.density = torch.nn.Parameter(density)
color = torch.full((3, self.grid_resolution, self.grid_resolution, self.grid_resolution), 0.0)
self.color = torch.nn.Parameter(color)
self.density_activation = torch.nn.Softplus()
def forward(
self,
ray_bundle: ImplicitronRayBundle,
fun_viewpool=None,
global_code=None,
):
densities = self.density_activation(self.density[None, None])
voxel_size = 2.0 * float(self.extent) / self.grid_resolution
features = self.color.sigmoid()[None]
# Like other PyTorch3D structures, the actual Volumes object should only exist as long
# as one iteration of training. It is local to this function.
volume = Volumes(densities=densities, features=features, voxel_size=voxel_size)
sampler = VolumeSampler(volumes=volume)
densities, features = sampler(ray_bundle)
# When an implicit function is used for raymarching, i.e. for MultiPassEmissionAbsorptionRenderer,
# it must return (densities, features, an auxiliary tuple)
return densities, features, {}
PyTorch3D 中的主要模型物件為 GenericModel
,具有可插入的主要步驟元件,包括渲染器和隱含函數。有兩種等效的方法可在此處建立它。
CONSTRUCT_MODEL_FROM_CONFIG = True
if CONSTRUCT_MODEL_FROM_CONFIG:
# Via a DictConfig - this is how our training loop with hydra works
cfg = get_default_args(GenericModel)
cfg.implicit_function_class_type = "MyVolumes"
cfg.render_image_height=output_resolution
cfg.render_image_width=output_resolution
cfg.loss_weights={"loss_rgb_huber": 1.0}
cfg.tqdm_trigger_threshold=19000
cfg.raysampler_AdaptiveRaySampler_args.scene_extent= 4.0
gm = GenericModel(**cfg)
else:
# constructing GenericModel directly
gm = GenericModel(
implicit_function_class_type="MyVolumes",
render_image_height=output_resolution,
render_image_width=output_resolution,
loss_weights={"loss_rgb_huber": 1.0},
tqdm_trigger_threshold=19000,
raysampler_AdaptiveRaySampler_args = {"scene_extent": 4.0}
)
# In this case we can get the equivalent DictConfig cfg object to the way gm is configured as follows
cfg = OmegaConf.structured(gm)
預設渲染器為發射吸收射線追蹤器。我們維持此預設值。
# We can display the configuration in use as follows.
remove_unused_components(cfg)
yaml = OmegaConf.to_yaml(cfg, sort_keys=False)
%page -r yaml
device = torch.device("cuda:0")
gm.to(device)
assert next(gm.parameters()).is_cuda
train_data_collated = [FrameData.collate([frame.to(device)]) for frame in dataset_map.train]
gm.train()
optimizer = torch.optim.Adam(gm.parameters(), lr=0.1)
iterator = tqdm.tqdm(range(2000))
for n_batch in iterator:
optimizer.zero_grad()
frame = train_data_collated[n_batch % len(dataset_map.train)]
out = gm(**frame, evaluation_mode=EvaluationMode.TRAINING)
out["objective"].backward()
if n_batch % 100 == 0:
iterator.set_postfix_str(f"loss: {float(out['objective']):.5f}")
optimizer.step()
我們從所有視點產生完整的影像,以觀察其外觀。
def to_numpy_image(image):
# Takes an image of shape (C, H, W) in [0,1], where C=3 or 1
# to a numpy uint image of shape (H, W, 3)
return (image * 255).to(torch.uint8).permute(1, 2, 0).detach().cpu().expand(-1, -1, 3).numpy()
def resize_image(image):
# Takes images of shape (B, C, H, W) to (B, C, output_resolution, output_resolution)
return torch.nn.functional.interpolate(image, size=(output_resolution, output_resolution))
gm.eval()
images = []
expected = []
masks = []
masks_expected = []
for frame in tqdm.tqdm(train_data_collated):
with torch.no_grad():
out = gm(**frame, evaluation_mode=EvaluationMode.EVALUATION)
image_rgb = to_numpy_image(out["images_render"][0])
mask = to_numpy_image(out["masks_render"][0])
expd = to_numpy_image(resize_image(frame.image_rgb)[0])
mask_expected = to_numpy_image(resize_image(frame.fg_probability)[0])
images.append(image_rgb)
masks.append(mask)
expected.append(expd)
masks_expected.append(mask_expected)
我們會繪出顯示預測影像和預期影像的網格,接著是每個視點的預測遮罩和預期遮罩。這是一個由四列影像所構成的網格,包覆在許多大型列中,也就是
┌────────┬────────┐ ┌────────┐
│pred │pred │ │pred │
│image │image │ │image │
│1 │2 │ │n │
├────────┼────────┤ ├────────┤
│expected│expected│ │expected│
│image │image │ ... │image │
│1 │2 │ │n │
├────────┼────────┤ ├────────┤
│pred │pred │ │pred │
│mask │mask │ │mask │
│1 │2 │ │n │
├────────┼────────┤ ├────────┤
│expected│expected│ │expected│
│mask │mask │ │mask │
│1 │2 │ │n │
├────────┼────────┤ ├────────┤
│pred │pred │ │pred │
│image │image │ │image │
│n+1 │n+1 │ │2n │
├────────┼────────┤ ├────────┤
│expected│expected│ │expected│
│image │image │ ... │image │
│n+1 │n+2 │ │2n │
├────────┼────────┤ ├────────┤
│pred │pred │ │pred │
│mask │mask │ │mask │
│n+1 │n+2 │ │2n │
├────────┼────────┤ ├────────┤
│expected│expected│ │expected│
│mask │mask │ │mask │
│n+1 │n+2 │ │2n │
└────────┴────────┘ └────────┘
...
</center></small>
images_to_display = [images.copy(), expected.copy(), masks.copy(), masks_expected.copy()]
n_rows = 4
n_images = len(images)
blank_image = images[0] * 0
n_per_row = 1+(n_images-1)//n_rows
for _ in range(n_per_row*n_rows - n_images):
for group in images_to_display:
group.append(blank_image)
images_to_display_listed = [[[i] for i in j] for j in images_to_display]
split = []
for row in range(n_rows):
for group in images_to_display_listed:
split.append(group[row*n_per_row:(row+1)*n_per_row])
Image.fromarray(np.block(split))
# Print the maximum channel intensity in the first image.
print(images[1].max()/255)
plt.ioff()
fig, ax = plt.subplots(figsize=(3,3))
ax.grid(None)
ims = [[ax.imshow(im, animated=True)] for im in images]
ani = animation.ArtistAnimation(fig, ims, interval=80, blit=True)
ani_html = ani.to_jshtml()
HTML(ani_html)
# If you want to see the output of the model with the volume forced to opaque white, run this and re-evaluate
# with torch.no_grad():
# gm._implicit_functions[0]._fn.density.fill_(9.0)
# gm._implicit_functions[0]._fn.color.fill_(9.0)