65、基于卷积神经网络的调制分类(matlab)

news2024/10/5 19:14:42

 1、基于卷积神经网络的调制分类的原理及流程

基于卷积神经网络(CNN)的调制分类是一种常见的信号处理任务,用于识别或分类不同调制方式的信号。下面是基于CNN的调制分类的原理和流程:

原理:

  • CNN是一种深度学习模型,通过卷积层、池化层和全连接层等结构来提取数据中的特征。在调制分类任务中,CNN可以学习到调制信号的特征以区分不同的调制方式。
  • 输入到CNN模型的数据是经过预处理和特征提取后的信号样本,通常是时域信号或频域信号。CNN将这些信号作为输入,并通过网络中的不同层来提取特征并完成调制分类任务。

流程:

  1. 数据准备:准备好用于训练和测试的信号样本数据集,每个样本包含一个已知调制方式的信号。

  2. 数据预处理:对信号数据进行预处理,可能包括归一化、降噪、平滑处理等,以确保数据质量。

  3. 数据特征提取:将信号数据转换为适合CNN输入的格式,例如在时域或频域下进行信号特征提取,将其转换为矩阵形式。

  4. 构建CNN模型:定义CNN模型的结构,包括卷积层、池化层、激活函数层和全连接层等。可以根据具体需求自定义网络结构。

  5. 模型训练:使用训练集数据对CNN模型进行训练,通过反向传播算法不断调整模型参数以使模型输出尽可能接近真实标签。

  6. 模型评估:使用测试集数据评估训练好的模型性能,包括准确率、召回率等指标,对模型进行优化和调整。

  7. 模型应用:将训练好的CNN模型用于未知信号的调制分类,通过模型预测得到信号的调制方式。

  8. 参数调优:根据模型评估结果,调整模型结构、超参数等进行优化,以提高调制分类的准确性和性能。

在Matlab中,可以使用深度学习工具箱等相关工具进行CNN模型的搭建和训练

2、基于卷积神经网络的调制分类的说明

使用卷积神经网络 (CNN) 进行调制分类

生成合成的、通道减损波形。使用生成的波形作为训练数据,训练 CNN 进行调制分类

 

3、使用 CNN 预测调制类型

1)调制数据类型

二相相移键控 (BPSK)

四相相移键控 (QPSK)

八相相移键控 (8-PSK)

十六相正交调幅 (16-QAM)

六十四相正交调幅 (64-QAM)

四相脉冲振幅调制 (PAM4)

高斯频移键控 (GFSK)

连续相位频移键控 (CPFSK)

广播 FM (B-FM)

双边带振幅调制 (DSB-AM)

单边带振幅调制 (SSB-AM)

2)实现代码

modulationTypes = categorical(sort(["BPSK", "QPSK", "8PSK", ...
  "16QAM", "64QAM", "PAM4", "GFSK", "CPFSK", ...
  "B-FM", "DSB-AM", "SSB-AM"]));

3)加载训练网络代码

 

load trainedModulationClassificationNetwork
trainedNet
trainedNet = 
  dlnetwork with properties:

         Layers: [19×1 nnet.cnn.layer.Layer]
    Connections: [18×2 table]
     Learnables: [22×3 table]
          State: [10×3 table]
     InputNames: {'Input Layer'}
    OutputNames: {'SoftMax'}
    Initialized: 1

  View summary with summary.

4、加载训练的网络

1)说明

经过训练的 CNN 接受 1024 个通道减损采样,并预测每个帧的调制类型

生成几个因莱斯多径衰落、中心频率和采样时间漂移以及 AWGN 而有所减损的 PAM4 帧。

以下函数生成合成信号来测试 CNN。然后使用 CNN 预测帧的调制类型。

randi:生成随机位

pammod (Communications Toolbox):PAM4 调制位

rcosdesign (Signal Processing Toolbox):设计平方根升余弦脉冲整形滤波器

filter:脉冲确定符号的形状

comm.RicianChannel (Communications Toolbox):应用莱斯多径通道

comm.PhaseFrequencyOffset (Communications Toolbox):应用时钟偏移引起的相位和/或频率偏移

interp1:应用时钟偏移引起的计时漂移

awgn (Communications Toolbox):添加 AWGN

2)实现代码

rng(123456)
% Random bits
d = randi([0 3], 1024, 1);
% PAM4 modulation
syms = pammod(d,4);
% Square-root raised cosine filter
filterCoeffs = rcosdesign(0.35,4,8);
tx = filter(filterCoeffs,1,upsample(syms,8));

% Channel
SNR = 30;
maxOffset = 5;
fc = 902e6;
fs = 200e3;
multipathChannel = comm.RicianChannel(...
  'SampleRate', fs, ...
  'PathDelays', [0 1.8 3.4] / 200e3, ...
  'AveragePathGains', [0 -2 -10], ...
  'KFactor', 4, ...
  'MaximumDopplerShift', 4);

frequencyShifter = comm.PhaseFrequencyOffset(...
  'SampleRate', fs);

% Apply an independent multipath channel
reset(multipathChannel)
outMultipathChan = multipathChannel(tx);

% Determine clock offset factor
clockOffset = (rand() * 2*maxOffset) - maxOffset;
C = 1 + clockOffset / 1e6;

% Add frequency offset
frequencyShifter.FrequencyOffset = -(C-1)*fc;
outFreqShifter = frequencyShifter(outMultipathChan);

% Add sampling time drift
t = (0:length(tx)-1)' / fs;
newFs = fs * C;
tp = (0:length(tx)-1)' / newFs;
outTimeDrift = interp1(t, outFreqShifter, tp);

% Add noise
rx = awgn(outTimeDrift,SNR,0);

% Frame generation for classification
unknownFrames = helperModClassGetNNFrames(rx);

% Classification
scores1 = predict(trainedNet,unknownFrames);
prediction1 = scores2label(scores1,modulationTypes);

3)返回分类器预测 

prediction1
prediction1 = 7×1 categorical
     PAM4 
     PAM4 
     PAM4 
     PAM4 
     PAM4 
     PAM4 
     PAM4 

4) 分类器还返回一个包含每一帧分数的向量

代码

helperModClassPlotScores(scores1,modulationTypes)

视图效果

3aa5555d0b43495989ff2829144c3a1c.png

5、生成用于训练的波形

1)说明1

为每种调制类型生成 10000 个帧,其中 80% 用于训练,10% 用于验证,10% 用于测试。

网络训练阶段使用训练和验证帧

使用测试帧获得最终分类准确度。每帧的长度为 1024 个样本,采样率为 200 kHz。对于数字调制类型,八个采样表示一个符号。

2)代码实现


trainNow = false;
if trainNow == true
  numFramesPerModType = 10000;
else
  numFramesPerModType = 200;
end
percentTrainingSamples = 80;
percentValidationSamples = 10;
percentTestSamples = 10;

sps = 8;                % Samples per symbol
spf = 1024;             % Samples per frame
fs = 200e3;             % Sample rate
fc = [902e6 100e6];     % Center frequencies

 3)说明2

创建通道减损

让每帧通过通道并具有

  • AWGN

  • 莱斯多径衰落

  • 时钟偏移,导致中心频率偏移和采样时间漂移

由于本示例中的网络基于单个帧作出决定,因此每个帧必须通过独立的通道。

AWGN

通道增加 SNR 为 30 dB 的 AWGN。使用 awgn (Communications Toolbox) 函数实现通道。

莱斯多径

通道使用 comm.RicianChannel (Communications Toolbox) System object™ 通过莱斯多径衰落通道传递信号。假设延迟分布为 [0 1.8 3.4] 个样本,对应的平均路径增益为 [0 -2 -10] dB。K 因子为 4,最大多普勒频移为 4 Hz,等效于 902 MHz 的步行速度。使用以下设置实现通道。

时钟偏移

时钟偏移是发射机和接收机的内部时钟源不准确造成的。

代码

maxDeltaOff = 5;
deltaOff = (rand()*2*maxDeltaOff) - maxDeltaOff;
C = 1 + (deltaOff/1e6);

4)说明3

 

频率偏移

基于时钟偏移因子 C 和中心频率,对每帧进行频率偏移。使用 comm.PhaseFrequencyOffset (Communications Toolbox) 实现通道。

采样率偏移

基于时钟偏移因子 C,对每帧进行采样率偏移。使用 interp1 函数实现通道,以 C×fs 的新速率对帧进行重新采样。

合并后的通道

使用 helperModClassTestChannel 对象对帧应用所有三种通道减损。

代码

channel = helperModClassTestChannel(...
  'SampleRate', fs, ...
  'SNR', SNR, ...
  'PathDelays', [0 1.8 3.4] / fs, ...
  'AveragePathGains', [0 -2 -10], ...
  'KFactor', 4, ...
  'MaximumDopplerShift', 4, ...
  'MaximumClockOffset', 5, ...
  'CenterFrequency', 902e6)
channel = 
  helperModClassTestChannel with properties:

                    SNR: 30
        CenterFrequency: 902000000
             SampleRate: 200000
             PathDelays: [0 9.0000e-06 1.7000e-05]
       AveragePathGains: [0 -2 -10]
                KFactor: 4
    MaximumDopplerShift: 4
     MaximumClockOffset: 5

5)波形生成

说明

创建一个循环,它为每种调制类型生成通道减损的帧并将这些帧及其对应标签存储在 MAT 文件中。通过将数据保存到文件中,您无需每次运行此示例时都生成数据。您还可以更高效地共享数据。

从每帧的开头删除随机数量的样本,以去除瞬变并确保帧相对于符号边界具有随机起点。

代码 

rng(12)
tic
numModulationTypes = length(modulationTypes);
channelInfo = info(channel);
transDelay = 50;
pool = getPoolSafe();
if ~isa(pool,"parallel.ClusterPool")
  dataDirectory = fullfile(tempdir,"ModClassDataFiles");
else
  dataDirectory = uigetdir("","Select network location to save data files");
end
disp("Data file directory is " + dataDirectory)

fileNameRoot = "frame";

% Check if data files exist
dataFilesExist = false;
if exist(dataDirectory,'dir')
  files = dir(fullfile(dataDirectory,sprintf("%s*",fileNameRoot)));
  if length(files) == numModulationTypes*numFramesPerModType
    dataFilesExist = true;
  end
end

if ~dataFilesExist
  disp("Generating data and saving in data files...")
  [success,msg,msgID] = mkdir(dataDirectory);
  if ~success
    error(msgID,msg)
  end
  for modType = 1:numModulationTypes
    elapsedTime = seconds(toc);
    elapsedTime.Format = 'hh:mm:ss';
    fprintf('%s - Generating %s frames\n', ...
      elapsedTime, modulationTypes(modType))
    
    label = modulationTypes(modType);
    numSymbols = (numFramesPerModType / sps);
    dataSrc = helperModClassGetSource(modulationTypes(modType), sps, 2*spf, fs);
    modulator = helperModClassGetModulator(modulationTypes(modType), sps, fs);
    if contains(char(modulationTypes(modType)), {'B-FM','DSB-AM','SSB-AM'})
      % Analog modulation types use a center frequency of 100 MHz
      channel.CenterFrequency = 100e6;
    else
      % Digital modulation types use a center frequency of 902 MHz
      channel.CenterFrequency = 902e6;
    end
    
    for p=1:numFramesPerModType
      % Generate random data
      x = dataSrc();
      
      % Modulate
      y = modulator(x);
      
      % Pass through independent channels
      rxSamples = channel(y);
      
      % Remove transients from the beginning, trim to size, and normalize
      frame = helperModClassFrameGenerator(rxSamples, spf, spf, transDelay, sps);
      
      % Save data file
      fileName = fullfile(dataDirectory,...
        sprintf("%s%s%03d",fileNameRoot,modulationTypes(modType),p));
      save(fileName,"frame","label")
    end
  end
else
  disp("Data files exist. Skip data generation.")
end

Generating data and saving in data files...
00:00:09 - Generating 16QAM frames
00:00:11 - Generating 64QAM frames
00:00:13 - Generating 8PSK frames
00:00:15 - Generating B-FM frames
00:00:17 - Generating BPSK frames
00:00:20 - Generating CPFSK frames
00:00:22 - Generating DSB-AM frames
00:00:24 - Generating GFSK frames
00:00:26 - Generating PAM4 frames
00:00:28 - Generating QPSK frames
00:00:30 - Generating SSB-AM frames

6)效果显示

 实虚部振幅代码

helperModClassPlotTimeDomain(dataDirectory,modulationTypes,fs)

视图效果

382b3112ff3f47059e05480ea885a7cf.png

帧代码 

helperModClassPlotSpectrogram(dataDirectory,modulationTypes,fs,sps)

视图效果

e98df04bc485429e80b22d6ac0596484.png

7)创建数据存储代码

frameDS = signalDatastore(dataDirectory,'SignalVariableNames',["frame","label"]);

8) 拆分为训练、验证和测试代码

splitPercentages = [percentTrainingSamples,percentValidationSamples,percentTestSamples];
[trainDS,validDS,testDS] = helperModClassSplitData(frameDS,splitPercentages);

9) 将数据导入内存代码

% Read the training and validation frames into the memory
pctExists = parallelComputingLicenseExists();
trainFrames = transform(trainDS, @helperModClassReadFrame);
rxTrainFrames = readall(trainFrames,"UseParallel",pctExists);
validFrames = transform(validDS, @helperModClassReadFrame);
rxValidFrames = readall(validFrames,"UseParallel",pctExists);

% Read the training and validation labels into the memory
trainLabels = transform(trainDS, @helperModClassReadLabel);
rxTrainLabels = readall(trainLabels,"UseParallel",pctExists);
validLabels = transform(validDS, @helperModClassReadLabel);
rxValidLabels = readall(validLabels,"UseParallel",pctExists);

6、训练 CNN

1)说明

使用的 CNN 由五个卷积层和一个全连接层组成。除最后一个卷积层外,每个卷积层后面都有一个批量归一化层、修正线性单元 (ReLU) 激活层和最大池化层。在最后一个卷积层中,最大池化层被一个全局平均池化层取代。输出层具有 softmax 激活。

2)实现代码

modClassNet = helperModClassCNN(modulationTypes,sps,spf);

3)配置网络代码 

maxEpochs = 20;
miniBatchSize = 1024;
trainingPlots = "none";
metrics = [];
verbose = true;
validationFrequency = floor(numel(rxTrainLabels)/miniBatchSize);
options = trainingOptions('sgdm', ...
  InitialLearnRate = 3e-1, ...
  MaxEpochs = maxEpochs, ...
  MiniBatchSize = miniBatchSize, ...
  Shuffle = 'every-epoch', ...
  Plots = trainingPlots, ...
  Verbose = verbose, ...
  ValidationData = {rxValidFrames,rxValidLabels}, ...
  ValidationFrequency = validationFrequency, ...
  ValidationPatience = 5, ...
  Metrics = metrics, ...
  LearnRateSchedule = 'piecewise', ...
  LearnRateDropPeriod = 6, ...
  LearnRateDropFactor = 0.75, ...
  OutputNetwork='best-validation-loss');

4)训练网络代码 

if trainNow == true
  elapsedTime = seconds(toc);
  elapsedTime.Format = 'hh:mm:ss';
  fprintf('%s - Training the network\n', elapsedTime)
  trainedNet = trainnet(rxTrainFrames,rxTrainLabels,modClassNet,"crossentropy",options);
else
  load trainedModulationClassificationNetwork
end

 5)训练结果评估代码

elapsedTime = seconds(toc);
elapsedTime.Format = 'hh:mm:ss';
fprintf('%s - Classifying test frames\n', elapsedTime)
% Read the test frames into the memory
testFrames = transform(testDS, @helperModClassReadFrame);
rxTestFrames = readall(testFrames,"UseParallel",pctExists);

% Read the test labels into the memory
testLabels = transform(testDS, @helperModClassReadLabel);
rxTestLabels = readall(testLabels,"UseParallel",pctExists);

scores = predict(trainedNet,cat(3,rxTestFrames{:}));
rxTestPred = scores2label(scores,modulationTypes);
testAccuracy = mean(rxTestPred == rxTestLabels);
disp("Test accuracy: " + testAccuracy*100 + "%")

7、使用 SDR 进行测试

1)说明

使用 helperModClassSDRTest 函数,通过空口信号测试经过训练的网络的性能。要执行此测试,您必须有专用的 SDR 用于发送和接收。

2)代码实现

radioPlatform = "ADALM-PLUTO";

switch radioPlatform
  case "ADALM-PLUTO"
    if helperIsPlutoSDRInstalled() == true
      radios = findPlutoRadio();
      if length(radios) >= 2
        helperModClassSDRTest(radios);
      else
        disp('Selected radios not found. Skipping over-the-air test.')
      end
    end
  case {"USRP B2xx","USRP X3xx","USRP N2xx"}
    if (helperIsUSRPInstalled() == true) && (helperIsPlutoSDRInstalled() == true)
      txRadio = findPlutoRadio();
      rxRadio = findsdru();
      switch radioPlatform
        case "USRP B2xx"
          idx = contains({rxRadio.Platform}, {'B200','B210'});
        case "USRP X3xx"
          idx = contains({rxRadio.Platform}, {'X300','X310'});
        case "USRP N2xx"
          idx = contains({rxRadio.Platform}, 'N200/N210/USRP2');
      end
      rxRadio = rxRadio(idx);
      if (length(txRadio) >= 1) && (length(rxRadio) >= 1)
        helperModClassSDRTest(rxRadio);
      else
        disp('Selected radios not found. Skipping over-the-air test.')
      end
    end
end

3)视图效果
e57411c7e7f0475f80bf53e17aee8802.png

8、总结

基于卷积神经网络(CNN)的调制分类在Matlab中可以通过深度学习工具箱等相关工具来实现。下面是对基于CNN的调制分类在Matlab中的关键步骤的总结:

总结步骤:

  1. 数据准备:准备带有标签的调制信号数据集,确保每个样本包含一个已知调制方式的信号。

  2. 数据预处理:对信号数据进行预处理,包括归一化、降噪等操作,以保证数据的质量。

  3. 数据特征提取:将信号数据转换为适合CNN输入的格式,可以在时域或频域下提取信号特征,并将其表示为矩阵形式。

  4. 构建CNN模型:定义CNN模型的结构,包括卷积层、池化层、激活函数层和全连接层等。可以根据具体需求自定义网络结构。

  5. 模型训练:使用训练集数据对CNN模型进行训练,通过反向传播算法不断调整模型参数以优化模型性能。

  6. 模型评估:使用测试集数据评估训练好的CNN模型的性能,包括准确率、召回率等指标,对模型进行优化和调整。

  7. 模型应用:将训练好的CNN模型用于未知信号的调制分类,通过模型预测得到信号的调制方式。

  8. 参数调优:根据模型评估结果,调整模型结构、超参数等进行优化,以提高调制分类的准确性和性能。

通过以上步骤,可以在Matlab中实现基于CNN的调制分类任务,从而对不同调制方式的信号进行准确分类和识别。在实际应用中,可以根据具体问题的需求对模型进行定制和调整,以获得更好的性能和效果。

9、源代码

代码

%% 基于卷积神经网络的调制分类
%使用卷积神经网络 (CNN) 进行调制分类
%生成合成的、通道减损波形。使用生成的波形作为训练数据,训练 CNN 进行调制分类

%% 使用 CNN 预测调制类型
%可识别以下八种数字调制类型和三种模拟调制类型
%二相相移键控 (BPSK)
%四相相移键控 (QPSK)
%八相相移键控 (8-PSK)
%十六相正交调幅 (16-QAM)
%六十四相正交调幅 (64-QAM)
%四相脉冲振幅调制 (PAM4)
%高斯频移键控 (GFSK)
%连续相位频移键控 (CPFSK)
%广播 FM (B-FM)
%双边带振幅调制 (DSB-AM)
%单边带振幅调制 (SSB-AM)
modulationTypes = categorical(sort(["BPSK", "QPSK", "8PSK", ...
  "16QAM", "64QAM", "PAM4", "GFSK", "CPFSK", ...
  "B-FM", "DSB-AM", "SSB-AM"]));

%% 加载训练的网络

load trainedModulationClassificationNetwork
trainedNet

%经过训练的 CNN 接受 1024 个通道减损采样,并预测每个帧的调制类型
%生成几个因莱斯多径衰落、中心频率和采样时间漂移以及 AWGN 而有所减损的 PAM4 帧。
%以下函数生成合成信号来测试 CNN。然后使用 CNN 预测帧的调制类型。
%randi:生成随机位
%pammod (Communications Toolbox):PAM4 调制位
%rcosdesign (Signal Processing Toolbox):设计平方根升余弦脉冲整形滤波器
%filter:脉冲确定符号的形状
%comm.RicianChannel (Communications Toolbox):应用莱斯多径通道
%comm.PhaseFrequencyOffset (Communications Toolbox):应用时钟偏移引起的相位和/或频率偏移
%interp1:应用时钟偏移引起的计时漂移
%awgn (Communications Toolbox):添加 AWGN
% Set the random number generator to a known state to be able to regenerate
% the same frames every time the simulation is run
rng(123456)
% Random bits
d = randi([0 3], 1024, 1);
% PAM4 modulation
syms = pammod(d,4);
% Square-root raised cosine filter
filterCoeffs = rcosdesign(0.35,4,8);
tx = filter(filterCoeffs,1,upsample(syms,8));

% Channel
SNR = 30;
maxOffset = 5;
fc = 902e6;
fs = 200e3;
multipathChannel = comm.RicianChannel(...
  'SampleRate', fs, ...
  'PathDelays', [0 1.8 3.4] / 200e3, ...
  'AveragePathGains', [0 -2 -10], ...
  'KFactor', 4, ...
  'MaximumDopplerShift', 4);

frequencyShifter = comm.PhaseFrequencyOffset(...
  'SampleRate', fs);

% Apply an independent multipath channel
reset(multipathChannel)
outMultipathChan = multipathChannel(tx);

% Determine clock offset factor
clockOffset = (rand() * 2*maxOffset) - maxOffset;
C = 1 + clockOffset / 1e6;

% Add frequency offset
frequencyShifter.FrequencyOffset = -(C-1)*fc;
outFreqShifter = frequencyShifter(outMultipathChan);

% Add sampling time drift
t = (0:length(tx)-1)' / fs;
newFs = fs * C;
tp = (0:length(tx)-1)' / newFs;
outTimeDrift = interp1(t, outFreqShifter, tp);

% Add noise
rx = awgn(outTimeDrift,SNR,0);

% Frame generation for classification
unknownFrames = helperModClassGetNNFrames(rx);

% Classification
scores1 = predict(trainedNet,unknownFrames);
prediction1 = scores2label(scores1,modulationTypes);
%返回分类器预测
prediction1
%分类器还返回一个包含每一帧分数的向量
%分数对应于每个帧具有预测的调制类型的概率。绘制分数图。
helperModClassPlotScores(scores1,modulationTypes)

%% 生成用于训练的波形
%为每种调制类型生成 10000 个帧,其中 80% 用于训练,10% 用于验证,10% 用于测试。
%网络训练阶段使用训练和验证帧
%使用测试帧获得最终分类准确度。每帧的长度为 1024 个样本,采样率为 200 kHz。对于数字调制类型,八个采样表示一个符号。

trainNow = false;
if trainNow == true
  numFramesPerModType = 10000;
else
  numFramesPerModType = 200;
end
percentTrainingSamples = 80;
percentValidationSamples = 10;
percentTestSamples = 10;

sps = 8;                % Samples per symbol
spf = 1024;             % Samples per frame
fs = 200e3;             % Sample rate
fc = [902e6 100e6];     % Center frequencies

%创建通道减损:AWGN/莱斯多径衰落/时钟偏移,导致中心频率偏移和采样时间漂移
%AWGN:通道增加 SNR 为 30 dB 的 AWGN。使用 awgn (Communications Toolbox) 函数实现通道
%莱斯多径:通道使用 comm.RicianChannel (Communications Toolbox) System object™ 通过莱斯多径衰落通道传递信号。
%时钟偏移:时钟偏移是发射机和接收机的内部时钟源不准确造成的。
maxDeltaOff = 5;
deltaOff = (rand()*2*maxDeltaOff) - maxDeltaOff;
C = 1 + (deltaOff/1e6);

%频率偏移:基于时钟偏移因子 C 和中心频率,对每帧进行频率偏移
%采样率偏移:基于时钟偏移因子 C,对每帧进行采样率偏移。
%合并后的通道:使用 helperModClassTestChannel 对象对帧应用所有三种通道减损
channel = helperModClassTestChannel(...
  'SampleRate', fs, ...
  'SNR', SNR, ...
  'PathDelays', [0 1.8 3.4] / fs, ...
  'AveragePathGains', [0 -2 -10], ...
  'KFactor', 4, ...
  'MaximumDopplerShift', 4, ...
  'MaximumClockOffset', 5, ...
  'CenterFrequency', 902e6)
%使用 info 对象函数查看有关通道的基本信息
chInfo = info(channel)

%波形生成
% Set the random number generator to a known state to be able to regenerate
% the same frames every time the simulation is run
rng(12)
tic
numModulationTypes = length(modulationTypes);
channelInfo = info(channel);
transDelay = 50;
pool = getPoolSafe();
if ~isa(pool,"parallel.ClusterPool")
  dataDirectory = fullfile(tempdir,"ModClassDataFiles");
else
  dataDirectory = uigetdir("","Select network location to save data files");
end
disp("Data file directory is " + dataDirectory)

fileNameRoot = "frame";

% Check if data files exist
dataFilesExist = false;
if exist(dataDirectory,'dir')
  files = dir(fullfile(dataDirectory,sprintf("%s*",fileNameRoot)));
  if length(files) == numModulationTypes*numFramesPerModType
    dataFilesExist = true;
  end
end

if ~dataFilesExist
  disp("Generating data and saving in data files...")
  [success,msg,msgID] = mkdir(dataDirectory);
  if ~success
    error(msgID,msg)
  end
  for modType = 1:numModulationTypes
    elapsedTime = seconds(toc);
    elapsedTime.Format = 'hh:mm:ss';
    fprintf('%s - Generating %s frames\n', ...
      elapsedTime, modulationTypes(modType))
    
    label = modulationTypes(modType);
    numSymbols = (numFramesPerModType / sps);
    dataSrc = helperModClassGetSource(modulationTypes(modType), sps, 2*spf, fs);
    modulator = helperModClassGetModulator(modulationTypes(modType), sps, fs);
    if contains(char(modulationTypes(modType)), {'B-FM','DSB-AM','SSB-AM'})
      % Analog modulation types use a center frequency of 100 MHz
      channel.CenterFrequency = 100e6;
    else
      % Digital modulation types use a center frequency of 902 MHz
      channel.CenterFrequency = 902e6;
    end
    
    for p=1:numFramesPerModType
      % Generate random data
      x = dataSrc();
      
      % Modulate
      y = modulator(x);
      
      % Pass through independent channels
      rxSamples = channel(y);
      
      % Remove transients from the beginning, trim to size, and normalize
      frame = helperModClassFrameGenerator(rxSamples, spf, spf, transDelay, sps);
      
      % Save data file
      fileName = fullfile(dataDirectory,...
        sprintf("%s%s%03d",fileNameRoot,modulationTypes(modType),p));
      save(fileName,"frame","label")
    end
  end
else
  disp("Data files exist. Skip data generation.")
end
%显示波形
helperModClassPlotTimeDomain(dataDirectory,modulationTypes,fs)

helperModClassPlotSpectrogram(dataDirectory,modulationTypes,fs,sps)

%创建数据存储
%使用 signalDatastore 对象来管理包含生成的复杂波形的文件
frameDS = signalDatastore(dataDirectory,'SignalVariableNames',["frame","label"]);
%拆分为训练、验证和测试
splitPercentages = [percentTrainingSamples,percentValidationSamples,percentTestSamples];
[trainDS,validDS,testDS] = helperModClassSplitData(frameDS,splitPercentages);
%将数据导入内存
%神经网络训练是迭代进行
% Read the training and validation frames into the memory
pctExists = parallelComputingLicenseExists();
trainFrames = transform(trainDS, @helperModClassReadFrame);
rxTrainFrames = readall(trainFrames,"UseParallel",pctExists);
validFrames = transform(validDS, @helperModClassReadFrame);
rxValidFrames = readall(validFrames,"UseParallel",pctExists);

% Read the training and validation labels into the memory
trainLabels = transform(trainDS, @helperModClassReadLabel);
rxTrainLabels = readall(trainLabels,"UseParallel",pctExists);
validLabels = transform(validDS, @helperModClassReadLabel);
rxValidLabels = readall(validLabels,"UseParallel",pctExists);

%% 训练 CNN
%CNN 由五个卷积层和一个全连接层组成
%一个卷积层外,每个卷积层后面都有一个批量归一化层、修正线性单元 (ReLU) 激活层和最大池化层
modClassNet = helperModClassCNN(modulationTypes,sps,spf);
%配置 TrainingOptionsSGDM 以使用小批量大小为 1024 的 SGDM 求解器
maxEpochs = 20;
miniBatchSize = 1024;
trainingPlots = "none";
metrics = [];
verbose = true;
validationFrequency = floor(numel(rxTrainLabels)/miniBatchSize);
options = trainingOptions('sgdm', ...
  InitialLearnRate = 3e-1, ...
  MaxEpochs = maxEpochs, ...
  MiniBatchSize = miniBatchSize, ...
  Shuffle = 'every-epoch', ...
  Plots = trainingPlots, ...
  Verbose = verbose, ...
  ValidationData = {rxValidFrames,rxValidLabels}, ...
  ValidationFrequency = validationFrequency, ...
  ValidationPatience = 5, ...
  Metrics = metrics, ...
  LearnRateSchedule = 'piecewise', ...
  LearnRateDropPeriod = 6, ...
  LearnRateDropFactor = 0.75, ...
  OutputNetwork='best-validation-loss');
%训练神网络
if trainNow == true
  elapsedTime = seconds(toc);
  elapsedTime.Format = 'hh:mm:ss';
  fprintf('%s - Training the network\n', elapsedTime)
  trainedNet = trainnet(rxTrainFrames,rxTrainLabels,modClassNet,"crossentropy",options);
else
  load trainedModulationClassificationNetwork
end

%通过获得测试帧的分类准确度来评估经过训练的网络
elapsedTime = seconds(toc);
elapsedTime.Format = 'hh:mm:ss';
fprintf('%s - Classifying test frames\n', elapsedTime)
% Read the test frames into the memory
testFrames = transform(testDS, @helperModClassReadFrame);
rxTestFrames = readall(testFrames,"UseParallel",pctExists);

% Read the test labels into the memory
testLabels = transform(testDS, @helperModClassReadLabel);
rxTestLabels = readall(testLabels,"UseParallel",pctExists);

scores = predict(trainedNet,cat(3,rxTestFrames{:}));
rxTestPred = scores2label(scores,modulationTypes);
testAccuracy = mean(rxTestPred == rxTestLabels);
disp("Test accuracy: " + testAccuracy*100 + "%")

%% 使用 SDR 进行测试
%使用 helperModClassSDRTest 函数,通过空口信号测试经过训练的网络的性能。
radioPlatform = "ADALM-PLUTO";

switch radioPlatform
  case "ADALM-PLUTO"
    if helperIsPlutoSDRInstalled() == true
      radios = findPlutoRadio();
      if length(radios) >= 2
        helperModClassSDRTest(radios);
      else
        disp('Selected radios not found. Skipping over-the-air test.')
      end
    end
  case {"USRP B2xx","USRP X3xx","USRP N2xx"}
    if (helperIsUSRPInstalled() == true) && (helperIsPlutoSDRInstalled() == true)
      txRadio = findPlutoRadio();
      rxRadio = findsdru();
      switch radioPlatform
        case "USRP B2xx"
          idx = contains({rxRadio.Platform}, {'B200','B210'});
        case "USRP X3xx"
          idx = contains({rxRadio.Platform}, {'X300','X310'});
        case "USRP N2xx"
          idx = contains({rxRadio.Platform}, 'N200/N210/USRP2');
      end
      rxRadio = rxRadio(idx);
      if (length(txRadio) >= 1) && (length(rxRadio) >= 1)
        helperModClassSDRTest(rxRadio);
      else
        disp('Selected radios not found. Skipping over-the-air test.')
      end
    end
end





工程文件

https://download.csdn.net/download/XU157303764/89498445

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1885272.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

root密码忘了怎么办(从系统引导过程解决)

目录 1.Linux系统密码忘记 2.系统引导过程 2.1 systemd 2.2 GRUB和GRUB2 2.3 运行级别 3.修复MBR扇区故障和GRUB引导故障 3.1 MBR扇区故障 3.2 GRUB引导故障 1.Linux系统密码忘记 我们在生活中经常遇到这类困扰,就是某个账号还是账户密码忘了,这…

Llama也能做图像生成?文生图模型已开源

导读 基于next-token prediction的图像生成方法首次在ImageNet benchmark超越了LDM, DiT等扩散模型,证明了最原始的自回归模型架构同样可以实现极具竞争力的图像生成性能。 Llama也能做图像生成?文生图模型已开源 香港大学、字节跳动提出了基于自回归模…

【AI大模型】大型模型飞跃升级—文档图像识别领域迎来技术巨变_图像识别大模型

写在前面 2023年12月31日,第十九届中国图象图形学学会青年科学家会议在广州举行,由中国图象图形学学会主办。 该会议的目标是促进青年科学家之间的交流与合作,以提升我国在图像图形领域的科研水平和创新能力。 由中国图象图形学学会和上海合合…

如何将音频文件发送至摄像头

目前再很多互联互通的场景下,如AI盒子再从摄像头上取视频分析,分析出发生某个事件,需要反向通过摄像头的喇叭播放语音,发出告警提示,使用场景如下 盒子上对于此类场景的需求往往不能满足,或者为这个需求需要…

Day8: 232.用栈实现队列 225. 用队列实现栈 20. 有效的括号 1047. 删除字符串中的所有相邻重复项

题目232. 用栈实现队列 - 力扣(LeetCode) class MyQueue { public:MyQueue() {}void push(int x) { // 出栈input.push(x);}int pop() {// 如果出栈为空,把入栈元素全都转移到出栈if (output.empty()) {while (!input.empty()) {int itop i…

【WEB前端2024】3D智体编程:乔布斯3D纪念馆-第52课-语音控制机器人

【WEB前端2024】3D智体编程:乔布斯3D纪念馆-第52课-语音控制机器人 使用dtns.network德塔世界(开源的智体世界引擎),策划和设计《乔布斯超大型的开源3D纪念馆》的系列教程。dtns.network是一款主要由JavaScript编写的智体世界引擎…

彭涛 | 2024年6月小结

6月是忙碌的一个月,换办公室,买家具,群发售,新小伙伴入职等等 1、出海小报童 这个月时间主要做小报童,从刚开始设计内容大纲,到写作,后续拉新花费了大量时间。 比如我们要去调研同行&#xff0c…

新能源行业必会基础知识-----电力市场概论笔记-----中长期合约电力市场

新能源行业知识体系-------主目录-----持续更新(进不去说明我没写完):https://blog.csdn.net/grd_java/article/details/139946830 目录 1. 合约市场2. 双边交易3. 集中交易4. 挂牌交易及互联网中长期电力交易平台5. 中长期交易的优势 1. 合约市场 什么是合约市场 …

从选题到定稿:软考高级系统架构设计师论文写作全攻略

一、论文考试概述 软考系统架构设计师考试的最后一门是论文写作,安排在下午进行,时长两小时,要求撰写约3000字的论文,以45分为及格线。时间紧迫,不容过多犹豫与思考,因此需迅速选定并着手撰写。论文题目通…

【数据结构】C语言实现二叉树

C语言实现二叉树 导读一、二叉树的数据类型二、二叉树的初始化2.1 补充知识点——传址传参2.2 补充知识点——指针传参 三、二叉树的创建3.1 通过添加结点创建BST3.2 通过结点序列创建二叉树3.2.1 由遍历序列手算构建二叉树3.2.1.1 构建步骤3.2.1.2 习题演练3.2.1.3 小结 3.2.2…

在C#/Net中使用Mqtt

net中MQTT的应用场景 c#常用来开发上位机程序,或者其他一些跟设备打交道比较多的系统,所以会经常作为拥有数据的终端,可以用来采集上传数据,而MQTT也是物联网常用的协议,所以下面介绍在C#开发中使用MQTT。 安装MQTTn…

yolov5实例分割跑通以及C#读取yolov5_Seg实例分割转换onnx进行检测部署

一、首先需要训练yolov5_seg的模型,可以去网上学习,或者你直接用我的, 训练环境和yolov5—7.0的环境一样,你可以直接拷过来用。 yolov5_seg算法 链接:https://pan.baidu.com/s/1m-3lFWRHwg5t8MmIOKm4FA 提取码&…

第十四届蓝桥杯省赛C++B组D题【飞机降落】题解(AC)

解题思路 这道题目要求我们判断给定的飞机是否都能在它们的油料耗尽之前降落。为了寻找是否存在合法的降落序列,我们可以使用深度优先搜索(DFS)的方法,尝试所有可能的降落顺序。 首先,我们需要理解题目中的条件。每架…

tcpdump命令详解及使用实例

1、抓所有网卡数据包,保存到指定路径 tcpdump -i any -w /oemdata/123.pcap&一、tcpdump简介 tcpdump可以将网络中传送的数据包完全截获下来提供分析。它支持针对网络层、协议、主机、网络或端口的过滤,并提供and、or、not等逻辑语句来去掉无用的信…

Python中爬虫编程的常见问题及解决方案

Python中爬虫编程的常见问题及解决方案 引言: 随着互联网的发展,网络数据的重要性日益突出。爬虫编程成为大数据分析、网络安全等领域中必备的技能。然而,爬虫编程不仅需要良好的编程基础,还需要面对着各种常见的问题。本文将介绍…

Qt中文乱码如何解决

目录 一、使用建议 二、其它设置 一、使用建议 Qt对中文的支持不是很友好,使用QtCreator会出现各种乱七八糟的中文代码问题,如何处理这种问题? (1)粘贴别人的代码时,先在记事本里粘贴一遍,再…

【Python机器学习】gradio库(快速创建简单的 Web 界面来演示机器学习模型)

文章目录 1. 主要特点2. 安装 Gradio3. 基于tensorflow的例子4. 基于Pytorch的例子4.1 步骤4.2 代码4.3 使用说明Gradio 是一个 Python 库,用于快速创建简单的 Web 界面来演示机器学习模型。它被广泛用于各种应用,如音频、文本、图像处理和更多。Gradio 使得任何人都可以轻松…

深圳比创达电子EMC|EMC电磁兼容性行业:推动电子产品向更高发展

随着科技的飞速发展,电子产品在我们的日常生活中无处不在,从智能手机到智能家居,从医疗设备到工业自动化,这些设备的普及和更新换代对电磁兼容性(EMC)提出了更高的要求。 一、EMC电磁兼容性行业的概述 EM…

深度学习笔记: 最详尽解释预测系统的分类指标(精确率、召回率和 F1 值)

欢迎收藏Star我的Machine Learning Blog:https://github.com/purepisces/Wenqing-Machine_Learning_Blog。如果收藏star, 有问题可以随时与我交流, 谢谢大家! 预测系统的分类指标(精确率、召回率和 F1 值) 简介 让我们来谈谈预测系统的分类指标以及对精确率、召回…

气象站的气象工具都有哪些呢?

气象站,作为观测和记录天气现象的重要基地,拥有一系列专业的气象工具。这些工具不仅能够帮助我们深入了解大气的运动规律,还能为天气预报、气候研究等提供宝贵的数据支持。 风速风向仪也是气象站重要的工具。它通常由风向变送器和风速变送器组…