4

TensorFlow的模型准确率计量方法

 3 years ago
source link: https://yinguobing.com/metrics-in-tensorflow/
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

封面图片: Piret Ilver

TensorFlow中的 metrics 模块提供了一系列模型指标评估方法。当使用 Keras 的 Model.fit 函数训练时,可以直接在编译模型时传入该类别的一个实例即可实现自动计算。但是当手动实现训练循环时,需要自行实现评估逻辑,手动更新计量指标。这里以 CategoricalAccuracy (类别准确率)为例,说明具体的使用方法与注意事项。

CategoricalAccuracy 可以用来计算分类模型的准确率。它需要至少两个输入变量:独热标签与预测值。例如训练时某个分类任务的标签为 [0, 0, 1] ,模型的输出为 [0.1, 0,1, 0.9] 。使用以下代码可以获得当前的准确率计量结果。

m = tf.keras.metrics.CategoricalAccuracy()
m.update_state([0, 0, 1], [0.1, 0.1, 0.9])
m.result()

# 输出结果为 
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>
计算分类模型当前准确率的代码

即当前模型的准确率为1.0。

需要注意的是准确率的计量结果是累计的。假设第二次评估的标签与预测值分比为 [0, 0, 1][0.1, 0.9, 0.1] 时,即模型做出了错误的预测,继续使用如下代码获取最新的评估结果。

m.update_state([0, 0, 1], [0.1, 0.9, 0.1])
m.result()

# 输出结果为
<tf.Tensor: shape=(), dtype=float32, numpy=0.5>
更新模型的累计准确率

综合两次评估可以得出当前模型的准确率为0.5。

在实际训练中,我们希望计量一段训练过程中的模型准确率,并以此为依据决定是否保存该模型。因此计量过程存在开始与结束节点。当通过结束节点之后,需要重置该计量对象。

m.reset_state()
m.result()

# 输出结果为
<tf.Tensor: shape=(), dtype=float32, numpy=0.0>
模型准确率重置为0

重置的具体节点则可以灵活设定。如果你的数据集不大,可以将epoch结尾作为重置节点。如果你的数据集太过庞大,则可以在保存模型后重置。只要牢记你设定该计量对象的目的便不难做出抉择。


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK