# Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved.
Implicitron 的元件全數都基於統一的分層設定系統。這允許針對每個新元件定義可設定變數和所有預設值。然後自動將與實驗相關的所有設定組合成一個單一設定檔,此檔案會完整指定實驗。其中一項特別重要的功能是延伸點,使用者可以在延伸點插入 Implicitron 基礎元件的自己的子類別。
定義本系統的檔案位於 PyTorch3D 回應中的此處。Implicitron 量體教學指南包含使用設定系統的一個簡單範例。本教學指南提供使用和修改 Implicitron 可設定元件的詳細實作經驗。
確保已安裝 torch
和 torchvision
。如果尚未安裝 pytorch3d
,請使用下列儲存格進行安裝
import os
import sys
import torch
need_pytorch3d=False
try:
import pytorch3d
except ModuleNotFoundError:
need_pytorch3d=True
if need_pytorch3d:
if torch.__version__.startswith("2.2.") and sys.platform.startswith("linux"):
# We try to install PyTorch3D via a released wheel.
pyt_version_str=torch.__version__.split("+")[0].replace(".", "")
version_str="".join([
f"py3{sys.version_info.minor}_cu",
torch.version.cuda.replace(".",""),
f"_pyt{pyt_version_str}"
])
!pip install fvcore iopath
!pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html
else:
# We try to install PyTorch3D from source.
!pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'
確保已安裝 omegaconf。如果尚未安裝,請執行此儲存格。(不應需要重新啟動執行時期。)
!pip install omegaconf
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
from omegaconf import DictConfig, OmegaConf
from pytorch3d.implicitron.tools.config import (
Configurable,
ReplaceableBase,
expand_args_fields,
get_default_args,
registry,
run_auto_creation,
)
@dataclass
class MyDataclass:
a: int
b: int = 8
c: Optional[Tuple[int, ...]] = None
def __post_init__(self):
print(f"created with a = {self.a}")
self.d = 2 * self.b
my_dataclass_instance = MyDataclass(a=18)
assert my_dataclass_instance.d == 16
👷 請注意此處的 dataclass
裝飾器是修改類別本身定義的函數。它會在定義後立即執行。我們的設定系統需要 implicitron 函式庫程式碼包含類別,其中其修改的版本需要了解使用者定義的實作。因此,類別的修改需要延後。我們不使用裝飾器。
dc = DictConfig({"a": 2, "b": True, "c": None, "d": "hello"})
assert dc.a == dc["a"] == 2
OmegaConf 可至 yaml 進行序列處理,反之亦然。Hydra 函式庫的設定檔依賴此功能。
print(OmegaConf.to_yaml(dc))
assert OmegaConf.create(OmegaConf.to_yaml(dc)) == dc
OmegaConf.structured 提供一個 DictConfig,其來源為資料類別或資料類別的實例。相較於一般的 DictConfig,它會進行類型檢查,而且只能新增已知的鍵。
structured = OmegaConf.structured(MyDataclass)
assert isinstance(structured, DictConfig)
print(structured)
print()
print(OmegaConf.to_yaml(structured))
structured
知道它缺少 a
的值。
此類物件具有相容於資料類別的成員,因此可執行下列初始化。
structured.a = 21
my_dataclass_instance2 = MyDataclass(**structured)
print(my_dataclass_instance2)
您也可以對實例呼叫 OmegaConf.structured。
structured_from_instance = OmegaConf.structured(my_dataclass_instance)
my_dataclass_instance3 = MyDataclass(**structured_from_instance)
print(my_dataclass_instance3)
我們提供與 OmegaConf.structured
等效但支援更多功能的函式。若要使用我們的函式達成上述目的,請使用下列內容。請注意,我們使用特殊的基礎類別 Configurable
指出可組態類別,而不是裝飾器。
class MyConfigurable(Configurable):
a: int
b: int = 8
c: Optional[Tuple[int, ...]] = None
def __post_init__(self):
print(f"created with a = {self.a}")
self.d = 2 * self.b
# The expand_args_fields function modifies the class like @dataclasses.dataclass.
# If it has not been called on a Configurable object before it has been instantiated, it will
# be called automatically.
expand_args_fields(MyConfigurable)
my_configurable_instance = MyConfigurable(a=18)
assert my_configurable_instance.d == 16
# get_default_args also calls expand_args_fields automatically
our_structured = get_default_args(MyConfigurable)
assert isinstance(our_structured, DictConfig)
print(OmegaConf.to_yaml(our_structured))
our_structured.a = 21
print(MyConfigurable(**our_structured))
我們的系統允許 Configurable
類別彼此包含。請記住一件事:在 __post_init__
中加入對 run_auto_creation
的呼叫。
class Inner(Configurable):
a: int = 8
b: bool = True
c: Tuple[int, ...] = (2, 3, 4, 6)
class Outer(Configurable):
inner: Inner
x: str = "hello"
xx: bool = False
def __post_init__(self):
run_auto_creation(self)
outer_dc = get_default_args(Outer)
print(OmegaConf.to_yaml(outer_dc))
outer = Outer(**outer_dc)
assert isinstance(outer, Outer)
assert isinstance(outer.inner, Inner)
print(vars(outer))
print(outer.inner)
請注意,inner_args
是 outer
的額外成員。run_auto_creation(self)
等同於
self.inner = Inner(**self.inner_args)
如果類別使用 ReplaceableBase
而非 Configurable
作為基礎類別,我們稱其為可取代的。它表示它設計為子類別可使用它來取代它本身。我們可能會使用 NotImplementedError
來表示子類別預期實作的功能。系統會維護一個包含每個 ReplaceableBase
子類別的全球 registry
。子類別使用裝飾器向其中註冊自己。
包含 ReplaceableBase
的可組態類別(即使用我們的系統的類別,即 Configurable
或 ReplaceableBase
的子項)也必須包含對應的 str
型別的 class_type
欄位,以指出要使用的具體子類別。
class InnerBase(ReplaceableBase):
def say_something(self):
raise NotImplementedError
@registry.register
class Inner1(InnerBase):
a: int = 1
b: str = "h"
def say_something(self):
print("hello from an Inner1")
@registry.register
class Inner2(InnerBase):
a: int = 2
def say_something(self):
print("hello from an Inner2")
class Out(Configurable):
inner: InnerBase
inner_class_type: str = "Inner1"
x: int = 19
def __post_init__(self):
run_auto_creation(self)
def talk(self):
self.inner.say_something()
Out_dc = get_default_args(Out)
print(OmegaConf.to_yaml(Out_dc))
Out_dc.inner_class_type = "Inner2"
out = Out(**Out_dc)
print(out.inner)
out.talk()
請注意,在此情況下有很多 args
成員。通常可以在程式碼中忽略它們。它們是設定檔所需要的。
print(vars(out))
class MyLinear(torch.nn.Module, Configurable):
d_in: int = 2
d_out: int = 200
def __post_init__(self):
super().__init__()
self.linear = torch.nn.Linear(in_features=self.d_in, out_features=self.d_out)
def forward(self, x):
return self.linear.forward(x)
my_linear = MyLinear()
input = torch.zeros(2)
output = my_linear(input)
print("output shape:", output.shape)
my_linear
具有 Module
的所有一般特性。例如,它可以使用 torch.save
和 torch.load
儲存和載入。它有參數。
for name, value in my_linear.named_parameters():
print(name, value.shape)
假設我正在使用帶有 第 5 節 中的 Out
的程式庫,但我想實作自己的 InnerBase
子項。我所需要做的就是註冊其定義,但我需要在對於 Out
明確或隱式地呼叫 expand_args_fields
之前執行此操作。
@registry.register
class UserImplementedInner(InnerBase):
a: int = 200
def say_something(self):
print("hello from the user")
在這個時候,我們需要重新定義 Out
類別。否則,如果它已經在沒有 UserImplementedInner
的情況下擴充,則下列情況會無法運作,因為類別所知道的實作會在它擴充時固定下來。
如果您從腳本執行實驗,則這裡要記住的事情是,您必須匯入自己的模組,才能在使用程式庫類別之前,註冊自己的實作。
class Out(Configurable):
inner: InnerBase
inner_class_type: str = "Inner1"
x: int = 19
def __post_init__(self):
run_auto_creation(self)
def talk(self):
self.inner.say_something()
out2 = Out(inner_class_type="UserImplementedInner")
print(out2.inner)
讓我們看看如果我們有一個可插入的子元件,我們需要做些什麼,以允許使用者提供自己的子元件。
class SubComponent(Configurable):
x: float = 0.25
def apply(self, a: float) -> float:
return a + self.x
class LargeComponent(Configurable):
repeats: int = 4
subcomponent: SubComponent
def __post_init__(self):
run_auto_creation(self)
def apply(self, a: float) -> float:
for _ in range(self.repeats):
a = self.subcomponent.apply(a)
return a
large_component = LargeComponent()
assert large_component.apply(3) == 4
print(OmegaConf.to_yaml(LargeComponent))
使用泛型
class SubComponentBase(ReplaceableBase):
def apply(self, a: float) -> float:
raise NotImplementedError
@registry.register
class SubComponent(SubComponentBase):
x: float = 0.25
def apply(self, a: float) -> float:
return a + self.x
class LargeComponent(Configurable):
repeats: int = 4
subcomponent: SubComponentBase
subcomponent_class_type: str = "SubComponent"
def __post_init__(self):
run_auto_creation(self)
def apply(self, a: float) -> float:
for _ in range(self.repeats):
a = self.subcomponent.apply(a)
return a
large_component = LargeComponent()
assert large_component.apply(3) == 4
print(OmegaConf.to_yaml(LargeComponent))
以下事項必須變更
SubComponentBase
。SubComponent
獲得 @registry.register
裝飾,並將其基礎類別變更為新的基礎類別。subcomponent_class_type
為外部類別的成員。subcomponent_args
必須變更為 subcomponent_SubComponent_args
。__post_init__
或未在其中呼叫 run_auto_creation
。 subcomponent_class_type = "SubComponent"
而非 subcomponent_class_type: str = "SubComponent"