torch 從 1.0 開始支持了 jit 模塊,其大概包括以下幾個部分:
一種新的計算圖中間表示 (Intermediate Representation),之後簡稱為 IR.
從 Python 代碼導出IR的兩種方法,即 trace 與 script.
IR 優化以及 IR 的解釋器(翻譯為具體的運算 op).
jit 的簡單介紹以及兩種導出方式的使用例子
jit 中 IR 的形式
導出 IR 的兩種方式,trace 與 script 的源碼解讀
IR 優化的簡單介紹
1 jit 的簡單介紹以及使用例子JIT 簡介
如前言,這篇解讀雖然標題是 JIT,但是真正稱得上即時編譯器的部分是在導出 IR 後,即優化 IR 計算圖,並且解釋為對應 operation 的過程,即 PyTorch jit 相關 code 帶來的優化一般是計算圖級別優化,比如部分運算的融合,但是對具體算子(如卷積)是沒有特定優化的,其依舊調用 torch 的基礎算子庫.
大家也可以在導出 IR 也就是 torchscript 後,使用其他的編譯優化或者解釋器,如現在也有 script to a TensorRT engine, TRTtorch( 轉 tensorRT 的方案。
import torchvision.models as modelsresnet = torch.jit.trace(models.resnet18(), torch.rand(1,3,224,224))output=resnet(torch.ones(1,3,224,224))print(output)output=resnet(torch.ones(1,3,224,224))'')

output 便是我們導出的中間表示,其可以 save 下來,在其他框架使用
我們可以看下 output 中的 IR,即 torchscript 表徵的計算圖是什麼樣子的。
graph(%self.1 : __torch__.torchvision.models.resnet.___torch_mangle_194.ResNet, %input.1 : Float(1:150528, 3:50176, 224:224, 224:1, requires_grad=0, device=cpu)): %1472 : __torch__.torch.nn.modules.linear.___torch_mangle_193.Linear = prim::GetAttr[name="fc"](%self.1) %1469 : __torch__.torch.nn.modules.pooling.___torch_mangle_192.AdaptiveAvgPool2d = prim::GetAttr[name="avgpool"](%self.1) %1468 : __torch__.torch.nn.modulesjieshao.container.___torch_mangle_191.Sequential = prim::GetAttr[name="layer4"](%self.1) %1422 : __torch__.torch.nn.modules.container.___torch_mangle_175.Sequential = prim::GetAttr[name="layer3"](%self.1) .... %1556 : Tensor = prim::CallMethod[name="forward"](%1469, %1555) %1202 : int = prim::Constant[value=1]() %1203 : int = prim::Constant[value=-1]() %input : Float(1:512, 512:1, requires_grad=1, device=cpu) = aten::flatten(%1556, %1202, %1203) %1557 : Tensor = prim::CallMethod[name="forward"](%1472, %input) return (%1557)

這便是 trace 方法的使用,其核心實現的入口便是torch.jit.trace,參數為你需要導出的 model,以及合法輸入 input,其大概原理恰如其名,便是跟蹤模型 inference 過程,將模型對輸入進行的操作逐一記錄下來,並對應到 IR 的操作,從而得到原本模型 forward 的 IR。
ote:但是這種實現方式有很明顯的缺陷,PyTorch 作為動態圖網絡,會有很多的 input dependent 的控制流語句,根據輸入的不同可能會執行情況會不同(if 或者 變長的 loop),這樣就無法 trace 到完整的計算圖。如下就是一個 trace
失敗的 case:
if x > 2.0: r = torch.tensor(1.0) else: r = torch.tensor(2.0) return rftrace = torch.jit.trace(test, (torch.ones(1)))y = torch.ones(1) * 5print(ftrace(y))# results: tensor(2.)# 因為輸入只走了的分支else
@torch.jit.scriptdef foo(x, y): if x.max() > y.max(): r = x else: r = y return rprint(foo.graph)print(foo(torch.Tensor([0]), torch.Tensor([1])))print(foo(torch.Tensor([1]), torch.Tensor([0])))graph(%x.1 : Tensor, %y.1 : Tensor): %3 : Tensor = aten::max(%x.1) %5 : Tensor = aten::max(%y.1) # 可以看到確實捕捉到了控制語句, %6 : Tensor = aten::gt(%3, %5) %7 : bool = aten::Bool(%6) %r : Tensor = prim::If(%7) block0(): -> (%x.1) block1(): -> (%y.1) return (%r)tensor([1.])tensor([1.])

script 使用是在你需要的地方 (fuction or nn.Module (默認追蹤 forward 函數))掛載裝飾器torch.jit.script,其轉換方式跟 trace 是完全不同的思路,script 直接解析你的 PyTorch 代碼,通過語法分析解析你的邏輯為一棵語法樹,然後轉換為中間表示 IR。
Note: 雖然其可以解決 trace 存在無法追蹤動態邏輯的問題,但是 Python 作為靈活度極高的語法, 想完整支持解析各種 Python 操作幾乎是不可能的,因此我們需要額外的時間熟悉哪些寫法是可以被解析的,讓我們寫代碼的體驗大打折扣。
import torchimport torch.nn as nnimport torch.nn.functional as Fclass MyModule(nn.Module): def __init__(self): super(MyModule, self).__init__() # torch.jit.trace produces a ScriptModule's conv1 and conv2 self.conv1 = torch.jit.trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16)) self.conv2 = torch.jit.trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16)) def forward(self, input): input = F.relu(self.conv1(input)) input = F.relu(self.conv2(input)) return inputscripted_module = torch.jit.script(MyModule())

1 大部分情況 model 只有 tensor operation,就直接無腦 tracing2 帶 control-flow (if-else, for-loop) 的,上 scripting3 碰上 scripting 不能 handle 的語法,要麼重寫,要麼把 tracing 和 scripting 合起來用(比如說只在有 control-flow 的代碼用 scripting,其他用 tracing)
trace 與 script 都不能轉換第三方 Python 庫中的函數,儘量所有代碼都使用 PyTorch 實現, 自定義 op 需要註冊成 jit 操作( torch 的 op 其實也註冊了),最後轉成 torchscript。
TORCH_LIBRARY(my_ops, m) { m.def("warp_perspective", warp_perspective);}
2 IR (torchscript)的基本表示
PyTorch 中的各種設計(parameter,計算節點等)在 torchscript 中是如何對應的呢?
這便是轉換出的 IR 結果,torchscrip 以下結構組合。
source code
對標 nn.Module
對標 PyTorch 的 parameter
包括 FunctionSchema 方法描述,Graph 實際計算圖,GraphExecutor do the optimization and execution
定義 function 的具體實現,包括 Nodes,Blocks,Values
控制語句 if,loop + list of nodes
# %x.1 valuegraph(%x.1 : Tensor, %y.1 : Tensor): # aten::max 就是一個Node # Tensor: Type-TensorType %3 : Tensor = aten::max(%x.1) %5 : Tensor = aten::max(%y.1) %6 : Tensor = aten::gt(%3, %5) %7 : bool = aten::Bool(%6) %r : Tensor = prim::If(%7) # Blocks block0(): -> (%x.1) block1(): -> (%y.1) return (%r)

3 導出 IR 的兩種方式,trace 與 script
因為其具體實現頗為複雜,粘貼的源碼也僅僅保留了簡單 case 跑過的分支,並且省去了絕大部分細節,讀者如有需要更多細節可以自行去源碼查閱。
trace 實現
func, example_inputs, optimize=None, check_trace=True, check_inputs=None, check_tolerance=1e-5, strict=True, _force_outplace=False, _module_class=None, _compilation_unit=_python_cu,): # 發現是nn.Module instacene forward, 追蹤forward if isinstance(func, torch.nn.Module): return trace_module( func, {"forward": example_inputs}, None, check_trace, wrap_check_inputs(check_inputs), check_tolerance, strict, _force_outplace, _module_class, ) # 傳進來的是某個module instance的forward if ( hasattr(func, "__self__") and isinstance(func.__self__, torch.nn.Module) and func.__name__ == "forward" ): return trace_module( func.__self__, {"forward": example_inputs}, None, check_trace, wrap_check_inputs(check_inputs), check_tolerance, strict, _force_outplace, _module_class, ) # 一個查找變量名的接口 var_lookup_fn = _create_interpreter_name_lookup_fn(0) # C++ 入口 traced = torch._C._create_function_from_trace( name, func, example_inputs, var_lookup_fn, strict, _force_outplace ) # 檢查traced 與 原func是否有差異 if check_trace: if check_inputs is not None: _check_trace( check_inputs, func, traced, check_tolerance, strict, _force_outplace, False, _module_class, ) else: _check_trace( [example_inputs], func, traced, check_tolerance, strict, _force_outplace, False, _module_class, ) return traced

我們發現經過簡單的判斷,代碼便進入了 C++ 相關函數
traced = torch._C._create_function_from_trace( name, func, example_inputs, var_lookup_fn, strict, _force_outplace)

我們去 C++ 中看下發生了什麼
std::pair<std::shared_ptr<TracingState>, Stack> trace( Stack inputs, const std::function<Stack(Stack)>& traced_fn, std::function<std::string(const Variable&)> var_name_lookup_fn, bool strict, bool force_outplace, Module* self) { try { auto state = std::make_shared<TracingState>(); # setTracingState 將state 這個實例set下來,在之後計算節點get出來insert計算過程 setTracingState(state); #state這個數據結構會在forward過程中存儲trace到的計算過程 if (self) { Value* self_value = state->graph->insertInput(0, "self")->setType( self->_ivalue()->type()); gatherParametersAndBuffers(state, self_value, *self, {"__module"}); } for (IValue& input : inputs) { input = addInput(state, input, input.type(), state->graph->addInput()); } auto graph = state->graph; # 將python中的變量名解析函數綁定下來 getTracingState()->lookup_var_name_fn = std::move(var_name_lookup_fn); getTracingState()->strict = strict; getTracingState()->force_outplace = force_outplace; # 開始forward,在計算發生時,會把計算記錄到state中 auto out_stack = traced_fn(inputs); // Exit a trace, treating 'out_stack' as the outputs of the trace. These // are the variables whose values will be computed upon subsequent // invocations of the trace. size_t i = 0; for (auto& output : out_stack) { // NB: The stack is in "reverse" order, so when we pass the diagnostic // number we need to flip it based on size. state->graph->registerOutput( state->getOutput(output, out_stack.size() - i)); i++; } setTracingState(nullptr); if (getInlineEverythingMode()) { Inline(*graph); } FixupTraceScopeBlocks(graph, self); NormalizeOps(graph); return {state, out_stack}; } catch (...) { tracer::abandon(); throw; }}

那麼具體記錄 operation 的過程發生在哪裡呢?
Operator createOperatorFromC10_withTracingHandledHere( const c10::OperatorHandle& op) { return Operator(op, [op](Stack& stack) { const auto input_size = op.schema().arguments().size(); const auto output_size = op.schema().returns().size(); Node* node = nullptr; std::shared_ptr<jit::tracer::TracingState> tracer_state; // trace the input before unwrapping, otherwise we may lose // the input information if (jit::tracer::isTracing()) { # 獲取 tracer_state tracer_state = jit::tracer::getTracingState(); auto symbol = Symbol::fromQualString(op.schema().name()); const auto& graph = tracer::getTracingState()->graph; node = graph->create(symbol, 0); tracer::recordSourceLocation(node); const auto& args = op.schema().arguments(); int i = 0; # 記錄args for (auto iter = stack.end() - input_size; iter != stack.end(); ++iter, ++i) { // TODO we need to refactor graph APIs (e.g., addInputs) // appropriately; after that, we can get rid of the giant if-else // block we will clean this tech debt together in the following PRs auto type = args[i].type(); if (type->kind() == TypeKind::OptionalType) { if (iter->isNone()) { Value* none = graph->insertNode(graph->createNone())->output(); node->addInput(none); continue; } else { type = type->expect<OptionalType>()->getElementType(); } } if (type->isSubtypeOf(TensorType::get())) { AT_ASSERT(iter->isTensor()); tracer::addInputs(node, args[i].name().c_str(), iter->toTensor()); } else if (type->kind() == TypeKind::FloatType) { AT_ASSERT(iter->isDouble()); tracer::addInputs(node, args[i].name().c_str(), iter->toDouble()); } else if (type->kind() == TypeKind::IntType) { AT_ASSERT(iter->isInt()); tracer::addInputs(node, args[i].name().c_str(), iter->toInt()); } else if (type->kind() == TypeKind::BoolType) { AT_ASSERT(iter->isBool()); tracer::addInputs(node, args[i].name().c_str(), iter->toBool()); } else if (type->kind() == TypeKind::StringType) { AT_ASSERT(iter->isString()); tracer::addInputs(node, args[i].name().c_str(), iter->toStringRef()); } else if (type->kind() == TypeKind::NumberType) { tracer::addInputs(node, args[i].name().c_str(), iter->toScalar()); } else if (type->kind() == TypeKind::ListType) { const auto& elem_type = type->expect<ListType>()->getElementType(); if (elem_type->isSubtypeOf(TensorType::get())) { AT_ASSERT(iter->isTensorList()); auto list = iter->toTensorVector(); tracer::addInputs(node, args[i].name().c_str(), list); } else if (elem_type->kind() == TypeKind::FloatType) { AT_ASSERT(iter->isDoubleList()); // NB: now, tracer doesn't support tracing double list. We add // special handling here, since in our case, we assume that all the // doubles in the list are constants auto value = iter->toDoubleVector(); std::vector<Value*> info(value.size()); for (size_t value_index = 0; value_index < value.size(); ++value_index) { info[value_index] = graph->insertConstant(value[value_index]); tracer::recordSourceLocation(info[value_index]->node()); } node->addInput( graph ->insertNode(graph->createList(jit::FloatType::get(), info)) ->output()); } else if (elem_type->kind() == TypeKind::IntType) { AT_ASSERT(iter->isIntList()); tracer::addInputs( node, args[i].name().c_str(), iter->toIntVector()); } else if (elem_type->kind() == TypeKind::BoolType) { AT_ASSERT(iter->isBoolList()); tracer::addInputs( node, args[i].name().c_str(), iter->toBoolList().vec()); } else { throw std::runtime_error( "unsupported input list type: " + elem_type->str()); } } else if (iter->isObject()) { tracer::addInputs(node, args[i].name().c_str(), iter->toObject()); } else { throw std::runtime_error("unsupported input type: " + type->str()); } } # node嵌入graph graph->insertNode(node); jit::tracer::setTracingState(nullptr); }

可以看到,在具體運算發生時,會使用 getTracingState() 得到 forward 開始去創建的 state,然後看到根據 op.schema().name() 得到計算類型(比如相加),根據計算類型通過 createNone 方法創建一個計算節點,然後創建計算輸入,最後把計算 node insert 到 graph 中,完成一次對計算的記錄。
因為 script 得到 IR 的方式是解析源碼,因此對於不同的代碼形式會略有不同(函數,class,nn.Module的instance):1 Python 函數 簡化後 code
def script(obj, optimize=None, _frames_up=0, _rcb=None): # fucntion 分支 if hasattr(obj, "__script_if_tracing_wrapper"): obj = obj.__original_fn _rcb = _jit_internal.createResolutionCallbackFromClosure(obj) # 檢查重載 _check_directly_compile_overloaded(obj) # 是否之前被script過了 maybe_already_compiled_fn = _try_get_jit_cached_function(obj) if maybe_already_compiled_fn: return maybe_already_compiled_fn # 得到ast語法樹 ast = get_jit_def(obj, obj.__name__) if _rcb is None: _rcb = _jit_internal.createResolutionCallbackFromClosure(obj) #c++ 入口,根據ast得到ir fn = torch._C._jit_script_compile( qualified_name, ast, _rcb, get_default_args(obj) ) # Forward docstrings fn.__doc__ = obj.__doc__ # cache起來 _set_jit_function_cache(obj, fn) return fn

我們看下 get_jit_def( 是如何得到 jit 規定的 ast 語法樹的
def get_jit_def(fn, def_name, self_name=None): # 得到源代碼的一些信息 sourcelines, file_lineno, filename = get_source_lines_and_file(fn, torch._C.ErrorReport.call_stack()) sourcelines = normalize_source_lines(sourcelines) source = dedent_src ''.join(sourcelines) # dedent_src 為包含了要script函數的字符串 dedent_src = dedent(source) # 調用python ast包將字符串解析為Python的ast py_ast = ast.parse(dedent_src) # 得到python類型注釋 type_line = torch.jit.annotations.get_type_line(source) #ctx中包含了函數所有原信息 ctx = SourceContext(source, filename, file_lineno, leading_whitespace_len, True) fn_def = py_ast.body[0] # build_def將python 的ast 轉化為torchjit 使用的ast格式 return build_def(ctx, fn_def, type_line, def_name, self_name=self_name)

用一個簡單的例子給大家解釋下 py_ast.body[0] 是什麼
import ast... func_def= \... """def test(a):... a = a + 2... return a + 1"""... results = ast.parse(func_def)

Python 解析出的 AST
可見,ast.body 是一個 list,其長度等於解析的 string 中包含的函數的個數,我們看第一個元素,其中 value 是一個
Binop具體為一個Add,left 是Name類型,id為 `a,right是Num,也就是2,這個Binop即解析的a = a + 2。
因為我們 get_source_lines_and_file 返回的一定是一個 single top-level function, 因此我們直接取用第 0 個元素,即 py_ast.body[0] 就可以了。
接下來看build_def是如何將 Python 的 ast 轉化為自己需要的 ast 的。
def build_def(ctx, py_def, type_line, def_name, self_name=None): .... return Def(Ident(r, def_name), decl, build_stmts(ctx, body))

因為ctx 包含 source code 所有信息, body 是 Python ast 解析結果,那麼build_stmts中應該包含我們想要的答案。
from torch._C._jit_tree_views import ( ClassDef, Ident, Stmt, Decl, Def, Var, EmptyTypeAnnotation, Param, ExprStmt, Assign, Delete, Return, Raise, Assert, AugAssign, While, For, If, Pass, Break, Continue, Apply, Dots, Select, TrueLiteral, FalseLiteral, NoneLiteral, Starred, ListLiteral, TupleLiteral, DictLiteral, Const, StringLiteral, ListComp, Attribute, BinOp, UnaryOp, SliceExpr, Subscript, TernaryIf, With, WithItem, Property, DictComp,)# jit中定義的ast基本結構def build_stmts(ctx, stmts): #發現其調用了`build_stmt` stmts = [build_stmt(ctx, s) for s in stmts] return list(filter(None, stmts))#`build_stmt` 是一個StmtBuilder()的instancebuild_stmt = StmtBuilder()build_expr = ExprBuilder()class Builder(object): def __call__(self, ctx, node): # 可見會根據解析出的ast的類型返回相應的build方法,從截圖可以看到`a+2`是一個`Assign`類型 # 因此會調用build_Assign method = getattr(self, 'build_' + node.__class__.__name__, None) if method is None: raise UnsupportedNodeError(ctx, node) return method(ctx, node)class StmtBuilder(Builder): @staticmethod def build_Assign(ctx, stmt): # 截圖可以看到stmt.value是一個Binop # build_expr是ExprBuilder的INSTANCE,其會調用`build_BinOp` rhs = build_expr(ctx, stmt.value) lhs = [build_expr(ctx, x) for x in stmt.targets] return Assign(lhs, rhs) @staticmethod def build_Expr(ctx, stmt): # Binop value = stmt.value if value.__class__.__name__ == 'Str': # If a statement is a string literal expression, # then it is a docstring. Just ignore it. return None else: return ExprStmt(build_expr(ctx, value)) class ExprBuilder(Builder): binop_map = { ast.Add: '+', ast.Sub: '-', ast.Mult: '*', ast.Div: '/', ast.Pow: '**', ast.Mod: '%', ast.FloorDiv: '//', ast.BitAnd: '&', ast.BitXor: '^', ast.BitOr: '|', ast.LShift: '<<', ast.RShift: '>>', } @staticmethod def build_BinOp(ctx, expr): #expr.left是個`Name`調用build_Name lhs = build_expr(ctx, expr.left) rhs = build_expr(ctx, expr.right) op = type(expr.op) # 轉化為約定的代表運算類型的string 符號 op_token = ExprBuilder.binop_map.get(op) return BinOp(op_token, lhs, rhs)

最終轉化為的格式,類似於 S-expression.(
(def (ident test) (decl (list (param (ident a) (option) (option) (False))) (option)) (list (assign (list (variable (ident a))) (option (+ (variable (ident a)) (const 2))) (option)) (return (+ (variable (ident a)) (const 1)))))

好的,我們已經得到得到jit約定的 AST 樹了,接下來我們要進入 torch._C._jit_script_compile查看如何將這樣的 ast 樹轉化為 IR.
C++ 入口為 script_compile_function
static StrongFunctionPtr script_compile_function( const c10::QualifiedName& name, const Def& def, const FunctionDefaults& defaults, const ResolutionCallback& rcb) { # def 中包含ast,跟着它就能找到答案 auto cu = get_python_cu(); #看來是get_python_cu這個類中的define函數完成的 auto defined_functions = cu->define( QualifiedName(name.prefix()), /*properties=*/{}, /*propResolvers=*/{}, {def}, {pythonResolver(rcb)}, nullptr, true); TORCH_INTERNAL_ASSERT(defined_functions.size() == 1); auto& defined = defined_functions[0]; defined->setSchema(getSchemaWithNameAndDefaults( def.range(), defined->getSchema(),, defaults)); StrongFunctionPtr ret(std::move(cu), defined); didFinishEmitFunction(ret); return ret;}# 發現只是wapper了下CompilationUnitinline std::shared_ptr<CompilationUnit> get_python_cu() { return py::module::import("torch.jit._state") .attr("_python_cu") .cast<std::shared_ptr<CompilationUnit>>();}#關於compilation_unit#/torch/csrc/jit/api/compilation_unit.h // for historic reasons, these are defined in ir_emitter.cpp // Returns the list of Functions just defined. std::vector<Function*> define( const c10::optional<c10::QualifiedName>& prefix, const std::vector<Property>& properties, const std::vector<ResolverPtr>& propResolvers, const std::vector<Def>& definitions, const std::vector<ResolverPtr>& defResolvers, /* determines how we handle free variables in each definition*/ // if non-null, the first argument to each def, is bound to this value const Self* self, // see [name mangling] bool shouldMangle = false);#實現在torch/csrc/jit/frontend/ir_emitter.cppstd::unique_ptr<Function> CompilationUnit::define( const c10::optional<QualifiedName>& prefix, const Def& def, const ResolverPtr& resolver, const Self* self, const std::unordered_map<std::string, Function*>& function_table, bool shouldMangle) const { auto _resolver = resolver; ..... auto creator = [def, _resolver, self](Function& method) { .... ##核心代碼to_ir to_ir(def, _resolver, self, method); }; auto fn = torch::make_unique<GraphFunction>( std::move(name), std::make_shared<Graph>(), creator); return fn;}

我們跟隨 def,找到了一個轉化為 IR 的關鍵的struct to_ir ,其輸入中有 def,也就是 ast,_resolver 是 Python 中傳過來的解析名字的函數,我們可以在內部找到關鍵部分
to_ir( const Def& def, ResolverPtr resolver_, const Self* self, Function& method) // method being constructed : method(method), graph(method.graph()), resolver(std::move(resolver_)), typeParser_(resolver), environment_stack(nullptr) { AT_ASSERT(resolver); pushFrame(graph->block(), /*starts_def=*/true); #emitDef 中會調用emitStatements method.setSchema(emitDef(def, self, graph->block())); ConvertToSSA(graph); CanonicalizeModifiedLoops(graph); NormalizeOps(graph); runCleanupPasses(graph); }private: #在to_ir 的private中我們可以看到Graph Function這些我們之前介紹的IR的組成部分 Function& method; std::shared_ptr<Graph> graph; ResolverPtr resolver; std::unordered_map<int64_t, Value*> integral_constants; #emitDef 中會調用emitStatements FunctionSchema emitDef(const Def& def, const Self* self, Block* block) { ...... // body auto stmts_list = def.statements(); emitStatements(stmts_list.begin(), stmts_list.end()); ........ } void emitStatements( List<Stmt>::const_iterator begin, List<Stmt>::const_iterator end) { for (; begin != end; ++begin) { auto stmt = *begin; ErrorReport::CallStack::update_pending_range(stmt.range()); switch (stmt.kind()) { case TK_IF: emitIf(If(stmt)); break; case TK_WHILE: emitWhile(While(stmt)); break; case TK_FOR: emitFor(For(stmt)); break; case TK_ASSIGN: emitAssignment(Assign(stmt)); ................. break; default: throw ErrorReport(stmt) << "Unrecognized statement kind " << kindToString(stmt.kind()); } // Found an exit statement in this block. The remaining statements aren't // reachable so we don't emit them. if (exit_blocks.count(environment_stack->block())) return; } }我們可以看到根據stmt.kind(),會進入而各種emit裡面,其中一定可以找到graph->insertNode(graph->create(.....));類似的操作,對應我們建立IR graph

以上是我們以一個 function 為例子,接下來我們以 script 一個 module 為例,其有一些獨有的挑戰,因為有一些變量的指代,是需要初始化後才知道的,同時,我們希望 script 完的 module 對外還能保持一樣的接口,即可以正常訪問原有 module 的屬性,那麼應該怎麼做呢?
在 module 原有的 init 結束後隨即開始完整的 script forward 函數,替換涉及到的所有函數為 script 後的函數
如何在一個類的 init 函數後面綁定行為呢,我們想到 metaclass,torch.jit 實現了 ScriptMeta這個 metaclass。
class MyModule(torch.jit.ScriptModule): @torch.jit.script_method def f(self.x): return x * x @torch.jit.script_method def forward(self, x): return x + self.f(x)關於script_methoddef script_method(fn): _rcb = _jit_internal.createResolutionCallbackFromFrame(frames_up=2) ast = get_jit_def(fn, fn.__name__, self_name="ScriptModule") #暫時沒有script,只是返回包含ast的nametuple return ScriptMethodStub(_rcb, ast, fn) ScriptMethodStub = collections.namedtuple('ScriptMethodStub', ('resolution_callback', 'def_', 'original_method'))1 移除所有script_method屬性被(@script_method修飾的方法),確保訪問到的是script function2 修改module的_init_,確保module的self.param或者self.module初始化後立即編譯所有的script_method,從而生成的instance的forward已經被替換class ScriptMeta(type): def __init__(cls, name, bases, attrs): # noqa: B902 # cls ScriptMeta的instance,是一個類如ScriptModule cls._methods: Dict[str, Any] = {} cls._constants_set = set(getattr(cls, "__constants__", ())) for base in reversed(bases): # 還記得嗎trace的module也是有一個_methods的屬性 for k, v in getattr(base, "_methods", {}).items(): cls._methods[k] = v base_constants = getattr(base, "_constants_set", set()) cls._constants_set = cls._constants_set.union(base_constants) # 找到現在所有被@script_method修飾的方法,放到_method,並刪除原有attr # init後之後統一script for k, v in sorted(attrs.items()): if isinstance(v, ScriptMethodStub): delattr(cls, k) cls._methods[v.original_method.__name__] = v original_init = getattr(cls, "__init__", lambda self: None) # 此處實現了init結束後,調用create_script_module進行script @functools.wraps(original_init) def init_then_script(self, *args, **kwargs): # 此處的self為instance num_methods = len(cls._methods) original_init(self, *args, **kwargs) added_methods_in_init = len(cls._methods) > num_methods if type(self) == cls: # 選取需要script的method def make_stubs(module): cls = type(module) if hasattr(cls, "_methods"): return [v for k, v in sorted(cls._methods.items())] else: # infer_methods_to_compile 是一個選取要script函數的函數 return infer_methods_to_compile(module) # 講所有script_method一塊編譯為_actual_script_module屬性 self.__dict__[ "_actual_script_module" ] = torch.jit._recursive.create_script_module(self, make_stubs, share_types=not added_methods_in_init) # Delete the Python attributes that now shadow the ScriptModule # ones, so that __getattr__ and __setattr__ will properly find # the scripted versions. concrete_type = self._actual_script_module._concrete_type for name in concrete_type.get_attributes(): delattr(self, name) for name, _ in concrete_type.get_modules(): delattr(self, name) for name in ("_parameters", "_buffers", "_modules"): delattr(self, name) cls.__init__ = init_then_script # type: ignore return super(ScriptMeta, cls).__init__(name, bases, attrs) class _CachedForward(object): def __get__(self, obj, cls): return self.__getattr__("forward") # type: ignore class ScriptModule(with_metaclass(ScriptMeta, Module)): # type: ignore def __init__(self): super(ScriptModule, self).__init__() forward = _CachedForward() # 想訪問module的attr,返回_actual_script_module的attr def __getattr__(self, attr): if "_actual_script_module" not in self.__dict__: return super(ScriptModule, self).__getattr__(attr) return getattr(self._actual_script_module, attr) def __setattr__(self, attr, value): if "_actual_script_module" not in self.__dict__: # Unwrap torch.jit.Attribute into a regular setattr + recording # the provided type in __annotations__. # # This ensures that if we use the attr again in `__init__`, it # will look like the actual value, not an instance of Attribute. if isinstance(value, Attribute): if "__annotations__" not in self.__class__.__dict__: self.__class__.__annotations__ = {} self.__annotations__[attr] = value.type value = value.value return super(ScriptModule, self).__setattr__(attr, value) setattr(self._actual_script_module, attr, value)...
關於 create_script_module 函數會 script method 然後返回一個 RecursiveScriptModule,但是其邏輯較為複雜,在此不再展開。
關於 getattribute vs getattr
當訪問某個實例屬性時,getattribute 會被無條件調用,當這個屬性不存在,則會調用 getattr,如未實現自己的 getattr 方法,會拋出 AttributeError 提示找不到這個屬性,如果自定義了自己 getattr 方法的話方法會在這種找不到屬性的情況下被調用。
4 IR優化的簡單介紹
jit 一般涉及如下優化: loop unrolling peephole optimization constant propagation DCE fusion inlining... 我們看如下例子:
def test(x): # Dead code Elimination for i in range(1000): y = x + 1 for i in range(100): #peephole optimization x = x.t() x = x.t() return x.sum()opt_test = torch.jit.script(test)s = time()inputs = torch.ones(4,4).cuda()s = time()for i in range(10000): test(inputs)print(time()-s)# 95ss = time()for i in range(10000): opt_test(inputs)print(time()-s)# 0.13sprint(opt_test.graph)print(opt_test.graph_for(inputs))95.138237953186040.13010907173156738graph(%x.1 : Tensor): %22 : None = prim::Constant() %13 : bool = prim::Constant[value=1]() # /home/SENSETIME/zhangshilong/PycharmProjects/pythonProject/ %10 : int = prim::Constant[value=100]() # /home/SENSETIME/zhangshilong/PycharmProjects/pythonProject/ %x : Tensor = prim::Loop(%10, %13, %x.1) # /home/SENSETIME/zhangshilong/PycharmProjects/pythonProject/ block0(%i : int, %x.10 : Tensor): %x.4 : Tensor = aten::t(%x.10) # /home/SENSETIME/zhangshilong/PycharmProjects/pythonProject/ %x.7 : Tensor = aten::t(%x.4) # /home/SENSETIME/zhangshilong/PycharmProjects/pythonProject/ -> (%13, %x.7) %23 : Tensor = aten::sum(%x, %22) # /home/SENSETIME/zhangshilong/PycharmProjects/pythonProject/ return (%23)graph(%x.1 : Tensor): %1 : None = prim::Constant() %2 : Tensor = aten::sum(%x.1, %1) # /home/SENSETIME/zhangshilong/PycharmProjects/pythonProject/ return (%2)

關於 IR 計算圖優化
IR 的 Method 中內置 GraphExecutor object,創建於第一次執行的時候,負責優化。文件 pytorch-master/torch/csrc/jit/api/method.h scritp_method 的 C++ 原型里
GraphExecutor& get_executor() { return function_->get_executor(); }

GraphExecutor 的定義在/torch/csrc/jit/runtime/graph_executor.cpp,可見其由 graph 產生,定義了 run 方法執行
GraphExecutor::GraphExecutor( const std::shared_ptr<Graph>& graph, std::string function_name) : pImpl( IsNewExecutorEnabled() ? dynamic_cast<GraphExecutorImplBase*>( new ProfilingGraphExecutorImpl( graph, std::move(function_name))) : dynamic_cast<GraphExecutorImplBase*>( new GraphExecutorImpl(graph, std::move(function_name)))) {}std::shared_ptr<Graph> GraphExecutor::graph() const { return pImpl->graph;}const ExecutionPlan& GraphExecutor::getPlanFor( Stack& inputs, size_t remaining_bailout_depth) { return pImpl->getPlanFor(inputs, remaining_bailout_depth);} std::shared_ptr<GraphExecutorImplBase> pImpl;.....關於GraphExecutorImplBase,/torch/csrc/jit/runtime/graph_executor.cppconst ExecutionPlan& getOrCompile(const Stack& stack) { ..... auto plan = compileSpec(spec); } }# compileSpec 會返回一個planExecutionPlan compileSpec(const ArgumentSpec& spec) { auto opt_graph = graph->copy(); GRAPH_DUMP("Optimizing the following function:", opt_graph); arg_spec_creator_.specializeTypes(*opt_graph, spec); // Phase 0. Inline functions, then clean up any artifacts that the inliner // left in that may inhibit optimization ..... runRequiredPasses(opt_graph); GRAPH_DEBUG( "After runRequiredPasses, before ConstantPropagation\n", *opt_graph); // Phase 2. Propagate detailed information about the spec through the // graph (enabled more specializations in later passes). // Shape propagation sometimes depends on certain arguments being // constants, and constant propagation doesn't need shape // information anyway, so it's better to run it first. ConstantPropagation(opt_graph); GRAPH_DEBUG( "After ConstantPropagation, before PropagateInputShapes\n", *opt_graph); PropagateInputShapes(opt_graph); GRAPH_DEBUG( "After PropagateInputShapes, before PropagateRequiresGrad\n", *opt_graph); PropagateRequiresGrad(opt_graph); GRAPH_DEBUG( "After PropagateRequiresGrad, before runOptimization\n", *opt_graph); // Phase 3. Run differentiable optimizations (i.e. simple graph rewrites // that we can still execute using autograd). runOptimization(opt_graph); .....各種優化 return ExecutionPlan(opt_graph, function_name_); }

這些優化在 torch/csrc/jit/passes/ 文件夾 torch/csrc/jit/passes/dead_code_elimination.cpp /torch/csrc/jit/passes/fuse_linear.cpp torch/csrc/jit/passes/remove_dropout.cpp torch/csrc/jit/passes/fold_conv_bn.cpp

2. PyTorch 部署_TorchScript(
3. pytorch_wiki(
4. PyTorch-JIT-Source-Code-Read-Note(
5. Abstract_syntax_tree(


