均方误差损失函数(MSE)和交叉熵损失函数详解

为什么需要损失函数

前面的文章我们已经从模型角度介绍了损失函数,对于神经网络的训练,首先根据特征输入和初始的参数,前向传播计算出预测结果,然后与真实结果进行比较,得到它们之间的差值。

损失函数又可称为代价函数或目标函数,是用来衡量算法模型预测结果和真实标签之间吻合程度(误差)的函数。通常会选择非负数作为预测值和真实值之间的误差,误差越小,则模型越好。

     有了这个损失函数,我们便可以采用优化算法更新网络参数,使得训练样本的平均损失最小。

而损失函数根据任务的不同,也可以分为不同的类型,下面进行介绍。

 

均方误差损失函数(MSE)

其中f(xi)是第i个样本的模型预测值,Yi是第i个样本的真实标签值,二者差值求平方,一共有n个样本,平方和求平均。

在回归问题中,均方误差损失函数用于度量样本点到回归曲线的距离,通过最小化平方损失使样本点可以更好地拟合回归曲线。由于无参数、计算成本低和具有明确物理意义等优点,MSE已成为一种优秀的距离度量方法。尽管MSE在图像和语音处理方面表现较弱,但它仍是评价信号质量的标准。

代码实现:

import numpy as np

# 自定义实现

def MSELoss(x:list,y:list):

    """    x:list,代表模型预测的一组数据    y:list,代表真实样本对应的一组数据    """

    assert len(x)==len(y)

    x=np.array(x)

    y=np.array(y)

    loss=np.sum(np.square(x - y)) / len(x)

    return loss

#计算过程举例x=[1,2]y=[0,1]loss=((1-0)**2 + (2-1)**2)÷2=(1+1)÷2=1

# pytorch版本

loss = nn.MSELoss()

predict = torch.randn(3, 5, requires_grad=True)

target = torch.randn(3, 5)

output = loss(predict, target)

从代码中可以看到,MSELoss需要的两个参数分别是真实标签值和模型预测值,两者可以是任意形状的张量,但二者形状和维度需要一致。就是说每个样本的预测值和标签值可以是任意维度的张量,这点要注意,在实际应用中时要认真考虑标签的形状。

 

交叉熵损失

pytorch中的CrossEntropyLoss()函数实际就是先把输出结果进行sigmoid,随后再放到传统的交叉熵函数中,就会得到结果。

交叉熵是信息论中的一个概念,最初用于估算平均编码长度,引入机器学习后,用于评估当前训练得到的概率分布与真实分布的差异情况。为了使神经网络的每一层输出从线性组合转为非线性逼近,以提高模型的预测精度,在以交叉熵为损失函数的神经网络模型中一般选用tanh、sigmoid、softmax或ReLU作为激活函数。

交叉熵损失函数刻画了实际输出概率与期望输出概率之间的相似度,也就是交叉熵的值越小,两个概率分布就越接近,特别是在正负样本不均衡的分类问题中,常用交叉熵作为损失函数。目前,交叉熵损失函数是卷积神经网络中最常使用的分类损失函数,它可以有效避免梯度消散。在二分类情况下也叫做对数损失函数。

一般的交叉熵用数学公式表示是:

-Q(x) log P(x)

其中Q(x)是真实值,P(x)是预测值。

当p(x)和Q(x)是矩阵的时候,就分别对其计算,然后求和即可

在pytorch中的交叉熵损失CrossEntropyLoss 包含了两部分,softmax和交叉熵计算,下面分别介绍这两部分

假设有 N 个样本,每个样本属于 C 个类别之一。对于第 i 个样本,它的真实类别标签为 y_i,模型的预测输出 logits 为xi​=(xi1​,xi2​,…,xiC​),其中xic表示第i个样本在第c 类别上的原始输出分数(logits)(注意这里是预测分数值,不是概率值)。

交叉熵损失的计算步骤如下:

(1)预测概率分布
对 logits 进行 softmax 操作,将预测输出其转换为概率分布:

其中 pic表示第i个样本属于第c类别的预测概率。

   此时预测输出的概率分布是f(xi)=(pi1,pi2,…,piC)

  1. 真实概率分布:

对于样本i,其真实分布会根据归属的类别自动创建一个one-hot概率分布,即所属类别的位置为1,其它均为0,则会输出一个one-hot概率分布Q(xi)=(qi1,qi2,…,qiC)。比如5个类别,第i个样本的真实类别为3,则Q(xi)[0,0,1,0,0]。

实际计算的时候不难发现target中为0经过乘法都是0了,因此最后只剩下正确类型的这个损失差距 最后公式可以演变成 - log Q(x)

(3)负对数似然(Negative Log-Likelihood)
对于单个样本,计算负对数似然:

其中是第i 个样本的交叉熵损失,但事实上,只在真实类别位置处概率为1,其余位置均为0,因此,可以进一步简化为

 其中,yi代表第i个样本在真实类别j=yi处的预测概率。其本质是利用真实概率分布筛选了预测概率分布在真实类别的概率值,并求负对数似然。

对于N个样本,则对这N个样本的交叉熵损失函数求和再求平均即可。

  1. 代码解析

cross_loss = torch.nn.CrossEntropyLoss(reduction='none')

#注意这里的预测输入是N*C,其中N是样本数,C是类别数,此时还不是概率,所以使用交叉熵损失函数的网络最后不需要softmax,损失函数自带。

input = torch.tensor([[4, 14, 19, 15],

                       [18, 6, 14, 7],

                       [18, 5, 3, 16]], dtype=torch.float)

#真实标签是每个样本的类别(1*N),api会自动生成one-hot概率分布(N*C)

    target = torch.tensor([0, 3, 2])

  #然后计算损失函数值

  loss = cross_loss(input, target)

    torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean', label_smoothing=0.0)

参数

    参数说明:

1. weight:

  • CE 和 BCE 系列都有此参数,用于为每个类别的 loss 设置权重,常用于类别不均衡问题;
  • weight 必须是 float 类型的 1D tensor,长度和类别长度一致:weight = torch.from_numpy(np.array([0.6, 0.2, 0.2])).float().to(device)
  • 注意:weight 加起来未必一定要等于 1,类 c 对应的 weight 为 W_c = (N-N_c) / N,数目越多的类,weight 越小,weight 越大,此类得到的 loss 被放大;

2. ignore_index:

  • 其中 BCE 系列没有此参数,此参数用于指定忽略某些类别的 loss;

3. size_average:

  • 该参数指定 loss 是否在一个 batch 内平均,即是否除以 N,目前此参数已经被弃用

4. reduce:

  • 目前此参数已经被弃用

5. reduction:

  • 此参数在新版本中是为了取代 ”size_average“ 和 "reduce" 参数的;
  • mean (default):返回 N 个 loss 的平均值;
  • sum:返回 N 个 loss 的 sum;
  • None:直接返回一个 batch 中的 N 个 loss;

6. pos_weight:

  • 只有 BCEWithLogits 系列有次参数;
  • 与 weight 参数的区别是:WIP;

(5)nn.CrossEntropyLoss=nn.LogSoftmax(dim=1)+nn.NLLLoss()

(5)多维交叉熵

文本类数据通常是三维数据,预测通常是(batch_size,seq_length,num_vocab_size),而target是(batch_size,seq_length),此时需要预测的形状,通常使用permute操作成 (batch_size,num_vocab_size,seq_length)

参考资料

https://zhuanlan.zhihu.com/p/261059231

交叉熵的数学原理及应用——pytorch中的CrossEntropyLoss()函数 - 不愿透漏姓名的王建森 - 博客园

MSELoss — PyTorch 2.5 documentation


http://www.niftyadmin.cn/n/5795815.html

相关文章

c++数据结构算法复习基础--12--排序算法-常见笔试面试问题

1、STL里sort算法用的是什么排序算法? 快速排序算法。 插入排序(待排序序列个数<32时,系统默认32)。 递归层数太深,转成堆排序。 #include<algorithm> //算法库,头文件使用了快速排序: sort原码: 小到大 _EXPORT_STD template <class _RanIt> _CON…

中国人工智能学会技术白皮书

中国人工智能学会的技术白皮书具有多方面的重要作用&#xff0c;是极具权威性和价值的参考资料。 看看编委会和编写组的阵容&#xff0c;还是很让人觉得靠谱的 如何下载这份资料呢&#xff1f;下面跟着步骤来吧 步骤一&#xff1a;进入中国智能学会官网。百度搜索“中国智能学…

移除链表元素(最优解)

题目来源 203. 移除链表元素 - 力扣&#xff08;LeetCode&#xff09; 题目描述 给你一个链表的头节点 head 和一个整数 val &#xff0c;请你删除链表中所有满足 Node.val val 的节点&#xff0c;并返回 新的头节点 。 示例 1&#xff1a; 输入&#xff1a;head [1,2,6,3,4…

【杂谈】虚拟机与EasyConnect运行巧设:Reqable助力指定应用流量专属化

场景 公司用的是EasyConnect&#xff0c;这个软件非常好用&#xff0c;也非常稳定&#xff0c;但是有个缺点&#xff0c;就是会无条件拦截本机所有流量&#xff0c;而且会加入所有运行的exe程序&#xff0c;实现流量全部走代理。 准备材料 一个windows/Linux 桌面版虚拟机Ea…

Webpack学习笔记(5)

1.拆分开发环境和生产环境配置 很多配置在开发环境和生产环境存在不一致的情况&#xff0c;比如开发环境没有必要设置缓存&#xff0c;生产环境需要设置公共路径等等。 2.公共路径 使用publicPath配置项&#xff0c;可以通过它指定应用程序中所有资源的基础路径。 webpack.…

Eureka服务注册源码

spring-cloud-starter-netflix-eureka-client 版本是3.0.3 核心装备类&#xff1a; EurekaClientAutoConfiguration EurekaDiscoveryClientConfiguration 核心类&#xff0c;以及引用的关系如下 EurekaRegistration - EurekaInstanceConfigBean 实例配置bean- ApplicationInfo…

JaxaFx学习(三)

目录&#xff1a; &#xff08;1&#xff09;JavaFx MVVM架构实现 &#xff08;2&#xff09;javaFX知识点 &#xff08;3&#xff09;JavaFx的MVC架构 &#xff08;4&#xff09;JavaFx事件处理机制 &#xff08;5&#xff09;多窗体编程 &#xff08;6&#xff09;数据…

element-puls封装表单验证

项目场景&#xff1a; 提示&#xff1a;这里简述项目相关背景&#xff1a; 在做项目中会有一些简单的表单非空验证&#xff0c;这些验证比较简单&#xff0c;就是代码看着有点多&#xff0c;做起来浪费时间&#xff0c;所以我们可以将这个方法封装起来&#xff0c;然后挂载全…