AI 项目中常见的 Python 代码片段

文章目录
  1. 1. 使用 getattr 以字符串形式调用 . 后面的方法
  2. 2. BatchNorm 层, 别学辣
  3. 3. Micro tricks
    1. 3.1. 用下划线分割的数字
    2. 3.2. 函数参数中的 *
  4. 4. CLIP-related
    1. 4.1. Top 7 prompts
  5. 5. 好文章

本文将摘录在深度学习项目中常见的一些 Python 代码片段。

使用 getattr 以字符串形式调用 . 后面的方法

How to call Python function by name dynamically using a string?

比如你本来是想调用形如下面这一串东西:

1
model = torchvision.models.resnet18(pretrained=configs.use_trained_models)

但可惜的是,在写代码阶段你还暂时不能确实到底是 resnet18 还是 resnet50,你希望能用类似下面这种形式调用——但很显然这样是错误的:

1
2
model_name = 'resnet18'
model = torchvision.models[model_name](pretrained=configs.use_trained_models)

事实上正确的写法是:

1
2
model_name = 'resnet18'
model = getattr(torchvision.models, model_name)(pretrained=configs.use_trained_models)

BatchNorm 层, 别学辣

在下游数据集微调的时候往往会把 BatchNorm 层冻结。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def train(self, mode=True):
"""
Override the default train() to freeze the BN parameters
"""
super(MyNet, self).train(mode)
if self.freeze_bn:
print("Freezing Mean/Var of BatchNorm2D.")
if self.freeze_bn_affine:
print("Freezing Weight/Bias of BatchNorm2D.")
if self.freeze_bn:
for m in self.backbone.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
if self.freeze_bn_affine:
m.weight.requires_grad = False
m.bias.requires_grad = False

Micro tricks

用下划线分割的数字

1
x = 140_000

其实就是 x = 140000,你可以在任意位置加下划线,只是为了人好数位数罢了。

函数参数中的 *

  • 单星号 *:任意个 tuple 格式的参数
  • 双星号 **:任意个 dict 格式的参数

有关 OpenAI CLIP 模型的一切。

Top 7 prompts

K 自官方 Notebook

1
2
3
4
5
6
7
itap of a {}.
a bad photo of the {}.
a origami {}.
a photo of the large {}.
a {} in a video game.
art of the {}.
a photo of the small {}.

好文章