当前位置: 首页 > news >正文

别再手动提特征了!用Python+PyTorch搭建你的第一个智能故障诊断模型(以轴承振动数据为例)

用Python+PyTorch实现轴承振动数据的智能故障诊断:从零搭建CNN模型

轴承作为旋转机械的核心部件,其健康状态直接影响设备寿命。传统振动分析依赖工程师手动提取时频域特征,既耗时又难以捕捉复杂故障模式。我们以凯斯西储大学轴承数据集为例,用PyTorch构建端到端的智能诊断系统,实现振动信号到故障类别的自动映射。

1. 环境准备与数据加载

工欲善其事,必先利其器。建议使用Python 3.8+环境,通过conda创建虚拟环境避免依赖冲突:

conda create -n fault_diagnosis python=3.8 conda activate fault_diagnosis pip install torch torchvision pandas scikit-learn matplotlib

凯斯西储大学轴承数据集包含正常状态和多种故障类型(内圈、外圈、滚动体故障),采样频率12kHz。我们使用pandas加载预处理后的CSV版本:

import pandas as pd from sklearn.model_selection import train_test_split data = pd.read_csv('bearing_vibration.csv') X = data.iloc[:, :-1].values # 振动信号序列 y = data.iloc[:, -1].values # 故障标签 # 划分训练集和测试集 X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.2, stratify=y, random_state=42 )

提示:工业现场数据往往存在类别不平衡问题,可采用过采样或加权损失函数处理

2. 数据预处理与增强策略

原始振动信号需要转化为适合CNN输入的格式。我们采用滑动窗口技术生成样本片段:

import numpy as np def create_sequences(data, labels, window_size=1024, step=512): sequences = [] sequence_labels = [] for i in range(0, len(data) - window_size, step): seq = data[i:i+window_size] sequences.append(seq) sequence_labels.append(labels[i + window_size // 2]) return np.array(sequences), np.array(sequence_labels) X_train_seq, y_train_seq = create_sequences(X_train, y_train) X_test_seq, y_test_seq = create_sequences(X_test, y_test)

为提升模型泛化能力,引入三种数据增强技术:

  • 随机缩放:振幅±10%波动
  • 随机偏移:添加±0.1倍标准差噪声
  • 随机裁剪:从1024点中随机选取896点
class VibrationAugment: def __call__(self, sample): if np.random.rand() > 0.5: sample *= 1 + np.random.uniform(-0.1, 0.1) if np.random.rand() > 0.5: sample += np.random.normal(0, 0.1*sample.std()) if np.random.rand() > 0.5: start = np.random.randint(0, 128) sample = sample[start:start+896] sample = np.pad(sample, (0, 1024-896)) return sample

3. 构建1D-CNN诊断模型

针对振动信号特性,设计包含并行卷积支路的网络结构:

import torch import torch.nn as nn import torch.nn.functional as F class FaultDiagnosisCNN(nn.Module): def __init__(self, num_classes): super().__init__() # 高频特征支路 self.branch1 = nn.Sequential( nn.Conv1d(1, 16, kernel_size=64, stride=8, padding=28), nn.BatchNorm1d(16), nn.ReLU(), nn.MaxPool1d(2) ) # 低频特征支路 self.branch2 = nn.Sequential( nn.Conv1d(1, 16, kernel_size=256, stride=32, padding=112), nn.BatchNorm1d(16), nn.ReLU(), nn.MaxPool1d(2) ) # 特征融合 self.fc = nn.Sequential( nn.Linear(16*2*31, 64), nn.ReLU(), nn.Dropout(0.5), nn.Linear(64, num_classes) ) def forward(self, x): x = x.unsqueeze(1) # [batch, 1, seq_len] h1 = self.branch1(x) h2 = self.branch2(x) h = torch.cat([h1, h2], dim=1) return self.fc(h.view(h.size(0), -1))

模型设计关键点:

  • 双支路结构:分别捕捉不同时间尺度的故障特征
  • 宽卷积核:适应振动信号的周期性特点
  • BatchNorm:加速训练并提升稳定性
  • Dropout:防止过拟合

4. 模型训练与性能优化

采用带热启动的余弦退火学习率策略,配合标签平滑技术:

from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts model = FaultDiagnosisCNN(num_classes=len(np.unique(y))) criterion = nn.CrossEntropyLoss(label_smoothing=0.1) optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2) for epoch in range(100): model.train() for batch_x, batch_y in train_loader: optimizer.zero_grad() outputs = model(batch_x) loss = criterion(outputs, batch_y) loss.backward() nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() scheduler.step()

评估指标除准确率外,还应关注:

  • 混淆矩阵:识别易混淆故障类型
  • F1-score:处理类别不平衡
  • ROC-AUC:综合评估分类性能
from sklearn.metrics import classification_report model.eval() with torch.no_grad(): y_pred = model(X_test_tensor).argmax(dim=1) print(classification_report(y_test, y_pred.cpu()))

5. 模型解释与工业部署

使用Grad-CAM可视化关键故障特征区域:

class GradCAM: def __init__(self, model, target_layer): self.model = model self.target_layer = target_layer self.gradients = None def save_gradient(self, grad): self.gradients = grad def __call__(self, x): feature_maps = None def hook_fn(module, input, output): nonlocal feature_maps feature_maps = output output.register_hook(self.save_gradient) handle = self.target_layer.register_forward_hook(hook_fn) output = self.model(x) handle.remove() one_hot = torch.zeros_like(output) one_hot[0, output.argmax()] = 1 self.model.zero_grad() output.backward(gradient=one_hot, retain_graph=True) weights = self.gradients.mean(dim=2, keepdim=True) cam = (weights * feature_maps).sum(dim=1, keepdim=True) cam = F.relu(cam) cam = F.interpolate(cam, size=x.shape[2], mode='linear') return cam.squeeze().cpu().numpy()

工业部署建议:

  • ONNX导出:实现跨平台部署
  • TensorRT优化:提升推理速度
  • 边缘设备适配:使用LibTorch在嵌入式系统运行
dummy_input = torch.randn(1, 1, 1024) torch.onnx.export(model, dummy_input, "fault_diagnosis.onnx", input_names=["vibration"], output_names=["fault_class"])

6. 持续改进方向

在实际项目中,我们发现以下优化手段效果显著:

  • 混合精度训练:减少显存占用
  • 知识蒸馏:压缩模型尺寸
  • 多传感器融合:结合温度、电流信号
  • 半监督学习:利用未标注数据
from torch.cuda.amp import autocast, GradScaler scaler = GradScaler() with autocast(): outputs = model(batch_x) loss = criterion(outputs, batch_y) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()
http://www.cnnetsun.cn/news/2694635.html

相关文章:

  • 告别重复劳动:用CodeFuse插件5分钟搞定Java/Python单元测试生成(附避坑指南)
  • 现在不看就晚了:Sora 2.4即将废弃的录制协议v1.7——30天倒计时内必须迁移的5个接口、2个事件钩子与1套兼容性验证清单
  • Windows上安装APK的终极方案:告别模拟器,体验原生安卓应用
  • 编写个人家庭应急物资管理系统,分类统计保质期,储备量,适配家庭突发应急场景。
  • 开发小区垃圾分类智能指引程序,识别垃圾品类,精准引导分类投放,贴合社区治理。
  • 超越振动信号:用IMS轴承数据集玩转5种故障预测模型(附PyTorch/Sklearn代码)
  • 自制2.4GHz全波偶极天线:原理、制作与WiFi信号增强实战
  • Unity Addressables热更实战:从本地模拟到远程服务器部署的保姆级流程(含Hosting服务)
  • 戴尔新款 XPS 13 7 月上市,低价对标 MacBook Neo,轻薄优势下能否突围?
  • Sora 2背景音乐自动裁剪失效?揭秘底层时间码映射机制:如何用Python脚本动态生成合规.wav头文件
  • 测试文章123
  • PyMobileDevice3终极指南:Python控制iOS设备的完整实战教程
  • 如何在Windows上快速安装安卓应用:APK-Installer完整实战指南
  • 霞鹜文楷:终极免费开源中文字体解决方案,轻松解决你的中文排版难题
  • Fibronectin CS-1 Fragment (1978-1985) ;EILDVPST
  • 告别混乱开发:用平头哥CDK的组件池功能管理你的多芯片项目
  • 2026实测:AI生成UI设计稿后,如何优雅集成到PageAdmin CMS?(附标签替换代码)
  • 阴阳师自动化脚本OnmyojiAutoScript:3分钟快速上手,彻底解放双手!
  • 解密Godot游戏资源:专业PCK文件提取工具深度解析
  • 人工处理数据的代价你算过吗?2026企业避坑指南:从Token黑洞到智能体进化
  • 别再为libcurl编译发愁了!Windows/Linux双平台保姆级编译指南(含OpenSSL依赖处理)
  • 基于ESP8266与WS2812B的便携式RGB补光灯DIY全流程解析
  • 如何彻底告别游戏鼠标消失问题:YoloMouse完整使用指南
  • 新手司机福音:低速出库时,FCTA/FCTB如何帮你避免“鬼探头”事故?
  • 机器学习高效学习路径:从基础到实战的完整框架与心法
  • SBTI刷屏引热议:在哪测才靠谱
  • Ansaldo P681T 信号调理板
  • 如何在电脑上免费畅玩任天堂Switch游戏?yuzu模拟器完整指南
  • 别再到处找教程了!5分钟搞定Python调用ChatGPT API的完整流程(附代码)
  • 基于ESP32的硬件加密保险箱:低成本实现超级加密与HMAC完整性验证