Conslusion first
Good news:
The ONNX can defines the loss and optimizer now within its format. However, current loss only have
NegativeLogLikelihoodLoss
andSoftmaxCrossEntropyLoss
. Also, it only can store optimizers, only haveAdagrad
,Adam
,Momentum
(SGD with standard momentum).
Bad news:
we need to update the onnx to 1.7, which is released last week, may not be so stable. In this release, ONNX defines a comlicated node called
GraphCall
to specify which gradients should be computed and how to update the tensors by using these gradients. Since we will update the weights following the backward, so this part may not be useful for us.
ONNX Training Preview (TrainingInfoProto)
In last week, the ONNX team has released a new version 1.7.0 which upgrade its opset version to 12. In this new rleases, they add a new feature called TrainingInfoProto
.
This new feature defines something about training information. There are two main parts in it, initialization-step
and training-algorithm-step
.
initialization-step
initialization-step
means the developer can defines a initialization
. For its type, the initialization
is a formal ONNX graph. It doesn’t have input but seveal outputs. The developer can defines some nodes in this graph, such as RandomNormal
or RandomUniform
, and in another field called initialization_binding
, the developer can assign these outputs to the specific tensors in the inference graph.
The current supported ramdom methods are: RandomNormal
or RandomUniform
.
training-algorithm-step
training-algorithm-step
defines a field called algorithm
. It defines a inference graph which represents a training algorithm’s step. Given required inputs, it computes outputs to update tensors in its own or in the main computaton graph. update_binding
contains a key-value pair of strings to assign the outputs to some specific tensors.
In general, this graph contains loss node, gradient node, optimizer node, increment of iteration count, and some calls to the inference graph. The field algorithm.node is the only place the user can use GraphCall operator.
Loss node
NegativeLogLikelihoodLoss
SoftmaxCrossEntropyLoss
Optimizer node
Adagrad
Adam
Momentum
: SG with standard momentum
Gradient node
The gradient node actually only defines the necessary information to compute the gradient for all graph, for example, at the following graph, the gradient defines its inputs containing the xs
(intermidate weights) and zs
(input of the graph), and y
(the output of the graph), and its outputs having dY/dW
, dY/dZ
whose order corresponds to the inputs in xs
.
It doesn’t defines any logic about how to compute the dY/dW
, dY/dZ
.
1 | W --> Conv --> H --> Gemm --> Y |
GraphCall node
The GraphCall operator invokes a graph inside TrainingInfoProto’s algorithm field. The GraphCall inputs and outputs are bound to those of invoked graph by position.
Based on the above inference graph, the GraphCall can use like this:
1 | .-------- W (a global and mutable variable from |
The previous section’s inference graph is called by GraphCall(graph_name="MyInferenceGraph")
, and it uses a new batch of inputs (X_1
, Z_1
) to compute Y_1
.
Gradient
defines the graidents the graph should compute, finally, it gets W_new
amd T_new
.
The it uses the following update_binding
to udpate the tensors:
1 | update_binding: {"W": "W_new", "T": "T_new"} |
API
1 | # handle ONNX |
save and load: how to load onnx model, and how to save loss and sgd