环境配置

  1. 查看matlab与python的版本对应关系 Versions of Python Compatible with MATLAB Products by Release
    我用的是matlab_r2021a,支持python3.8。
  2. 找到python解释器的位置
    我之前把pytorch装在了虚拟环境dl中,所以要找到这个虚拟环境下python解释器的位置。
    1
    2
    3
    conda activate dl
    where python
    /Users/wenxin/opt/anaconda3/envs/dl/bin/python
  3. 将matlab中的python版本设置为指定路径的python解释器
    1
    2
    3
    4
    5
    6
    7
    8
    >> pyenv('Version', '/Users/wenxin/opt/anaconda3/envs/dl/bin/python')
    >> pyversion

    version: '3.8'
    executable: '/Users/wenxin/opt/anaconda3/envs/dl/bin/python'
    library: '/Users/wenxin/opt/anaconda3/envs/dl/lib/libpython3.8.dylib'
    home: '/Users/wenxin/opt/anaconda3/envs/dl'
    isloaded: 0

Python端

simple_model.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 神经网络模型
import torch
from torch import nn

class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.linear = nn.Linear(10, 1)

def forward(self, x):
return self.linear(x)

# 将状态字典保存为.pth文件(只保存模型的权重和偏置,而不保存模型结构)
# model = SimpleModel()
# torch.save(model.state_dict(), 'Code/simple_model.pth')

inference.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# 网络预测部分
import torch
import numpy as np
from simple_model import SimpleModel

def prediction(params_path, data):
model = SimpleModel()
model.load_state_dict(torch.load(params_path, map_location=torch.device('cpu')))
model.eval() # 将模型设置为评估模式
data = torch.tensor(data, dtype=torch.float32)

with torch.no_grad():
pred = model(data)
return pred.numpy()

Matlab端

test_simple_model.m

1
2
3
4
5
6
7
8
9
in_data = ones(1, 10) * 0.6;
pred = matpy('simple_model.pth', in_data);
disp(pred);

function pred = matpy(params_path, data)
model = py.importlib.import_module('inference');
pred = model.prediction(pyargs('params_path', params_path, 'data', data));
pred = double(pred);
end

参考文章