ref: b5c811738513ce3e53d2cfcefc3a863bd7f21701
parent: 088f2d7992396de22fff27b5ee6cc22a2abc4404
parent: 64abf8daf0d1e5900091002802a3b8efb90f0e75
author: Hui Su <huisu@google.com>
date: Tue Feb 5 17:20:10 EST 2019
Merge "Improve the partition split prediction model"
--- a/vp9/encoder/vp9_encodeframe.c
+++ b/vp9/encoder/vp9_encodeframe.c
@@ -3481,9 +3481,9 @@
}
// Use a neural net model to prune partition-none and partition-split search.
-// The model uses prediction residue variance and quantization step size as
-// input features.
-#define FEATURES 6
+// Features used: QP; spatial block size contexts; variance of prediction
+// residue after simple_motion_search.
+#define FEATURES 12
static void ml_predict_var_rd_paritioning(const VP9_COMP *const cpi,
MACROBLOCK *const x,
PC_TREE *const pc_tree,
@@ -3502,28 +3502,27 @@
uint8_t *const pred_buf = pred_buffer;
#endif // CONFIG_VP9_HIGHBITDEPTH
const int speed = cpi->oxcf.speed;
- int i;
float thresh = 0.0f;
switch (bsize) {
case BLOCK_64X64:
- nn_config = &vp9_var_rd_part_nnconfig_64;
- thresh = speed > 0 ? 3.5f : 3.0f;
+ nn_config = &vp9_part_split_nnconfig_64;
+ thresh = speed > 0 ? 2.8f : 3.0f;
break;
case BLOCK_32X32:
- nn_config = &vp9_var_rd_part_nnconfig_32;
+ nn_config = &vp9_part_split_nnconfig_32;
thresh = speed > 0 ? 3.5f : 3.0f;
break;
case BLOCK_16X16:
- nn_config = &vp9_var_rd_part_nnconfig_16;
- thresh = speed > 0 ? 3.5f : 4.0f;
+ nn_config = &vp9_part_split_nnconfig_16;
+ thresh = speed > 0 ? 3.8f : 4.0f;
break;
case BLOCK_8X8:
- nn_config = &vp9_var_rd_part_nnconfig_8;
+ nn_config = &vp9_part_split_nnconfig_8;
if (cm->width >= 720 && cm->height >= 720)
thresh = speed > 0 ? 2.5f : 2.0f;
else
- thresh = speed > 0 ? 3.5f : 2.0f;
+ thresh = speed > 0 ? 3.8f : 2.0f;
break;
default: assert(0 && "Unexpected block size."); return;
}
@@ -3542,6 +3541,7 @@
ref_mv.row = ref_mv.col = 0;
else
ref_mv = pc_tree->mv;
+ vp9_setup_src_planes(x, cpi->Source, mi_row, mi_col);
simple_motion_search(cpi, x, bsize, mi_row, mi_col, ref_mv, ref, pred_buf);
pc_tree->mv = x->e_mbd.mi[0]->mv[0].as_mv;
}
@@ -3560,8 +3560,8 @@
float score;
// Generate model input features.
- features[feature_idx++] = logf((float)(dc_q * dc_q) / 256.0f + 1.0f);
- vp9_setup_src_planes(x, cpi->Source, mi_row, mi_col);
+ features[feature_idx++] = logf((float)dc_q + 1.0f);
+
// Get the variance of the residue as input features.
{
const int bs = 4 * num_4x4_blocks_wide_lookup[bsize];
@@ -3575,7 +3575,19 @@
const unsigned int var =
cpi->fn_ptr[bsize].vf(src, src_stride, pred, pred_stride, &sse);
const float factor = (var == 0) ? 1.0f : (1.0f / (float)var);
+ const MACROBLOCKD *const xd = &x->e_mbd;
+ const int has_above = !!xd->above_mi;
+ const int has_left = !!xd->left_mi;
+ const BLOCK_SIZE above_bsize = has_above ? xd->above_mi->sb_type : bsize;
+ const BLOCK_SIZE left_bsize = has_left ? xd->left_mi->sb_type : bsize;
+ int i;
+ features[feature_idx++] = (float)has_above;
+ features[feature_idx++] = (float)b_width_log2_lookup[above_bsize];
+ features[feature_idx++] = (float)b_height_log2_lookup[above_bsize];
+ features[feature_idx++] = (float)has_left;
+ features[feature_idx++] = (float)b_width_log2_lookup[left_bsize];
+ features[feature_idx++] = (float)b_height_log2_lookup[left_bsize];
features[feature_idx++] = logf((float)var + 1.0f);
for (i = 0; i < 4; ++i) {
const int x_idx = (i & 1) * bs / 2;
@@ -3604,7 +3616,6 @@
}
}
#undef FEATURES
-#undef LABELS
static int get_rdmult_delta(VP9_COMP *cpi, BLOCK_SIZE bsize, int mi_row,
int mi_col, int orig_rdmult) {
--- a/vp9/encoder/vp9_partition_models.h
+++ b/vp9/encoder/vp9_partition_models.h
@@ -966,175 +966,209 @@
#undef FEATURES
#endif // CONFIG_ML_VAR_PARTITION
-#define FEATURES 6
+#define FEATURES 12
#define LABELS 1
-static const float vp9_var_rd_part_nn_weights_64_layer0[FEATURES * 8] = {
- -0.100129f, 0.128867f, -1.375086f, -2.268096f, -1.470368f, -2.296274f,
- 0.034445f, -0.062993f, -2.151904f, 0.523215f, 1.611269f, 1.530051f,
- 0.418182f, -1.330239f, 0.828388f, 0.386546f, -0.026188f, -0.055459f,
- -0.474437f, 0.861295f, -2.208743f, -0.652991f, -2.985873f, -1.728956f,
- 0.388052f, -0.420720f, 2.015495f, 1.280342f, 3.040914f, 1.760749f,
- -0.009062f, 0.009623f, 1.579270f, -2.012891f, 1.629662f, -1.796016f,
- -0.279782f, -0.288359f, 1.875618f, 1.639855f, 0.903020f, 0.906438f,
- 0.553394f, -1.621589f, 0.185063f, 0.605207f, -0.133560f, 0.588689f,
+#define NODES 8
+static const float vp9_part_split_nn_weights_64_layer0[FEATURES * NODES] = {
+ -0.609728f, -0.409099f, -0.472449f, 0.183769f, -0.457740f, 0.081089f,
+ 0.171003f, 0.578696f, -0.019043f, -0.856142f, 0.557369f, -1.779424f,
+ -0.274044f, -0.320632f, -0.392531f, -0.359462f, -0.404106f, -0.288357f,
+ 0.200620f, 0.038013f, -0.430093f, 0.235083f, -0.487442f, 0.424814f,
+ -0.232758f, -0.442943f, 0.229397f, -0.540301f, -0.648421f, -0.649747f,
+ -0.171638f, 0.603824f, 0.468497f, -0.421580f, 0.178840f, -0.533838f,
+ -0.029471f, -0.076296f, 0.197426f, -0.187908f, -0.003950f, -0.065740f,
+ 0.085165f, -0.039674f, -5.640702f, 1.909538f, -1.434604f, 3.294606f,
+ -0.788812f, 0.196864f, 0.057012f, -0.019757f, 0.336233f, 0.075378f,
+ 0.081503f, 0.491864f, -1.899470f, -1.764173f, -1.888137f, -1.762343f,
+ 0.845542f, 0.202285f, 0.381948f, -0.150996f, 0.556893f, -0.305354f,
+ 0.561482f, -0.021974f, -0.703117f, 0.268638f, -0.665736f, 1.191005f,
+ -0.081568f, -0.115653f, 0.272029f, -0.140074f, 0.072683f, 0.092651f,
+ -0.472287f, -0.055790f, -0.434425f, 0.352055f, 0.048246f, 0.372865f,
+ 0.111499f, -0.338304f, 0.739133f, 0.156519f, -0.594644f, 0.137295f,
+ 0.613350f, -0.165102f, -1.003731f, 0.043070f, -0.887896f, -0.174202f,
};
-static const float vp9_var_rd_part_nn_bias_64_layer0[8] = {
- 0.659717f, 0.120912f, 0.329894f, -1.586385f,
- 1.715839f, 0.085754f, 2.038774f, 0.268119f,
+static const float vp9_part_split_nn_bias_64_layer0[NODES] = {
+ 1.182714f, 0.000000f, 0.902019f, 0.953115f,
+ -1.372486f, -1.288740f, -0.155144f, -3.041362f,
};
-static const float vp9_var_rd_part_nn_weights_64_layer1[8 * LABELS] = {
- -3.445586f, 2.375620f, 1.236970f, 0.804030f,
- -2.448384f, 2.827254f, 2.291478f, 0.790252f,
+static const float vp9_part_split_nn_weights_64_layer1[NODES * LABELS] = {
+ 0.841214f, 0.456016f, 0.869270f, 1.692999f,
+ -1.700494f, -0.911761f, 0.030111f, -1.447548f,
};
-static const float vp9_var_rd_part_nn_bias_64_layer1[LABELS] = {
- -1.16608453f,
+static const float vp9_part_split_nn_bias_64_layer1[LABELS] = {
+ 1.17782545f,
};
-static const NN_CONFIG vp9_var_rd_part_nnconfig_64 = {
+static const NN_CONFIG vp9_part_split_nnconfig_64 = {
FEATURES, // num_inputs
LABELS, // num_outputs
1, // num_hidden_layers
{
- 8,
+ NODES,
}, // num_hidden_nodes
{
- vp9_var_rd_part_nn_weights_64_layer0,
- vp9_var_rd_part_nn_weights_64_layer1,
+ vp9_part_split_nn_weights_64_layer0,
+ vp9_part_split_nn_weights_64_layer1,
},
{
- vp9_var_rd_part_nn_bias_64_layer0,
- vp9_var_rd_part_nn_bias_64_layer1,
+ vp9_part_split_nn_bias_64_layer0,
+ vp9_part_split_nn_bias_64_layer1,
},
};
-static const float vp9_var_rd_part_nn_weights_32_layer0[FEATURES * 8] = {
- 0.022420f, -0.032201f, 1.228065f, -2.767655f, 1.928743f, 0.566863f,
- 0.459229f, 0.422048f, 0.833395f, 0.822960f, -0.232227f, 0.586895f,
- 0.442856f, -0.018564f, 0.227672f, -1.291306f, 0.119428f, -0.776563f,
- -0.042947f, 0.183129f, 0.592231f, 1.174859f, -0.503868f, 0.270102f,
- -0.330537f, -0.036340f, 1.144630f, 1.783710f, 1.216929f, 2.038085f,
- 0.373782f, -0.430258f, 1.957002f, 1.383908f, 2.012261f, 1.585693f,
- -0.394399f, -0.337523f, -0.238335f, 0.007819f, -0.368294f, 0.437875f,
- -0.318923f, -0.242000f, 2.276263f, 1.501432f, 0.645706f, 0.344774f,
+static const float vp9_part_split_nn_weights_32_layer0[FEATURES * NODES] = {
+ -0.105488f, -0.218662f, 0.010980f, -0.226979f, 0.028076f, 0.743430f,
+ 0.789266f, 0.031907f, -1.464200f, 0.222336f, -1.068493f, -0.052712f,
+ -0.176181f, -0.102654f, -0.973932f, -0.182637f, -0.198000f, 0.335977f,
+ 0.271346f, 0.133005f, 1.674203f, 0.689567f, 0.657133f, 0.283524f,
+ 0.115529f, 0.738327f, 0.317184f, -0.179736f, 0.403691f, 0.679350f,
+ 0.048925f, 0.271338f, -1.538921f, -0.900737f, -1.377845f, 0.084245f,
+ 0.803122f, -0.107806f, 0.103045f, -0.023335f, -0.098116f, -0.127809f,
+ 0.037665f, -0.523225f, 1.622185f, 1.903999f, 1.358889f, 1.680785f,
+ 0.027743f, 0.117906f, -0.158810f, 0.057775f, 0.168257f, 0.062414f,
+ 0.086228f, -0.087381f, -3.066082f, 3.021855f, -4.092155f, 2.550104f,
+ -0.230022f, -0.207445f, -0.000347f, 0.034042f, 0.097057f, 0.220088f,
+ -0.228841f, -0.029405f, -1.507174f, -1.455184f, 2.624904f, 2.643355f,
+ 0.319912f, 0.585531f, -1.018225f, -0.699606f, 1.026490f, 0.169952f,
+ -0.093579f, -0.142352f, -0.107256f, 0.059598f, 0.043190f, 0.507543f,
+ -0.138617f, 0.030197f, 0.059574f, -0.634051f, -0.586724f, -0.148020f,
+ -0.334380f, 0.459547f, 1.620600f, 0.496850f, 0.639480f, -0.465715f,
};
-static const float vp9_var_rd_part_nn_bias_32_layer0[8] = {
- -0.023846f, -1.348117f, 1.365007f, -1.644164f,
- 0.062992f, 1.257980f, -0.098642f, 1.388472f,
+static const float vp9_part_split_nn_bias_32_layer0[NODES] = {
+ -1.125885f, 0.753197f, -0.825808f, 0.004839f,
+ 0.583920f, 0.718062f, 0.976741f, 0.796188f,
};
-static const float vp9_var_rd_part_nn_weights_32_layer1[8 * LABELS] = {
- 3.016729f, 0.622684f, -1.021302f, 1.490383f,
- 1.702046f, -2.964618f, 0.689045f, 1.711754f,
+static const float vp9_part_split_nn_weights_32_layer1[NODES * LABELS] = {
+ -0.458745f, 0.724624f, -0.479720f, -2.199872f,
+ 1.162661f, 1.194153f, -0.716896f, 0.824080f,
};
-static const float vp9_var_rd_part_nn_bias_32_layer1[LABELS] = {
- -1.28798676f,
+static const float vp9_part_split_nn_bias_32_layer1[LABELS] = {
+ 0.71644074f,
};
-static const NN_CONFIG vp9_var_rd_part_nnconfig_32 = {
+static const NN_CONFIG vp9_part_split_nnconfig_32 = {
FEATURES, // num_inputs
LABELS, // num_outputs
1, // num_hidden_layers
{
- 8,
+ NODES,
}, // num_hidden_nodes
{
- vp9_var_rd_part_nn_weights_32_layer0,
- vp9_var_rd_part_nn_weights_32_layer1,
+ vp9_part_split_nn_weights_32_layer0,
+ vp9_part_split_nn_weights_32_layer1,
},
{
- vp9_var_rd_part_nn_bias_32_layer0,
- vp9_var_rd_part_nn_bias_32_layer1,
+ vp9_part_split_nn_bias_32_layer0,
+ vp9_part_split_nn_bias_32_layer1,
},
};
-static const float vp9_var_rd_part_nn_weights_16_layer0[FEATURES * 8] = {
- -0.726813f, -0.026748f, 1.376946f, 1.467961f, 1.961810f, 1.690412f,
- 0.596484f, -0.261486f, -0.310905f, -0.366311f, -1.300086f, -0.534336f,
- 0.040520f, -0.032391f, -1.194214f, 2.438063f, -3.915334f, 1.997270f,
- 0.673696f, -0.676393f, 1.654886f, 1.553838f, 1.129691f, 1.360201f,
- 0.255001f, 0.336442f, -0.487759f, -0.634555f, 0.479170f, -0.110475f,
- -0.661852f, -0.158872f, -0.350243f, -0.303957f, -0.045018f, 0.586151f,
- -0.262463f, 0.228079f, -1.688776f, -1.594502f, -2.261078f, -1.802535f,
- 0.034748f, -0.028476f, 2.713258f, 0.212446f, -1.529202f, -2.560178f,
+static const float vp9_part_split_nn_weights_16_layer0[FEATURES * NODES] = {
+ -0.003629f, -0.046852f, 0.220428f, -0.033042f, 0.049365f, 0.112818f,
+ -0.306149f, -0.005872f, 1.066947f, -2.290226f, 2.159505f, -0.618714f,
+ -0.213294f, 0.451372f, -0.199459f, 0.223730f, -0.321709f, 0.063364f,
+ 0.148704f, -0.293371f, 0.077225f, -0.421947f, -0.515543f, -0.240975f,
+ -0.418516f, 1.036523f, -0.009165f, 0.032484f, 1.086549f, 0.220322f,
+ -0.247585f, -0.221232f, -0.225050f, 0.993051f, 0.285907f, 1.308846f,
+ 0.707456f, 0.335152f, 0.234556f, 0.264590f, -0.078033f, 0.542226f,
+ 0.057777f, 0.163471f, 0.039245f, -0.725960f, 0.963780f, -0.972001f,
+ 0.252237f, -0.192745f, -0.836571f, -0.460539f, -0.528713f, -0.160198f,
+ -0.621108f, 0.486405f, -0.221923f, 1.519426f, -0.857871f, 0.411595f,
+ 0.947188f, 0.203339f, 0.174526f, 0.016382f, 0.256879f, 0.049818f,
+ 0.057836f, -0.659096f, 0.459894f, 0.174695f, 0.379359f, 0.062530f,
+ -0.210201f, -0.355788f, -0.208432f, -0.401723f, -0.115373f, 0.191336f,
+ -0.109342f, 0.002455f, -0.078746f, -0.391871f, 0.149892f, -0.239615f,
+ -0.520709f, 0.118568f, -0.437975f, 0.118116f, -0.565426f, -0.206446f,
+ 0.113407f, 0.558894f, 0.534627f, 1.154350f, -0.116833f, 1.723311f,
};
-static const float vp9_var_rd_part_nn_bias_16_layer0[8] = {
- 0.495983f, 1.858545f, 0.162974f, 1.992247f,
- -2.698863f, 0.110020f, 0.550830f, 0.420941f,
+static const float vp9_part_split_nn_bias_16_layer0[NODES] = {
+ 0.013109f, -0.034341f, 0.679845f, -0.035781f,
+ -0.104183f, 0.098055f, -0.041130f, 0.160107f,
};
-static const float vp9_var_rd_part_nn_weights_16_layer1[8 * LABELS] = {
- 1.768409f, -1.394240f, 1.076846f, -1.762808f,
- 1.517405f, 0.535195f, -0.426827f, 1.002272f,
+static const float vp9_part_split_nn_weights_16_layer1[NODES * LABELS] = {
+ 1.499564f, -0.403259f, 1.366532f, -0.469868f,
+ 0.482227f, -2.076697f, 0.527691f, 0.540495f,
};
-static const float vp9_var_rd_part_nn_bias_16_layer1[LABELS] = {
- -1.65894794f,
+static const float vp9_part_split_nn_bias_16_layer1[LABELS] = {
+ 0.01134653f,
};
-static const NN_CONFIG vp9_var_rd_part_nnconfig_16 = {
+static const NN_CONFIG vp9_part_split_nnconfig_16 = {
FEATURES, // num_inputs
LABELS, // num_outputs
1, // num_hidden_layers
{
- 8,
+ NODES,
}, // num_hidden_nodes
{
- vp9_var_rd_part_nn_weights_16_layer0,
- vp9_var_rd_part_nn_weights_16_layer1,
+ vp9_part_split_nn_weights_16_layer0,
+ vp9_part_split_nn_weights_16_layer1,
},
{
- vp9_var_rd_part_nn_bias_16_layer0,
- vp9_var_rd_part_nn_bias_16_layer1,
+ vp9_part_split_nn_bias_16_layer0,
+ vp9_part_split_nn_bias_16_layer1,
},
};
-static const float vp9_var_rd_part_nn_weights_8_layer0[FEATURES * 8] = {
- -0.804900f, -1.214983f, 0.840202f, 0.686566f, 0.155804f, 0.025542f,
- -1.244635f, -0.368403f, 0.364150f, 1.081073f, 0.552387f, 0.452715f,
- 0.652968f, -0.293058f, 0.048967f, 0.021240f, -0.662981f, 0.424700f,
- 0.008293f, -0.013088f, 0.747007f, -1.453907f, -1.498226f, 1.593252f,
- -0.239557f, -0.143766f, 0.064311f, 1.320998f, -0.477411f, 0.026374f,
- 0.730884f, -0.675124f, 0.965521f, 0.863658f, 0.809186f, 0.812280f,
- 0.513131f, 0.185102f, 0.211354f, 0.793666f, 0.121714f, -0.015383f,
- -0.650980f, -0.046581f, 0.911141f, 0.806319f, 0.974773f, 0.815893f,
+static const float vp9_part_split_nn_weights_8_layer0[FEATURES * NODES] = {
+ -0.668875f, -0.159078f, -0.062663f, -0.483785f, -0.146814f, -0.608975f,
+ -0.589145f, 0.203704f, -0.051007f, -0.113769f, -0.477511f, -0.122603f,
+ -1.329890f, 1.403386f, 0.199636f, -0.161139f, 2.182090f, -0.014307f,
+ 0.015755f, -0.208468f, 0.884353f, 0.815920f, 0.632464f, 0.838225f,
+ 1.369483f, -0.029068f, 0.570213f, -0.573546f, 0.029617f, 0.562054f,
+ -0.653093f, -0.211910f, -0.661013f, -0.384418f, -0.574038f, -0.510069f,
+ 0.173047f, -0.274231f, -1.044008f, -0.422040f, -0.810296f, 0.144069f,
+ -0.406704f, 0.411230f, -0.144023f, 0.745651f, -0.595091f, 0.111787f,
+ 0.840651f, 0.030123f, -0.242155f, 0.101486f, -0.017889f, -0.254467f,
+ -0.285407f, -0.076675f, -0.549542f, -0.013544f, -0.686566f, -0.755150f,
+ 1.623949f, -0.286369f, 0.170976f, 0.016442f, -0.598353f, -0.038540f,
+ 0.202597f, -0.933582f, 0.599510f, 0.362273f, 0.577722f, 0.477603f,
+ 0.767097f, 0.431532f, 0.457034f, 0.223279f, 0.381349f, 0.033777f,
+ 0.423923f, -0.664762f, 0.385662f, 0.075744f, 0.182681f, 0.024118f,
+ 0.319408f, -0.528864f, 0.976537f, -0.305971f, -0.189380f, -0.241689f,
+ -1.318092f, 0.088647f, -0.109030f, -0.945654f, 1.082797f, 0.184564f,
};
-static const float vp9_var_rd_part_nn_bias_8_layer0[8] = {
- 0.176134f, 0.651308f, 2.007761f, 0.068812f,
- 1.061517f, 1.487161f, -2.308147f, 1.099828f,
+static const float vp9_part_split_nn_bias_8_layer0[NODES] = {
+ -0.237472f, 2.051396f, 0.297062f, -0.730194f,
+ 0.060472f, -0.565959f, 0.560869f, -0.395448f,
};
-static const float vp9_var_rd_part_nn_weights_8_layer1[8 * LABELS] = {
- 0.683032f, 1.326393f, -1.661539f, 1.438920f,
- 1.118023f, -2.237380f, 1.518468f, 2.010416f,
+static const float vp9_part_split_nn_weights_8_layer1[NODES * LABELS] = {
+ 0.568121f, 1.575915f, -0.544309f, 0.751595f,
+ -0.117911f, -1.340730f, -0.739671f, 0.661216f,
};
-static const float vp9_var_rd_part_nn_bias_8_layer1[LABELS] = {
- -1.65423989f,
+static const float vp9_part_split_nn_bias_8_layer1[LABELS] = {
+ -0.63375306f,
};
-static const NN_CONFIG vp9_var_rd_part_nnconfig_8 = {
+static const NN_CONFIG vp9_part_split_nnconfig_8 = {
FEATURES, // num_inputs
LABELS, // num_outputs
1, // num_hidden_layers
{
- 8,
+ NODES,
}, // num_hidden_nodes
{
- vp9_var_rd_part_nn_weights_8_layer0,
- vp9_var_rd_part_nn_weights_8_layer1,
+ vp9_part_split_nn_weights_8_layer0,
+ vp9_part_split_nn_weights_8_layer1,
},
{
- vp9_var_rd_part_nn_bias_8_layer0,
- vp9_var_rd_part_nn_bias_8_layer1,
+ vp9_part_split_nn_bias_8_layer0,
+ vp9_part_split_nn_bias_8_layer1,
},
};
+#undef NODES
#undef FEATURES
#undef LABELS