Linear Regression

这里举一个回归模型的例子,展示几种模型可视化的方法,分别是

  • print
  • torchinfo.summary
  • torchsummary
  • torchviz
  • torchview
  • netron工具

首先创建一个简单的回归模型

import torch

# 定义带有两个全连接层和ReLU激活函数的线性回归模型
class LinearRegression(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(LinearRegression, self).__init__()
        self.fc1 = torch.nn.Linear(input_dim, hidden_dim)
        self.relu = torch.nn.ReLU()
        self.fc2 = torch.nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

# 输入输出维度和隐藏层维度
input_dim = 1
hidden_dim = 64
output_dim = 1
batch_size = 32

# 创建模型
model = LinearRegression(input_dim, hidden_dim, output_dim)

# 打印模型
print(model)

# LinearRegression(
#  (fc1): Linear(in_features=1, out_features=64, bias=True)
#  (relu): ReLU()
#  (fc2): Linear(in_features=64, out_features=1, bias=True)
# )

Model Summary

打印模型的简要信息,包括可训练参数等

import torchinfo


# 打印模型summary
torchinfo.summary(
    model,
    input_size=(batch_size, input_dim),
    col_names=["input_size", "output_size", "num_params", "trainable"]
)
============================================================================================================================================
Layer (type:depth-idx)                   Input Shape               Output Shape              Param #                   Trainable
============================================================================================================================================
LinearRegression                         [32, 1]                   [32, 1]                   --                        True
├─Linear: 1-1                            [32, 1]                   [32, 64]                  128                       True
├─ReLU: 1-2                              [32, 64]                  [32, 64]                  --                        --
├─Linear: 1-3                            [32, 64]                  [32, 1]                   65                        True
============================================================================================================================================
Total params: 193
Trainable params: 193
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 0.01
============================================================================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.02
Params size (MB): 0.00
Estimated Total Size (MB): 0.02
============================================================================================================================================

还可以打印简单的信息

import torchsummary

torchsummary.summary(model, input_size=(batch_size, input_dim))
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Linear-1               [-1, 32, 64]             128
              ReLU-2               [-1, 32, 64]               0
            Linear-3                [-1, 32, 1]              65
================================================================
Total params: 193
Trainable params: 193
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.03
Params size (MB): 0.00
Estimated Total Size (MB): 0.03
----------------------------------------------------------------

Torchviz

可以导出graphviz图,需要计算机安装graphviz 执行dot -V 验证graphviz成功安装

import torchviz

# 定义一个示例输入
example_input = torch.randn(batch_size, input_dim)

model = LinearRegression(input_dim, hidden_dim, output_dim)

# 使用torchviz绘制计算图
output = model(example_input)
dot = torchviz.make_dot(output, params=dict(model.named_parameters()), show_attrs=False, show_saved=True)
dot.render("linear_regression_torchviz", format="png", cleanup=True, view=False)

# 如果不是在Jupyter中,注释下面两行
from IPython.display import Image, display
display(Image('./linear_regression.png'))
Torchviz图示 ◎ Torchviz图示

Torchview

可以绘制规范的线框图,便于展示模型的层次

from torchview import draw_graph

model = LinearRegression(input_dim, hidden_dim, output_dim)

# device='meta' -> no memory is consumed for visualization
model_graph = draw_graph(
    model,
    input_size=(batch_size, input_dim), 
    expand_nested=True,
    save_graph=True, 
    filename="linear_regression_torchview",
    device='meta')
model_graph.visual_graph.render("linear_regression_torchview", format="png")
model_graph.visual_graph

torchview

netron工具

可以看出详细的计算图,但是需要将模型导出成ONNX格式

# 导出onnx并用netron展示
model = LinearRegression(input_dim, hidden_dim, output_dim)
example_input = torch.randn(batch_size, input_dim)
torch.onnx.export(model, example_input, "linear_regression.onnx", verbose=True)

Transfomer

如下是Pytorch的Transformer图示

# Transformer模型可视化
import torch
from torch.nn import Transformer
from torch.nn import TransformerDecoder
from torch.nn import TransformerDecoderLayer

transformer_model = Transformer(num_encoder_layers=2, num_decoder_layers=2)

src = torch.rand(10,32,512)
tgt = torch.rand(20,32,512)

print(transformer_model)

Model Summary

import torchinfo

torchinfo.summary(
    transformer_model,
    input_size=((10,32,512), (20,32,512)),
    col_names=["input_size", "output_size", "num_params", "trainable"]
)
=================================================================================================================================================
Layer (type:depth-idx)                        Input Shape               Output Shape              Param #                   Trainable
=================================================================================================================================================
Transformer                                   [10, 32, 512]             [20, 32, 512]             --                        True
├─TransformerEncoder: 1-1                     [10, 32, 512]             [10, 32, 512]             --                        True
│    └─ModuleList: 2-1                        --                        --                        --                        True
│    │    └─TransformerEncoderLayer: 3-1      [10, 32, 512]             [10, 32, 512]             3,152,384                 True
│    │    └─TransformerEncoderLayer: 3-2      [10, 32, 512]             [10, 32, 512]             3,152,384                 True
│    └─LayerNorm: 2-2                         [10, 32, 512]             [10, 32, 512]             1,024                     True
├─TransformerDecoder: 1-2                     [20, 32, 512]             [20, 32, 512]             --                        True
│    └─ModuleList: 2-3                        --                        --                        --                        True
│    │    └─TransformerDecoderLayer: 3-3      [20, 32, 512]             [20, 32, 512]             4,204,032                 True
│    │    └─TransformerDecoderLayer: 3-4      [20, 32, 512]             [20, 32, 512]             4,204,032                 True
│    └─LayerNorm: 2-4                         [20, 32, 512]             [20, 32, 512]             1,024                     True
=================================================================================================================================================
Total params: 14,714,880
Trainable params: 14,714,880
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 126.18
=================================================================================================================================================
Input size (MB): 1.97
Forward/backward pass size (MB): 64.23
Params size (MB): 33.64
Estimated Total Size (MB): 99.84
=================================================================================================================================================

Torchview

from torchview import draw_graph

transformer_model = Transformer(num_encoder_layers=1, num_decoder_layers=1)

model_graph = draw_graph(
    transformer_model,
    input_size=((10,32,512), (20,32,512)), 
    expand_nested=True,
    save_graph=True,
    filename="transformer_torchview",
    device='meta')
model_graph.visual_graph.render("transformer_torchview", format="png")
model_graph.visual_graph

transformer_torchview

单纯decoder图示

import torch
from torch.nn import Transformer
from torch.nn import TransformerDecoder
from torch.nn import TransformerDecoderLayer
decoder_layer = TransformerDecoderLayer(d_model=512, nhead=8)
decoder_model = TransformerDecoder(decoder_layer=decoder_layer, num_layers=2)


model_graph = draw_graph(
    decoder_model,
    input_size=((10,32,512), (20,32,512)), 
    expand_nested=True,
    save_graph=True,
    filename="transformer_torchview",
    device='meta')
model_graph.visual_graph.render("transformer_torchview", format="png")
model_graph.visual_graph

decoder_torchview

Ref

pytorch模型网络可视化画图工具合集(文后附上完整代码) | by MLTalks | Medium

Netron