2025.1.26机器学习笔记:C-RNN-GAN文献阅读

news2025/1/31 8:08:02

2025.1.26周报

  • 文献阅读
    • 题目信息
    • 摘要
    • Abstract
    • 创新点
    • 网络架构
    • 实验
    • 结论
    • 缺点以及后续展望
  • 总结

文献阅读

题目信息

  • 题目: C-RNN-GAN: Continuous recurrent neural networks with adversarial training
  • 会议期刊: NIPS
  • 作者: Olof Mogren
  • 发表时间: 2016
  • 文章链接:https://arxiv.org/pdf/1611.09904

摘要

生成对抗网络(GANs)目的是生成数据,而循环神经网络(RNNs)常用于生成数据序列。目前已有研究用RNN进行音乐生成,但多使用符号表示。本论文中,作者研究了使用对抗训练生成连续数据的序列可行性,并使用古典音乐的midi文件进行评估。作者提出C-RNN-GAN(连续循环生成对抗网络)这种神经网络架构,用对抗训练来对序列的整体联合概率建模并生成高质量的数据序列。通过在古典音乐midi格式序列上训练该模型,并用音阶一致性和音域等指标进行评估,以验证生成对抗训练是一种可行的训练网络的方法,提出的模型为连续数据的生成提供了新思路。

Abstract

The purpose of Generative Adversarial Networks (GANs) is to generate data, while Recurrent Neural Networks (RNNs) are often used for generating data sequences. Currently, there have been many studies using RNNs for music generation, but most of them employ symbolic representations. In this paper, the authors investigate the feasibility of using adversarial training to generate sequences of continuous data, and evaluate it using classical music MIDI files. They propose the C-RNN-GAN (Continuous Recurrent Neural Network GAN), a neural network architecture that uses adversarial training to model the joint probability of the entire sequence and generate high-quality data sequences. By training this model on classical music MIDI format sequences and assessing it with metrics such as scale consistency and range, the authors demonstrate that adversarial training is a viable method for training networks, and the proposed model offers a new approach for the generation of continuous data.

创新点

本研究创新性在于提出C-RNN-GAN模型,作者采用对抗训练方式处理连续序列数据。作者使用四个实值标量对音乐信号进行生成,此外,还使用了反向传播算法进行端到端训练。

网络架构

提出C-RNN-GAN模型,RNN-GAN 由生成器(Generator)和判别器(Discriminator)两个主要部分组成。
如下图所示:
生成器(G)从随机输入(噪声)生成音乐序列。其包含LSTM层和全连接层。输入为随机噪声输入(如,随机向量);输出是生成的音乐序列。
判别器(D)用于区分生成的音乐序列和真实音乐序列。D由Bi-LSTM(双向长短期记忆网络)组成。输入为真实或生成的音乐序列;输出为一个概率值(表示输入序列是真实音乐的概率)。
在训练中,G与D相互对抗,生成器和判别器交替训练,生成器的目标是欺骗判别器,判别器的目标是准确区分真实和生成的音乐。
在这里插入图片描述
其中G与D的损失函数表达式如下:
L G = 1 m ∑ i = 1 m log ⁡ ( 1 − D ( G ( z ( i ) ) ) ) L_{G}=\frac{1}{m} \sum_{i=1}^{m} \log \left(1-D\left(G\left(\boldsymbol{z}^{(i)}\right)\right)\right) LG=m1i=1mlog(1D(G(z(i))))
L D = 1 m ∑ i = 1 m [ − log ⁡ D ( x ( i ) ) − ( log ⁡ ( 1 − D ( G ( z ( i ) ) ) ) ) ] L_{D}=\frac{1}{m} \sum_{i=1}^{m}\left[-\log D\left(\boldsymbol{x}^{(i)}\right)-\left(\log \left(1-D\left(G\left(\boldsymbol{z}^{(i)}\right)\right)\right)\right)\right] LD=m1i=1m[logD(x(i))(log(1D(G(z(i)))))]
其中, z ( i ) z^{(i)} z(i) [ 0 , 1 ] k [0,1]^{k} [0,1]k 中的均匀随机向量的序列,而 x ( i ) x^(i) x(i) 是来自训练数据的序列,k 表示随机序列中的数据的维数。G 中每个单元格的输入是一个随机向量,与先前单元格的输出串联。.
其实就跟我们之前阅读的GAN差不多,这里就不在赘述了。

实验

从网络收集midi格式的古典音乐文件作为训练数据,训练数据是以midi格式的音乐文件形式从网上收集的,包含着名的古典音乐作品。 每个midi事件被加载并与其持续时间,音调,强度(速度)以及自上一音调开始以来的时间一起保存。音调数据在内部用相应的声音频率表示。所有数据归一化为每秒384点的刻度分辨率。 该数据包含来自160位不同古典音乐作曲家的3697个m​​idi文件,最后作者通过多维度指标评估生成音乐。

实验的模型评估指标:
Polyphony(复音):衡量两个音调同时演奏的频率。
Scale consistency(音阶一致性):通过计算属于标准音阶的音调比例得出,报告最匹配音阶的数值。
Repetitions (重复度):计算样本中的重复程度,仅考虑音调及其顺序,不考虑时间。
Tone span(音域):样本中最低和最高音调之间的半音步数。

模型参数:
生成器(G)和判别器(D)中的LSTM网络深度都为2,每个LSTM单元具有350个隐藏单元。
D双向的,而G是单向的。其中,来自D中的每个LSTM单元的输出被馈送到完全连接的层,其中权重在时间步长上共享,然后每个单元的sigmoid输出被平均化。
此外,在训练中使用反向传播(BPTT)和小批量随机梯度下降。学习率设置为0.1,并且将L2正则化应用于G和D中的权重。模型预训练6个epochs,平方误差损失以预测训练序列中的下一个事件。每个LSTM单元的输入是随机向量v,与前一时间步的输出连接。 v均匀分布在 [ 0 , 1 ] k [0,1]^k [0,1]k 中,并且k被选择为每个音调中的特征数量。在预训练期间,作者使用序列长度的模式,从短序列开始,从训练数据中随机样,最终用越来越长的序列训练模型。

实验结果:
C-RNN-GAN随着训练进行,生成音乐的复杂性增加。独特音调数量有微弱上升趋势,音阶一致性在10-15个周期后趋于稳定。
3音调重复在前25个周期有上升趋势,然后保持在较低水平,其与使用的音调数量相关。
在这里插入图片描述
Baseline(一个类似于生成器的循环网络)变化程度未达到C-RNN-GAN的水平。使用的独特音调数量一直低很多,音阶一致性与C-RNN-GAN相似,但音域与独特音调数量的关系比C-RNN-GAN更紧密,表明其使用的音调变化性更小。
在这里插入图片描述
C-RNN-GAN-3(3的意思是每个LSTM单元三个音调输出)与C-RNN-GAN和Baseline模型相比,获得了更高的复音分数。
在第50 - 55个周期左右达到许多零值输出状态后,在音域、独特音调数量、强度范围和3音调重复方面达到了更高的值。
在这里插入图片描述
真实音乐强度范围与生成音乐相似,音阶一致性略高但变化更大,复音分数与C-RNN-GAN-3相似,3音调重复高很多,但由于歌曲长度不同难以比较(通过除以真实音乐长度与生成音乐长度之比进行了归一化)。
在这里插入图片描述
从实验结果可以看出对抗训练有助于模型学习更多变、音域更广、强度范围更大的音乐。其中,模型每个LSTM单元输出多于一个音调有助于生成复音分数更高的音乐。虽然生成音乐是复音的,但在实验评估的复音分数方面,C-RNN-GAN得分较低,而允许每个LSTM单元同时输出多达三个音调的模型(C-RNN-GAN-3)在复音方面得分更好。虽然样本之间的时间差异较大,但在一首曲子内大致相同。
代码:https://github.com/olofmogren/c-rnn-gan


"""
模型参数:
learning_rate - 学习率的初始值
max_grad_norm - 梯度的最大允许范数
num_layers - LSTM 层的数量
songlength - LSTM 展开的步数
hidden_size - LSTM 单元的数量
epochs_before_decay - 使用初始学习率训练的轮数
max_epoch - 训练的总轮数
keep_prob - Dropout 层中保留权重的概率
lr_decay - 在 "epochs_before_decay" 之后每个轮数的学习率衰减
batch_size - 批量大小
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import time, datetime, os, sys
import pickle as pkl
from subprocess import call, Popen

import numpy as np
import tensorflow as tf
from tensorflow.python.client import timeline


import music_data_utils
from midi_statistics import get_all_stats

flags = tf.flags
logging = tf.logging

flags.DEFINE_string("datadir", None, "保存和加载 MIDI 音乐文件的目录")
flags.DEFINE_string("traindir", None, "保存检查点和 gnuplot 文件的目录")
flags.DEFINE_integer("epochs_per_checkpoint", 2, "每个检查点之间进行的训练轮数")
flags.DEFINE_boolean("log_device_placement", False, "输出设备放置的信息")
flags.DEFINE_string("call_after", None, "退出后调用的命令")
flags.DEFINE_integer("exit_after", 1440, "运行多少分钟后退出")
flags.DEFINE_integer("select_validation_percentage", None, "选择作为验证集的数据的随机百分比")
flags.DEFINE_integer("select_test_percentage", None, "选择作为测试集的数据的随机百分比")
flags.DEFINE_boolean("sample", False, "从模型中采样输出。假设训练已经完成。将采样输出保存到文件中")
flags.DEFINE_integer("works_per_composer", None, "限制每个作曲家加载的作品数量")
flags.DEFINE_boolean("disable_feed_previous", False, "在生成器中,将前一个单元的输出作为下一个单元的输入")
flags.DEFINE_float("init_scale", 0.05, "权重的初始缩放值")
flags.DEFINE_float("learning_rate", 0.1, "学习率")
flags.DEFINE_float("d_lr_factor", 0.5, "学习率衰减因子")
flags.DEFINE_float("max_grad_norm", 5.0, "梯度的最大允许范数")
flags.DEFINE_float("keep_prob", 0.5, "保留权重的概率。1表示不使用 Dropout")
flags.DEFINE_float("lr_decay", 1.0, "在 'epochs_before_decay' 之后每个轮数的学习率衰减")
flags.DEFINE_integer("num_layers_g", 2, "生成器 G 中堆叠的循环单元数量")
flags.DEFINE_integer("num_layers_d", 2, "判别器 D 中堆叠的循环单元数量")
flags.DEFINE_integer("songlength", 100, "限制歌曲输入的事件数量")
flags.DEFINE_integer("meta_layer_size", 200, "元信息模块的隐藏层大小")
flags.DEFINE_integer("hidden_size_g", 100, "生成器 G 的循环部分的隐藏层大小")
flags.DEFINE_integer("hidden_size_d", 100, "判别器 D 的循环部分的隐藏层大小,默认与 G 相同")
flags.DEFINE_integer("epochs_before_decay", 60, "开始衰减之前进行的轮数")
flags.DEFINE_integer("max_epoch", 500, "停止训练之前的总轮数")
flags.DEFINE_integer("batch_size", 20, "批量大小")
flags.DEFINE_integer("biscale_slow_layer_ticks", 8, "Biscale 慢层的刻度")
flags.DEFINE_boolean("multiscale", False, "多尺度 RNN")
flags.DEFINE_integer("pretraining_epochs", 6, "进行语言模型风格预训练的轮数")
flags.DEFINE_boolean("pretraining_d", False, "在预训练期间训练 D")
flags.DEFINE_boolean("initialize_d", False, "初始化 D 的变量,无论检查点中是否有已训练的版本")
flags.DEFINE_boolean("ignore_saved_args", False, "告诉程序忽略已保存的参数,而是使用命令行参数")
flags.DEFINE_boolean("pace_events", False, "在解析输入数据时,如果某个四分音符位置没有音符,则插入一个虚拟事件")
flags.DEFINE_boolean("minibatch_d", False, "为小批量增加核特征以提高多样性")
flags.DEFINE_boolean("unidirectional_d", False, "使用单向 RNN 而不是双向 RNN 作为 D")
flags.DEFINE_boolean("profiling", False, "性能分析。在 plots 目录中写入 timeline.json 文件")
flags.DEFINE_boolean("float16", False, "使用 float16 数据类型,否则,使用 float32")
flags.DEFINE_boolean("adam", False, "使用 Adam 优化器")
flags.DEFINE_boolean("feature_matching", False, "生成器 G 的特征匹配目标")
flags.DEFINE_boolean("disable_l2_regularizer", False, "对权重进行 L2 正则化")
flags.DEFINE_float("reg_scale", 1.0, "L2 正则化系数")
flags.DEFINE_boolean("synthetic_chords", False, "使用合成生成的和弦进行训练(每个事件三个音符)")
flags.DEFINE_integer("tones_per_cell", 1, "每个 RNN 单元输出的最大音符数量")
flags.DEFINE_string("composer", None, "指定一个作曲家,并仅在此作曲家的作品上训练模型")
flags.DEFINE_boolean("generate_meta", False, "将作曲家和流派作为输出的一部分生成")
flags.DEFINE_float("random_input_scale", 1.0, "随机输入的缩放比例(1表示与生成的特征大小相同)")
flags.DEFINE_boolean("end_classification", False, "仅在 D 的末尾进行分类。否则,在每个时间步进行分类并取平均值")

FLAGS = flags.FLAGS

model_layout_flags = ['num_layers_g', 'num_layers_d', 'meta_layer_size', 'hidden_size_g', 'hidden_size_d', 'biscale_slow_layer_ticks', 'multiscale', 'multiscale', 'disable_feed_previous', 'pace_events', 'minibatch_d', 'unidirectional_d', 'feature_matching', 'composer']

def make_rnn_cell(rnn_layer_sizes,
                  dropout_keep_prob=1.0,
                  attn_length=0,
                  base_cell=tf.contrib.rnn.BasicLSTMCell,
                  state_is_tuple=True,
                  reuse=False):
"""
根据给定的超参数创建一个RNN单元。
  参数:
    rnn_layer_sizes:一个整数列表,表示 RNN 每层的大小。
    dropout_keep_prob:一个浮点数,表示保留任何给定子单元输出的概率。
    attn_length:注意力向量的大小。
    base_cell:用于子单元的基础 tf.contrib.rnn.RNNCell。
    state_is_tuple:一个布尔值,指定是否使用隐藏矩阵和单元矩阵的元组作为状态,而不是拼接矩阵。
  return:
      一个基于给定超参数的 tf.contrib.rnn.MultiRNNCell。
  """
  cells = []
  for num_units in rnn_layer_sizes:
    cell = base_cell(num_units, state_is_tuple=state_is_tuple, reuse=reuse)
    cell = tf.contrib.rnn.DropoutWrapper(
        cell, output_keep_prob=dropout_keep_prob)
    cells.append(cell)

  cell = tf.contrib.rnn.MultiRNNCell(cells, state_is_tuple=state_is_tuple)
  if attn_length:
    cell = tf.contrib.rnn.AttentionCellWrapper(
        cell, attn_length, state_is_tuple=state_is_tuple, reuse=reuse)

  return cell
def restore_flags(save_if_none_found=True):
  if FLAGS.traindir:
    saved_args_dir = os.path.join(FLAGS.traindir, 'saved_args')
    if save_if_none_found:
      try: os.makedirs(saved_args_dir)
      except: pass
    for arg in FLAGS.__flags:
      if arg not in model_layout_flags:
        continue
      if FLAGS.ignore_saved_args and os.path.exists(os.path.join(saved_args_dir, arg+'.pkl')):
        print('{:%Y-%m-%d %H:%M:%S}: saved_args: Found {} setting from saved state, but using CLI args ({}) and saving (--ignore_saved_args).'.format(datetime.datetime.today(), arg, getattr(FLAGS, arg)))
      elif os.path.exists(os.path.join(saved_args_dir, arg+'.pkl')):
        with open(os.path.join(saved_args_dir, arg+'.pkl'), 'rb') as f:
          setattr(FLAGS, arg, pkl.load(f))
          print('{:%Y-%m-%d %H:%M:%S}: saved_args: {} from saved state ({}), ignoring CLI args.'.format(datetime.datetime.today(), arg, getattr(FLAGS, arg)))
      elif save_if_none_found:
        print('{:%Y-%m-%d %H:%M:%S}: saved_args: Found no {} setting from saved state, using CLI args ({}) and saving.'.format(datetime.datetime.today(), arg, getattr(FLAGS, arg)))
        with open(os.path.join(saved_args_dir, arg+'.pkl'), 'wb') as f:
            print(getattr(FLAGS, arg),arg)
            pkl.dump(getattr(FLAGS, arg), f)
      else:
        print('{:%Y-%m-%d %H:%M:%S}: saved_args: Found no {} setting from saved state, using CLI args ({}) but not saving.'.format(datetime.datetime.today(), arg, getattr(FLAGS, arg)))

# 定义数据类型
def data_type():
  return tf.float16 if FLAGS.float16 else tf.float32
  #return tf.float16


def my_reduce_mean(what_to_take_mean_over):
  return tf.reshape(what_to_take_mean_over, shape=[-1])[0]
  denom = 1.0
  #print(what_to_take_mean_over.get_shape())
  for d in what_to_take_mean_over.get_shape():
    #print(d)
    if type(d) == tf.Dimension:
      denom = denom*d.value
    else:
      denom = denom*d
  return tf.reduce_sum(what_to_take_mean_over)/denom

def linear(inp, output_dim, scope=None, stddev=1.0, reuse_scope=False):
  norm = tf.random_normal_initializer(stddev=stddev, dtype=data_type())
  const = tf.constant_initializer(0.0, dtype=data_type())
  with tf.variable_scope(scope or 'linear') as scope:
    scope.set_regularizer(tf.contrib.layers.l2_regularizer(scale=FLAGS.reg_scale))
    if reuse_scope:
      scope.reuse_variables()
    #print('inp.get_shape(): {}'.format(inp.get_shape()))
    w = tf.get_variable('w', [inp.get_shape()[1], output_dim], initializer=norm, dtype=data_type())
    b = tf.get_variable('b', [output_dim], initializer=const, dtype=data_type())
  return tf.matmul(inp, w) + b

def minibatch(inp, num_kernels=25, kernel_dim=10, scope=None, msg='', reuse_scope=False):
  with tf.variable_scope(scope or 'minibatch_d') as scope:
    scope.set_regularizer(tf.contrib.layers.l2_regularizer(scale=FLAGS.reg_scale))
    if reuse_scope:
      scope.reuse_variables()
  
    inp = tf.Print(inp, [inp],
            '{} inp = '.format(msg), summarize=20, first_n=20)
    x = tf.sigmoid(linear(inp, num_kernels * kernel_dim, scope))
    activation = tf.reshape(x, (-1, num_kernels, kernel_dim))
    activation = tf.Print(activation, [activation],
            '{} activation = '.format(msg), summarize=20, first_n=20)
    diffs = tf.expand_dims(activation, 3) - \
                tf.expand_dims(tf.transpose(activation, [1, 2, 0]), 0)
    diffs = tf.Print(diffs, [diffs],
            '{} diffs = '.format(msg), summarize=20, first_n=20)
    abs_diffs = tf.reduce_sum(tf.abs(diffs), 2)
    abs_diffs = tf.Print(abs_diffs, [abs_diffs],
            '{} abs_diffs = '.format(msg), summarize=20, first_n=20)
    minibatch_features = tf.reduce_sum(tf.exp(-abs_diffs), 2)
    minibatch_features = tf.Print(minibatch_features, [tf.reduce_min(minibatch_features), tf.reduce_max(minibatch_features)],
            '{} minibatch_features (min,max) = '.format(msg), summarize=20, first_n=20)
  return tf.concat( [inp, minibatch_features],1)

class RNNGAN(object):
  """定义RNN-GAN模型."""
  def __init__(self, is_training, num_song_features=None, num_meta_features=None):
    batch_size = FLAGS.batch_size
    self.batch_size =  batch_size
    songlength = FLAGS.songlength
    self.songlength = songlength#self.global_step= tf.Variable(0, trainable=False)

    print('songlength: {}'.format(self.songlength))
    self._input_songdata = tf.placeholder(shape=[batch_size, songlength, num_song_features], dtype=data_type())
    self._input_metadata = tf.placeholder(shape=[batch_size, num_meta_features], dtype=data_type())
    #_split = tf.split(self._input_songdata,songlength,1)[0]
    print("self._input_songdata",self._input_songdata, 'songlength',songlength)
    #print(tf.squeeze(_split,[1]))
    songdata_inputs = [tf.squeeze(input_, [1])
              for input_ in tf.split(self._input_songdata,songlength,1)]
  
    
    with tf.variable_scope('G') as scope:
      scope.set_regularizer(tf.contrib.layers.l2_regularizer(scale=FLAGS.reg_scale))
      #lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(FLAGS.hidden_size_g, forget_bias=1.0, state_is_tuple=True)
      if is_training and FLAGS.keep_prob < 1:
        #lstm_cell = tf.nn.rnn_cell.DropoutWrapper(
        #    lstm_cell, output_keep_prob=FLAGS.keep_prob)
        cell = make_rnn_cell([FLAGS.hidden_size_g]*FLAGS.num_layers_g,dropout_keep_prob=FLAGS.keep_prob)
      else:
         cell = make_rnn_cell([FLAGS.hidden_size_g]*FLAGS.num_layers_g)	  

      #cell = tf.nn.rnn_cell.MultiRNNCell([lstm_cell for _ in range( FLAGS.num_layers_g)], state_is_tuple=True)
      self._initial_state = cell.zero_state(batch_size, data_type())

      # TODO: (possibly temporarily) disabling meta info
      if FLAGS.generate_meta:
        metainputs = tf.random_uniform(shape=[batch_size, int(FLAGS.random_input_scale*num_meta_features)], minval=0.0, maxval=1.0)
        meta_g = tf.nn.relu(linear(metainputs, FLAGS.meta_layer_size, scope='meta_layer', reuse_scope=False))
        meta_softmax_w = tf.get_variable("meta_softmax_w", [FLAGS.meta_layer_size, num_meta_features])
        meta_softmax_b = tf.get_variable("meta_softmax_b", [num_meta_features])
        meta_logits = tf.nn.xw_plus_b(meta_g, meta_softmax_w, meta_softmax_b)
        meta_probs = tf.nn.softmax(meta_logits)

      random_rnninputs = tf.random_uniform(shape=[batch_size, songlength, int(FLAGS.random_input_scale*num_song_features)], minval=0.0, maxval=1.0, dtype=data_type())
      
      random_rnninputs = [tf.squeeze(input_, [1]) for input_ in tf.split( random_rnninputs,songlength,1)]
      
      # REAL GENERATOR:
      state = self._initial_state
      # as we feed the output as the input to the next, we 'invent' the initial 'output'.
      generated_point = tf.random_uniform(shape=[batch_size, num_song_features], minval=0.0, maxval=1.0, dtype=data_type())
      outputs = []
      self._generated_features = []
      for i,input_ in enumerate(random_rnninputs):
        if i > 0: scope.reuse_variables()
        concat_values = [input_]
        if not FLAGS.disable_feed_previous:
          concat_values.append(generated_point)
        if FLAGS.generate_meta:
          concat_values.append(meta_probs)
        if len(concat_values):
          input_ = tf.concat(axis=1, values=concat_values)
        input_ = tf.nn.relu(linear(input_, FLAGS.hidden_size_g,
                            scope='input_layer', reuse_scope=(i!=0)))
        output, state = cell(input_, state)
        outputs.append(output)
        #generated_point = tf.nn.relu(linear(output, num_song_features, scope='output_layer', reuse_scope=(i!=0)))
        generated_point = linear(output, num_song_features, scope='output_layer', reuse_scope=(i!=0))
        self._generated_features.append(generated_point)
      
     
      # PRETRAINING GENERATOR, will feed inputs, not generated outputs:
      scope.reuse_variables()
      # as we feed the output as the input to the next, we 'invent' the initial 'output'.
      prev_target = tf.random_uniform(shape=[batch_size, num_song_features], minval=0.0, maxval=1.0, dtype=data_type())
      outputs = []
      self._generated_features_pretraining = []
      for i,input_ in enumerate(random_rnninputs):
        concat_values = [input_]
        if not FLAGS.disable_feed_previous:
          concat_values.append(prev_target)
        if FLAGS.generate_meta:
          concat_values.append(self._input_metadata)
        if len(concat_values):
          input_ = tf.concat(axis=1, values=concat_values)
        input_ = tf.nn.relu(linear(input_, FLAGS.hidden_size_g, scope='input_layer', reuse_scope=(i!=0)))
        output, state = cell(input_, state)
        outputs.append(output)
        #generated_point = tf.nn.relu(linear(output, num_song_features, scope='output_layer', reuse_scope=(i!=0)))
        generated_point = linear(output, num_song_features, scope='output_layer', reuse_scope=(i!=0))
        self._generated_features_pretraining.append(generated_point)
        prev_target = songdata_inputs[i]
      
      #outputs, state = tf.nn.rnn(cell, transformed, initial_state=self._initial_state)

      #self._generated_features = [tf.nn.relu(linear(output, num_song_features, scope='output_layer', reuse_scope=(i!=0))) for i,output in enumerate(outputs)]

    self._final_state = state

    # These are used both for pretraining and for D/G training further down.
    self._lr = tf.Variable(FLAGS.learning_rate, trainable=False, dtype=data_type())
    self.g_params = [v for v in tf.trainable_variables() if v.name.startswith('model/G/')]
    if FLAGS.adam:
      g_optimizer = tf.train.AdamOptimizer(self._lr)
    else:
      g_optimizer = tf.train.GradientDescentOptimizer(self._lr)
   
    reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
    reg_constant = 0.1  # Choose an appropriate one.
    reg_loss = reg_constant * sum(reg_losses)
    reg_loss = tf.Print(reg_loss, reg_losses,
                  'reg_losses = ', summarize=20, first_n=20)
   
    # 预训练
    print(tf.transpose(tf.stack(self._generated_features_pretraining), perm=[1, 0, 2]).get_shape())
    print(self._input_songdata.get_shape())
    self.rnn_pretraining_loss = tf.reduce_mean(tf.squared_difference(x=tf.transpose(tf.stack(self._generated_features_pretraining), perm=[1, 0, 2]), y=self._input_songdata))
    if not FLAGS.disable_l2_regularizer:
      self.rnn_pretraining_loss = self.rnn_pretraining_loss+reg_loss
    
    
    pretraining_grads, _ = tf.clip_by_global_norm(tf.gradients(self.rnn_pretraining_loss, self.g_params), FLAGS.max_grad_norm)
    self.opt_pretraining = g_optimizer.apply_gradients(zip(pretraining_grads, self.g_params))


    # The discriminator tries to tell the difference between samples from the
    # true data distribution (self.x) and the generated samples (self.z).
    #
    # Here we create two copies of the discriminator network (that share parameters),
    # as you cannot use the same network with different inputs in TensorFlow.
    with tf.variable_scope('D') as scope:
      scope.set_regularizer(tf.contrib.layers.l2_regularizer(scale=FLAGS.reg_scale))
      # Make list of tensors. One per step in recurrence.
      # Each tensor is batchsize*numfeatures.
      # TODO: (possibly temporarily) disabling meta info
      print('self._input_songdata shape {}'.format(self._input_songdata.get_shape()))
      print('generated data shape {}'.format(self._generated_features[0].get_shape()))
      # TODO: (possibly temporarily) disabling meta info
      if FLAGS.generate_meta:
        songdata_inputs = [tf.concat([self._input_metadata, songdata_input],1) for songdata_input in songdata_inputs]
      #print(songdata_inputs[0])
      #print(songdata_inputs[0])
      #print('metadata inputs shape {}'(self._input_metadata.get_shape()))
      #print('generated metadata shape {}'.format(meta_probs.get_shape()))
      self.real_d,self.real_d_features = self.discriminator(songdata_inputs, is_training, msg='real')
      scope.reuse_variables()
      # TODO: (possibly temporarily) disabling meta info
      if FLAGS.generate_meta:
        generated_data = [tf.concat([meta_probs, songdata_input],1) for songdata_input in self._generated_features]
      else:
        generated_data = self._generated_features
      if songdata_inputs[0].get_shape() != generated_data[0].get_shape():
        print('songdata_inputs shape {} != generated data shape {}'.format(songdata_inputs[0].get_shape(), generated_data[0].get_shape()))
      self.generated_d,self.generated_d_features = self.discriminator(generated_data, is_training, msg='generated')

    # Define the loss for discriminator and generator networks (see the original
    # paper for details), and create optimizers for both
    self.d_loss = tf.reduce_mean(-tf.log(tf.clip_by_value(self.real_d, 1e-1000000, 1.0)) \
                                 -tf.log(1 - tf.clip_by_value(self.generated_d, 0.0, 1.0-1e-1000000)))
    self.g_loss_feature_matching = tf.reduce_sum(tf.squared_difference(self.real_d_features, self.generated_d_features))
    self.g_loss = tf.reduce_mean(-tf.log(tf.clip_by_value(self.generated_d, 1e-1000000, 1.0)))

    if not FLAGS.disable_l2_regularizer:
      self.d_loss = self.d_loss+reg_loss
      self.g_loss_feature_matching = self.g_loss_feature_matching+reg_loss
      self.g_loss = self.g_loss+reg_loss
    self.d_params = [v for v in tf.trainable_variables() if v.name.startswith('model/D/')]

    if not is_training:
      return

    d_optimizer = tf.train.GradientDescentOptimizer(self._lr*FLAGS.d_lr_factor)
    d_grads, _ = tf.clip_by_global_norm(tf.gradients(self.d_loss, self.d_params),
                                        FLAGS.max_grad_norm)
    self.opt_d = d_optimizer.apply_gradients(zip(d_grads, self.d_params))
    if FLAGS.feature_matching:
      g_grads, _ = tf.clip_by_global_norm(tf.gradients(self.g_loss_feature_matching,
                                                       self.g_params),
                                        FLAGS.max_grad_norm)
    else:
      g_grads, _ = tf.clip_by_global_norm(tf.gradients(self.g_loss, self.g_params),
                                        FLAGS.max_grad_norm)
    self.opt_g = g_optimizer.apply_gradients(zip(g_grads, self.g_params))

    self._new_lr = tf.placeholder(shape=[], name="new_learning_rate", dtype=data_type())
    self._lr_update = tf.assign(self._lr, self._new_lr)

  def discriminator(self, inputs, is_training, msg=''):
    # RNN discriminator:
    #for i in xrange(len(inputs)):
    #  print('shape inputs[{}] {}'.format(i, inputs[i].get_shape()))
    #inputs[0] = tf.Print(inputs[0], [inputs[0]],
    #        '{} inputs[0] = '.format(msg), summarize=20, first_n=20)
    if is_training and FLAGS.keep_prob < 1:
      inputs = [tf.nn.dropout(input_, FLAGS.keep_prob) for input_ in inputs]
    
    #lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(FLAGS.hidden_size_d, forget_bias=1.0, state_is_tuple=True)
    if is_training and FLAGS.keep_prob < 1:
      #lstm_cell = tf.nn.rnn_cell.DropoutWrapper(
      #lstm_cell, output_keep_prob=FLAGS.keep_prob)
      cell_fw = make_rnn_cell([FLAGS.hidden_size_d]* FLAGS.num_layers_d,dropout_keep_prob=FLAGS.keep_prob)
      
      cell_bw = make_rnn_cell([FLAGS.hidden_size_d]* FLAGS.num_layers_d,dropout_keep_prob=FLAGS.keep_prob)
    else:
      cell_fw = make_rnn_cell([FLAGS.hidden_size_d]* FLAGS.num_layers_d)
      
      cell_bw = make_rnn_cell([FLAGS.hidden_size_d]* FLAGS.num_layers_d)
    #cell_fw = tf.nn.rnn_cell.MultiRNNCell([lstm_cell for _ in range( FLAGS.num_layers_d)], state_is_tuple=True)
    self._initial_state_fw = cell_fw.zero_state(self.batch_size, data_type())
    if not FLAGS.unidirectional_d:
      #lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(FLAGS.hidden_size_g, forget_bias=1.0, state_is_tuple=True)
      #if is_training and FLAGS.keep_prob < 1:
      #  lstm_cell = tf.nn.rnn_cell.DropoutWrapper(
      #      lstm_cell, output_keep_prob=FLAGS.keep_prob)
      #cell_bw = tf.nn.rnn_cell.MultiRNNCell([lstm_cell for _ in range( FLAGS.num_layers_d)], state_is_tuple=True)
      self._initial_state_bw = cell_bw.zero_state(self.batch_size, data_type())
      print("cell_fw",cell_fw.output_size)
      #print("cell_bw",cell_bw.output_size)
      #print("inputs",inputs)
      #print("initial_state_fw",self._initial_state_fw)
      #print("initial_state_bw",self._initial_state_bw)
      outputs, state_fw, state_bw = tf.contrib.rnn.static_bidirectional_rnn(cell_fw, cell_bw, inputs, initial_state_fw=self._initial_state_fw, initial_state_bw=self._initial_state_bw)
      #outputs[0] = tf.Print(outputs[0], [outputs[0]],
      #        '{} outputs[0] = '.format(msg), summarize=20, first_n=20)
      #state = tf.concat(state_fw, state_bw)
      #endoutput = tf.concat(concat_dim=1, values=[outputs[0],outputs[-1]])
    else:
      outputs, state = tf.nn.rnn(cell_fw, inputs, initial_state=self._initial_state_fw)
      #state = self._initial_state
	  
      #outputs, state = cell_fw(tf.convert_to_tensor (inputs),state)
      #endoutput = outputs[-1]

    if FLAGS.minibatch_d:
      outputs = [minibatch(tf.reshape(outp, shape=[FLAGS.batch_size, -1]), msg=msg, reuse_scope=(i!=0)) for i,outp in enumerate(outputs)]
    # decision = tf.sigmoid(linear(outputs[-1], 1, 'decision'))
    if FLAGS.end_classification:
      decisions = [tf.sigmoid(linear(output, 1, 'decision', reuse_scope=(i!=0))) for i,output in enumerate([outputs[0], outputs[-1]])]
      decisions = tf.stack(decisions)
      decisions = tf.transpose(decisions, perm=[1,0,2])
      print('shape, decisions: {}'.format(decisions.get_shape()))
    else:
      decisions = [tf.sigmoid(linear(output, 1, 'decision', reuse_scope=(i!=0))) for i,output in enumerate(outputs)]
      decisions = tf.stack(decisions)
      decisions = tf.transpose(decisions, perm=[1,0,2])
      print('shape, decisions: {}'.format(decisions.get_shape()))
    decision = tf.reduce_mean(decisions, reduction_indices=[1,2])
    decision = tf.Print(decision, [decision],
            '{} decision = '.format(msg), summarize=20, first_n=20)
    return (decision,tf.transpose(tf.stack(outputs), perm=[1,0,2]))
      

  
  def assign_lr(self, session, lr_value):
    session.run(self._lr_update, feed_dict={self._new_lr: lr_value})

  @property
  def generated_features(self):
    return self._generated_features

  @property
  def input_songdata(self):
    return self._input_songdata

  @property
  def input_metadata(self):
    return self._input_metadata

  @property
  def targets(self):
    return self._targets

  @property
  def initial_state(self):
    return self._initial_state

  @property
  def cost(self):
    return self._cost

  @property
  def final_state(self):
    return self._final_state

  @property
  def lr(self):
    return self._lr

  @property
  def train_op(self):
    return self._train_op



def run_epoch(session, model, loader, datasetlabel, eval_op_g, eval_op_d, pretraining=False, verbose=False, run_metadata=None, pretraining_d=False):
  """Runs the model on the given data."""
  #epoch_size = ((len(data) // model.batch_size) - 1) // model.songlength
  epoch_start_time = time.time()
  g_loss, d_loss = 10.0, 10.0
  g_losses, d_losses = 0.0, 0.0
  iters = 0
  #state = session.run(model.initial_state)
  time_before_graph = None
  time_after_graph = None
  times_in_graph = []
  times_in_python = []
  #times_in_batchreading = []
  loader.rewind(part=datasetlabel)
  [batch_meta, batch_song] = loader.get_batch(model.batch_size, model.songlength, part=datasetlabel)

  run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)

  while batch_meta is not None and batch_song is not None:
    op_g = eval_op_g
    op_d = eval_op_d
    if datasetlabel == 'train' and not pretraining: # and not FLAGS.feature_matching:
      if d_loss == 0.0 and g_loss == 0.0:
        print('Both G and D train loss are zero. Exiting.')
        break
        #saver.save(session, checkpoint_path, global_step=m.global_step)
        #break
      elif d_loss == 0.0:
        #print('D train loss is zero. Freezing optimization. G loss: {:.3f}'.format(g_loss))
        op_g = tf.no_op()
      elif g_loss == 0.0: 
        #print('G train loss is zero. Freezing optimization. D loss: {:.3f}'.format(d_loss))
        op_d = tf.no_op()
      elif g_loss < 2.0 or d_loss < 2.0:
        if g_loss*.7 > d_loss:
          #print('G train loss is {:.3f}, D train loss is {:.3f}. Freezing optimization of D'.format(g_loss, d_loss))
          op_g = tf.no_op()
        #elif d_loss*.7 > g_loss:
          #print('G train loss is {:.3f}, D train loss is {:.3f}. Freezing optimization of G'.format(g_loss, d_loss))
        op_d = tf.no_op()
    #fetches = [model.cost, model.final_state, eval_op]
    if pretraining:
      if pretraining_d:
        fetches = [model.rnn_pretraining_loss, model.d_loss, op_g, op_d]
      else:
        fetches = [model.rnn_pretraining_loss, tf.no_op(), op_g, op_d]
    else:
      fetches = [model.g_loss, model.d_loss, op_g, op_d]
    feed_dict = {}
    feed_dict[model.input_songdata.name] = batch_song
    feed_dict[model.input_metadata.name] = batch_meta
    #print(batch_song)
    #print(batch_song.shape)
    
    #for i, (c, h) in enumerate(model.initial_state):
    #  feed_dict[c] = state[i].c
    #  feed_dict[h] = state[i].h
    #cost, state, _ = session.run(fetches, feed_dict)
    time_before_graph = time.time()
    if iters > 0:
      times_in_python.append(time_before_graph-time_after_graph)
    if run_metadata:
      g_loss, d_loss, _, _ = session.run(fetches, feed_dict, options=run_options, run_metadata=run_metadata)
    else:
      g_loss, d_loss, _, _ = session.run(fetches, feed_dict)
    time_after_graph = time.time()
    if iters > 0:
      times_in_graph.append(time_after_graph-time_before_graph)
    g_losses += g_loss
    if not pretraining:
      d_losses += d_loss
    iters += 1

    if verbose and iters % 10 == 9:
      songs_per_sec = float(iters * model.batch_size)/float(time.time() - epoch_start_time)
      avg_time_in_graph = float(sum(times_in_graph))/float(len(times_in_graph))
      avg_time_in_python = float(sum(times_in_python))/float(len(times_in_python))
      #avg_time_batchreading = float(sum(times_in_batchreading))/float(len(times_in_batchreading))
      if pretraining:
        print("{}: {} (pretraining) batch loss: G: {:.3f}, avg loss: G: {:.3f}, speed: {:.1f} songs/s, avg in graph: {:.1f}, avg in python: {:.1f}.".format(datasetlabel, iters, g_loss, float(g_losses)/float(iters), songs_per_sec, avg_time_in_graph, avg_time_in_python))
      else:
        print("{}: {} batch loss: G: {:.3f}, D: {:.3f}, avg loss: G: {:.3f}, D: {:.3f} speed: {:.1f} songs/s, avg in graph: {:.1f}, avg in python: {:.1f}.".format(datasetlabel, iters, g_loss, d_loss, float(g_losses)/float(iters), float(d_losses)/float(iters),songs_per_sec, avg_time_in_graph, avg_time_in_python))
    #batchtime = time.time()
    [batch_meta, batch_song] = loader.get_batch(model.batch_size, model.songlength, part=datasetlabel)
    #times_in_batchreading.append(time.time()-batchtime)

  if iters == 0:
    return (None,None)

  g_mean_loss = g_losses/iters
  if pretraining and not pretraining_d:
    d_mean_loss = None
  else:
    d_mean_loss = d_losses/iters
  return (g_mean_loss, d_mean_loss)


def sample(session, model, batch=False):
  """Samples from the generative model."""
  #state = session.run(model.initial_state)
  fetches = [model.generated_features]
  feed_dict = {}
  generated_features, = session.run(fetches, feed_dict)
  #print( generated_features)
  print( generated_features[0].shape)
  # The following worked when batch_size=1.
  # generated_features = [np.squeeze(x, axis=0) for x in generated_features]
  # If batch_size != 1, we just pick the first sample. Wastefull, yes.
  returnable = []
  if batch:
    for batchno in range(generated_features[0].shape[0]):
      returnable.append([x[batchno,:] for x in generated_features])
  else:
    returnable = [x[0,:] for x in generated_features]
  return returnable

def main(_):
  if not FLAGS.datadir:
    raise ValueError("Must set --datadir to midi music dir.")
  if not FLAGS.traindir:
    raise ValueError("Must set --traindir to dir where I can save model and plots.")
 
  restore_flags()
 
  summaries_dir = None
  plots_dir = None
  generated_data_dir = None
  summaries_dir = os.path.join(FLAGS.traindir, 'summaries')
  plots_dir = os.path.join(FLAGS.traindir, 'plots')
  generated_data_dir = os.path.join(FLAGS.traindir, 'generated_data')
  try: os.makedirs(FLAGS.traindir)
  except: pass
  try: os.makedirs(summaries_dir)
  except: pass
  try: os.makedirs(plots_dir)
  except: pass
  try: os.makedirs(generated_data_dir)
  except: pass
  directorynames = FLAGS.traindir.split('/')
  experiment_label = ''
  while not experiment_label:
    experiment_label = directorynames.pop()
  
  global_step = -1
  if os.path.exists(os.path.join(FLAGS.traindir, 'global_step.pkl')):
    with open(os.path.join(FLAGS.traindir, 'global_step.pkl'), 'r') as f:
      global_step = pkl.load(f)
  global_step += 1

  songfeatures_filename = os.path.join(FLAGS.traindir, 'num_song_features.pkl')
  metafeatures_filename = os.path.join(FLAGS.traindir, 'num_meta_features.pkl')
  synthetic=None
  if FLAGS.synthetic_chords:
    synthetic = 'chords'
    print('Training on synthetic chords!')
  if FLAGS.composer is not None:
    print('Single composer: {}'.format(FLAGS.composer))
  loader = music_data_utils.MusicDataLoader(FLAGS.datadir, FLAGS.select_validation_percentage, FLAGS.select_test_percentage, FLAGS.works_per_composer, FLAGS.pace_events, synthetic=synthetic, tones_per_cell=FLAGS.tones_per_cell, single_composer=FLAGS.composer)
  if FLAGS.synthetic_chords:
    # This is just a print out, to check the generated data.
    batch = loader.get_batch(batchsize=1, songlength=400)
    loader.get_midi_pattern([batch[1][0][i] for i in xrange(batch[1].shape[1])])

  num_song_features = loader.get_num_song_features()
  print('num_song_features:{}'.format(num_song_features))
  num_meta_features = loader.get_num_meta_features()
  print('num_meta_features:{}'.format(num_meta_features))

  train_start_time = time.time()
  checkpoint_path = os.path.join(FLAGS.traindir, "model.ckpt")

  songlength_ceiling = FLAGS.songlength

  if global_step < FLAGS.pretraining_epochs:
    FLAGS.songlength = int(min(((global_step+10)/10)*10,songlength_ceiling))
    FLAGS.songlength = int(min((global_step+1)*4,songlength_ceiling))
 
  with tf.Graph().as_default(), tf.Session(config=tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)) as session:
    with tf.variable_scope("model", reuse=None) as scope:
      scope.set_regularizer(tf.contrib.layers.l2_regularizer(scale=FLAGS.reg_scale))
      m = RNNGAN(is_training=True, num_song_features=num_song_features, num_meta_features=num_meta_features)


    if FLAGS.initialize_d:
      vars_to_restore = {}
      for v in tf.trainable_variables():
        if v.name.startswith('model/G/'):
          print(v.name[:-2])
          vars_to_restore[v.name[:-2]] = v
      saver = tf.train.Saver(vars_to_restore)
      ckpt = tf.train.get_checkpoint_state(FLAGS.traindir)
      if ckpt and tf.gfile.Exists(ckpt.model_checkpoint_path):
        print("Reading model parameters from %s" % ckpt.model_checkpoint_path,end=" ")
        saver.restore(session, ckpt.model_checkpoint_path)
        session.run(tf.initialize_variables([v for v in tf.trainable_variables() if v.name.startswith('model/D/')]))
      else:
        print("Created model with fresh parameters.")
        session.run(tf.initialize_all_variables())
      saver = tf.train.Saver(tf.all_variables())
    else:
      saver = tf.train.Saver(tf.all_variables())
      ckpt = tf.train.get_checkpoint_state(FLAGS.traindir)
      if ckpt and tf.gfile.Exists(ckpt.model_checkpoint_path):
        print("Reading model parameters from %s" % ckpt.model_checkpoint_path)
        saver.restore(session, ckpt.model_checkpoint_path)
      else:
        print("Created model with fresh parameters.")
        session.run(tf.initialize_all_variables())

    run_metadata = None
    if FLAGS.profiling:
      run_metadata = tf.RunMetadata()
    if not FLAGS.sample:
      train_g_loss,train_d_loss = 1.0,1.0
      for i in range(global_step, FLAGS.max_epoch):
        lr_decay = FLAGS.lr_decay ** max(i - FLAGS.epochs_before_decay, 0.0)

        if global_step < FLAGS.pretraining_epochs:
          #new_songlength = int(min(((i+10)/10)*10,songlength_ceiling))
          new_songlength = int(min((i+1)*4,songlength_ceiling))
        else:
          new_songlength = songlength_ceiling
        if new_songlength != FLAGS.songlength:
          print('Changing songlength, now training on {} events from songs.'.format(new_songlength))
          FLAGS.songlength = new_songlength
          with tf.variable_scope("model", reuse=True) as scope:
            scope.set_regularizer(tf.contrib.layers.l2_regularizer(scale=FLAGS.reg_scale))
            m = RNNGAN(is_training=True, num_song_features=num_song_features, num_meta_features=num_meta_features)

        if not FLAGS.adam:
          m.assign_lr(session, FLAGS.learning_rate * lr_decay)

        save = False
        do_exit = False

        print("Epoch: {} Learning rate: {:.3f}, pretraining: {}".format(i, session.run(m.lr), (i<FLAGS.pretraining_epochs)))
        if i<FLAGS.pretraining_epochs:
          opt_d = tf.no_op()
          if FLAGS.pretraining_d:
            opt_d = m.opt_d
          train_g_loss,train_d_loss = run_epoch(session, m, loader, 'train', m.opt_pretraining, opt_d, pretraining = True, verbose=True, run_metadata=run_metadata, pretraining_d=FLAGS.pretraining_d)
          if FLAGS.pretraining_d:
            try:
              print("Epoch: {} Pretraining loss: G: {:.3f}, D: {:.3f}".format(i, train_g_loss, train_d_loss))
            except:
              print(train_g_loss)
              print(train_d_loss)
          else:
            print("Epoch: {} Pretraining loss: G: {:.3f}".format(i, train_g_loss))
        else:
          train_g_loss,train_d_loss = run_epoch(session, m, loader, 'train', m.opt_d, m.opt_g, verbose=True, run_metadata=run_metadata)
          try:
            print("Epoch: {} Train loss: G: {:.3f}, D: {:.3f}".format(i, train_g_loss, train_d_loss))
          except:
            print("Epoch: {} Train loss: G: {}, D: {}".format(i, train_g_loss, train_d_loss))
        valid_g_loss,valid_d_loss = run_epoch(session, m, loader, 'validation', tf.no_op(), tf.no_op())
        try:
          print("Epoch: {} Valid loss: G: {:.3f}, D: {:.3f}".format(i, valid_g_loss, valid_d_loss))
        except:
          print("Epoch: {} Valid loss: G: {}, D: {}".format(i, valid_g_loss, valid_d_loss))
        
        if train_d_loss == 0.0 and train_g_loss == 0.0:
          print('Both G and D train loss are zero. Exiting.')
          save = True
          do_exit = True
        if i % FLAGS.epochs_per_checkpoint == 0:
          save = True
        if FLAGS.exit_after > 0 and time.time() - train_start_time > FLAGS.exit_after*60:
          print("%s: Has been running for %d seconds. Will exit (exiting after %d minutes)."%(datetime.datetime.today().strftime('%Y-%m-%d %H:%M:%S'), (int)(time.time() - train_start_time), FLAGS.exit_after))
          save = True
          do_exit = True

        if save:
          saver.save(session, checkpoint_path, global_step=i)
          with open(os.path.join(FLAGS.traindir, 'global_step.pkl'), 'wb') as f:
            pkl.dump(i, f)
          if FLAGS.profiling:
            # Create the Timeline object, and write it to a json
            tl = timeline.Timeline(run_metadata.step_stats)
            ctf = tl.generate_chrome_trace_format()
            with open(os.path.join(plots_dir, 'timeline.json'), 'w') as f:
              f.write(ctf)
          print('{}: Saving done!'.format(i))

        step_time, loss = 0.0, 0.0
        if train_d_loss is None: #pretraining
          train_d_loss = 0.0
          valid_d_loss = 0.0
          valid_g_loss = 0.0
        if not os.path.exists(os.path.join(plots_dir, 'gnuplot-input.txt')):
          with open(os.path.join(plots_dir, 'gnuplot-input.txt'), 'w') as f:
            f.write('# global-step learning-rate train-g-loss train-d-loss valid-g-loss valid-d-loss\n')
        with open(os.path.join(plots_dir, 'gnuplot-input.txt'), 'a') as f:
          try:
            f.write('{} {:.4f} {:.2f} {:.2f} {:.3} {:.3f}\n'.format(i, m.lr.eval(), train_g_loss, train_d_loss, valid_g_loss, valid_d_loss))
          except:
            f.write('{} {} {} {} {} {}\n'.format(i, m.lr.eval(), train_g_loss, train_d_loss, valid_g_loss, valid_d_loss))
        if not os.path.exists(os.path.join(plots_dir, 'gnuplot-commands-loss.txt')):
          with open(os.path.join(plots_dir, 'gnuplot-commands-loss.txt'), 'a') as f:
            f.write('set terminal postscript eps color butt "Times" 14\nset yrange [0:400]\nset output "loss.eps"\nplot \'gnuplot-input.txt\' using ($1):($3) title \'train G\' with linespoints, \'gnuplot-input.txt\' using ($1):($4) title \'train D\' with linespoints, \'gnuplot-input.txt\' using ($1):($5) title \'valid G\' with linespoints, \'gnuplot-input.txt\' using ($1):($6) title \'valid D\' with linespoints, \n')
        if not os.path.exists(os.path.join(plots_dir, 'gnuplot-commands-midistats.txt')):
          with open(os.path.join(plots_dir, 'gnuplot-commands-midistats.txt'), 'a') as f:
            f.write('set terminal postscript eps color butt "Times" 14\nset yrange [0:127]\nset xrange [0:70]\nset output "midistats.eps"\nplot \'midi_stats.gnuplot\' using ($1):(100*$3) title \'Scale consistency, %\' with linespoints, \'midi_stats.gnuplot\' using ($1):($6) title \'Tone span, halftones\' with linespoints, \'midi_stats.gnuplot\' using ($1):($10) title \'Unique tones\' with linespoints, \'midi_stats.gnuplot\' using ($1):($23) title \'Intensity span, units\' with linespoints, \'midi_stats.gnuplot\' using ($1):(100*$24) title \'Polyphony, %\' with linespoints, \'midi_stats.gnuplot\' using ($1):($12) title \'3-tone repetitions\' with linespoints\n')
        try:
          Popen(['gnuplot','gnuplot-commands-loss.txt'], cwd=plots_dir)
          Popen(['gnuplot','gnuplot-commands-midistats.txt'], cwd=plots_dir)
        except:
          print('failed to run gnuplot. Please do so yourself: gnuplot gnuplot-commands.txt cwd={}'.format(plots_dir))
        
        song_data = sample(session, m, batch=True)
        midi_patterns = []
        print('formatting midi...')
        midi_time = time.time()
        for d in song_data:
          midi_patterns.append(loader.get_midi_pattern(d))
        print('done. time: {}'.format(time.time()-midi_time))
        
        filename = os.path.join(generated_data_dir, 'out-{}-{}-{}.mid'.format(experiment_label, i, datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S')))
        loader.save_midi_pattern(filename, midi_patterns[0])
  
        stats = []
        print('getting stats...')
        stats_time = time.time()
        for p in midi_patterns:
          stats.append(get_all_stats(p))
        print('done. time: {}'.format(time.time()-stats_time))
        #print(stats)
        stats = [stat for stat in stats if stat is not None]
        if len(stats):
          stats_keys_string = ['scale']
          stats_keys = ['scale_score', 'tone_min', 'tone_max', 'tone_span', 'freq_min', 'freq_max', 'freq_span', 'tones_unique', 'repetitions_2', 'repetitions_3', 'repetitions_4', 'repetitions_5', 'repetitions_6', 'repetitions_7', 'repetitions_8', 'repetitions_9', 'estimated_beat', 'estimated_beat_avg_ticks_off', 'intensity_min', 'intensity_max', 'intensity_span', 'polyphony_score', 'top_2_interval_difference', 'top_3_interval_difference', 'num_tones']
          statsfilename = os.path.join(plots_dir, 'midi_stats.gnuplot')
          if not os.path.exists(statsfilename):
            with open(statsfilename, 'a') as f:
              f.write('# Average numers over one minibatch of size {}.\n'.format(FLAGS.batch_size))
              f.write('# global-step {} {}\n'.format(' '.join([s.replace(' ', '_') for s in stats_keys_string]), ' '.join(stats_keys)))
          with open(statsfilename, 'a') as f:
            f.write('{} {} {}\n'.format(i, ' '.join(['{}'.format(stats[0][key].replace(' ', '_')) for key in stats_keys_string]), ' '.join(['{:.3f}'.format(sum([s[key] for s in stats])/float(len(stats))) for key in stats_keys])))
          print('Saved {}.'.format(filename))
          
        if do_exit:
          if FLAGS.call_after is not None:
            print("%s: Will call \"%s\" before exiting."%(datetime.datetime.today().strftime('%Y-%m-%d %H:%M:%S'), FLAGS.call_after))
            res = call(FLAGS.call_after.split(" "))
            print ('{}: call returned {}.'.format(datetime.datetime.today().strftime('%Y-%m-%d %H:%M:%S'), res))
          exit()
        sys.stdout.flush()


      test_g_loss,test_d_loss = run_epoch(session, m, loader, 'test', tf.no_op(), tf.no_op())
      print("Test loss G: %.3f, D: %.3f" %(test_g_loss, test_d_loss))

    song_data = sample(session, m)
    filename = os.path.join(generated_data_dir, 'out-{}-{}-{}.mid'.format(experiment_label, i, datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S')))
    loader.save_data(filename, song_data)
    print('Saved {}.'.format(filename))



if __name__ == "__main__":
  tf.app.run()

结论

作者提出了一种基于生成对抗网络训练的连续数据循环神经网络C-RNN-GAN。实验结果表明对抗训练有助于模型学习更多变的模式。虽然生成音乐与训练数据中的音乐相比仍有差距,但C-RNN-GAN生成音乐更接近真实音乐。

缺点以及后续展望

模型虽能生成音乐,但与人类判断的音乐仍有差距,后续可深入探究生成音乐与真实音乐存在差距的原因。作者提出可以进一步优化模型结构,提高生成音乐的质量。此外,还可研究该模型在其他类型连续序列数据中的应用。

总结

本周我阅读了一篇关于GAN生成序列数据的论文,为下一次阅读TimeGAN论文打作铺垫。通过阅读这篇论文,我了解到C-RNN-GAN模型如何利用对抗训练来生成连续序列数据,其中,生成器(G)包含LSTM层和全连接层;判别器(D)由Bi-LSTM(双向长短期记忆网络)组成。即 D双向的,G是单向的。同时,作者也通过实验证明了C-RNN-GAN的优势,虽然模型在序列数据生成方面有一定的效果,但仍存在一些不足之处,如生成序列数据与真实序列数据之间任然存在差距、模型结构尚可优化、应用到其他场景等等。作者提出的这些不足与展望为我后续研究数据增强方向提供了参考和思路。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2285066.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

设置jmeter界面图标字体大小

设置jmeter界面图标字体大小 方法&#xff1a;点击“选项” -> 点击放大、缩小。&#xff08;可进行全局的菜单、左侧目录结构树、元件界面显示等字体图标的放大、缩小。&#xff09;

使用 MSYS2 qemu 尝鲜Arm64架构国产Linux系统

近期&#xff0c;我的师弟咨询我关于Arm64架构的国产CPU国产OS开发工具链问题。他们公司因为接手了一个国企的单子&#xff0c;需要在这类环境下开发程序。说实在的我也没有用过这个平台&#xff0c;但是基于常识&#xff0c;推测只要基于C和Qt&#xff0c;应该问题不大。 1. …

RocketMQ实战—1.订单系统面临的技术挑战

大纲 1.一个订单系统的整体架构、业务流程及负载情况 2.订单系统面临的技术问题一&#xff1a;下订单的同时还要发券、发红包、Push推送等导致性能太差 3.订单系统面临的技术问题二&#xff1a;订单退款时经常流程失败导致无法完成退款 4.订单系统面临的技术问题三&#xf…

Linux学习笔记——用户管理

一、用户管理命令 useradd #用户增加命令 usermod #用户修改命令 passwd #密码修改命令 userdel #用户删除命令 su #用户提权命令 1、useradd命令&#xff08;加用户&#xff09;&#xff1a; 创建并设置用户信息&#xff0c;使用us…

【AI】【本地部署】OpenWebUI的升级并移植旧有用户信息

【背景】 OpenWebUI的版本升级频率很高&#xff0c;并会修改旧版本的Bug&#xff0c;不过对于已经在使用的系统&#xff0c;升级后现有用户信息都会丢失&#xff0c;于是研究如何在升级后将现有的用户信息移植到升级后版本。 【准备工作】 OpenWebUI的升级步骤在Docker中有现…

PyCharm接入DeepSeek实现AI编程

目录 效果演示 创建API key 在PyCharm中下载CodeGPT插件 配置Continue DeepSeek 是一家专注于人工智能技术研发的公司&#xff0c;致力于开发高性能、低成本的 AI 模型。DeepSeek-V3 是 DeepSeek 公司推出的最新一代 AI 模型。其前身是 DeepSeek-V2.5&#xff0c;经过持续的…

21款炫酷烟花合集

系列专栏 《Python趣味编程》《C/C趣味编程》《HTML趣味编程》《Java趣味编程》 写在前面 Python、C/C、HTML、Java等4种语言实现18款炫酷烟花的代码。 Python Python烟花① 完整代码&#xff1a;Python动漫烟花&#xff08;完整代码&#xff09; ​ Python烟花② 完整…

zyNo.15(Web题型总结1)

web 一、工具使用 1.sqlmap使用 在目录页输入cmd就可以打开程序 使用方法查看输入python sqlmap.py --help 二、web攻防知识体系 新手村 WEB CTF入门 md5绕过、变量覆盖、随机数问题 sql注入 MySQL注入介绍与联合…

将 OneLake 数据索引到 Elasticsearch - 第 1 部分

作者&#xff1a;来自 Elastic Gustavo Llermaly 学习配置 OneLake&#xff0c;使用 Python 消费数据并在 Elasticsearch 中索引文档&#xff0c;然后运行语义搜索。 OneLake 是一款工具&#xff0c;可让你连接到不同的 Microsoft 数据源&#xff0c;例如 Power BI、Data Activ…

C++11中array容器的常见用法

文章目录 一、概述二、std::array的特点三、std::array的定义与初始化三、std::array的常用成员函数四、与 C 风格数组的互操作 一、概述 在 C11 中&#xff0c;std::array 是一个新的容器类型&#xff0c;它提供了一个固定大小的数组封装。相比传统的 C 风格数组&#xff0c;…

澳洲硕士毕业论文写作中如何把握主题

每到毕业季时&#xff0c;澳洲硕士毕业论文写作是留学生学业的头等大事。但是经常有留学生在澳洲毕业论文写作过程中会遇到写了一半&#xff0c;但是不知道应该如何继续下去的问题。有时候是在literature review的部分就越写越觉得偏离了方向&#xff0c;有时候是在数据收集阶段…

在Windows系统中本地部署属于自己的大语言模型(Ollama + open-webui + deepseek-r1)

文章目录 1 在Windows系统中安装Ollama&#xff0c;并成功启动&#xff1b;2 非docker方式安装open-webui3下载并部署模型deepseek-r1 Ollama Ollama 是一个命令行工具&#xff0c;用于管理和运行机器学习模型。它简化了模型的下载与部署&#xff0c;支持跨平台使用&#xff0c…

DeepSeek辅助学术写作摘要内容

学术摘要写作 摘要是文章的精华&#xff0c;通常在200-250词左右。要包括研究的目的、方法、结果和结论。让AI工具作为某领域内资深的研究专家&#xff0c;编写摘要需要言简意赅&#xff0c;直接概括论文的核心&#xff0c;为读者提供快速了解的窗口。 下面我们使用DeepSeek编…

网络工程师 (5)系统可靠性

前言 系统可靠性是指系统在规定的条件和规定的时间内&#xff0c;完成规定功能的能力。这种能力不仅涵盖了系统本身的稳定性和耐久性&#xff0c;还涉及了系统在面对各种内外部干扰和故障时的恢复能力和容错性。系统可靠性是评价一个系统性能优劣的关键指标之一&#xff0c;对于…

RoboVLM——通用机器人策略的VLA设计哲学:如何选择骨干网络、如何构建VLA架构、何时添加跨本体数据

前言 本博客内解读不少VLA模型了&#xff0c;包括π0等&#xff0c;且如此文的开头所说 前两天又重点看了下openvla&#xff0c;和cogact&#xff0c;发现 目前cogACT把openvla的动作预测换成了dit&#xff0c;在模型架构层面上&#xff0c;逼近了π0​那为了进一步逼近&#…

MySQL--》深度解析InnoDB引擎的存储与事务机制

目录 InnoDB架构 事务原理 MVCC InnoDB架构 从MySQL5.5版本开始默认使用InnoDB存储引擎&#xff0c;它擅长进行事务处理&#xff0c;具有崩溃恢复的特性&#xff0c;在日常开发中使用非常广泛&#xff0c;其逻辑存储结构图如下所示&#xff0c; 下面是InnoDB架构图&#xf…

SpringCloudAlibaba 服务保护 Sentinel 项目集成实践

目录 一、简介1.1、服务保护的基本概念1.1.1、服务限流/熔断1.1.2、服务降级1.1.3、服务的雪崩效应1.1.4、服务的隔离的机制 1.2、Sentinel的主要特性1.3、Sentinel整体架构1.4、Sentinel 与 Hystrix 对比 二、Sentinel控制台部署3.1、版本选择和适配3.2、本文使用各组件版本3.…

STM32 GPIO配置 点亮LED灯

本次是基于STM32F407ZET6做一个GPIO配置&#xff0c;实现点灯实验。 新建文件 LED.c、LED.h文件&#xff0c;将其封装到Driver文件中。 双击Driver文件将LED.c添加进来 编写头文件&#xff0c;这里注意需要将Driver头文件声明一下。 在LED.c、main.c里面引入头文件LED.h LED初…

MFC结构体数据文件读写实例

程序功能将结构体内数组数据写入文件和读出 2Dlg.h中代码: typedef struct Student {int nNum[1000];float fScore;CString sss;}stu; class CMy2Dlg : public CDialog { // Construction public:CMy2Dlg(CWnd* pParent NULL); // standard constructorstu stu1; ... } 2Dl…

jemalloc 5.3.0的tsd模块的源码分析

一、背景 在主流的内存库里&#xff0c;jemalloc作为android 5.0-android 10.0的默认分配器肯定占用了非常重要的一席之地。jemalloc的低版本和高版本之间的差异特别大&#xff0c;低版本的诸多网上整理的总结&#xff0c;无论是在概念上和还是在结构体命名上在新版本中很多都…