当前位置: 首页 > news >正文

yolov26改进 | 损失函数改进篇 | 最新ShapeIoU、InnerShapeIoU损失助力细节涨点(含三十余种损失函数改进方法)

开始讲解之前推荐一下我的专栏,本专栏的内容支持(分类、检测、分割、追踪、关键点检测),专栏目前为限时折扣,欢迎大家订阅本专栏,本专栏每周更新5-7篇最新机制,更有包含我所有改进的文件和交流群提供给大家,本人定期在群内分享发表论文方法和经验。


一、本文介绍

本文给大家带来的改进机制是损失函数的改进机制标题虽然提到了ShapeIoU和InnnerShapeIoU但是本文的内容包括过去到现在的百分之九十以上的损失函数的实现,同时使用方法非常简单,在本文的末尾还会教大家在改进模型时何时添加损失函数才能达到最好的效果,同时在开始讲解之前推荐一下我的专栏,本专栏的内容支持(分类、检测、分割、追踪、关键点检测),专栏目前为限时折扣,欢迎大家订阅本专栏,本专栏每周更新3-5篇最新机制,更有包含我所有改进的文件和交流群提供给大家,本文支持的损失函数共有如下图片所示

欢迎大家订阅我的专栏一起学习YOLO!

专栏链接:YOLOv26有效涨点专栏包含:Conv、注意力机制、主干/Backbone、损失函数、优化器、后处理等改进机制


目录

一、本文介绍

二、ShapeIoU

三、核心代码

四、 损失函数使用方式

4.1 修改一

4.2 修改二

4.3 步骤三(未必要修改)

4.4 步骤四

五、总结


二、ShapeIoU

官方论文地址:官方论文地址

官方代码地址:官方代码地址


这幅图展示了在目标检测任务中,两种不同情况或方法下的边界框回归的对比。

GT (Ground Truth): 用桃色框表示,指的是图像中物体实际的位置和形状。在目标检测中,算法试图尽可能准确地预测这个框。

Anchor: 蓝色框代表一个预定义的框,是算法预设的一系列框,用于与GT框进行匹配,寻找最佳的候选框。

在图中,我们看到四个不同的情况(A、B、C、D),每个都显示了一个anchor与GT的对比,并给出了IoU(交并比)的数值。IoU是一个常用的度量,用来评估预测边界框与真实边界框之间的重叠程度。

论文中给了一堆公式,大家有兴趣的可以看看。


三、核心代码

下面的代码的使用方式看章节四。

def bbox_iou( box1: torch.Tensor, box2: torch.Tensor, xywh: bool = True, GIoU: bool = False, DIoU: bool = False, CIoU: bool = False, eps: float = 1e-7, *, SIoU: bool = False, EIoU: bool = False, WIoU: bool = False, ShapeIoU: bool = False, Focal: bool = False, Inner: bool = False, ratio: float = 0.7, alpha: float = 1.0, gamma: float = 0.5, scale: bool = False, shape_scale: float = 0.0, ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: """Calculate IoU / GIoU / DIoU / CIoU / EIoU / SIoU / WIoU / ShapeIoU between bounding boxes. Args: box1 (torch.Tensor): Boxes with last dimension 4. box2 (torch.Tensor): Boxes with last dimension 4. xywh (bool): If True, boxes are in (x, y, w, h) format. If False, boxes are in (x1, y1, x2, y2) format. GIoU (bool): Enable Generalized IoU. DIoU (bool): Enable Distance IoU. CIoU (bool): Enable Complete IoU. SIoU (bool): Enable SIoU. EIoU (bool): Enable EIoU. WIoU (bool): Enable WIoU. ShapeIoU (bool): Enable ShapeIoU. Focal (bool): Enable Focal-IoU style output. Inner (bool): Enable Inner-IoU. ratio (float): Inner box scale ratio. alpha (float): Power factor used by some extended losses. gamma (float): Focal exponent. scale (bool): Enable WIoU v2 / v3 dynamic scaling. shape_scale (float): ShapeIoU shape scale parameter. eps (float): Small value to avoid division by zero. Returns: torch.Tensor or Tuple[torch.Tensor, ...] """ # ------------------------------------------------------------ # 1. Get coordinates of bounding boxes # ------------------------------------------------------------ if xywh: # Input format: x, y, w, h (x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1) w1 = w1.clamp(min=eps) h1 = h1.clamp(min=eps) w2 = w2.clamp(min=eps) h2 = h2.clamp(min=eps) else: # Input format: x1, y1, x2, y2 b1_x1_ori, b1_y1_ori, b1_x2_ori, b1_y2_ori = box1.chunk(4, -1) b2_x1_ori, b2_y1_ori, b2_x2_ori, b2_y2_ori = box2.chunk(4, -1) w1 = (b1_x2_ori - b1_x1_ori).clamp(min=eps) h1 = (b1_y2_ori - b1_y1_ori).clamp(min=eps) w2 = (b2_x2_ori - b2_x1_ori).clamp(min=eps) h2 = (b2_y2_ori - b2_y1_ori).clamp(min=eps) x1 = (b1_x1_ori + b1_x2_ori) / 2 y1 = (b1_y1_ori + b1_y2_ori) / 2 x2 = (b2_x1_ori + b2_x2_ori) / 2 y2 = (b2_y1_ori + b2_y2_ori) / 2 # ------------------------------------------------------------ # 2. Normal box or Inner box # ------------------------------------------------------------ if Inner: # Inner-IoU: # 用 ratio 缩小 box,只在缩小后的 inner box 上计算 IoU。 # ratio=0.7 表示使用原框中心区域的 70% 宽高。 r = torch.as_tensor(ratio, dtype=box1.dtype, device=box1.device).clamp(min=eps) w1_half, h1_half = (w1 * r) / 2, (h1 * r) / 2 w2_half, h2_half = (w2 * r) / 2, (h2 * r) / 2 b1_x1, b1_x2 = x1 - w1_half, x1 + w1_half b1_y1, b1_y2 = y1 - h1_half, y1 + h1_half b2_x1, b2_x2 = x2 - w2_half, x2 + w2_half b2_y1, b2_y2 = y2 - h2_half, y2 + h2_half area1 = w1 * h1 * r * r area2 = w2 * h2 * r * r else: # Normal IoU box w1_half, h1_half = w1 / 2, h1 / 2 w2_half, h2_half = w2 / 2, h2 / 2 b1_x1, b1_x2 = x1 - w1_half, x1 + w1_half b1_y1, b1_y2 = y1 - h1_half, y1 + h1_half b2_x1, b2_x2 = x2 - w2_half, x2 + w2_half b2_y1, b2_y2 = y2 - h2_half, y2 + h2_half area1 = w1 * h1 area2 = w2 * h2 # ------------------------------------------------------------ # 3. Intersection and union # ------------------------------------------------------------ inter = ( (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp(min=0) * (b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)).clamp(min=0) ) union = area1 + area2 - inter + eps iou = inter / union iou = torch.nan_to_num(iou, nan=0.0, posinf=1.0, neginf=0.0) iou = iou.clamp(min=0.0, max=1.0) # Focal weight focal_weight = iou.clamp(min=0, max=1).pow(gamma) # ------------------------------------------------------------ # 4. Extended IoU losses # ------------------------------------------------------------ if CIoU or DIoU or GIoU or EIoU or SIoU or WIoU or ShapeIoU: cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1) ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1) if CIoU or DIoU or EIoU or SIoU or WIoU or ShapeIoU: c2 = cw.pow(2) + ch.pow(2) + eps rho2 = ( (b2_x1 + b2_x2 - b1_x1 - b1_x2).pow(2) + (b2_y1 + b2_y2 - b1_y1 - b1_y2).pow(2) ) / 4 # CIoU if CIoU: v = (4 / math.pi**2) * ( (w2 / (h2 + eps)).atan() - (w1 / (h1 + eps)).atan() ).pow(2) with torch.no_grad(): ciou_alpha = v / (v - iou + (1 + eps)) out = iou - (rho2 / c2 + v * ciou_alpha) return (out, focal_weight) if Focal else out # EIoU if EIoU: rho_w2 = ((b2_x2 - b2_x1) - (b1_x2 - b1_x1)).pow(2) rho_h2 = ((b2_y2 - b2_y1) - (b1_y2 - b1_y1)).pow(2) cw2 = (cw.pow(2) + eps).pow(alpha) ch2 = (ch.pow(2) + eps).pow(alpha) out = iou - (rho2 / c2 + rho_w2 / cw2 + rho_h2 / ch2) return (out, focal_weight) if Focal else out # SIoU if SIoU: s_cw = (b2_x1 + b2_x2 - b1_x1 - b1_x2) * 0.5 + eps s_ch = (b2_y1 + b2_y2 - b1_y1 - b1_y2) * 0.5 + eps sigma = (s_cw.pow(2) + s_ch.pow(2)).sqrt().clamp(min=eps) sin_alpha_1 = s_cw.abs() / sigma sin_alpha_2 = s_ch.abs() / sigma threshold = math.sqrt(2) / 2 sin_alpha = torch.where( sin_alpha_1 > threshold, sin_alpha_2, sin_alpha_1, ) sin_alpha = sin_alpha.clamp(min=-1 + eps, max=1 - eps) angle_cost = torch.cos(torch.arcsin(sin_alpha) * 2 - math.pi / 2) rho_x = (s_cw / (cw + eps)).pow(2) rho_y = (s_ch / (ch + eps)).pow(2) distance_gamma = angle_cost - 2 distance_cost = ( 2 - torch.exp(distance_gamma * rho_x) - torch.exp(distance_gamma * rho_y) ) omega_w = (w1 - w2).abs() / torch.max(w1, w2).clamp(min=eps) omega_h = (h1 - h2).abs() / torch.max(h1, h2).clamp(min=eps) shape_cost = ( (1 - torch.exp(-omega_w)).pow(4) + (1 - torch.exp(-omega_h)).pow(4) ) out = iou - (0.5 * (distance_cost + shape_cost) + eps).pow(alpha) return (out, focal_weight) if Focal else out # ShapeIoU if ShapeIoU: # Shape-Distance shape_scale_tensor = torch.as_tensor( shape_scale, dtype=box1.dtype, device=box1.device, ) w2_s = w2.clamp(min=eps).pow(shape_scale_tensor) h2_s = h2.clamp(min=eps).pow(shape_scale_tensor) shape_den = (w2_s + h2_s).clamp(min=eps) ww = 2 * w2_s / shape_den hh = 2 * h2_s / shape_den cw_shape = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1) ch_shape = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1) c2_shape = cw_shape.pow(2) + ch_shape.pow(2) + eps center_distance_x = ( (b2_x1 + b2_x2 - b1_x1 - b1_x2).pow(2) ) / 4 center_distance_y = ( (b2_y1 + b2_y2 - b1_y1 - b1_y2).pow(2) ) / 4 center_distance = hh * center_distance_x + ww * center_distance_y distance = center_distance / c2_shape # Shape-Shape omega_w = hh * (w1 - w2).abs() / torch.max(w1, w2).clamp(min=eps) omega_h = ww * (h1 - h2).abs() / torch.max(h1, h2).clamp(min=eps) shape_cost = ( (1 - torch.exp(-omega_w)).pow(4) + (1 - torch.exp(-omega_h)).pow(4) ) out = iou - distance - 0.5 * shape_cost return (out, focal_weight) if Focal else out # WIoU if WIoU: if Focal: raise RuntimeError("WIoU does not support Focal=True.") distance_weight = torch.exp(rho2 / c2) if scale: wiou_scale = WIoU_Scale(1 - iou) return ( wiou_scale._scaled_loss(), (1 - iou) * distance_weight, iou, ) return iou, distance_weight # DIoU out = iou - rho2 / c2 return (out, focal_weight) if Focal else out # GIoU c_area = cw * ch + eps out = iou - (((c_area - union) / c_area) + eps).pow(alpha) return (out, focal_weight) if Focal else out # Normal IoU return (iou, focal_weight) if Focal else iou

四、 损失函数使用方式

4.1 修改一

第一步我们需要找到如下的文件ultralytics/utils/metrics.py,找到如下的代码,下面的图片是原先的代码部分截图的正常样子,然后我们将整个代码块一将下面的个方法(这里这是部分截图)内容全部替换


4.2 修改二

第二步我们找到另一个文件如下->"ultralytics/utils/loss.py",我们按图进行修改.

iou = bbox_iou( pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, GIoU=False, DIoU=False, CIoU=True, EIoU=False, SIoU=False, WIoU=False, ShapeIoU=False, Focal=False, Inner=False, ratio=0.7, alpha=1.0, gamma=0.5, scale=False, shape_scale=0.0, eps=1e-7) # 默认使用 CIoU;GIoU / DIoU / CIoU / EIoU / SIoU / WIoU / ShapeIoU 建议一次只开一个。 # Inner=False 表示不启用 Inner-IoU;若使用 Inner-CIoU,只需设置 Inner=True,ratio 控制内框比例。 # ShapeIoU=False 表示不启用 ShapeIoU;若使用 ShapeIoU,将 CIoU=False, ShapeIoU=True。 # Focal 和 WIoU 会改变返回值结构,若开启需要同步修改后续 loss 计算。

4.3 步骤三(未必要修改)

我们找到另一个文件如下->"ultralytics/utils/loss.py"(步骤二的下一行),此处是使用Focus和WIoU时候需要修改的代码,你不使用跳过此步骤直接进行步骤三即可,

if type(iou) is tuple: if len(iou) == 2: # increased the weight of low/high IoU loss_iou = ((1 - iou[1].detach().squeeze()) * (1 - iou[0].squeeze()) * weight).sum() / target_scores_sum # Focal # lbox += (iou[1].detach().squeeze() * (1 - iou[0].squeeze())* weight).sum() / target_scores_sum # Focal-inv # 这里有两种方法,大家可以自行尝试,这里的Focal-inv也是文章中提出的. else: loss_iou = (iou[0] * iou[1] * weight).sum() / target_scores_sum else: loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum # iou loss

4.4 步骤四

我们还需要修改一处,找到如下的文件''ultralytics/utils/tal.py''然后找到其中下面图片的代码,用我给的代码替换红框内的代码。

return bbox_iou(gt_bboxes, pd_bboxes, xywh=False, GIoU=False, DIoU=False, CIoU=True, EIoU=False, SIoU=False, WIoU=False, ShapeIoU=False, Focal=False, Inner=True, ratio=0.7, alpha=1.0, gamma=0.5, eps=1e-7, scale=False, shape_scale=0.0).squeeze(-1).clamp_(0)

此处和loss.py里面的最好是使用同一个参数(但非必须),但Focus和WIoU此处不能使用需要注意.


五、总结

到此本文的正式分享内容就结束了,在这里给大家推荐我的YOLOv26改进有效涨点专栏,本专栏目前为新开的平均质量分98分,后期我会根据各种最新的前沿顶会进行论文复现,也会对一些老的改进机制进行补充,目前本专栏免费阅读(暂时,大家尽早关注不迷路~),如果大家觉得本文帮助到你了,订阅本专栏,关注后续更多的更新~

专栏链接:YOLOv26有效涨点专栏包含:Conv、注意力机制、主干/Backbone、损失函数、优化器、后处理等改进机制

http://www.cnnetsun.cn/news/2907849.html

相关文章:

  • 3步掌握d2s-editor:零基础玩转暗黑破坏神2存档修改
  • 如何快速掌握AI图层分离:5步提升设计效率的完整指南
  • 什么是 supremum pseudo-record?
  • FLEXPART模式实战:如何用后向轨迹分析锁定污染源(附Python后处理脚本)
  • 别再手动PS了!用Python+OpenCV给论文配图加局部放大镜,5分钟搞定
  • 第1章:架构基础
  • 如何免费获取抖音无水印高清视频:douyin-downloader完整指南
  • 生产级机器学习系统:防御性设计与系统性风险治理
  • 从零样本到思维分支:LLM推理增强的工业级落地路径
  • Docker分层构建缓存原理详解:零基础快速吃透镜像加速机制
  • MCU模拟比较器与DAC实战:低功耗监控与自动波形生成
  • SPI驱动非标准字长外设:硬件打包与软件模拟方案详解
  • BERTScore深度解析:为什么这个文本评估指标能碾压传统方法?
  • 小红书无水印下载终极指南:3分钟掌握批量采集技巧
  • 嵌入式定时器与DAC实战:从抗噪滤波到自动波形生成
  • 别再只用qemu-img了!QEMU快照的两种玩法(磁盘/检查点)与实战避坑指南
  • 终极指南:在Linux上安装Realtek 8922AE WiFi 7网卡驱动的完整教程
  • 抖音下载器开源项目实战教程:从零搭建24小时自动采集系统完整指南
  • 深入解析MC56F81xxxL中断与eDMA:从原理到实战配置指南
  • i.MX21 SSI接口AC97模式详解:寄存器配置与多通道音频驱动开发
  • 深入解析NXP LS1046A SEC队列接口与错误处理寄存器
  • 3步精通:开源工具高效下载MOOC课程
  • SAP UI5 没有 NgModule,但有自己的装配秩序
  • MC68SZ328 UART与Memory Stick主机控制器深度解析与实战配置
  • MC68377 QADC64模块详解:队列式ADC原理、寄存器配置与嵌入式数据采集实战
  • Windows本地实时语音转文字终极指南:5分钟搭建你的隐私安全助手
  • Linux jbd2_journal_recover日志恢复与superblock标记
  • Linux jbd2_journal_commit_transaction日志提交与forget链表
  • 【毕业设计】基于 SpringBoot 的数据资产备案与登记管理系统研究 适配企业数字化转型的数据资产登记系统开发与实践(源码+文档+远程调试,全bao定制等)
  • 深入解析MC68377 CTM9 DASM:输出比较与PWM模式实战指南