“生成音乐“ 【循环神经网络】

news2024/9/22 5:25:09

前言

本文介绍循环神经网络的进阶案例,通过搭建和训练一个模型,来对钢琴的音符进行预测,通过重复调用模型来进而生成一段音乐;

使用到Maestro的钢琴MIDI文件 ,每个文件由不同音符组成,音符用三个量来表示:音高pitch、步长step、持续时间duration。通过搭建和训练循环神经网络模型,输入一系列音符能预测下一个音符。

下图是一个钢琴MIDI文件,由不同音符组成:

思路流程

  1. 导入数据集
  2. 探索集数据,并进行数据预处理
  3. 构建模型(搭建神经网络结构、编译模型)
  4. 训练模型(把数据输入模型、评估准确性、作出预测、验证预测)  
  5. 使用训练好的模型
  6. 优化模型、重新构建模型、训练模型、使用模型

一、导入数据集

 使用到Maestro的钢琴MIDI文件 ,每个文件中由不同音轨组成,音轨中包含了一些音符;我们可以通过遍历每个音轨中的音符,获取音符的开始时间、结束时间、音高、音量等信息,进行音乐分析和处理。

我们到Maestro下载maestro-v2.0.0-midi.zip文件,它包含1282个钢琴MIDI文件,大约58M左右;解压后能看到如下的文件。

2004

2006

2008

2009

2011

2013

2014

2015

2017

2018

LICENSE

maestro-v2.0.0.csv

maestro-v2.0.0.json

README

我们可以使用电脑播放器打开文件夹中钢琴MIDI文件,比如:maestro-v2.0.0/2004/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav.midi文件,能听到一段钢琴音乐。

二、探索集数据,并进行数据预处理

2.1 解析MIDI文件

我们使用 pretty_midi 库创建和解析MIDI文件,首先安装一下它,执行如下的命令:

!pip install pretty_midi

在notebook jupytre中播放MIDI音频文件,需要安装pyfluidsynth库,执行如下的命令:

!sudo apt install -y fluidsynth

写一个程序解析MIDI文件,

import glob
import pretty_midi


# 加载maestro-v2.0.0目录下的每个midi文件
filenames = glob.glob(str('./maestro-v2.0.0/*/*.mid*'))
print('Number of files:', len(filenames))

# 使用pretty_midi库解析单个MIDI文件,并检查音符的格式
sample_file = filenames[1]
print(sample_file)
pm = pretty_midi.PrettyMIDI(sample_file)

# 对MIDI文件进行检查
print('Number of instruments:', len(pm.instruments))
instrument = pm.instruments[0]
instrument_name = pretty_midi.program_to_instrument_name(instrument.program)
print('Instrument name:', instrument_name)

2.2 提取音符

在训练模型时,将使用三个变量来表示音符:pitch、step 和 duration。

  • pitch是音符的音高,以MIDI音高值表示,范围是0到127,0表示最低音高,127表示最高音高。
  • step是是从上一个音符或曲目的开始经过的时间。
  • duration是音符的持续时间,以“ticks”为单位表示,一个tick表示MIDI时间分辨率中的最小时间单位,具体的时间取决于MIDI文件的时间分辨率参数。

所以我们需要对每个MIDI文件进行提取音符。

上面打开的xxxMIDI文件,查看它的5个音符

# 查看xxxMIDI文件的10个音符
for i, note in enumerate(instrument.notes[:5]):
  note_name = pretty_midi.note_number_to_name(note.pitch)
  duration = note.end - note.start
  print(f'{i}: pitch={note.pitch}, note_name={note_name},'
        f' duration={duration:.4f}')

能看到如下的信息

0: pitch=78, note_name=F#5, duration=0.0292

1: pitch=66, note_name=F#4, duration=0.0333

2: pitch=71, note_name=B4, duration=0.0292

3: pitch=83, note_name=B5, duration=0.0365

4: pitch=73, note_name=C#5, duration=0.0333

写一个函数来从MIDI文件中提取音符

import pandas as pd
import collections
import numpy as np

# 从MIDI文件中提取音符
def midi_to_notes(midi_file: str) -> pd.DataFrame:
  pm = pretty_midi.PrettyMIDI(midi_file)
  instrument = pm.instruments[0]
  notes = collections.defaultdict(list)

  # 按开始时间对笔记排序
  sorted_notes = sorted(instrument.notes, key=lambda note: note.start)
  prev_start = sorted_notes[0].start

  for note in sorted_notes:
    start = note.start
    end = note.end
    notes['pitch'].append(note.pitch)
    notes['start'].append(start)
    notes['end'].append(end)
    notes['step'].append(start - prev_start)
    notes['duration'].append(end - start)
    prev_start = start

  return pd.DataFrame({name: np.array(value) for name, value in notes.items()})

通过midi_to_notes函数,提取一个MIDI文件中提取音符

raw_notes = midi_to_notes('./xxx.midi')
raw_notes.head()

 比如,MIDI文件名称为:maestro-v2.0.0/2004/MIDI-Unprocessed_XP_14_R1_2004_01-03_ORIG_MID--AUDIO_14_R1_2004_03_Track03_wav.midi

pitch是音高。duration 是音符将播放多长时间(以秒为单位),是音符结束时间(end)和音符开始时间(start)之间的差值。step 是从前一个音符开始所经过的时间。

pitchstartendstepduration
0781.0666671.0958330.0000000.029167
1661.0718751.1052080.0052080.033333
2831.2177081.2541670.1458330.036458
3711.2208331.2500000.0031250.029167
4851.3562501.4072920.1354170.051042

解释音符名称可能比解释音高更容易,可以使用下面的函数将数字音高值转换为音符名称。音符名称显示了音符类型、变音记号和八度数(例如 C#4)。

get_note_names = np.vectorize(pretty_midi.note_number_to_name)
sample_note_names = get_note_names(raw_notes['pitch'])
sample_note_names[:10]

输出信息

array(['F#5', 'F#4', 'B5', 'B4', 'C#6', 'C#5', 'D#6', 'B0', 'D#5', 'B1'], dtype='<U3')

2.3 绘制音轨

从MIDI文件中提取音符后,写一个函数来绘制pitch音高、duration持续时间

from matplotlib import pyplot as plt
from typing import Dict, List, Optional, Sequence, Tuple
import seaborn as sns

# 绘制pitch音高、duration持续时间
def plot_piano_roll(notes: pd.DataFrame, count: Optional[int] = None):
  if count:
    title = f'First {count} notes'
  else:
    title = f'Whole track'
    count = len(notes['pitch'])
  plt.figure(figsize=(20, 4))
  plot_pitch = np.stack([notes['pitch'], notes['pitch']], axis=0)
  plot_start_stop = np.stack([notes['start'], notes['end']], axis=0)
  plt.plot(
      plot_start_stop[:, :count], plot_pitch[:, :count], color="b", marker=".")
  plt.xlabel('Time [s]')
  plt.ylabel('Pitch')
  _ = plt.title(title)


# 查看MIDI文件30个音符的分布情况
plot_piano_roll(raw_notes, count=30)

# 绘制整个音轨的音符
plot_piano_roll(raw_notes)

查看MIDI文件50个音高和持续时间的情况

绘制整个音轨的音符

 

2.4 检查音符分布

检查每个音符变量的分布,通过如下函数实现

def plot_distributions(notes: pd.DataFrame, drop_percentile=2.5):
  plt.figure(figsize=[15, 5])
  plt.subplot(1, 3, 1)
  sns.histplot(notes, x="pitch", bins=20)

  plt.subplot(1, 3, 2)
  max_step = np.percentile(notes['step'], 100 - drop_percentile)
  sns.histplot(notes, x="step", bins=np.linspace(0, max_step, 21))

  plt.subplot(1, 3, 3)
  max_duration = np.percentile(notes['duration'], 100 - drop_percentile)
  sns.histplot(notes, x="duration", bins=np.linspace(0, max_duration, 21))

# 查看音符的分布
plot_distributions(raw_notes)

能看到如下的音符分布:

 

2.5 创建训练数据集

通过从MIDI文件中提取音符来创建训练数据集,音符用三个变量来表示:pitch(音高)、step(音符名)和 duration(持续时间)。

对于成批的音符序列训练模型;每个样本将包含一系列音符作为输入特征,下一个音符作为标签。通过这种方式,模型将被训练来预测序列中的下一个音符。

以下代码是创建训练数据集的:

key_order = ['pitch', 'step', 'duration']
train_notes = np.stack([all_notes[key] for key in key_order], axis=1)
notes_ds = tf.data.Dataset.from_tensor_slices(train_notes)


# 每个示例将由一系列音符组成作为输入特征,并将下一个音符作为标签。
# 通过这种方式,模型将被训练以预测序列中的下一个音符。
def create_sequences(
    dataset: tf.data.Dataset, 
    seq_length: int,
    vocab_size = 128,
) -> tf.data.Dataset:

  seq_length = seq_length+1

  windows = dataset.window(seq_length, shift=1, stride=1,
                              drop_remainder=True)

  flatten = lambda x: x.batch(seq_length, drop_remainder=True)
  sequences = windows.flat_map(flatten)
  
  def scale_pitch(x):
    x = x/[vocab_size,1.0,1.0]
    return x

  def split_labels(sequences):
    inputs = sequences[:-1]
    labels_dense = sequences[-1]
    labels = {key:labels_dense[i] for i,key in enumerate(key_order)}
    return scale_pitch(inputs), labels

  return sequences.map(split_labels, num_parallel_calls=tf.data.AUTOTUNE)

设置每个示例的序列长度。尝试使用不同的长度(例如,50、100、150),以确定哪种长度最适合数据,或使用超参数调整。词汇表的大小(Vocab_Size)设置为128,表示Pretty_MIDI支持的所有音调。

seq_length = 25
vocab_size = 128
seq_ds = create_sequences(notes_ds, seq_length, vocab_size)
seq_ds.element_spec

batch_size = 64
buffer_size = n_notes - seq_length  
train_ds = (seq_ds
            .shuffle(buffer_size)
            .batch(batch_size, drop_remainder=True)
            .cache()
            .prefetch(tf.data.experimental.AUTOTUNE))

三、构建模型

模型输入是序列的音符,输出是一个音符;即:通过输入一段连续的音符,预测下一个音符。

输入:输入的维度是nx3,n是指音符的个数长度,3是指音符使用pitch(音高)、step(音符名)和 duration(持续时间)三个变量来表示;

输出:预测一个音符,设置模型输出的维度是3,表示音符的3个变量。

模型主体:LSTM结构。

损失函数:对于pitch和duration,使用基于均方误差的自定义损失函数。

def mse_with_positive_pressure(y_true: tf.Tensor, y_pred: tf.Tensor):
  mse = (y_true - y_pred) ** 2
  positive_pressure = 10 * tf.maximum(-y_pred, 0.0)
  return tf.reduce_mean(mse + positive_pressure)

下面是搭建网络的代码:

# 设置输入
input_shape = (seq_length, 3)
learning_rate = 0.005

# 模型输入层
inputs = tf.keras.Input(input_shape)

# 使用循环神经网络的变体LSTM层
x = tf.keras.layers.LSTM(128)(inputs)

# 输出层
outputs = {
  'pitch': tf.keras.layers.Dense(128, name='pitch')(x),
  'step': tf.keras.layers.Dense(1, name='step')(x),
  'duration': tf.keras.layers.Dense(1, name='duration')(x),
}

# 构建模型
model = tf.keras.Model(inputs, outputs)

# 定义损失函数
loss = {
      'pitch': tf.keras.losses.SparseCategoricalCrossentropy(
          from_logits=True),
      'step': mse_with_positive_pressure,
      'duration': mse_with_positive_pressure,
}

# 模型训练的优化器
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

# 编译模型
model.compile(
    loss=loss,
    loss_weights={
        'pitch': 0.05,
        'step': 1.0,
        'duration':1.0,
    },
    optimizer=optimizer,
)

# 设置训练模型时的回调函数
callbacks = [
    tf.keras.callbacks.ModelCheckpoint(
        filepath='./training_checkpoints/ckpt_{epoch}',
        save_weights_only=True),
    tf.keras.callbacks.EarlyStopping(
        monitor='loss',
        patience=5,
        verbose=1,
        restore_best_weights=True),
]

查看一下网络模型:tf.keras.utils.plot_model(model)

或者用这样方式看看:model.summary()

 

四、训练模型

这里我们输入准备好的训练集数据,指定训练模型时的回调函数(保存模型权重、自动早停),模型一共训练50轮。

# 模型训练50轮
epochs = 50

# 开始训练模型
history = model.fit(
    train_ds,
    epochs=epochs,
    callbacks=callbacks,
)

下图是训练过程的截图,能看到模型训练到42轮时停止了,因为使用EarlyStopping()函数,模型的损失足够小了,就不再训练了。

通常loss越小越好,训练完模型后,画一下损失值的变化过程:

 

五、使用模型

使用 model.predict( )  函数,进行预测音符。但要使用模型生成音符,首先需要提供音符的起始序列。

下面的函数,从一系列音符中生成一个音符

def predict_next_note(
    notes: np.ndarray, 
    keras_model: tf.keras.Model, 
    temperature: float = 1.0) -> int:
  """使用经过训练的序列模型生成标签ID"""

  assert temperature > 0

  # 添加批次维度
  inputs = tf.expand_dims(notes, 0)

  predictions = model.predict(inputs)
  pitch_logits = predictions['pitch']
  step = predictions['step']
  duration = predictions['duration']

  pitch_logits /= temperature
  pitch = tf.random.categorical(pitch_logits, num_samples=1)
  pitch = tf.squeeze(pitch, axis=-1)
  duration = tf.squeeze(duration, axis=-1)
  step = tf.squeeze(step, axis=-1)

  # `step` 和 `duration` 值应该是非负数
  step = tf.maximum(0, step)
  duration = tf.maximum(0, duration)

  return int(pitch), float(step), float(duration)

举一个例子,生成一些音符

temperature = 2.0
num_predictions = 120

sample_notes = np.stack([raw_notes[key] for key in key_order], axis=1)

# 音符的初始序列; 音高被归一化,类似于训练序列
input_notes = (
    sample_notes[:seq_length] / np.array([vocab_size, 1, 1]))

generated_notes = []
prev_start = 0
for _ in range(num_predictions):
  pitch, step, duration = predict_next_note(input_notes, model, temperature)
  start = prev_start + step
  end = start + duration
  input_note = (pitch, step, duration)
  generated_notes.append((*input_note, start, end))
  input_notes = np.delete(input_notes, 0, axis=0)
  input_notes = np.append(input_notes, np.expand_dims(input_note, 0), axis=0)
  prev_start = start

generated_notes = pd.DataFrame(
    generated_notes, columns=(*key_order, 'start', 'end'))

# 查看成的generated_notes前5个音符
generated_notes.head(5)

# 查看成的generated_notes的音轨情况
plot_piano_roll(generated_notes)

# 查看生成的generated_notes 音符的分布情况
plot_distributions(generated_notes)

查看生成的前5个音符,

pitchstepdurationstartend
0370.0956330.0920780.0956330.187710
1770.0974170.6094620.1930490.802511
2760.0890490.4556260.2820990.737724
3940.0965750.4439370.3786730.822611
4970.1094040.3766040.4880770.864681

 

查看成的generated_notes的音轨情况

 查看生成的generated_notes 音符的分布情况

 

本文只供大家参考和学习,谢谢~

其它推荐文章:

[1] 【神经网络】综合篇——人工神经网络、卷积神经网络、循环神经网络、生成对抗网络

[2] 手把手搭建一个【卷积神经网络】

[3] “花朵分类“ 手把手搭建【卷积神经网络】

[4] 一篇文章“简单”认识《循环神经网络》

[5] 神经网络学习

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/357567.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

千锋教育嵌入式物联网教程之系统编程篇学习-04

目录 alarm函数 raise函数 abort函数 pause函数 转折点 signal函数 可重入函数 信号集 sigemptyset() sigfillset sigismember()​ sigaddset()​ sigdelset()​ 代码讲解 信号阻塞集 sigprocmask()​ alarm函数 相当于一个闹钟&#xff0c;默认动作是终止调用alarm函数的进…

HSCSEC 2023 个人练习

&#x1f60b; 大家好&#xff0c;我是YAy_17&#xff0c;是一枚爱好网安的小白。本人水平有限&#xff0c;欢迎各位大佬指点&#xff0c;欢迎关注&#x1f601;&#xff0c;一起学习 &#x1f497; &#xff0c;一起进步 ⭐ 。⭐ 此后如竟没有炬火&#xff0c;我便是唯一的光。…

聊一聊国际化i18n

i18n 概述i18n 是国际化的缩写&#xff0c;其完整的写法是 Internationalization&#xff0c;翻译为国际化。国际化是指在软件开发中对于不同语言和地区的支持。目的是为了让一款软件可以在不同的语言和地区环境下正常运行&#xff0c;使其适应全球各地的用户。这通常包括对语言…

Simulink 自动代码生成电机控制:低阶滑模观测器仿真实现及生成代码在开发板上运行

目录 理论参考 仿真实现 运行演示 总结 前段实时搭过高阶的滑模观测器&#xff0c;相比于高阶的&#xff0c;普通的滑模观测器计算量小更适合计算能力低的MCU&#xff0c;这里参考Microchip的16位MCU所使用的观测器&#xff0c;通过Simulink建模仿真实现系统控制&#xff0…

【查看多个长图】如何方便地在安卓手机上查看多个长图?如何更便捷地浏览长图合集

经常我会看到有些知识分享是通过长图形式进行。 往往在手机本地的图片浏览器中不能很方便地查看很多长图&#xff08;能放大&#xff0c;但是横向滑动时&#xff0c;无法保证同样的放缩比例浏览同一个文件夹&#xff09;。 我推荐下面一个APP和曲折解决办法。 1、perfect vi…

Error: Timeout trying to fetch resolutions from npm

总目录&#xff1a; 如何使用VSCode插件codesight扫描出前端项目的风险依赖包并借助 npm-force-resolutions 修复之&#xff1f;blackduck issue fix 文章目录问题描述【最终解决】我搜索到的解决方案npmjs 该依赖各版本列表及对应的被下载次数github issue 说降级到0.0.3就可以…

(十五)、从插件市场引入问题反馈页面【uniapp+uinicloud多用户社区博客实战项目(完整开发文档-从零到完整项目)】

1&#xff0c;插件市场问题反馈页面 插件市场链接 dloud插件插件市场中找到问题反馈插件&#xff1a; 首先确保登录了dcloud账号。 使用hbuilderX导入插件到自己项目中。 选择合并导入。 从插件市场导入意见反馈页面的路径地址如下&#xff1a; 2&#xff0c;点击跳转到…

论文阅读_AlphaGo_Zero

论文信息 name_en: Mastering the game of Go without human knowledge name_ch: 在没有人类知识的情况下掌握围棋游戏 paper_addr: http://www.nature.com/articles/nature24270 doi: 10.1038/nature24270 date_publish: 2017-10-01 tags: [‘深度学习’,‘强化学习’] if: 6…

【C++封装】C++面向对象模型

文章内容如下&#xff1a; 1&#xff09;成员变量和函数的存储 2&#xff09;this指针 3&#xff09;const修饰成员函数 4&#xff09;有元 一。成员变量和函数的存储 C实现了封装&#xff0c;数据(-变量)和处理数据的操作(-函数)是分开存储的&#xff0c;C中的非静态数据…

SpringBoot Notes

文章目录1 SpringBootWeb快速入门1.1Spring官网1.2 Web分析2. HTTP协议2.1 HTTP介绍34 SpringBootWeb请求响应5 响应6 分层解耦6.1 三层架构6.1.1 三层架构介绍6.1.2 基于三层架构的程序执行流程&#xff1a;6.1.3 代码拆分6.2 分层解耦6.2.1 内聚、耦合6.2.2 解耦思路6.3 IOC&…

[LeetCode周赛复盘] 第 333 场周赛20230219

[LeetCode周赛复盘] 第 333 场周赛20230219 一、本周周赛总结二、 [Easy] 6362. 合并两个二维数组 - 求和法1. 题目描述2. 思路分析3. 代码实现三、[Medium] 6365. 将整数减少到零需要的最少操作数1. 题目描述2. 思路分析3. 代码实现四、[Medium] 6364. 无平方子集计数1. 题目描…

操作系统闲谈08——系统调用、中断、异常

操作系统闲谈08——系统调用、中断、异常 一、系统调用 IDT - GDT - 系统调用表 找到对应系统调用号将系统调用号以及一些现场信息存入寄存器eax中&#xff08;ebx、ecx、edx存放其他信息&#xff09;&#xff0c;然后触发软中断&#xff08;x86中&#xff0c;0x80为中断号&…

设计模式 状态机

前言 本文梳理状态机概念&#xff0c;在实操中状态机和状态模式类似&#xff0c;只是被封装起来&#xff0c;可以很方便的实现状态初始化和状态转换。 概念 有限状态机&#xff08;finite-state machine&#xff09;又称有限状态自动机&#xff08;英语&#xff1a;finite-s…

ThreadLocal知识点总结

什么是ThreadLocal&#xff1f;它的作用是什么&#xff1f; ThreadLocal是线程Thread中属性threadLocals的管理者。 ThreadLocal是Java中lang包下的一个类&#xff0c;可以用于在多线程环境中为每个线程维护独立的变量副本。它的作用是让每个线程都拥有自己的数据副本&#xff…

Java面向对象的特性:封装,继承与多态

Java面向对象的特性 在学习Java的过程是必须要知道的Java三大特性&#xff1a;封装、继承、多态。如果要分为四类的话&#xff0c;加上抽象特性。 封装 1.封装概述 是面向对像三大特征之一&#xff08;封装&#xff0c;继承&#xff0c;多态&#xff09; 是面向对象编程语言对客…

语音增强学习路线图Roadmap

语音增强算是比较难的研究领域&#xff0c;从入门到精通有很多台阶&#xff0c;本文介绍一些有价值的书籍&#xff0c;值得反复阅读。主要分为基础类和进阶类书籍&#xff0c;大多都是理论和实践相结合的书籍&#xff0c;编程实践是抓手,让知识和基础理论变扎实。基础书籍《信号…

RT-Thread初识学习-01

1. RT-Thread 简介 1.1 RT-Thread 是什么 据不完全统计&#xff0c;世界有成千上万个 RTOS&#xff08;Real-time operating system&#xff0c;实时操作系统&#xff09;&#xff0c;RT-Thread 就是其中一个优秀的作品。 RT-Thread 内核的第一个版本是熊谱翔先生在 2006 年…

分布式-分布式存储笔记

读写分离 什么时候需要读写分离 互联网大部分业务场景都是读多写少的&#xff0c;读和写的请求对比可能差了不止一个数量级。为了不让数据库的读成为业务瓶颈&#xff0c;同时也为了保证写库的成功率&#xff0c;一般会采用读写分离的技术来保证。 读写分离的实现是把访问的压…

LeetCode-384-打乱数组

1、列表随机 为了能够初始化数组&#xff0c;我们使用nums保存当前的数组&#xff0c;利用orignal保存初始化数组。为了实现等可能随机打乱&#xff0c;考虑到随机数本质上是基于随机数种子的伪随机&#xff0c;我们采用如下的方式实现等可能随机&#xff1a;我们将所有元素压…

MySQL备份恢复(十二)

文章目录1. MySQL数据损坏类型1.1 物理损坏1.2 逻辑损坏2. DBA运维人员备份/恢复职责2.1 设计备份/容灾策略2.1.1 备份策略2.1.2 容灾策略2.2 定期的备份/容灾检查2.3 定期的故障恢复演练2.4 数据损坏时的快速准确恢复2.5 数据迁移工作3. MySQL常用备份工具3.1 逻辑备份方式3.2…