
Swin Transformer实战从权重加载警告到95%花卉分类精度的优化之路1. 理解_IncompatibleKeys警告的本质当你在花卉分类任务中加载预训练的Swin Transformer模型时控制台出现_IncompatibleKeys(missing_keys[head.weight, head.bias], unexpected_keys[layers.0.blocks.1.attn_mask...])这样的警告这实际上是PyTorch在告诉你模型权重加载时的关键信息。这个警告包含两部分重要信息missing_keys模型期望但预训练权重中没有提供的参数unexpected_keys预训练权重中存在但当前模型不需要的参数在Swin Transformer的案例中missing_keys通常指向分类头(head)的权重而unexpected_keys则多与注意力掩码(attn_mask)相关。这种现象的产生主要有三个原因分类头维度不匹配预训练模型通常是在ImageNet(1000类)上训练的而你的花卉分类可能只有5个类别模型结构微调你可能修改了原始Swin的结构比如移除了某些层版本差异使用的预训练权重与模型代码版本不完全兼容# 典型权重加载代码示例 model create_model(num_classes5) # 你的花卉类别数 weights_dict torch.load(pretrained_weights)[model] # 删除分类头权重避免不匹配 for k in list(weights_dict.keys()): if head in k: del weights_dict[k] model.load_state_dict(weights_dict, strictFalse) # strictFalse允许部分加载提示strictFalse参数是关键它允许模型只加载匹配的权重忽略不匹配的部分。这在迁移学习中是非常常见的做法。2. 权重加载问题的系统解决方案2.1 分类头不匹配的专业处理当遇到分类头维度不匹配时有几种专业处理方案方案对比表方案适用场景优点缺点完全替换分类头新任务类别数与预训练差异大完全适配新任务需要从头训练分类头部分权重初始化新类别数大于预训练利用部分预训练知识实现较复杂特征提取器模式小数据集避免破坏预训练特征无法进行端到端优化对于花卉分类这样的任务推荐采用以下最佳实践def init_class_head(model, pretrained_weights, num_classes): 智能初始化分类头 # 加载预训练权重(不含head) weights_dict torch.load(pretrained_weights)[model] pretrained_num_classes weights_dict[head.weight].shape[0] if num_classes pretrained_num_classes: # 类别数相同直接使用预训练head model.load_state_dict(weights_dict) else: # 类别数不同初始化新head del_keys [k for k in weights_dict.keys() if head in k] for k in del_keys: del weights_dict[k] model.load_state_dict(weights_dict, strictFalse) # 智能初始化新head if hasattr(model, head): if isinstance(model.head, nn.Linear): nn.init.trunc_normal_(model.head.weight, std0.02) if model.head.bias is not None: nn.init.zeros_(model.head.bias) return model2.2 注意力掩码问题的深入解析Swin Transformer中的attn_mask是实现移位窗口(shifted window)注意力的关键。预训练权重中保存的这些掩码可能与当前模型的输入尺寸相关导致出现unexpected_keys警告。解决方案确认模型输入尺寸与预训练时一致(通常是224x224)如果必须改变输入尺寸需要重新计算这些掩码def recompute_attn_masks(model, img_size224): 重新计算适应新尺寸的注意力掩码 for layer in model.layers: if hasattr(layer, blocks): for block in layer.blocks: if hasattr(block, attn_mask): # 删除旧的掩码 del block.attn_mask # 重新计算新尺寸下的掩码 H, W img_size // (2 ** i), img_size // (2 ** i) block.attn_mask compute_mask(H, W, block.window_size) return model3. 从93%到95%的精度提升策略3.1 数据增强的进阶技巧基础的数据增强如RandomResizedCrop和RandomHorizontalFlip已经能带来不错的效果但要突破精度瓶颈需要更精细的策略高级增强组合from torchvision import transforms train_transform transforms.Compose([ transforms.RandomResizedCrop(224, scale(0.8, 1.0)), transforms.RandomApply([ transforms.ColorJitter(brightness0.4, contrast0.4, saturation0.2, hue0.1) ], p0.8), transforms.RandomGrayscale(p0.2), transforms.RandomApply([transforms.GaussianBlur(kernel_size(5,5))], p0.5), transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(p0.2), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), transforms.RandomErasing(p0.5, scale(0.02, 0.1), ratio(0.3, 3.3)) ])注意增强强度需要根据数据集大小调整。小数据集(1万样本)适合更强增强大数据集则应减弱增强强度。3.2 模型微调的黄金法则分层学习率策略 Swin Transformer的不同层应该使用不同的学习率。通常浅层小学习率(保持预训练特征)深层中等学习率分类头最大学习率# 分层设置学习率 param_groups [ {params: model.patch_embed.parameters(), lr: base_lr * 0.1}, {params: model.layers[0].parameters(), lr: base_lr * 0.5}, {params: model.layers[1].parameters(), lr: base_lr * 0.7}, {params: model.layers[2].parameters(), lr: base_lr}, {params: model.layers[3].parameters(), lr: base_lr}, {params: model.head.parameters(), lr: base_lr * 2} ] optimizer optim.AdamW(param_groups, weight_decay0.05)优化器选择对比优化器最佳学习率适合场景训练时间AdamW1e-4到5e-4大多数情况中等SGDmomentum1e-2到1e-1大数据集长LAMB1e-3到5e-3超大模型短3.3 损失函数的艺术除了标准的CrossEntropyLoss可以尝试Label Smoothing缓解过拟合criterion nn.CrossEntropyLoss(label_smoothing0.1)Focal Loss处理类别不平衡class FocalLoss(nn.Module): def __init__(self, alpha1, gamma2): super().__init__() self.alpha alpha self.gamma gamma def forward(self, inputs, targets): ce_loss F.cross_entropy(inputs, targets, reductionnone) pt torch.exp(-ce_loss) loss self.alpha * (1-pt)**self.gamma * ce_loss return loss.mean()混合损失组合多个损失函数def hybrid_loss(preds, targets, alpha0.5): ce F.cross_entropy(preds, targets) focal FocalLoss()(preds, targets) return alpha * ce (1 - alpha) * focal4. 超参数优化的系统方法4.1 学习率与batch size的协同优化学习率预热(warmup)策略from torch.optim.lr_scheduler import LambdaLR def get_warmup_scheduler(optimizer, warmup_steps, total_steps): def lr_lambda(current_step): if current_step warmup_steps: return float(current_step) / float(max(1, warmup_steps)) return max( 0.0, float(total_steps - current_step) / float(max(1, total_steps - warmup_steps)) ) return LambdaLR(optimizer, lr_lambda)batch size与学习率的关系 当增大batch size时学习率也应相应调整。经验公式new_lr base_lr * (new_bs / base_bs)^0.5其中base_bs是原始batch size(如32)base_lr是对应的学习率(如1e-4)。4.2 正则化技术组合拳有效的正则化组合DropPath (Stochastic Depth)# 在Swin TransformerBlock中添加 self.drop_path DropPath(drop_path_rate) if drop_path_rate 0 else nn.Identity()Weight Decay AdamW优化器默认包含weight decay推荐值0.05MixUp数据增强def mixup_data(x, y, alpha0.4): if alpha 0: lam np.random.beta(alpha, alpha) else: lam 1 batch_size x.size()[0] index torch.randperm(batch_size).to(x.device) mixed_x lam * x (1 - lam) * x[index] y_a, y_b y, y[index] return mixed_x, y_a, y_b, lam4.3 训练监控与早停策略实现智能早停需要监控验证集指标class EarlyStopper: def __init__(self, patience3, min_delta0): self.patience patience self.min_delta min_delta self.counter 0 self.min_validation_loss float(inf) def early_stop(self, validation_loss): if validation_loss self.min_validation_loss: self.min_validation_loss validation_loss self.counter 0 elif validation_loss (self.min_validation_loss self.min_delta): self.counter 1 if self.counter self.patience: return True return False使用示例early_stopper EarlyStopper(patience5, min_delta0.01) for epoch in range(epochs): train_loss train_one_epoch() val_loss validate() if early_stopper.early_stop(val_loss): break5. 模型部署与性能优化5.1 模型量化加速将FP32模型量化为INT8可以显著提升推理速度# 动态量化 quantized_model torch.quantization.quantize_dynamic( model, {torch.nn.Linear}, dtypetorch.qint8 ) # 静态量化(需要校准数据) model.eval() model.qconfig torch.quantization.get_default_qconfig(fbgemm) quantized_model torch.quantization.prepare(model, inplaceFalse) # 用校准数据运行几次 quantized_model torch.quantization.convert(quantized_model)量化前后性能对比指标FP32模型INT8模型提升幅度模型大小107MB27MB75%减小推理时间45ms12ms3.75倍准确率95.2%94.8%-0.4%5.2 TorchScript优化将模型转换为TorchScript可以提高部署效率# 追踪模式 traced_model torch.jit.trace(model, example_input) # 脚本模式(适合有控制流的模型) scripted_model torch.jit.script(model) # 保存优化后模型 torch.jit.save(traced_model, swin_transformer_optimized.pt)5.3 ONNX导出与跨平台部署torch.onnx.export( model, dummy_input, swin_transformer.onnx, input_names[input], output_names[output], dynamic_axes{ input: {0: batch_size}, output: {0: batch_size} }, opset_version13 )提示导出ONNX前确保模型在eval()模式并处理好所有动态尺寸问题。可以使用Netron工具可视化检查导出的模型结构。