详解RT-DETR网络结构/数据集获取/环境搭建/训练/推理/验证/导出/部署

论文地址:RT-DETR论文地址

代码地址:RT-DETR官方下载地址

目录

一、本文介绍

二、RT-DETR的网络结构

2.1、模型概览

2.2、高效混合编码器

2.3、IoU感知查询选择

2.4、 可扩展的RT-DETR

三、RT-DERT的环境搭建

四、免费数据集获取

五、获取RT-DERT

5.1 渠道一官方版本

5.2 渠道二YOLOv8版本

六、RT-DERT的官方版本训练方法

6.1 步骤一 

6.2 步骤二 

七、RT-DERT的YOLO版本训练方法

7.1 方式一

7.2 方式二(推荐)

7.3 方式三 

八、RT-DERT的验证方法

8.1 参数讲解 

8.2 验证方法 

九、RT-DERT的推理方法

9.1 参数讲解

9.2 推理方法 

十、RT-DERT的导出方法

10.1 官方版本导出方法

10.2 YOLO版本导出方法

十一、全文总结


一、本文介绍

RT-DETR(Real-Time DEtection TRansformer)是一种新提出的目标检测模型,它利用Transformer的自注意力机制来处理图像数据,与YOLO系列有显著不同。不同于YOLO通过连续的卷积层直接对图像区域进行分析,RT-DETR采用Transformer架构中的自注意力机制,这允许模型更有效地理解图像中不同部分之间的关系。这种方法使得RT-DETR在处理图像中的复杂场景和多对象环境时,能够展现出更高的准确性和效率。此外,RT-DETR在保持高精度的同时,也针对实时处理进行了优化,使其适合需要快速响应的应用场景。这一结构上的创新使RT-DETR在目标检测领域成为了一种具有突破性的模型,展现出与传统方法不同的优势(从文章的实验结果来看RT-DETR在实时的目标检测上确实打败了YOLO),本文给大家带来的是关于RT-DETR网络结构、数据集获取、环境搭建、训练、推理、验证、导出、部署的讲解(同时RT-DETR支持官方版本训练和YOLOv8集成版本的训练我都会分别介绍)。 

下面是关于DETR系列的发展历程->

二、RT-DETR的网络结构

本文主要讲的是如何训练部署等方法具体的网络结构讲解可以看我的另一篇博客地址如下->

RT-DERT阅读笔记: RT-DETR论文阅读笔记(包括YOLO版本训练和官方版本训练)

 

2.1、模型概览

RT-DETR包括一个主干网络(backbone)、一个混合编码器(hybrid encoder)和一个带有辅助预测头的变换器解码器(transformer decoder)。模型架构的概览如下面的图片所示。

具体来说,我们利用主干网络的最后三个阶段的输出特征 {S3, S4, S5} 作为编码器的输入。混合编码器通过内尺度交互(intra-scale interaction)和跨尺度融合(cross-scale fusion)将多尺度特征转换成一系列图像特征。随后,采用IoU感知查询选择(IoU-aware query selection)从编码器输出序列中选择一定数量的图像特征,作为解码器的初始对象查询。最后,带有辅助预测头的解码器迭代优化对象查询,生成边框和置信度分数。

2.2、高效混合编码器

 计算瓶颈分析。为了加速训练收敛和提高性能,Zhu等人提出引入多尺度特征,并提出变形注意力机制来减少计算量。然而,尽管注意力机制的改进减少了计算开销,但输入序列长度的显著增加仍使编码器成为计算瓶颈,阻碍了DETR的实时实现。如[21]所报告,编码器占了49%的GFLOPs,但在Deformable-DETR中仅贡献了11%的AP。为了克服这一障碍,我们分析了多尺度变换器编码器中存在的计算冗余,并设计了一系列变体来证明内尺度和跨尺度特征的同时交互在计算上是低效的。

高级特征是从包含图像中对象丰富语义信息的低级特征中提取出来的。直觉上,在连接的多尺度特征上执行特征交互是多余的。为了验证这一观点,我们重新思考了编码器结构,并设计了一系列具有不同编码器的变体,如下图所示。

这一系列变体通过将多尺度特征交互分解为内尺度交互和跨尺度融合的两步操作,逐渐提高了模型精度,同时显著降低了计算成本(详细指标参见下表)。

我们首先移除DINO-R50中的多尺度变换器编码器作为基线A。接下来,插入不同形式的编码器,基于基线A生成一系列变体,具体如下:

- A → B:变体B插入了一个单尺度变换器编码器,它使用一个变换器块层。每个尺度的特征共享编码器进行内尺度特征交互,然后连接输出的多尺度特征。
- B → C:变体C基于B引入了跨尺度特征融合,并将连接的多尺度特征送入编码器进行特征交互。
- C → D:变体D将内尺度交互和跨尺度融合的多尺度特征解耦。首先使用单尺度变换器编码器进行内尺度交互,然后使用类似PANet的结构进行跨尺度融合。
- D → E:变体E在D的基础上进一步优化

了内尺度交互和跨尺度融合的多尺度特征,采用了我们设计的高效混合编码器(详见下文)。

混合设计。基于上述分析,我们重新思考了编码器的结构,并提出了一种新型的高效混合编码器,所提出的编码器由两个模块组成,即基于注意力的内尺度特征交互模块(AIFI)和基于CNN的跨尺度特征融合模块(CCFM)。AIFI基于变体D进一步减少了计算冗余,它只在S5上执行内尺度交互。我们认为,将自注意力操作应用于具有更丰富语义概念的高级特征,可以捕捉图像中概念实体之间的联系,这有助于后续模块检测和识别图像中的对象。同时,由于缺乏语义概念,低级特征的内尺度交互是不必要的,存在与高级特征交互重复和混淆的风险。为了验证这一观点,我们仅在变体D中对S5执行内尺度交互。CCFM也是基于变体D优化的,将由卷积层组成的几个融合块插入到融合路径中。融合块的作用是将相邻特征融合成新的特征,其结构如图4所示。融合块包含N个RepBlocks,两个路径的输出通过逐元素加法融合。我们可以将此过程表示如下:

Q = K = V = \text{Flatten}(S_5)

F_5 = \text{Reshape}(\text{Attn}(Q, K, V))

\text{Output} = \text{CCFM}(\{S_3, S_4, F_5\})

式中,Attn代表多头自注意力,Reshape代表将特征的形状恢复为与S5相同,这是Flatten的逆操作。

2.3、IoU感知查询选择

DETR中的对象查询是一组可学习的嵌入,由解码器优化并由预测头映射到分类分数和边界框。然而,这些对象查询难以解释和优化,因为它们没有明确的物理含义。后续工作改进了对象查询的初始化,并将其扩展到内容查询和位置查询(锚点)。其中,提出了查询选择方案,它们共同的特点是利用分类分数从编码器中选择排名靠前的K个特征来初始化对象查询(或仅位置查询)。然而,由于分类分数和位置置信度的分布不一致,一些预测框虽有高分类分数,但与真实框(GT)不接近,这导致选择了分类分数高但IoU分数低的框,而丢弃了分类分数低但IoU分数高的框。这降低了检测器的性能。为了解决这个问题,我们提出了IoU感知查询选择,通过在训练期间对模型施加约束,使其对IoU分数高的特征产生高分类分数,对IoU分数低的特征产生低分类分数。因此,模型根据分类分数选择的排名靠前的K个编码

器特征的预测框,既有高分类分数又有高IoU分数。我们重新制定了检测器的优化目标如下:

L(\hat{y}, y) = L_{box}(\hat{b}, b) + L_{cls}(\hat{c}, \hat{b}, y, b)
                = L_{box}(\hat{b}, b) + L_{cls}(\hat{c}, c, IoU)

 其中,\hat{y}y分别代表预测和真实值,hat{y} = \{\hat{c}, \hat{b}\} 和 y = \{c, b\}c 和b 分别代表类别和边界框。我们将IoU分数引入分类分支的目标函数中(类似于VFL),以实现对正样本分类和定位的一致性约束。

效果分析。为了分析所提出的IoU感知查询选择的有效性,我们可视化了在val2017数据集上,由查询选择选出的编码器特征的分类分数和IoU分数,如图6所示。具体来说,我们首先根据分类分数选择排名靠前的K(实验中K=300)个编码器特征,然后可视化分类分数大于0.5的散点图。红点和蓝点分别计算自应用传统查询选择和IoU感知查询选择的模型。点越接近图的右上方,相应特征的质量越高,即分类标签和边界框更有可能描述图像中的真实对象。根据可视化结果,我们发现最显著的特点是大量蓝点集中在图的右上方,而红点集中在右下方。这表明,经IoU感知查询选择训练的模型可以产生更多高质量的编码器特征。

此外,我们对两种类型点的分布特征进行了定量分析。图中蓝点比红点多138%,即更多的红点的分类分数小于或等于0.5,可以被认为是低质量特征。然后,我们分析了分类分数大于0.5的特征的IoU分数,发现有120%的蓝点比红点的IoU分数大于0.5。定量结果进一步证明,IoU感知查询选择可以为对象查询提供更多具有准确分类(高分类分数)和精确位置(高IoU分数)的编码器特征,从而提高检测器的准确度。

2.4、 可扩展的RT-DETR

为了提供可扩展的RT-DETR版本,我们用HGNetv2替换了ResNet主干网络。我们使用深度乘数和宽度乘数一起缩放主干网络和混合编码器。因此,我们得到了两个版本的RT-DETR,具有不同的参数数量和FPS。对于我们的混合编码器,我们通过调整CCFM中RepBlocks的数量和编码器的嵌入维度来控制深度乘数和宽度乘数。值得注意的是,我们提出的不同规模的RT-DETR保持了同质的解码器,这便于使用高精度大型DETR模型进行轻量化检测器的蒸馏。这将是一个可探索的未来方向。

三、RT-DERT的环境搭建

大家如果没有搭建环境可以看我的另一篇博客,里面讲述了如何搭建pytorch环境(内容十分详细我每次重新更换系统都要看一遍)。

Win11上Pytorch的安装并在Pycharm上调用PyTorch最新超详细过程并附详细的系统变量添加过程,可解决pycharm中pip不好使的问题

在我们配置好环境之后,在之后模型获取完成之后,我们可以进行配置的安装我们可以在命令行下输入如下命令进行环境的配置。

pip install -r requirements.txt

输入如上命令之后我们就可以看到命令行在安装模型所需的库了。 

四、免费数据集获取

在我们开始训练之前,我们需要一份数据集,如何获取一个COCO的数据集大家可以看我的另一篇博客从YOLO官方指定的数据集网站Roboflow下载数据模型训练。

(这里需要注意的是RT-DETR官方的版本Pytorch下的只支持COCO数据集训练目前,所以你没有COCO版本的数据集可以用以下教程下载一个非常的快)

超详细教程YoloV8官方推荐免费数据集网站Roboflow一键导出Voc、COCO、Yolo、Csv等格式

我在上面随便下载了一个 数据集用它导出yolov8的数据集,以及自动给转换成txt的格式yaml文件也已经配置好了,我们直接用就可以。 

8673527d34eb42348770158c69de678f.png

五、获取RT-DERT

RT-DERT有两个获取的方式一个就是官方的版本(但是官方版本不如YOLO系列这么发展的成熟开源的版本里面还有许多存在的bug和功能存在限制),另一个版本就是集成在YOLOv8最新的ultralytics库里面的RT-DERT(我看网上的教程和改进基本上都是基于这个方式)

两个方式我都会讲具体选哪一个根据你个人的实际情况来定。 

5.1 渠道一官方版本

这里提供官方版本的RT-DETR下载的渠道,另一种方式就是通过Git进行克隆具体怎么选择看大家了。

方式一:

官方的RT-DERT下载地址:官方下载地址

方式二:  

通过Git的方式进行下载->

git clone https://github.com/lyuwenyu/RT-DETR.git

5.2 渠道二YOLOv8版本

这个版本的下载方式就是下载YOLOv8方式如下->

方式一:

官方YOLOv8下载地址:官方下载地址

方式二: 

(这里需要注意Git完或者下载完压缩包之后一定要进行用pip install -e . 否则会报识别不了YOLO的错误)

git clone https://github.com/ultralytics/ultralytics
cd ultralytics
pip install -e .

六、RT-DERT的官方版本训练方法

官方版本的训练方式又分两种,下面图片中使用两种不同的框架下实现的RT-DETR,我会拿PyTorch版本来进行举例另一种操作相同。

(前面提到过RT-DETR的Pytorch版本只支持COCO的数据集训练所以开始之前大家需要有一个COCO的数据集)

6.1 步骤一 

我们找到如下文件“rtdetr_pytorch/configs/dataset/coco_detection.yml”,内容如下->

6.2 步骤二 

我们找到下面的这个文件"RT-DETR-main/rtdetr_pytorch/tools/train.py" 

"RT-DETR-main/rtdetr_pytorch/tools/train.py" 文件得末尾如下图的右面所示我们找到左边的配置文件填写到右边 的config参数下。

之后我们运行整个文件即可开始训练 

PS:首次训练需要下载权重

训练过程如下-> 

 训练结果会保存在以下的文件的地址(后面我们导出需要这个文件)->

PS:这里只讲了Pytorch版本的训练方式, Paddle版本的训练方式一样就不重复了。

七、RT-DERT的YOLO版本训练方法 

这里讲的是YOLO版本的RT-DETR训练方法,一种方法是通过命令行另一种方法是通过创建文件来训练,共有参数如下->

RT-DETR的训练可以采用命令行的模型,下面是RT-DETR集成在ultralytics官方给定的训练/预测/验证/导出方式: 

yolo task=detect    mode=train    model=RT-DETR的权重或者yaml文件     args...
          classify       predict                                     args...
          segment        val                                         args...
                         export                                      format=onnx  args...

7.1 方式一

我们可以通过命令直接进行训练在其中指定参数,但是这样的方式,我们每个参数都要在其中打出来。命令如下:

python">yolo task=detect mode=train model=ResNet18_vd_pretrained_from_paddle.pth data=data.yaml batch=16 epochs=100 imgsz=640 workers=0 device=0

需要注意的是如果你是Windows系统的电脑其中的Workers最好设置成0否则容易报线程的错误。

7.2 方式二(推荐)

通过指定cfg直接进行训练,我们配置好ultralytics/cfg/default.yaml这个文件之后,可以直接执行这个文件进行训练,这样就不用在命令行输入其它的参数了。

python">yolo cfg=ultralytics/cfg/default.yaml

7.3 方式三 

 我们可以通过创建py文件来进行训练,这样的好处就是不用在终端上打命令,这也能省去一些工作量,我们在根目录下创建一个名字为run.py的文件,在其中输入代码

python">from ultralytics import RTDETR

# Load a model
model = RTDETR("ultralytics/cfg/models/rt-detr/rtdetr-l.yaml")  # build a new model from scratch

# Use the model
model.train(data="fire.v1i.yolov8/data.yaml", cfg="ultralytics/cfg/default.yaml", epochs=100)  # train the model

训练截图如下->

八、RT-DERT的验证方法

8.1 参数讲解 

验证的参数如下->

参数名类型参数讲解
1valbool用于控制是否在训练过程中进行验证/测试。
2splitstr用于指定用于验证/测试的数据集划分。可以选择 'val'、'test' 或 'train' 中的一个作为验证/测试数据集
3save_jsonbool用于控制是否将结果保存为 JSON 文件
4save_hybirdbool用于控制是否保存标签和附加预测结果的混合版本
5conffloat/optional用于设置检测时的目标置信度阈值
6ioufloat用于设置非极大值抑制(NMS)的交并比(IoU)阈值。
7max_detint用于设置每张图像的最大检测数。
8halfbool用于控制是否使用半精度(FP16)进行推断。
9dnnbool,用于控制是否使用 OpenCV DNN 进行 ONNX 推断。
10plotsbool用于控制在训练/验证过程中是否保存绘图结果。

8.2 验证方法 

 验证我们划分的验证集/测试集的情况,也就是评估我们训练出来的best.pt模型好与坏

python">yolo task=detect mode=val model=best.pt data=data.yaml device=0

九、RT-DERT的推理方法

我们训练好自己的模型之后,都会生成一个模型文件,保存在你设置的目录下,当我们再次想要实验该模型的效果之后就可以调用该模型进行推理了,我们也可以用官方的预训练权重来进行推理。

推理的方式和训练一样我们这里就选一种来进行举例其它的两种方式都是一样的操作只是需要改一下其中的一些参数即可:

9.1 参数讲解

参数名类型参数讲解
0sourcestr/optinal用于指定图像或视频的目录
1showbool用于控制是否在可能的情况下显示结果
2save_txtbool用于控制是否将结果保存为 .txt 文件
3save_confbool用于控制是否在保存结果时包含置信度分数
4save_cropbool用于控制是否将带有结果的裁剪图像保存下来
5show_labelsbool用于控制在绘图结果中是否显示目标标签
6show_confbool用于控制在绘图结果中是否显示目标置信度分数
7vid_strideint/optional用于设置视频的帧率步长
8stream_bufferbool用于控制是否缓冲所有流式帧(True)或返回最新的帧(False)
9line_widthint/list[int]/optional用于设置边界框的线宽度,如果缺失则自动设置
10visualizebool用于控制是否可视化模型的特征
11augmentbool用于控制是否对预测源应用图像增强
12agnostic_nmsbool用于控制是否使用无关类别的非极大值抑制(NMS)
13classesint/list[int]/optional用于按类别筛选结果
14retina_masksbool用于控制是否使用高分辨率分割掩码
15boxesbool用于控制是否在分割预测中显示边界框。

9.2 推理方法 

python">yolo task=detect mode=predict model=best.pt source=images device=0

 这里需要需要注意的是我们用模型进行推理的时候可以选择照片也可以选择一个视频的格式都可以。支持的视频格式有 

  • MP4(.mp4):这是一种常见的视频文件格式,通常具有较高的压缩率和良好的视频质量

  • AVI(.avi):这是一种较旧但仍广泛使用的视频文件格式。它通常具有较大的文件大小

  • MOV(.mov):这是一种常见的视频文件格式,通常与苹果设备和QuickTime播放器相关

  • MKV(.mkv):这是一种开放的多媒体容器格式,可以容纳多个视频、音频和字幕轨道

  • FLV(.flv):这是一种用于在线视频传输的流式视频文件格式

十、RT-DERT的导出方法

10.1 官方版本导出方法

官方版本目前只支持导出ONNX格式, 方法如下->

我们找到如下文件“RT-DETR-main/rtdetr_pytorch/tools/export_onnx.py”。

我们将我们训练完的权重文件输入到resume里,然后运行整个文件就会在同级目录下导出onnx文件如下->

10.2 YOLO版本导出方法

 当我们进行部署的时候可以进行文件导出,然后在进行部署。

YOLOv8支持的输出格式有如下

1. ONNX(Open Neural Network Exchange):ONNX 是一个开放的深度学习模型表示和转换的标准。它允许在不同的深度学习框架之间共享模型,并支持跨平台部署。导出为 ONNX 格式的模型可以在支持 ONNX 的推理引擎中进行部署和推理。

2. TensorFlow SavedModel:TensorFlow SavedModel 是 TensorFlow 框架的标准模型保存格式。它包含了模型的网络结构和参数,可以方便地在 TensorFlow 的推理环境中加载和使用。

3. PyTorch JIT(Just-In-Time):PyTorch JIT 是 PyTorch 的即时编译器,可以将 PyTorch 模型导出为优化的 Torch 脚本或 Torch 脚本模型。这种格式可以在没有 PyTorch 环境的情况下进行推理,并且具有更高的性能。

4. Caffe Model:Caffe 是一个流行的深度学习框架,它使用自己的模型表示格式。导出为 Caffe 模型的文件可以在 Caffe 框架中进行部署和推理。

5. TFLite(TensorFlow Lite):TFLite 是 TensorFlow 的移动和嵌入式设备推理框架,支持在资源受限的设备上进行高效推理。模型可以导出为 TFLite 格式,以便在移动设备或嵌入式系统中进行部署。

6. Core ML(Core Machine Learning):Core ML 是苹果的机器学习框架,用于在 iOS 和 macOS 上进行推理。模型可以导出为 Core ML 格式,以便在苹果设备上进行部署。

这些格式都提供了不同的优势和适用场景。选择合适的导出格式应该考虑到目标平台和部署环境的要求,以及所使用的深度学习框架的支持情况。

模型输出的参数有如下

参数名类型参数解释
0formatstr导出模型的格式
1kerasbool表示是否使用Keras
2optimizebool用于在导出TorchScript模型时进行优化,以便在移动设备上获得更好的性能
3int8bool用于在导出CoreML或TensorFlow模型时进行INT8量化
4dynamicbool用于在导出CoreML或TensorFlow模型时进行INT8量化
5simplifybool用于在导出ONNX模型时进行模型简化
6opsetint/optional用于指定导出ONNX模型时的opset版本
7workspaceint用于指定TensorRT模型的工作空间大小,以GB为单位
8nmsbool用于在导出CoreML模型时添加非极大值抑制(NMS)

命令行命令如下: 

yolo task=detect mode=export model=best.pt format=onnx  

十一、全文总结

到此本文的正式分享内容就结束了,在这里给大家推荐我的RT-DETR改进有效涨点专栏,本专栏目前为新开的平均质量分98分,后期我会根据各种最新的前沿顶会进行论文复现,也会对一些老的改进机制进行补充,目前本专栏免费阅读(暂时,大家尽早关注不迷路~),如果大家觉得本文帮助到你了,订阅本专栏,关注后续更多的更新~

专栏回顾:RT-DETR改进专栏——论文收割机——持续复现各种顶会改进机制


http://www.niftyadmin.cn/n/5217929.html

相关文章

C++前缀和算法:统计美丽子字符串

题目 给你一个字符串 s 和一个正整数 k 。 用 vowels 和 consonants 分别表示字符串中元音字母和辅音字母的数量。 如果某个字符串满足以下条件,则称其为 美丽字符串 : vowels consonants,即元音字母和辅音字母的数量相等。 (vowels * cons…

6.3 Windows驱动开发:内核枚举IoTimer定时器

内核I/O定时器(Kernel I/O Timer)是Windows内核中的一个对象,它允许内核或驱动程序设置一个定时器,以便在指定的时间间隔内调用一个回调函数。通常,内核I/O定时器用于周期性地执行某个任务,例如检查驱动程序…

[C/C++]数据结构 堆的详解

一:概念 堆通常是一个可以被看做一棵完全二叉树的数组对象,它是一颗完全二叉树,堆存储的所有元素按完全二叉树的顺序存储方式存储在一个一维数组中,并且需要满足每个父亲结点总小于其子节点(或者每个父亲结点总大于其子节点) 堆可以分为两种: 小堆: 任意一个父亲节点都小于其子…

探索接口测试:SOAP、RestFul规则、JMeter及市面上的接口测试工具

引言 在当今软件开发领域,接口测试扮演着至关重要的角色。随着系统变得日益复杂和互联,对于内部和外部接口的测试变得愈发关键。接口测试不仅仅是验证接口的正确性,更是确保系统的稳定性、安全性和性能优越性的关键一环。 本篇博客将带您深入…

【docker系列】docker高阶篇

💝💝💝欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

网络爬虫(Python:Selenium、Scrapy框架;爬虫与反爬虫笔记)

网络爬虫(Python:Selenium、Scrapy框架;爬虫与反爬虫笔记) SeleniumWebDriver 对象提供的相关方法定位元素ActionChains的基本使用selenium显示等待和隐式等待显示等待隐式等待 Scrapy(异步网络爬虫框架)Sc…

智能优化算法应用:基于蚁狮算法无线传感器网络(WSN)覆盖优化 - 附代码

智能优化算法应用:基于蚁狮算法无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用:基于蚁狮算法无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.蚁狮算法4.实验参数设定5.算法结果6.参考文献7.MATLAB…

Java 基础学习(一)Java环境搭建和基本数据类型

1 Java 开发环境搭建 1.1 Java 编程语言 1.1.1 什么是Java编程语言 语言是人类进行沟通交流的各种表达符号,方便人与人之间进行沟通与信息交换;而计算机编程语言则是人与计算机之间进行信息交流沟通的一种特殊语言,也有语法规则、字符、符…