01-PyTorch加载数据初认识(dataset运用)
一、先看整体结构
这是一个标准的 PyTorch 自定义数据集模板,核心分为 3 个部分:
- 类定义 +
__init__:初始化路径和数据列表 __getitem__:按索引读取单张图片和标签__len__:返回数据集总长度
二、逐行代码讲解
1. 导入依赖
python运行
from torch.utils.data import Dataset from PIL import Image import osDataset:PyTorch 提供的抽象基类,所有自定义数据集都要继承它,这样才能被DataLoader识别;Image:来自 PIL 库,用来读取、处理图片;os:用来拼接文件路径、读取目录下的文件名,处理本地文件系统。
2. 类定义与初始化方法__init__
python运行
class MyData(Dataset): def __init__(self, root_dir, label_dir): self.root_dir = root_dir self.label_dir = label_dir self.path = os.path.join(self.root_dir, self.label_dir) self.img_path = os.listdir(self.path)class MyData(Dataset):定义一个新的类MyData,继承自Dataset;def __init__(self, root_dir, label_dir):类的构造函数,创建数据集对象时会自动执行,接收两个参数:root_dir:数据集的根目录,比如dataset/train;label_dir:类别目录,比如ants(代表蚂蚁的图片文件夹);
self.root_dir = root_dir:把根目录保存到实例变量中,后续可以在类的其他方法里调用;self.label_dir = label_dir:把类别目录保存到实例变量中;self.path = os.path.join(self.root_dir, self.label_dir):拼接根目录和类别目录,得到完整的图片文件夹路径,比如dataset/train/ants;self.img_path = os.listdir(self.path):读取dataset/train/ants目录下的所有文件名,存入self.img_path列表,后续可以按索引读取。
3. 核心方法__getitem__
python
运行
def __getitem__(self, idx): img_name = self.img_path[idx] img_item_path = os.path.join(self.root_dir, self.label_dir, img_name) img = Image.open(img_item_path) label = self.label_dir return img, labeldef __getitem__(self, idx):PyTorch 规定的方法,按索引读取数据,idx就是索引(从 0 开始);img_name = self.img_path[idx]:根据索引idx,从self.img_path列表中取出对应的图片文件名;img_item_path = os.path.join(self.root_dir, self.label_dir, img_name):拼接根目录、类别目录和图片文件名,得到单张图片的完整路径,比如dataset/train/ants/001.jpg;img = Image.open(img_item_path):用 PIL 读取图片,得到一个 Image 对象;label = self.label_dir:把类别目录名(比如ants)作为标签;return img, label:返回图片和对应的标签,后续模型训练时会接收这两个值。
4. 长度方法__len__
python
运行
def __len__(self): return len(self.img_path)def __len__(self):PyTorch 规定的方法,返回数据集的总样本数;return len(self.img_path):self.img_path是图片文件名列表,len(self.img_path)就是图片总数,比如dataset/train/ants目录下有 124 张图片,就返回 124。
三、代码执行流程(结合你的控制台)
python
运行
root_dir = "dataset/train" ants_label_dir = "ants" ants_dataset = MyData(root_dir, ants_label_dir)- 创建
MyData对象,传入根目录和类别目录; - 自动执行
__init__:拼接路径、读取图片列表; - 当你调用
len(ants_dataset)时,会执行__len__,返回图片总数; - 当你调用
ants_dataset[0]时,会执行__getitem__(0),返回第 1 张图片和标签。
四、补充说明与小优化
- 标签处理:这段代码里直接用
label = self.label_dir,后续训练时,模型需要的是数字标签,比如ants=0、bees=1,可以改成:python
运行
# 比如 ants 标签设为 0 label = 0 - 路径拼接:
os.path.join是跨平台的,Windows、Linux 都能正常拼接路径,避免手动写/或\出错; - 遥感影像适配:如果你后续要处理
.tif格式的遥感影像,把Image.open换成rasterio.open即可,核心逻辑不变。
