发布时间:2023-04-19 文章分类:电脑百科 投稿人:樱花 字号: 默认 | | 超大 打印

Optimizer

optimizer.param_groups用法的示例分析

日期:2022年7月25日

pytorch版本: 1.11.0

对于param_groups的探索

optimizer.param_groups: 是一个list,其中的元素为字典;

optimizer.param_groups[0]:长度为7的字典,包括[‘params’, ‘lr’, ‘betas’, ‘eps’, ‘weight_decay’, ‘amsgrad’, ‘maximize’]这7个参数;

下面用的Adam优化器创建了一个optimizer变量:

>>> optimizer.param_groups[0].keys()
>>> dict_keys(['params', 'lr', 'betas', 'eps', 'weight_decay', 'amsgrad', 'maximize'])

可以自己把训练参数分别赋予不同的学习率,这样子list里就不止一个元素了,而是多个字典了。

以网上的例子来继续试验:

import torch
import torch.optim as optim
w1 = torch.randn(3, 3)
w1.requires_grad = True
w2 = torch.randn(3, 3)
w2.requires_grad = True
o = optim.Adam([w1])
print(o.param_groups)
# 输出
>>> 
[{'params': [tensor([[-0.1002,  0.3526, -1.2212],
        			 [-0.4659,  0.0498, -0.2905],
        			 [ 1.1862, -0.6085,  0.4965]], requires_grad=True)],
  'lr': 0.001, 
  'betas': (0.9, 0.999),
  'eps': 1e-08,
  'weight_decay': 0,
  'amsgrad': False,
  'maximize': False}]

以下主要是Optimizer这个类有个add_param_group的方法

# Per the docs, the add_param_group method accepts a param_group parameter that is a dict. Example of use:
import torch
import torch.optim as optim
w1 = torch.randn(3, 3)
w1.requires_grad = True
w2 = torch.randn(3, 3)
w2.requires_grad = True
o = optim.Adam([w1])
print(o.param_groups)
# 输出
>>> [{'params': [tensor([[-1.5916, -1.6110, -0.5739],
        [ 0.0589, -0.5848, -0.9199],
        [-0.4206, -2.3198, -0.2062]], requires_grad=True)], 'lr': 0.001, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'maximize': False}]
o.add_param_group({'params': w2})
print(o.param_groups)
# 输出
>>> [{'params': [tensor([[-1.5916, -1.6110, -0.5739],
        [ 0.0589, -0.5848, -0.9199],
        [-0.4206, -2.3198, -0.2062]], requires_grad=True)], 'lr': 0.001, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'maximize': False}, 
     {'params': [tensor([[-0.5546, -1.2646,  1.6420],
        [ 0.0730, -0.0460, -0.0865],
        [ 0.3043,  0.4203, -0.3607]], requires_grad=True)], 'lr': 0.001, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'maximize': False}]

平时写代码如何动态修改学习率(常规操作)

for param_group in optimizer.param_groups:
    param_group["lr"] = lr 

补充:pytorch中的优化器总结

SGD优化器为例:

from torch import nn as nn
import torch as t
from torch.autograd import Variable as V
from torch import optim  # 优化器
# 定义一个LeNet网络
class LeNet(t.nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.features = t.nn.Sequential(
            t.nn.Conv2d(3, 6, 5),
            t.nn.ReLU(),
            t.nn.MaxPool2d(2, 2),
            t.nn.Conv2d(6, 16, 5),
            t.nn.ReLU(),
            t.nn.MaxPool2d(2, 2)
        )
        # 由于调整shape并不是一个class层,
        # 所以在涉及这种操作(非nn.Module操作)需要拆分为多个模型
        self.classifiter = t.nn.Sequential(
            t.nn.Linear(16*5*5, 120),
            t.nn.ReLU(),
            t.nn.Linear(120, 84),
            t.nn.ReLU(),
            t.nn.Linear(84, 10)
        )
    def forward(self, x):
        x = self.features(x)
        x = x.view(-1, 16*5*5)
        x = self.classifiter(x)
        return x
net = LeNet()
# 通常的step优化过程
optimizer = optim.SGD(params=net.parameters(), lr=1)
optimizer.zero_grad()  # 梯度清零,相当于net.zero_grad()
input = V(t.randn(1, 3, 32, 32))
output = net(input)
output.backward(output)  
optimizer.step()  # 执行优化

为不同的子网络参数不同的学习率,finetune常用,使分类器学习率参数更高,学习速度更快(理论上)。

1.经由构建网络时划分好的模组进行学习率设定,

# 为不同子网络设置不同的学习率,在finetune中经常用到
# 如果对某个参数不指定学习率,就使用默认学习率
optimizer = optim.SGD(
    [{'params': net.features.parameters()},  # 学习率为1e-5
     {'params': net.classifiter.parameters(), 'lr': 1e-2}], lr=1e-5
)

2.以网络层对象为单位进行分组,并设定学习率

# 只为两个全连接层设置较大的学习率,其余层的学习率较小
# 以层为单位,为不同层指定不同的学习率
# 提取指定层对象
special_layers = nn.ModuleList([net.classifiter[0], net.classifiter[3]])
# 获取指定层参数id
special_layers_params = list(map(id, special_layers.parameters()))
# 获取非指定层的参数id
base_params = filter(lambda p: id(p) not in special_layers_params, net.parameters())
optimizer = t.optim.SGD([
    {'params': base_params},
    {'params': special_layers.parameters(), 'lr': 0.01}], lr=0.001)

参考:
https://blog.csdn.net/weixin_43593330/article/details/108490956
https://www.cnblogs.com/hellcat/p/8496727.html
https://www.yisu.com/zixun/456082.html