TorchScript是什麼?
TorchScript是一種從PyTorch程式碼建立可序列化和可最佳化模型的方法。任何TorchScript程式都可以從Python程序中儲存,並載入到沒有Python依賴的程序中。
我們提供了一些工具來增量地將模型從純Python程式轉換為能夠獨立於Python執行的TorchScript程式,例如在獨立的c++程式中。這使得使用熟悉的Python工具在PyTorch中訓練模型,然後透過TorchScript將模型匯出到生產環境中成為可能,在這種環境中,Python程式可能由於效能和多執行緒的原因不適用。
編寫TorchScript程式碼
torch.jit.script(obj)
指令碼化一個函式或者nn。Module物件,將會檢查它的原始碼, 將其作為TorchScript程式碼使用TorchScrit編譯器編譯它,返回一個ScriptModule或ScriptFunction。
TorchScript語言自身是Python語言的一個子類, 因此它並非具有所有的Python語言特性
。 torch。jit。script能夠被
作為函式或裝飾器使用
。引數obj可以是class, function, nn。Module。
具體地,
指令碼化一個函式
:
torch。jit。script
裝飾器將會透過編譯函式被裝飾函式體來構造一個ScriptFunction物件。例如:
import torch
@torch。jit。script
def foo(x, y):
if x。max() > y。max():
r = x
else:
r = y
return r
print(type(foo)) # torch。jit。ScriptFuncion
# See the compiled graph as Python code
print(foo。code)
指令碼化一個nn.Module:
預設地編譯其forward方法,並遞迴地編譯其子模組以及被forward呼叫的函式。如果一個模組只使用TorchScript中支援的特性,則不需要更改原始模組程式碼。編譯器將構建ScriptModule,其中包含原始模組的屬性、引數和方法的副本。例如:
import torch
class MyModule(torch。nn。Module):
def __init__(self, N, M):
super(MyModule, self)。__init__()
# This parameter will be copied to the new ScriptModule
self。weight = torch。nn。Parameter(torch。rand(N, M))
# When this submodule is used, it will be compiled
self。linear = torch。nn。Linear(N, M)
def forward(self, input):
output = self。weight。mv(input)
# This calls the `forward` method of the `nn。Linear` module, which will
# cause the `self。linear` submodule to be compiled to a `ScriptModule` here
output = self。linear(output)
return output
scripted_module = torch。jit。script(MyModule(2, 3))
編譯一個不在forward中的方法以及遞迴地編譯其內的所有方法,可在此方法上使用裝飾器
torch。jit。export
為了忽視某些方法也可以使用裝飾器為了忽視某些方法也可以使用裝飾器
torch。jit。ignore
和
torch。jit。unused
import torch
import torch。nn as nn
class MyModule(nn。Module):
def __init__(self):
super(MyModule, self)。__init__()
@torch。jit。export
def some_entry_point(self, input):
return input + 10
@torch。jit。ignore
def python_only_fn(self, input):
# This function won‘t be compiled, so any
# Python APIs can be used
import pdb
pdb。set_trace()
def forward(self, input):
if self。training:
self。python_only_fn(input)
return input * 99
scripted_module = torch。jit。script(MyModule())
print(scripted_module。some_entry_point(torch。randn(2, 2)))
print(scripted_module(torch。randn(2, 2)))
torch.jit.trace(func,example_inputs,optimize=None,check_trace=True,check_inputs=None,check_tolerance=1e-5)
跟蹤一個函式並返回一個可執行的或ScriptFunction物件,將使用即時編譯(JIT)進行最佳化。跟蹤非常適合那些只操作單張量或張量的列表、字典和元組的程式碼。使用
torch。jit。trace
和
torch。jit。trace_module
,你能將一個模型或python函式轉為TorchScript中的
ScriptModule
或
ScriptFunction
。根據你提供的輸入樣例,它將會執行 該函式並記錄所有張量上執行的操作。
Tracing 僅僅正確地記錄那些不是資料依賴的函式和nn。Module(例如沒有對資料的條件判斷) 並且它們也沒有任何未跟蹤的外部依賴(例如執行輸入輸出或訪問全域性變數)。 Tracing 只記錄在給定張量上執行給定函式時所執行的操作。 因此,返回的ScriptModule將始終在任何輸入上執行相同的跟蹤圖。當你的模組需要根據輸入和/或模組狀態執行不同的操作集時,這就產生了一些重要的影響。例如:
Tracing不會記錄任何類似if語句或迴圈的控制流。當這個控制流在您的模組中是常量時,這是沒有問題的,並且它通常內聯了控制流決策。但有時控制流實際上是模型本身的一部分。例如,一個遞迴網路是一個輸入序列長度(可能是動態的)的迴圈。
在返回的ScriptModule中,無論ScriptModule處於哪種模式,在train和eval模式中具有不同行為的操作都將始終表現為處於跟蹤時所處的模式。
在這種情況下,Trace是不合適的,Script是更好的選擇。如果你跟蹤這樣的模型,您可能會在後續的模型呼叫中得到不正確的結果。當執行可能導致產生錯誤跟蹤的操作時,跟蹤程式將嘗試發出警告。
tracing a function:
import torch
def foo(x, y):
return 2 * x + y
# Run `foo` with the provided inputs and record the tensor operations
traced_foo = torch。jit。trace(foo, (torch。rand(3), torch。rand(3)))
# `traced_foo` can now be run with the TorchScript interpreter or saved
# and loaded in a Python-free environment
tracing a existing module
import torch
import torch。nn as nn
class Net(nn。Module):
def __init__(self):
super(Net, self)。__init__()
self。conv = nn。Conv2d(1, 1, 3)
def forward(self, x):
return self。conv(x)
n = Net()
example_weight = torch。rand(1, 1, 3, 3)
example_forward_input = torch。rand(1, 1, 3, 3)
# Trace a specific method and construct `ScriptModule` with
# a single `forward` method
module = torch。jit。trace(n。forward, example_forward_input)
# Trace a module (implicitly traces `forward`) and construct a
# `ScriptModule` with a single `forward` method
module = torch。jit。trace(n, example_forward_input)
torch.jit.trace_module(mod,inputs,optimize=None,check_trace=True,check_inputs=None,check_tolerance=1e-5)
跟蹤一個模組並返回一個可執行的ScriptModule,該指令碼模組將使用即時編譯進行最佳化。當一個模組被傳遞到
torch。jit。trace
,只執行和跟蹤forward方法。使用trace_module,您可以為要跟蹤的示例輸入指定一個方法名字典(參見下面的example_input引數)。
import torch
import torch。nn as nn
class Net(nn。Module):
def __init__(self):
super(Net, self)。__init__()
self。conv = nn。Conv2d(1, 1, 3)
def forward(self, x):
return self。conv(x)
def weighted_kernel_sum(self, weight):
return weight * self。conv。weight
n = Net()
example_weight = torch。rand(1, 1, 3, 3)
example_forward_input = torch。rand(1, 1, 3, 3)
# Trace a specific method and construct `ScriptModule` with
# a single `forward` method
module = torch。jit。trace(n。forward, example_forward_input)
# Trace a module (implicitly traces `forward`) and construct a
# `ScriptModule` with a single `forward` method
module = torch。jit。trace(n, example_forward_input)
# Trace specific methods on a module (specified in `inputs`), constructs
# a `ScriptModule` with `forward` and `weighted_kernel_sum` methods
inputs = {’forward‘ : example_forward_input, ’weighted_kernel_sum‘ : example_weight}
module = torch。jit。trace_module(n, inputs)
class torch.jit.ScriptModule
ScriptModule 封裝一個c++介面中的
torch::jit::Module
類, 有下列屬性及方法:
code
返回forward方法的內部圖的打印表示(具有有效的Python語法)
graph
返回forward方法的內部圖的字串表示形式
inlined_graph
返回forward方法的內部圖的字串表示形式。此圖將被預處理為內聯所有函式和方法呼叫。
save(f,_extra_files=ExtraFilesMap{})
class torch.jit.ScriptFunction
與上者類似
torch.jit.save(m,f,_extra_files=ExtraFilesMap{})
儲存此模組的離線版本,以便在單獨的程序中使用。所儲存的模組序列化此模組的所有方法、子模組、引數和屬性。它可以使用torch::jit::load(檔名)載入到c++ API中,也可以使用torch。jit。load載入到Python API中。為了能夠儲存模組,它必須不呼叫任何本機Python函式。這意味著所有子模組也必須是ScriptModule的子類。
所有模組,不管它們的裝置是什麼,總是在載入過程中載入到CPU上。這與torch.load()的語義不同,將來可能會改變。
import torch
import io
class MyModule(torch。nn。Module):
def forward(self, x):
return x + 10
m = torch。jit。script(MyModule())
# Save to file
torch。jit。save(m, ’scriptmodule。pt‘)
# This line is equivalent to the previous
m。save(“scriptmodule。pt”)
# Save to io。BytesIO buffer
buffer = io。BytesIO()
torch。jit。save(m, buffer)
# Save with extra files
extra_files = torch。_C。ExtraFilesMap()
extra_files[’foo。txt‘] = ’bar‘
torch。jit。save(m, ’scriptmodule。pt‘, _extra_files=extra_files)
torch.jit.load(f,map_location=None,_extra_files=ExtraFilesMap{})
載入先前用torch。jit。save儲存的ScriptModule或ScriptFunction所有之前儲存的模組,無論它們的裝置是什麼,都首先載入到CPU上,然後移動到它們儲存的裝置上。如果失敗(例如,因為執行時系統沒有特定的裝置),就會引發異常。
import torch
import io
torch。jit。load(’scriptmodule。pt‘)
# Load ScriptModule from io。BytesIO object
with open(’scriptmodule。pt‘, ’rb‘) as f:
buffer = io。BytesIO(f。read())
# Load all tensors to the original device
torch。jit。load(buffer)
# Load all tensors onto CPU, using a device
buffer。seek(0)
torch。jit。load(buffer, map_location=torch。device(’cpu‘))
# Load all tensors onto CPU, using a string
buffer。seek(0)
torch。jit。load(buffer, map_location=’cpu‘)
# Load with extra files。
extra_files = torch。_C。ExtraFilesMap()
extra_files[’foo。txt‘] = ’bar‘
torch。jit。load(’scriptmodule。pt‘, _extra_files=extra_files)
print(extra_files[’foo。txt‘])
torch.jit.ignore(drop=False, **kwargs)
這個裝飾器向編譯器表明,一個函式或方法應該被忽略,並保留為Python函式。這允許您在模型中保留尚未與TorchScript相容的程式碼。如果從TorchScript呼叫,被忽略的函式將把呼叫分派給Python直譯器。函式被忽略的模型不能匯出。使用drop=True引數時可以,但會丟擲異常。最好使用torch。jit。unused
import torch
import torch。nn as nn
class MyModule(nn。Module):
@torch。jit。ignore
def debugger(self, x):
import pdb
pdb。set_trace()
def forward(self, x):
x += 10
# The compiler would normally try to compile `debugger`,
# but since it is `@ignore`d, it will be left as a call
# to Python
self。debugger(x)
return x
m = torch。jit。script(MyModule())
# Error! The call `debugger` cannot be saved since it calls into Python
m。save(“m。pt”)
使用torch。jit。ignore(drop=True), 這一方法已被torch。jit。unused替代。
import torch
import torch。nn as nn
class MyModule(nn。Module):
@torch。jit。ignore(drop=True)
def training_method(self, x):
import pdb
pdb。set_trace()
def forward(self, x):
if self。training:
self。training_method(x)
return x
m = torch。jit。script(MyModule())
# This is OK since `training_method` is not saved, the call is replaced
# with a `raise`。
m。save(“m。pt”)
torch.jit.unused(fn)
這個裝飾器向編譯器表明,應該忽略一個函式或方法,並用引發異常來替換它。這允許您在模型中保留與TorchScript不相容的程式碼,同時仍然匯出模型。
import torch
import torch。nn as nn
class MyModule(nn。Module):
def __init__(self, use_memory_efficent):
super(MyModule, self)。__init__()
self。use_memory_efficent = use_memory_efficent
@torch。jit。unused
def memory_efficient(self, x):
import pdb
pdb。set_trace()
return x + 10
def forward(self, x):
# Use not-yet-scriptable memory efficient mode
if self。use_memory_efficient:
return self。memory_efficient(x)
else:
return x + 10
m = torch。jit。script(MyModule(use_memory_efficent=False))
m。save(“m。pt”)
m = torch。jit。script(MyModule(use_memory_efficient=True))
# exception raised
m(torch。rand(100))
混合Tracing和Scripting
在許多情況下,跟蹤或指令碼是將模型轉換為TorchScript的一種更簡單的方法。可以編寫跟蹤和指令碼來滿足模型某一部分的特定需求。
指令碼函式可以呼叫跟蹤函式。當您需要圍繞一個簡單的前饋模型使用控制流時,這一點特別有用。例如,序列到序列模型的波束搜尋通常用指令碼編寫,但可以呼叫使用跟蹤生成的編碼器模組。
例如在指令碼中呼叫跟蹤函式
import torch
def foo(x, y):
return 2 * x + y
traced_foo = torch。jit。trace(foo, (torch。rand(3), torch。rand(3)))
@torch。jit。script
def bar(x):
return traced_foo(x, x)
跟蹤函式也可以呼叫指令碼函式。當模型的一小部分需要一些控制流時,這是很有用的,即使大部分模型只是一個前饋網路。由跟蹤函式呼叫的指令碼函式中的控制流被正確儲存。
例如在跟蹤函式中呼叫指令碼函式
import torch
@torch。jit。script
def foo(x, y):
if x。max() > y。max():
r = x
else:
r = y
return r
def bar(x, y, z):
return foo(x, y) + z
traced_bar = torch。jit。trace(bar, (torch。rand(3), torch。rand(3), torch。rand(3)))
這個組合也適用於nn。Module。
import torch
import torchvision
class MyScriptModule(torch。nn。Module):
def __init__(self):
super(MyScriptModule, self)。__init__()
self。means = torch。nn。Parameter(torch。tensor([103。939, 116。779, 123。68])
。resize_(1, 3, 1, 1))
self。resnet = torch。jit。trace(torchvision。models。resnet18(),
torch。rand(1, 3, 224, 224))
def forward(self, input):
return self。resnet(input - self。means)
my_script_module = torch。jit。script(MyScriptModule())