diff --git a/proto/net.proto b/proto/net.proto index 87ee8a1..71c329b 100644 --- a/proto/net.proto +++ b/proto/net.proto @@ -91,6 +91,11 @@ message Weights { optional Layer ip1_mov_b = 14; optional Layer ip2_mov_w = 15; optional Layer ip2_mov_b = 16; + + // Policy Uncertainty head + // Currently only supports convolutional form. + optional ConvBlock unc1 = 17; + optional ConvBlock unc = 18; } message TrainingParams { @@ -133,7 +138,6 @@ message NetworkFormat { // Networks with PolicyFormat and ValueFormat specified NETWORK_CLASSICAL_WITH_HEADFORMAT = 3; NETWORK_SE_WITH_HEADFORMAT = 4; - NETWORK_ONNX = 5; } optional NetworkStructure network = 3; @@ -160,6 +164,14 @@ message NetworkFormat { MOVES_LEFT_V1 = 1; } optional MovesLeftFormat moves_left = 6; + + // Policy Uncertainity head architecture + enum PolicyUncFormat { + POLICY_UNC_NONE = 0; + // TODO: POLICY_UNC_CLASSICAL = 1; + POLICY_UNC_CONVOLUTION = 2; + } + optional PolicyUncFormat unc = 7; } message Format { @@ -174,34 +186,11 @@ message Format { optional NetworkFormat network_format = 2; } -message OnnxModel { - enum DataType { - UNKNOWN_DATATYPE = 0; - FLOAT = 1; - FLOAT16 = 10; - BFLOAT16 = 16; - } - - // Serialized OnnxProto model. - optional bytes model = 1; - optional DataType data_type = 2; - // Name of the input tensor to populate. - optional string input_planes = 3; - // Names of the output tensors to get results from. - // If some feature is not present, corresponding values are not set. - optional string output_value = 4; - optional string output_wdl = 5; - optional string output_policy = 6; - optional string output_mlh = 7; -} - message Net { optional fixed32 magic = 1; optional string license = 2; optional EngineVersion min_version = 3; optional Format format = 4; optional TrainingParams training_params = 5; - // Either weights or onnx_model is set, but not both. optional Weights weights = 10; - optional OnnxModel onnx_model = 11; }