來源丨https://zhuanlan.zhihu.com/p/361101354本文對torch中的jit模塊進行了詳細的解讀,主要介紹了jit的兩種到處方式的使用例子、IR的形式、導出IR的兩種方式的源碼解讀以及對IR優化的簡單介紹。>>加入極市CV技術交流群,走在計算機視覺的最前沿
前言torch 從 1.0 開始支持了 jit 模塊,其大概包括以下幾個部分:一種新的計算圖中間表示 (Intermediate Representation),之後簡稱為 IR.從 Python 代碼導出IR的兩種方法,即 trace 與 script.IR 優化以及 IR 的解釋器(翻譯為具體的運算 op).導出 IR 的兩種方式,trace 與 script 的源碼解讀1 jit 的簡單介紹以及使用例子JIT 簡介如前言,這篇解讀雖然標題是 JIT,但是真正稱得上即時編譯器的部分是在導出 IR 後,即優化 IR 計算圖,並且解釋為對應 operation 的過程,即 PyTorch jit 相關 code 帶來的優化一般是計算圖級別優化,比如部分運算的融合,但是對具體算子(如卷積)是沒有特定優化的,其依舊調用 torch 的基礎算子庫.大家也可以在導出 IR 也就是 torchscript 後,使用其他的編譯優化或者解釋器,如現在也有 script to a TensorRT engine, TRTtorch(https://github.com/NVIDIA/TRTorch) 轉 tensorRT 的方案。traceimport torchvision.models as modelsresnet = torch.jit.trace(models.resnet18(), torch.rand(1,3,224,224))output=resnet(torch.ones(1,3,224,224))print(output)output=resnet(torch.ones(1,3,224,224))resnet.save('resnet.pt')
output 便是我們導出的中間表示,其可以 save 下來,在其他框架使用我們可以看下 output 中的 IR,即 torchscript 表徵的計算圖是什麼樣子的。graph(%self.1 : __torch__.torchvision.models.resnet.___torch_mangle_194.ResNet, %input.1 : Float(1:150528, 3:50176, 224:224, 224:1, requires_grad=0, device=cpu)): %1472 : __torch__.torch.nn.modules.linear.___torch_mangle_193.Linear = prim::GetAttr[name="fc"](%self.1) %1469 : __torch__.torch.nn.modules.pooling.___torch_mangle_192.AdaptiveAvgPool2d = prim::GetAttr[name="avgpool"](%self.1) %1468 : __torch__.torch.nn.modulesjieshao.container.___torch_mangle_191.Sequential = prim::GetAttr[name="layer4"](%self.1) %1422 : __torch__.torch.nn.modules.container.___torch_mangle_175.Sequential = prim::GetAttr[name="layer3"](%self.1) .... %1556 : Tensor = prim::CallMethod[name="forward"](%1469, %1555) %1202 : int = prim::Constant[value=1]() %1203 : int = prim::Constant[value=-1]() %input : Float(1:512, 512:1, requires_grad=1, device=cpu) = aten::flatten(%1556, %1202, %1203) %1557 : Tensor = prim::CallMethod[name="forward"](%1472, %input) return (%1557)
這便是 trace 方法的使用,其核心實現的入口便是torch.jit.trace,參數為你需要導出的 model,以及合法輸入 input,其大概原理恰如其名,便是跟蹤模型 inference 過程,將模型對輸入進行的操作逐一記錄下來,並對應到 IR 的操作,從而得到原本模型 forward 的 IR。ote:但是這種實現方式有很明顯的缺陷,PyTorch 作為動態圖網絡,會有很多的 input dependent 的控制流語句,根據輸入的不同可能會執行情況會不同(if 或者 變長的 loop),這樣就無法 trace 到完整的計算圖。如下就是一個 traceif x > 2.0: r = torch.tensor(1.0) else: r = torch.tensor(2.0) return rftrace = torch.jit.trace(test, (torch.ones(1)))y = torch.ones(1) * 5print(ftrace(y))# results: tensor(2.)# 因為輸入只走了的分支else@torch.jit.scriptdef foo(x, y): if x.max() > y.max(): r = x else: r = y return rprint(foo.graph)print(foo(torch.Tensor([0]), torch.Tensor([1])))print(foo(torch.Tensor([1]), torch.Tensor([0])))graph(%x.1 : Tensor, %y.1 : Tensor): %3 : Tensor = aten::max(%x.1) %5 : Tensor = aten::max(%y.1) # 可以看到確實捕捉到了控制語句, %6 : Tensor = aten::gt(%3, %5) %7 : bool = aten::Bool(%6) %r : Tensor = prim::If(%7) block0(): -> (%x.1) block1(): -> (%y.1) return (%r)tensor([1.])tensor([1.])
script 使用是在你需要的地方 (fuction or nn.Module (默認追蹤 forward 函數))掛載裝飾器torch.jit.script,其轉換方式跟 trace 是完全不同的思路,script 直接解析你的 PyTorch 代碼,通過語法分析解析你的邏輯為一棵語法樹,然後轉換為中間表示 IR。Note: 雖然其可以解決 trace 存在無法追蹤動態邏輯的問題,但是 Python 作為靈活度極高的語法, 想完整支持解析各種 Python 操作幾乎是不可能的,因此我們需要額外的時間熟悉哪些寫法是可以被解析的,讓我們寫代碼的體驗大打折扣。兩者結合import torchimport torch.nn as nnimport torch.nn.functional as Fclass MyModule(nn.Module): def __init__(self): super(MyModule, self).__init__() # torch.jit.trace produces a ScriptModule's conv1 and conv2 self.conv1 = torch.jit.trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16)) self.conv2 = torch.jit.trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16)) def forward(self, input): input = F.relu(self.conv1(input)) input = F.relu(self.conv2(input)) return inputscripted_module = torch.jit.script(MyModule())
1 大部分情況 model 只有 tensor operation,就直接無腦 tracing2 帶 control-flow (if-else, for-loop) 的,上 scripting3 碰上 scripting 不能 handle 的語法,要麼重寫,要麼把 tracing 和 scripting 合起來用(比如說只在有 control-flow 的代碼用 scripting,其他用 tracing)如何擴展trace 與 script 都不能轉換第三方 Python 庫中的函數,儘量所有代碼都使用 PyTorch 實現, 自定義 op 需要註冊成 jit 操作( torch 的 op 其實也註冊了),最後轉成 torchscript。TORCH_LIBRARY(my_ops, m) { m.def("warp_perspective", warp_perspective);}1 EXTENDING TORCHSCRIPT WITH CUSTOM C++ OPERATORShttps://pytorch.org/tutorials/advanced/torch_script_custom_ops.html2 IR (torchscript)的基本表示PyTorch 中的各種設計(parameter,計算節點等)在 torchscript 中是如何對應的呢?這便是轉換出的 IR 結果,torchscrip 以下結構組合。包括 FunctionSchema 方法描述,Graph 實際計算圖,GraphExecutor do the optimization and execution定義 function 的具體實現,包括 Nodes,Blocks,Values控制語句 if,loop + list of nodes# %x.1 valuegraph(%x.1 : Tensor, %y.1 : Tensor): # aten::max 就是一個Node # Tensor: Type-TensorType %3 : Tensor = aten::max(%x.1) %5 : Tensor = aten::max(%y.1) %6 : Tensor = aten::gt(%3, %5) %7 : bool = aten::Bool(%6) %r : Tensor = prim::If(%7) # Blocks block0(): -> (%x.1) block1(): -> (%y.1) return (%r)
3 導出 IR 的兩種方式,trace 與 script因為其具體實現頗為複雜,粘貼的源碼也僅僅保留了簡單 case 跑過的分支,並且省去了絕大部分細節,讀者如有需要更多細節可以自行去源碼查閱。trace 實現func, example_inputs, optimize=None, check_trace=True, check_inputs=None, check_tolerance=1e-5, strict=True, _force_outplace=False, _module_class=None, _compilation_unit=_python_cu,): # 發現是nn.Module instacene forward, 追蹤forward if isinstance(func, torch.nn.Module): return trace_module( func, {"forward": example_inputs}, None, check_trace, wrap_check_inputs(check_inputs), check_tolerance, strict, _force_outplace, _module_class, ) # 傳進來的是某個module instance的forward if ( hasattr(func, "__self__") and isinstance(func.__self__, torch.nn.Module) and func.__name__ == "forward" ): return trace_module( func.__self__, {"forward": example_inputs}, None, check_trace, wrap_check_inputs(check_inputs), check_tolerance, strict, _force_outplace, _module_class, ) # 一個查找變量名的接口 var_lookup_fn = _create_interpreter_name_lookup_fn(0) # C++ 入口 traced = torch._C._create_function_from_trace( name, func, example_inputs, var_lookup_fn, strict, _force_outplace ) # 檢查traced 與 原func是否有差異 if check_trace: if check_inputs is not None: _check_trace( check_inputs, func, traced, check_tolerance, strict, _force_outplace, False, _module_class, ) else: _check_trace( [example_inputs], func, traced, check_tolerance, strict, _force_outplace, False, _module_class, ) return traced
我們發現經過簡單的判斷,代碼便進入了 C++ 相關函數traced = torch._C._create_function_from_trace( name, func, example_inputs, var_lookup_fn, strict, _force_outplace)
std::pair<std::shared_ptr<TracingState>, Stack> trace( Stack inputs, const std::function<Stack(Stack)>& traced_fn, std::function<std::string(const Variable&)> var_name_lookup_fn, bool strict, bool force_outplace, Module* self) { try { auto state = std::make_shared<TracingState>(); # setTracingState 將state 這個實例set下來,在之後計算節點get出來insert計算過程 setTracingState(state); #state這個數據結構會在forward過程中存儲trace到的計算過程 if (self) { Value* self_value = state->graph->insertInput(0, "self")->setType( self->_ivalue()->type()); gatherParametersAndBuffers(state, self_value, *self, {"__module"}); } for (IValue& input : inputs) { input = addInput(state, input, input.type(), state->graph->addInput()); } auto graph = state->graph; # 將python中的變量名解析函數綁定下來 getTracingState()->lookup_var_name_fn = std::move(var_name_lookup_fn); getTracingState()->strict = strict; getTracingState()->force_outplace = force_outplace; # 開始forward,在計算發生時,會把計算記錄到state中 auto out_stack = traced_fn(inputs); // Exit a trace, treating 'out_stack' as the outputs of the trace. These // are the variables whose values will be computed upon subsequent // invocations of the trace. size_t i = 0; for (auto& output : out_stack) { // NB: The stack is in "reverse" order, so when we pass the diagnostic // number we need to flip it based on size. state->graph->registerOutput( state->getOutput(output, out_stack.size() - i)); i++; } setTracingState(nullptr); if (getInlineEverythingMode()) { Inline(*graph); } FixupTraceScopeBlocks(graph, self); NormalizeOps(graph); return {state, out_stack}; } catch (...) { tracer::abandon(); throw; }}
那麼具體記錄 operation 的過程發生在哪裡呢?pytorch/torch/csrc/jit/runtime/register_c10_ops.cpphttps://github.com/pytorch/pytorch/blob/4e976b9334acbcaa015a27d56540cd2115c2639b/torch/csrc/jit/runtime/register_c10_ops.cpp#L30Operator createOperatorFromC10_withTracingHandledHere( const c10::OperatorHandle& op) { return Operator(op, [op](Stack& stack) { const auto input_size = op.schema().arguments().size(); const auto output_size = op.schema().returns().size(); Node* node = nullptr; std::shared_ptr<jit::tracer::TracingState> tracer_state; // trace the input before unwrapping, otherwise we may lose // the input information if (jit::tracer::isTracing()) { # 獲取 tracer_state tracer_state = jit::tracer::getTracingState(); auto symbol = Symbol::fromQualString(op.schema().name()); const auto& graph = tracer::getTracingState()->graph; node = graph->create(symbol, 0); tracer::recordSourceLocation(node); const auto& args = op.schema().arguments(); int i = 0; # 記錄args for (auto iter = stack.end() - input_size; iter != stack.end(); ++iter, ++i) { // TODO we need to refactor graph APIs (e.g., addInputs) // appropriately; after that, we can get rid of the giant if-else // block we will clean this tech debt together in the following PRs auto type = args[i].type(); if (type->kind() == TypeKind::OptionalType) { if (iter->isNone()) { Value* none = graph->insertNode(graph->createNone())->output(); node->addInput(none); continue; } else { type = type->expect<OptionalType>()->getElementType(); } } if (type->isSubtypeOf(TensorType::get())) { AT_ASSERT(iter->isTensor()); tracer::addInputs(node, args[i].name().c_str(), iter->toTensor()); } else if (type->kind() == TypeKind::FloatType) { AT_ASSERT(iter->isDouble()); tracer::addInputs(node, args[i].name().c_str(), iter->toDouble()); } else if (type->kind() == TypeKind::IntType) { AT_ASSERT(iter->isInt()); tracer::addInputs(node, args[i].name().c_str(), iter->toInt()); } else if (type->kind() == TypeKind::BoolType) { AT_ASSERT(iter->isBool()); tracer::addInputs(node, args[i].name().c_str(), iter->toBool()); } else if (type->kind() == TypeKind::StringType) { AT_ASSERT(iter->isString()); tracer::addInputs(node, args[i].name().c_str(), iter->toStringRef()); } else if (type->kind() == TypeKind::NumberType) { tracer::addInputs(node, args[i].name().c_str(), iter->toScalar()); } else if (type->kind() == TypeKind::ListType) { const auto& elem_type = type->expect<ListType>()->getElementType(); if (elem_type->isSubtypeOf(TensorType::get())) { AT_ASSERT(iter->isTensorList()); auto list = iter->toTensorVector(); tracer::addInputs(node, args[i].name().c_str(), list); } else if (elem_type->kind() == TypeKind::FloatType) { AT_ASSERT(iter->isDoubleList()); // NB: now, tracer doesn't support tracing double list. We add // special handling here, since in our case, we assume that all the // doubles in the list are constants auto value = iter->toDoubleVector(); std::vector<Value*> info(value.size()); for (size_t value_index = 0; value_index < value.size(); ++value_index) { info[value_index] = graph->insertConstant(value[value_index]); tracer::recordSourceLocation(info[value_index]->node()); } node->addInput( graph ->insertNode(graph->createList(jit::FloatType::get(), info)) ->output()); } else if (elem_type->kind() == TypeKind::IntType) { AT_ASSERT(iter->isIntList()); tracer::addInputs( node, args[i].name().c_str(), iter->toIntVector()); } else if (elem_type->kind() == TypeKind::BoolType) { AT_ASSERT(iter->isBoolList()); tracer::addInputs( node, args[i].name().c_str(), iter->toBoolList().vec()); } else { throw std::runtime_error( "unsupported input list type: " + elem_type->str()); } } else if (iter->isObject()) { tracer::addInputs(node, args[i].name().c_str(), iter->toObject()); } else { throw std::runtime_error("unsupported input type: " + type->str()); } } # node嵌入graph graph->insertNode(node); jit::tracer::setTracingState(nullptr); }
可以看到,在具體運算發生時,會使用 getTracingState() 得到 forward 開始去創建的 state,然後看到根據 op.schema().name() 得到計算類型(比如相加),根據計算類型通過 createNone 方法創建一個計算節點,然後創建計算輸入,最後把計算 node insert 到 graph 中,完成一次對計算的記錄。script因為 script 得到 IR 的方式是解析源碼,因此對於不同的代碼形式會略有不同(函數,class,nn.Module的instance):1 Python 函數 簡化後 codedef script(obj, optimize=None, _frames_up=0, _rcb=None): # fucntion 分支 if hasattr(obj, "__script_if_tracing_wrapper"): obj = obj.__original_fn _rcb = _jit_internal.createResolutionCallbackFromClosure(obj) # 檢查重載 _check_directly_compile_overloaded(obj) # 是否之前被script過了 maybe_already_compiled_fn = _try_get_jit_cached_function(obj) if maybe_already_compiled_fn: return maybe_already_compiled_fn # 得到ast語法樹 ast = get_jit_def(obj, obj.__name__) if _rcb is None: _rcb = _jit_internal.createResolutionCallbackFromClosure(obj) #c++ 入口,根據ast得到ir fn = torch._C._jit_script_compile( qualified_name, ast, _rcb, get_default_args(obj) ) # Forward docstrings fn.__doc__ = obj.__doc__ # cache起來 _set_jit_function_cache(obj, fn) return fn
我們看下 get_jit_def(https://github.com/pytorch/pytorch/blob/58eb23378f2a376565a66ac32c93a316c45b6131/torch/jit/frontend.py#L225) 是如何得到 jit 規定的 ast 語法樹的def get_jit_def(fn, def_name, self_name=None): # 得到源代碼的一些信息 sourcelines, file_lineno, filename = get_source_lines_and_file(fn, torch._C.ErrorReport.call_stack()) sourcelines = normalize_source_lines(sourcelines) source = dedent_src ''.join(sourcelines) # dedent_src 為包含了要script函數的字符串 dedent_src = dedent(source) # 調用python ast包將字符串解析為Python的ast py_ast = ast.parse(dedent_src) # 得到python類型注釋 type_line = torch.jit.annotations.get_type_line(source) #ctx中包含了函數所有原信息 ctx = SourceContext(source, filename, file_lineno, leading_whitespace_len, True) fn_def = py_ast.body[0] # build_def將python 的ast 轉化為torchjit 使用的ast格式 return build_def(ctx, fn_def, type_line, def_name, self_name=self_name)
用一個簡單的例子給大家解釋下 py_ast.body[0] 是什麼import ast... func_def= \... """def test(a):... a = a + 2... return a + 1"""... results = ast.parse(func_def)
可見,ast.body 是一個 list,其長度等於解析的 string 中包含的函數的個數,我們看第一個元素,其中 value 是一個Binop具體為一個Add,left 是Name類型,id為 `a,right是Num,也就是2,這個Binop即解析的a = a + 2。因為我們 get_source_lines_and_file 返回的一定是一個 single top-level function, 因此我們直接取用第 0 個元素,即 py_ast.body[0] 就可以了。接下來看build_def是如何將 Python 的 ast 轉化為自己需要的 ast 的。def build_def(ctx, py_def, type_line, def_name, self_name=None): .... return Def(Ident(r, def_name), decl, build_stmts(ctx, body))
因為ctx 包含 source code 所有信息, body 是 Python ast 解析結果,那麼build_stmts中應該包含我們想要的答案。我們用例子中a+2為例看會怎麼轉換,這部分可見 frontend.pyhttps://github.com/pytorch/pytorch/blob/58eb23378f2a376565a66ac32c93a316c45b6131/torch/jit/frontend.py#L528from torch._C._jit_tree_views import ( ClassDef, Ident, Stmt, Decl, Def, Var, EmptyTypeAnnotation, Param, ExprStmt, Assign, Delete, Return, Raise, Assert, AugAssign, While, For, If, Pass, Break, Continue, Apply, Dots, Select, TrueLiteral, FalseLiteral, NoneLiteral, Starred, ListLiteral, TupleLiteral, DictLiteral, Const, StringLiteral, ListComp, Attribute, BinOp, UnaryOp, SliceExpr, Subscript, TernaryIf, With, WithItem, Property, DictComp,)# jit中定義的ast基本結構def build_stmts(ctx, stmts): #發現其調用了`build_stmt` stmts = [build_stmt(ctx, s) for s in stmts] return list(filter(None, stmts))#`build_stmt` 是一個StmtBuilder()的instancebuild_stmt = StmtBuilder()build_expr = ExprBuilder()class Builder(object): def __call__(self, ctx, node): # 可見會根據解析出的ast的類型返回相應的build方法,從截圖可以看到`a+2`是一個`Assign`類型 # 因此會調用build_Assign method = getattr(self, 'build_' + node.__class__.__name__, None) if method is None: raise UnsupportedNodeError(ctx, node) return method(ctx, node)class StmtBuilder(Builder): @staticmethod def build_Assign(ctx, stmt): # 截圖可以看到stmt.value是一個Binop # build_expr是ExprBuilder的INSTANCE,其會調用`build_BinOp` rhs = build_expr(ctx, stmt.value) lhs = [build_expr(ctx, x) for x in stmt.targets] return Assign(lhs, rhs) @staticmethod def build_Expr(ctx, stmt): # Binop value = stmt.value if value.__class__.__name__ == 'Str': # If a statement is a string literal expression, # then it is a docstring. Just ignore it. return None else: return ExprStmt(build_expr(ctx, value)) class ExprBuilder(Builder): binop_map = { ast.Add: '+', ast.Sub: '-', ast.Mult: '*', ast.Div: '/', ast.Pow: '**', ast.Mod: '%', ast.FloorDiv: '//', ast.BitAnd: '&', ast.BitXor: '^', ast.BitOr: '|', ast.LShift: '<<', ast.RShift: '>>', } @staticmethod def build_BinOp(ctx, expr): #expr.left是個`Name`調用build_Name lhs = build_expr(ctx, expr.left) rhs = build_expr(ctx, expr.right) op = type(expr.op) # 轉化為約定的代表運算類型的string 符號 op_token = ExprBuilder.binop_map.get(op) return BinOp(op_token, lhs, rhs)
最終轉化為的格式,類似於 S-expression.(https://en.wikipedia.org/wiki/S-expression)(def (ident test) (decl (list (param (ident a) (option) (option) (False))) (option)) (list (assign (list (variable (ident a))) (option (+ (variable (ident a)) (const 2))) (option)) (return (+ (variable (ident a)) (const 1)))))
好的,我們已經得到得到jit約定的 AST 樹了,接下來我們要進入 torch._C._jit_script_compile查看如何將這樣的 ast 樹轉化為 IR.C++ 入口為 script_compile_functionstatic StrongFunctionPtr script_compile_function( const c10::QualifiedName& name, const Def& def, const FunctionDefaults& defaults, const ResolutionCallback& rcb) { # def 中包含ast,跟着它就能找到答案 auto cu = get_python_cu(); #看來是get_python_cu這個類中的define函數完成的 auto defined_functions = cu->define( QualifiedName(name.prefix()), /*properties=*/{}, /*propResolvers=*/{}, {def}, {pythonResolver(rcb)}, nullptr, true); TORCH_INTERNAL_ASSERT(defined_functions.size() == 1); auto& defined = defined_functions[0]; defined->setSchema(getSchemaWithNameAndDefaults( def.range(), defined->getSchema(), def.name().name(), defaults)); StrongFunctionPtr ret(std::move(cu), defined); didFinishEmitFunction(ret); return ret;}# 發現只是wapper了下CompilationUnitinline std::shared_ptr<CompilationUnit> get_python_cu() { return py::module::import("torch.jit._state") .attr("_python_cu") .cast<std::shared_ptr<CompilationUnit>>();}#關於compilation_unit#/torch/csrc/jit/api/compilation_unit.h // for historic reasons, these are defined in ir_emitter.cpp // Returns the list of Functions just defined. std::vector<Function*> define( const c10::optional<c10::QualifiedName>& prefix, const std::vector<Property>& properties, const std::vector<ResolverPtr>& propResolvers, const std::vector<Def>& definitions, const std::vector<ResolverPtr>& defResolvers, /* determines how we handle free variables in each definition*/ // if non-null, the first argument to each def, is bound to this value const Self* self, // see [name mangling] bool shouldMangle = false);#實現在torch/csrc/jit/frontend/ir_emitter.cppstd::unique_ptr<Function> CompilationUnit::define( const c10::optional<QualifiedName>& prefix, const Def& def, const ResolverPtr& resolver, const Self* self, const std::unordered_map<std::string, Function*>& function_table, bool shouldMangle) const { auto _resolver = resolver; ..... auto creator = [def, _resolver, self](Function& method) { .... ##核心代碼to_ir to_ir(def, _resolver, self, method); }; auto fn = torch::make_unique<GraphFunction>( std::move(name), std::make_shared<Graph>(), creator); return fn;}
我們跟隨 def,找到了一個轉化為 IR 的關鍵的struct to_ir ,其輸入中有 def,也就是 ast,_resolver 是 Python 中傳過來的解析名字的函數,我們可以在內部找到關鍵部分to_ir( const Def& def, ResolverPtr resolver_, const Self* self, Function& method) // method being constructed : method(method), graph(method.graph()), resolver(std::move(resolver_)), typeParser_(resolver), environment_stack(nullptr) { AT_ASSERT(resolver); pushFrame(graph->block(), /*starts_def=*/true); #emitDef 中會調用emitStatements method.setSchema(emitDef(def, self, graph->block())); ConvertToSSA(graph); CanonicalizeModifiedLoops(graph); NormalizeOps(graph); runCleanupPasses(graph); }private: #在to_ir 的private中我們可以看到Graph Function這些我們之前介紹的IR的組成部分 Function& method; std::shared_ptr<Graph> graph; ResolverPtr resolver; std::unordered_map<int64_t, Value*> integral_constants; #emitDef 中會調用emitStatements FunctionSchema emitDef(const Def& def, const Self* self, Block* block) { ...... // body auto stmts_list = def.statements(); emitStatements(stmts_list.begin(), stmts_list.end()); ........ } void emitStatements( List<Stmt>::const_iterator begin, List<Stmt>::const_iterator end) { for (; begin != end; ++begin) { auto stmt = *begin; ErrorReport::CallStack::update_pending_range(stmt.range()); switch (stmt.kind()) { case TK_IF: emitIf(If(stmt)); break; case TK_WHILE: emitWhile(While(stmt)); break; case TK_FOR: emitFor(For(stmt)); break; case TK_ASSIGN: emitAssignment(Assign(stmt)); ................. break; default: throw ErrorReport(stmt) << "Unrecognized statement kind " << kindToString(stmt.kind()); } // Found an exit statement in this block. The remaining statements aren't // reachable so we don't emit them. if (exit_blocks.count(environment_stack->block())) return; } }我們可以看到根據stmt.kind(),會進入而各種emit裡面,其中一定可以找到graph->insertNode(graph->create(.....));類似的操作,對應我們建立IR graph
以上是我們以一個 function 為例子,接下來我們以 script 一個 module 為例,其有一些獨有的挑戰,因為有一些變量的指代,是需要初始化後才知道的,同時,我們希望 script 完的 module 對外還能保持一樣的接口,即可以正常訪問原有 module 的屬性,那麼應該怎麼做呢?在 module 原有的 init 結束後隨即開始完整的 script forward 函數,替換涉及到的所有函數為 script 後的函數如何在一個類的 init 函數後面綁定行為呢,我們想到 metaclass,torch.jit 實現了 ScriptMeta這個 metaclass。class MyModule(torch.jit.ScriptModule): @torch.jit.script_method def f(self.x): return x * x @torch.jit.script_method def forward(self, x): return x + self.f(x)關於script_methoddef script_method(fn): _rcb = _jit_internal.createResolutionCallbackFromFrame(frames_up=2) ast = get_jit_def(fn, fn.__name__, self_name="ScriptModule") #暫時沒有script,只是返回包含ast的nametuple return ScriptMethodStub(_rcb, ast, fn) ScriptMethodStub = collections.namedtuple('ScriptMethodStub', ('resolution_callback', 'def_', 'original_method'))1 移除所有script_method屬性被(@script_method修飾的方法),確保訪問到的是script function2 修改module的_init_,確保module的self.param或者self.module初始化後立即編譯所有的script_method,從而生成的instance的forward已經被替換class ScriptMeta(type): def __init__(cls, name, bases, attrs): # noqa: B902 # cls ScriptMeta的instance,是一個類如ScriptModule cls._methods: Dict[str, Any] = {} cls._constants_set = set(getattr(cls, "__constants__", ())) for base in reversed(bases): # 還記得嗎trace的module也是有一個_methods的屬性 for k, v in getattr(base, "_methods", {}).items(): cls._methods[k] = v base_constants = getattr(base, "_constants_set", set()) cls._constants_set = cls._constants_set.union(base_constants) # 找到現在所有被@script_method修飾的方法,放到_method,並刪除原有attr # init後之後統一script for k, v in sorted(attrs.items()): if isinstance(v, ScriptMethodStub): delattr(cls, k) cls._methods[v.original_method.__name__] = v original_init = getattr(cls, "__init__", lambda self: None) # 此處實現了init結束後,調用create_script_module進行script @functools.wraps(original_init) def init_then_script(self, *args, **kwargs): # 此處的self為instance num_methods = len(cls._methods) original_init(self, *args, **kwargs) added_methods_in_init = len(cls._methods) > num_methods if type(self) == cls: # 選取需要script的method def make_stubs(module): cls = type(module) if hasattr(cls, "_methods"): return [v for k, v in sorted(cls._methods.items())] else: # infer_methods_to_compile 是一個選取要script函數的函數 return infer_methods_to_compile(module) # 講所有script_method一塊編譯為_actual_script_module屬性 self.__dict__[ "_actual_script_module" ] = torch.jit._recursive.create_script_module(self, make_stubs, share_types=not added_methods_in_init) # Delete the Python attributes that now shadow the ScriptModule # ones, so that __getattr__ and __setattr__ will properly find # the scripted versions. concrete_type = self._actual_script_module._concrete_type for name in concrete_type.get_attributes(): delattr(self, name) for name, _ in concrete_type.get_modules(): delattr(self, name) for name in ("_parameters", "_buffers", "_modules"): delattr(self, name) cls.__init__ = init_then_script # type: ignore return super(ScriptMeta, cls).__init__(name, bases, attrs) class _CachedForward(object): def __get__(self, obj, cls): return self.__getattr__("forward") # type: ignore class ScriptModule(with_metaclass(ScriptMeta, Module)): # type: ignore def __init__(self): super(ScriptModule, self).__init__() forward = _CachedForward() # 想訪問module的attr,返回_actual_script_module的attr def __getattr__(self, attr): if "_actual_script_module" not in self.__dict__: return super(ScriptModule, self).__getattr__(attr) return getattr(self._actual_script_module, attr) def __setattr__(self, attr, value): if "_actual_script_module" not in self.__dict__: # Unwrap torch.jit.Attribute into a regular setattr + recording # the provided type in __annotations__. # # This ensures that if we use the attr again in `__init__`, it # will look like the actual value, not an instance of Attribute. if isinstance(value, Attribute): if "__annotations__" not in self.__class__.__dict__: self.__class__.__annotations__ = {} self.__annotations__[attr] = value.type value = value.value return super(ScriptModule, self).__setattr__(attr, value) setattr(self._actual_script_module, attr, value)...關於 create_script_module 函數會 script method 然後返回一個 RecursiveScriptModule,但是其邏輯較為複雜,在此不再展開。關於 getattribute vs getattr當訪問某個實例屬性時,getattribute 會被無條件調用,當這個屬性不存在,則會調用 getattr,如未實現自己的 getattr 方法,會拋出 AttributeError 提示找不到這個屬性,如果自定義了自己 getattr 方法的話方法會在這種找不到屬性的情況下被調用。4 IR優化的簡單介紹jit 一般涉及如下優化: loop unrolling peephole optimization constant propagation DCE fusion inlining... 我們看如下例子:def test(x): # Dead code Elimination for i in range(1000): y = x + 1 for i in range(100): #peephole optimization x = x.t() x = x.t() return x.sum()opt_test = torch.jit.script(test)s = time()inputs = torch.ones(4,4).cuda()s = time()for i in range(10000): test(inputs)print(time()-s)# 95ss = time()for i in range(10000): opt_test(inputs)print(time()-s)# 0.13sprint(opt_test.graph)print(opt_test.graph_for(inputs))95.138237953186040.13010907173156738graph(%x.1 : Tensor): %22 : None = prim::Constant() %13 : bool = prim::Constant[value=1]() # /home/SENSETIME/zhangshilong/PycharmProjects/pythonProject/opt.py:10:4 %10 : int = prim::Constant[value=100]() # /home/SENSETIME/zhangshilong/PycharmProjects/pythonProject/opt.py:10:19 %x : Tensor = prim::Loop(%10, %13, %x.1) # /home/SENSETIME/zhangshilong/PycharmProjects/pythonProject/opt.py:10:4 block0(%i : int, %x.10 : Tensor): %x.4 : Tensor = aten::t(%x.10) # /home/SENSETIME/zhangshilong/PycharmProjects/pythonProject/opt.py:11:12 %x.7 : Tensor = aten::t(%x.4) # /home/SENSETIME/zhangshilong/PycharmProjects/pythonProject/opt.py:12:12 -> (%13, %x.7) %23 : Tensor = aten::sum(%x, %22) # /home/SENSETIME/zhangshilong/PycharmProjects/pythonProject/opt.py:13:11 return (%23)graph(%x.1 : Tensor): %1 : None = prim::Constant() %2 : Tensor = aten::sum(%x.1, %1) # /home/SENSETIME/zhangshilong/PycharmProjects/pythonProject/opt.py:13:11 return (%2)
關於 IR 計算圖優化IR 的 Method 中內置 GraphExecutor object,創建於第一次執行的時候,負責優化。文件 pytorch-master/torch/csrc/jit/api/method.h scritp_method 的 C++ 原型里GraphExecutor& get_executor() { return function_->get_executor(); }
GraphExecutor 的定義在/torch/csrc/jit/runtime/graph_executor.cpp,可見其由 graph 產生,定義了 run 方法執行GraphExecutor::GraphExecutor( const std::shared_ptr<Graph>& graph, std::string function_name) : pImpl( IsNewExecutorEnabled() ? dynamic_cast<GraphExecutorImplBase*>( new ProfilingGraphExecutorImpl( graph, std::move(function_name))) : dynamic_cast<GraphExecutorImplBase*>( new GraphExecutorImpl(graph, std::move(function_name)))) {}std::shared_ptr<Graph> GraphExecutor::graph() const { return pImpl->graph;}const ExecutionPlan& GraphExecutor::getPlanFor( Stack& inputs, size_t remaining_bailout_depth) { return pImpl->getPlanFor(inputs, remaining_bailout_depth);} std::shared_ptr<GraphExecutorImplBase> pImpl;.....關於GraphExecutorImplBase,/torch/csrc/jit/runtime/graph_executor.cppconst ExecutionPlan& getOrCompile(const Stack& stack) { ..... auto plan = compileSpec(spec); } }# compileSpec 會返回一個planExecutionPlan compileSpec(const ArgumentSpec& spec) { auto opt_graph = graph->copy(); GRAPH_DUMP("Optimizing the following function:", opt_graph); arg_spec_creator_.specializeTypes(*opt_graph, spec); // Phase 0. Inline functions, then clean up any artifacts that the inliner // left in that may inhibit optimization ..... runRequiredPasses(opt_graph); GRAPH_DEBUG( "After runRequiredPasses, before ConstantPropagation\n", *opt_graph); // Phase 2. Propagate detailed information about the spec through the // graph (enabled more specializations in later passes). // Shape propagation sometimes depends on certain arguments being // constants, and constant propagation doesn't need shape // information anyway, so it's better to run it first. ConstantPropagation(opt_graph); GRAPH_DEBUG( "After ConstantPropagation, before PropagateInputShapes\n", *opt_graph); PropagateInputShapes(opt_graph); GRAPH_DEBUG( "After PropagateInputShapes, before PropagateRequiresGrad\n", *opt_graph); PropagateRequiresGrad(opt_graph); GRAPH_DEBUG( "After PropagateRequiresGrad, before runOptimization\n", *opt_graph); // Phase 3. Run differentiable optimizations (i.e. simple graph rewrites // that we can still execute using autograd). runOptimization(opt_graph); .....各種優化 return ExecutionPlan(opt_graph, function_name_); }
這些優化在 torch/csrc/jit/passes/ 文件夾 torch/csrc/jit/passes/dead_code_elimination.cpp /torch/csrc/jit/passes/fuse_linear.cpp torch/csrc/jit/passes/remove_dropout.cpp torch/csrc/jit/passes/fold_conv_bn.cpp
參考:1. INTRODUCTION TO TORCHSCRIPT(https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html)2. PyTorch 部署_TorchScript(https://zhuanlan.zhihu.com/p/135911580)3. pytorch_wiki(https://github.com/pytorch/pytorch/wiki)4. PyTorch-JIT-Source-Code-Read-Note(https://zasdfgbnm.github.io/2018/09/20/PyTorch-JIT-Source-Code-Read-Note/)5. Abstract_syntax_tree(https://en.wikipedia.org/wiki/Abstract_syntax_tree)作者:因本人卑微的算法調參俠一枚,對於部署了解不深。如有紕漏,望評論區不吝指正。公眾號後台回復「目標檢測綜述」獲取目標檢測二十年綜述下載~
數據集資源匯總:90+深度學習開源數據集整理|包括目標檢測、工業缺陷、圖像分割等多個方向實操教程:實操教程|Pytorch轉ONNX詳解|一文解決樣本不均衡(全)CVPR 2022:CVPR'22 最新132篇論文分方向整理|CVPR'22 最新106篇論文分方向整理備註:姓名-學校/公司-研究方向-城市(如:小極-北大-目標檢測-深圳)
即可申請加入極市目標檢測/圖像分割/工業檢測/人臉/醫學影像/3D/SLAM/自動駕駛/超分辨率/姿態估計/ReID/GAN/圖像增強/OCR/視頻理解等技術交流群
每月大咖直播分享、真實項目需求對接、求職內推、算法競賽、乾貨資訊匯總、與10000+來自港科大、北大、清華、中科院、CMU、騰訊、百度等名校名企視覺開發者互動交流~
覺得有用麻煩給個在看啦~