使用阿里云的PAI-DSW进行语音关键词的训练

news2025/1/13 13:11:59

以前使用谷歌的Colab进行过在线的模型训练,不过要科学上网总是比较麻烦,今天第一次尝试使用阿里云的人工智能平台PAI/交互式建模(DSW)进行在线训练。

我采用的训练笔记本是TensorFlow的Simple audio recognition: Recognizing keywordssimple_audio_pi/simple_audio_train_numpy.ipynb at main · mkvenkit/simple_audio_pi (github.com)

 将脚本上传后,直接打开,就可以看到笔记本了。

笔记本的操作和其他平台差不多,就不详细介绍了。

脚本里面做了一些小修改,因为阿里云平台下载谷歌的mini_speech_commands特别慢,所以就把下载的步骤跳过去了。上传了我们以前下载的版本,在脚本里面添加了路径:

data_dir = pathlib.Path("eiq-model-zoo-main/tasks/audio/command-recognition/micro-speech-LSTM/data/mini_speech_commands")

另外,我们对数据集进行了调整,只保留了yes和no,其他的语音文件都归于unknown。

数据量不大,训练只用了43秒就完成了。

程序的执行结果附后。

#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Simple audio recognition: Recognizing keywords

​编辑 View on TensorFlow.org​编辑 Run in Google Colab​编辑 View source on GitHub​编辑Download notebook

This tutorial will show you how to build a basic speech recognition network that recognizes ten different words. It's important to know that real speech and audio recognition systems are much more complex, but like MNIST for images, it should give you a basic understanding of the techniques involved. Once you've completed this tutorial, you'll have a model that tries to classify a one second audio clip as "down", "go", "left", "no", "right", "stop", "up" and "yes".

Setup

Import necessary modules and dependencies.

import os
import pathlib

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import tensorflow as tf

from tensorflow.keras.layers.experimental import preprocessing
from tensorflow.keras import layers
from tensorflow.keras import models
from IPython import display


# Set seed for experiment reproducibility
seed = 42
tf.random.set_seed(seed)
np.random.seed(seed)

from scipy.io import wavfile
from scipy import signal
2024-07-30 20:49:22.876703: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-07-30 20:49:22.985427: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2024-07-30 20:49:22.985444: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.
print(tf.__version__)
2.9.3

Import the Speech Commands dataset

You'll write a script to download a portion of the Speech Commands dataset. The original dataset consists of over 105,000 WAV audio files of people saying thirty different words. This data was collected by Google and released under a CC BY license, and you can help improve it by contributing five minutes of your own voice.

You'll be using a portion of the dataset to save time with data loading. Extract the mini_speech_commands.zip and load it in using the tf.data API.

data_dir = pathlib.Path('data/mini_speech_commands')
if not data_dir.exists():
  tf.keras.utils.get_file(
      'mini_speech_commands.zip',
      origin="http://storage.googleapis.com/download.tensorflow.org/data/mini_speech_commands.zip",
      extract=True,
      cache_dir='.', cache_subdir='data')
Downloading data from http://storage.googleapis.com/download.tensorflow.org/data/mini_speech_commands.zip
  3235840/182082353 [..............................] - ETA: 1:42:46


---------------------------------------------------------------------------

KeyboardInterrupt                         Traceback (most recent call last)

Cell In[4], line 3
      1 data_dir = pathlib.Path('data/mini_speech_commands')
      2 if not data_dir.exists():
----> 3   tf.keras.utils.get_file(
      4       'mini_speech_commands.zip',
      5       origin="http://storage.googleapis.com/download.tensorflow.org/data/mini_speech_commands.zip",
      6       extract=True,
      7       cache_dir='.', cache_subdir='data')


File /opt/conda/lib/python3.10/site-packages/keras/utils/data_utils.py:283, in get_file(fname, origin, untar, md5_hash, file_hash, cache_subdir, hash_algorithm, extract, archive_format, cache_dir)
    281 try:
    282   try:
--> 283     urlretrieve(origin, fpath, DLProgbar())
    284   except urllib.error.HTTPError as e:
    285     raise Exception(error_msg.format(origin, e.code, e.msg))


File /opt/conda/lib/python3.10/site-packages/keras/utils/data_utils.py:84, in urlretrieve(url, filename, reporthook, data)
     82 response = urlopen(url, data)
     83 with open(filename, 'wb') as fd:
---> 84   for chunk in chunk_read(response, reporthook=reporthook):
     85     fd.write(chunk)


File /opt/conda/lib/python3.10/site-packages/keras/utils/data_utils.py:73, in urlretrieve.<locals>.chunk_read(response, chunk_size, reporthook)
     71 count = 0
     72 while True:
---> 73   chunk = response.read(chunk_size)
     74   count += 1
     75   if reporthook is not None:


File /opt/conda/lib/python3.10/http/client.py:466, in HTTPResponse.read(self, amt)
    463 if self.length is not None and amt > self.length:
    464     # clip the read to the "end of response"
    465     amt = self.length
--> 466 s = self.fp.read(amt)
    467 if not s and amt:
    468     # Ideally, we would raise IncompleteRead if the content-length
    469     # wasn't satisfied, but it might break compatibility.
    470     self._close_conn()


File /opt/conda/lib/python3.10/socket.py:705, in SocketIO.readinto(self, b)
    703 while True:
    704     try:
--> 705         return self._sock.recv_into(b)
    706     except timeout:
    707         self._timeout_occurred = True


KeyboardInterrupt: 

Check basic statistics about the dataset.

data_dir = pathlib.Path("eiq-model-zoo-main/tasks/audio/command-recognition/micro-speech-LSTM/data/mini_speech_commands")
commands = np.array(tf.io.gfile.listdir(str(data_dir)))
commands = commands[commands != 'README.md']
print('Commands:', commands)
Commands: ['no' 'yes' 'unknown']

Extract the audio files into a list and shuffle it.

filenames = tf.io.gfile.glob(str(data_dir) + '/*/*')
filenames = tf.random.shuffle(filenames)
num_samples = len(filenames)
print('Number of total examples:', num_samples)
print('Number of examples per label:',
      len(tf.io.gfile.listdir(str(data_dir/commands[0]))))
print('Example file tensor:', filenames[0])
Number of total examples: 5311
Number of examples per label: 1000
Example file tensor: tf.Tensor(b'eiq-model-zoo-main/tasks/audio/command-recognition/micro-speech-LSTM/data/mini_speech_commands/unknown/cd7f8c1b_nohash_2.wav', shape=(), dtype=string)

Split the files into training, validation and test sets using a 80:10:10 ratio, respectively.

train_files = filenames[:6400]
val_files = filenames[6400: 6400 + 800]
test_files = filenames[-800:]

print('Training set size', len(train_files))
print('Validation set size', len(val_files))
print('Test set size', len(test_files))
Training set size 5311
Validation set size 0
Test set size 800

Reading audio files and their labels

The audio file will initially be read as a binary file, which you'll want to convert into a numerical tensor.

To load an audio file, you will use tf.audio.decode_wav, which returns the WAV-encoded audio as a Tensor and the sample rate.

A WAV file contains time series data with a set number of samples per second. Each sample represents the amplitude of the audio signal at that specific time. In a 16-bit system, like the files in mini_speech_commands, the values range from -32768 to 32767. The sample rate for this dataset is 16kHz. Note that tf.audio.decode_wav will normalize the values to the range [-1.0, 1.0].

def decode_audio(audio_binary):
  audio, _ = tf.audio.decode_wav(audio_binary)
  return tf.squeeze(audio, axis=-1)

The label for each WAV file is its parent directory.

def get_label(file_path):
  parts = tf.strings.split(file_path, os.path.sep)

  # Note: You'll use indexing here instead of tuple unpacking to enable this 
  # to work in a TensorFlow graph.
  return parts[-2] 

Let's define a method that will take in the filename of the WAV file and output a tuple containing the audio and labels for supervised training.

def get_waveform_and_label(file_path):
  label = get_label(file_path)
  audio_binary = tf.io.read_file(file_path)
  waveform = decode_audio(audio_binary)
  return waveform, label

You will now apply process_path to build your training set to extract the audio-label pairs and check the results. You'll build the validation and test sets using a similar procedure later on.

AUTOTUNE = tf.data.experimental.AUTOTUNE
files_ds = tf.data.Dataset.from_tensor_slices(train_files)
waveform_ds = files_ds.map(get_waveform_and_label, num_parallel_calls=AUTOTUNE)
print(AUTOTUNE)
-1

Let's examine a few audio waveforms with their corresponding labels.

rows = 3
cols = 3
n = rows*cols
fig, axes = plt.subplots(rows, cols, figsize=(10, 12))
for i, (audio, label) in enumerate(waveform_ds.take(n)):
  r = i // cols
  c = i % cols
  ax = axes[r][c]
  ax.plot(audio.numpy())
  ax.set_yticks(np.arange(-1.2, 1.2, 0.2))
  label = label.numpy().decode('utf-8')
  ax.set_title(label)

plt.show()

Spectrogram

You'll convert the waveform into a spectrogram, which shows frequency changes over time and can be represented as a 2D image. This can be done by applying the short-time Fourier transform (STFT) to convert the audio into the time-frequency domain.

A Fourier transform (tf.signal.fft) converts a signal to its component frequencies, but loses all time information. The STFT (tf.signal.stft) splits the signal into windows of time and runs a Fourier transform on each window, preserving some time information, and returning a 2D tensor that you can run standard convolutions on.

STFT produces an array of complex numbers representing magnitude and phase. However, you'll only need the magnitude for this tutorial, which can be derived by applying tf.abs on the output of tf.signal.stft.

Choose frame_length and frame_step parameters such that the generated spectrogram "image" is almost square. For more information on STFT parameters choice, you can refer to this video on audio signal processing.

You also want the waveforms to have the same length, so that when you convert it to a spectrogram image, the results will have similar dimensions. This can be done by simply zero padding the audio clips that are shorter than one second.

def stft(x):
    f, t, spec = signal.stft(x.numpy(), fs=16000, nperseg=255, noverlap = 124, nfft=256)
    return tf.convert_to_tensor(np.abs(spec))

def get_spectrogram(waveform):
  # Padding for files with less than 16000 samples
  zero_padding = tf.zeros([16000] - tf.shape(waveform), dtype=tf.float32)

  # Concatenate audio with padding so that all audio clips will be of the 
  # same length
  waveform = tf.cast(waveform, tf.float32)
  equal_length = tf.concat([waveform, zero_padding], 0)
    
  spectrogram = tf.py_function(func=stft, inp=[equal_length], Tout=tf.float32)
       
  spectrogram.set_shape((129, 124))

  #spectrogram = tf.signal.stft(equal_length, frame_length=255, frame_step=128)
      
  #spectrogram = tf.abs(spectrogram)
  print("spectrogram:", spectrogram)

  return spectrogram

Next, you will explore the data. Compare the waveform, the spectrogram and the actual audio of one example from the dataset.

for waveform, label in waveform_ds.take(1):
  label = label.numpy().decode('utf-8')
  spectrogram = get_spectrogram(waveform)

print('Label:', label)
print('Waveform shape:', waveform.shape)
print('Spectrogram shape:', spectrogram.shape)
print('Audio playback')
display.display(display.Audio(waveform, rate=16000))
spectrogram: tf.Tensor(
[[1.3323117e-04 2.0556885e-05 5.4556225e-05 ... 8.4603002e-05
  4.9450118e-05 4.0375954e-07]
 [1.7885730e-04 1.2982974e-04 5.5361979e-05 ... 1.1477976e-04
  4.5945391e-05 4.0284965e-07]
 [2.0510760e-04 1.7932945e-04 1.7563498e-05 ... 1.4174559e-04
  1.1835613e-04 4.0013109e-07]
 ...
 [2.0882417e-06 3.5290028e-05 3.5315792e-05 ... 9.3307534e-07
  2.4691594e-06 4.4648058e-08]
 [2.5977388e-06 3.1138712e-05 1.2549801e-05 ... 5.2627638e-07
  2.0020402e-06 4.4576122e-08]
 [2.6915066e-06 2.8589993e-05 2.5866726e-05 ... 4.3330269e-07
  1.7054866e-06 4.4551889e-08]], shape=(129, 124), dtype=float32)
Label: unknown
Waveform shape: (16000,)
Spectrogram shape: (129, 124)
Audio playback
def plot_spectrogram(spectrogram, ax):
    # Convert to frequencies to log scale and transpose so that the time is
    # represented in the x-axis (columns).
    log_spec = np.log(spectrogram)
    height = log_spec.shape[0]
    X = np.arange(16000, step=height + 1)
    Y = range(height)
    print(X.shape, Y, log_spec.shape)
    ax.pcolormesh(X, Y, log_spec)


fig, axes = plt.subplots(2, figsize=(12, 8))
timescale = np.arange(waveform.shape[0])
axes[0].plot(timescale, waveform.numpy())
axes[0].set_title('Waveform')
axes[0].set_xlim([0, 16000])
plot_spectrogram(spectrogram.numpy(), axes[1])
axes[1].set_title('Spectrogram')
plt.show()
(124,) range(0, 129) (129, 124)

Now transform the waveform dataset to have spectrogram images and their corresponding labels as integer IDs.

def get_spectrogram_and_label_id(audio, label):
  spectrogram = get_spectrogram(audio)
  spectrogram = tf.expand_dims(spectrogram, -1)
  label_id = tf.argmax(tf.cast(label == commands, "uint32"))
  return spectrogram, label_id
print(type(waveform_ds))
print(type(label), type(commands))
print(commands)
a = tf.constant((label == commands), "uint8")
label_id = tf.argmax(a)
print(label_id)
<class 'tensorflow.python.data.ops.dataset_ops.ParallelMapDataset'>
<class 'str'> <class 'numpy.ndarray'>
['no' 'yes' 'unknown']
tf.Tensor(2, shape=(), dtype=int64)
spectrogram_ds = waveform_ds.map(
    get_spectrogram_and_label_id, num_parallel_calls=AUTOTUNE)
spectrogram: Tensor("EagerPyFunc:0", shape=(129, 124), dtype=float32, device=/job:localhost/replica:0/task:0)

Examine the spectrogram "images" for different samples of the dataset.

rows = 3
cols = 3
n = rows*cols
fig, axes = plt.subplots(rows, cols, figsize=(10, 10))
for i, (spectrogram, label_id) in enumerate(spectrogram_ds.take(n)):
  r = i // cols
  c = i % cols
  ax = axes[r][c]
  plot_spectrogram(np.squeeze(spectrogram.numpy()), ax)
  ax.set_title(commands[label_id.numpy()])
  ax.axis('off')
  
plt.show()
(124,) range(0, 129) (129, 124)
(124,) range(0, 129) (129, 124)
(124,) range(0, 129) (129, 124)
(124,) range(0, 129) (129, 124)
(124,) range(0, 129) (129, 124)
(124,) range(0, 129) (129, 124)
(124,) range(0, 129) (129, 124)
(124,) range(0, 129) (129, 124)
(124,) range(0, 129) (129, 124)


/tmp/ipykernel_271/1663808029.py:4: RuntimeWarning: divide by zero encountered in log
  log_spec = np.log(spectrogram)

Build and train the model

Now you can build and train your model. But before you do that, you'll need to repeat the training set preprocessing on the validation and test sets.

def preprocess_dataset(files):
  files_ds = tf.data.Dataset.from_tensor_slices(files)
  output_ds = files_ds.map(get_waveform_and_label, num_parallel_calls=AUTOTUNE)
  output_ds = output_ds.map(
      get_spectrogram_and_label_id,  num_parallel_calls=AUTOTUNE)
  return output_ds
train_ds = spectrogram_ds
val_ds = preprocess_dataset(val_files)
test_ds = preprocess_dataset(test_files)
spectrogram: Tensor("EagerPyFunc:0", shape=(129, 124), dtype=float32, device=/job:localhost/replica:0/task:0)
spectrogram: Tensor("EagerPyFunc:0", shape=(129, 124), dtype=float32, device=/job:localhost/replica:0/task:0)

Batch the training and validation sets for model training.

batch_size = 64
train_ds = train_ds.batch(batch_size)
val_ds = val_ds.batch(batch_size)

Add dataset cache() and prefetch() operations to reduce read latency while training the model.

train_ds = train_ds.cache().prefetch(AUTOTUNE)
val_ds = val_ds.cache().prefetch(AUTOTUNE)

For the model, you'll use a simple convolutional neural network (CNN), since you have transformed the audio files into spectrogram images. The model also has the following additional preprocessing layers:

  • A Resizing layer to downsample the input to enable the model to train faster.
  • A Normalization layer to normalize each pixel in the image based on its mean and standard deviation.

For the Normalization layer, its adapt method would first need to be called on the training data in order to compute aggregate statistics (i.e. mean and standard deviation).

for spectrogram, _ in spectrogram_ds.take(1):
  input_shape = spectrogram.shape
print('Input shape:', input_shape)
num_labels = len(commands)

norm_layer = preprocessing.Normalization()
norm_layer.adapt(spectrogram_ds.map(lambda x, _: x))

model = models.Sequential([
    layers.Input(shape=input_shape),
    preprocessing.Resizing(32, 32), 
    norm_layer,
    layers.Conv2D(32, 3, activation='relu'),
    layers.Conv2D(64, 3, activation='relu'),
    layers.MaxPooling2D(),
    layers.Dropout(0.25),
    layers.Flatten(),
    layers.Dense(128, activation='relu'),
    layers.Dropout(0.5),
    layers.Dense(num_labels),
])

model.summary()
Input shape: (129, 124, 1)
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 resizing (Resizing)         (None, 32, 32, 1)         0         
                                                                 
 normalization (Normalizatio  (None, 32, 32, 1)        3         
 n)                                                              
                                                                 
 conv2d (Conv2D)             (None, 30, 30, 32)        320       
                                                                 
 conv2d_1 (Conv2D)           (None, 28, 28, 64)        18496     
                                                                 
 max_pooling2d (MaxPooling2D  (None, 14, 14, 64)       0         
 )                                                               
                                                                 
 dropout (Dropout)           (None, 14, 14, 64)        0         
                                                                 
 flatten (Flatten)           (None, 12544)             0         
                                                                 
 dense (Dense)               (None, 128)               1605760   
                                                                 
 dropout_1 (Dropout)         (None, 128)               0         
                                                                 
 dense_1 (Dense)             (None, 3)                 387       
                                                                 
=================================================================
Total params: 1,624,966
Trainable params: 1,624,963
Non-trainable params: 3
_________________________________________________________________
model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy'],
)
EPOCHS = 10
history = model.fit(
    train_ds, 
    validation_data=val_ds,  
    epochs=EPOCHS,
    callbacks=tf.keras.callbacks.EarlyStopping(verbose=1, patience=2),
)
Epoch 1/10
83/83 [==============================] - ETA: 0s - loss: 0.7221 - accuracy: 0.7098WARNING:tensorflow:Early stopping conditioned on metric `val_loss` which is not available. Available metrics are: loss,accuracy
83/83 [==============================] - 8s 84ms/step - loss: 0.7221 - accuracy: 0.7098
Epoch 2/10
83/83 [==============================] - ETA: 0s - loss: 0.4803 - accuracy: 0.7998WARNING:tensorflow:Early stopping conditioned on metric `val_loss` which is not available. Available metrics are: loss,accuracy
83/83 [==============================] - 4s 47ms/step - loss: 0.4803 - accuracy: 0.7998
Epoch 3/10
82/83 [============================>.] - ETA: 0s - loss: 0.3632 - accuracy: 0.8514WARNING:tensorflow:Early stopping conditioned on metric `val_loss` which is not available. Available metrics are: loss,accuracy
83/83 [==============================] - 4s 47ms/step - loss: 0.3631 - accuracy: 0.8520
Epoch 4/10
83/83 [==============================] - ETA: 0s - loss: 0.3013 - accuracy: 0.8823WARNING:tensorflow:Early stopping conditioned on metric `val_loss` which is not available. Available metrics are: loss,accuracy
83/83 [==============================] - 4s 48ms/step - loss: 0.3013 - accuracy: 0.8823
Epoch 5/10
83/83 [==============================] - ETA: 0s - loss: 0.2520 - accuracy: 0.8970WARNING:tensorflow:Early stopping conditioned on metric `val_loss` which is not available. Available metrics are: loss,accuracy
83/83 [==============================] - 4s 47ms/step - loss: 0.2520 - accuracy: 0.8970
Epoch 6/10
83/83 [==============================] - ETA: 0s - loss: 0.2313 - accuracy: 0.9111WARNING:tensorflow:Early stopping conditioned on metric `val_loss` which is not available. Available metrics are: loss,accuracy
83/83 [==============================] - 4s 48ms/step - loss: 0.2313 - accuracy: 0.9111
Epoch 7/10
82/83 [============================>.] - ETA: 0s - loss: 0.2068 - accuracy: 0.9167WARNING:tensorflow:Early stopping conditioned on metric `val_loss` which is not available. Available metrics are: loss,accuracy
83/83 [==============================] - 4s 48ms/step - loss: 0.2060 - accuracy: 0.9172
Epoch 8/10
82/83 [============================>.] - ETA: 0s - loss: 0.1722 - accuracy: 0.9390WARNING:tensorflow:Early stopping conditioned on metric `val_loss` which is not available. Available metrics are: loss,accuracy
83/83 [==============================] - 4s 47ms/step - loss: 0.1714 - accuracy: 0.9396
Epoch 9/10
82/83 [============================>.] - ETA: 0s - loss: 0.1561 - accuracy: 0.9402WARNING:tensorflow:Early stopping conditioned on metric `val_loss` which is not available. Available metrics are: loss,accuracy
83/83 [==============================] - 4s 46ms/step - loss: 0.1549 - accuracy: 0.9409
Epoch 10/10
82/83 [============================>.] - ETA: 0s - loss: 0.1271 - accuracy: 0.9520WARNING:tensorflow:Early stopping conditioned on metric `val_loss` which is not available. Available metrics are: loss,accuracy
83/83 [==============================] - 4s 48ms/step - loss: 0.1267 - accuracy: 0.9518

Let's check the training and validation loss curves to see how your model has improved during training.

model.save('simple_audio_model_numpy.sav')
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op, _jit_compiled_convolution_op while saving (showing 2 of 2). These functions will not be directly callable after loading.


INFO:tensorflow:Assets written to: simple_audio_model_numpy.sav/assets


INFO:tensorflow:Assets written to: simple_audio_model_numpy.sav/assets
metrics = history.history
plt.plot(history.epoch, metrics['loss'], metrics['val_loss'])
plt.legend(['loss', 'val_loss'])
plt.show()
---------------------------------------------------------------------------

KeyError                                  Traceback (most recent call last)

Cell In[32], line 2
      1 metrics = history.history
----> 2 plt.plot(history.epoch, metrics['loss'], metrics['val_loss'])
      3 plt.legend(['loss', 'val_loss'])
      4 plt.show()


KeyError: 'val_loss'

TensorFlow Lite

# Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model('simple_audio_model_numpy.sav') # path to the SavedModel directory
tflite_model = converter.convert()

# Save the model.
with open('simple_audio_model_numpy.tflite', 'wb') as f:
  f.write(tflite_model)
2024-07-30 21:04:52.280960: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:362] Ignored output_format.
2024-07-30 21:04:52.280984: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:365] Ignored drop_control_dependency.
2024-07-30 21:04:52.281613: I tensorflow/cc/saved_model/reader.cc:43] Reading SavedModel from: simple_audio_model_numpy.sav
2024-07-30 21:04:52.283562: I tensorflow/cc/saved_model/reader.cc:81] Reading meta graph with tags { serve }
2024-07-30 21:04:52.283579: I tensorflow/cc/saved_model/reader.cc:122] Reading SavedModel debug info (if present) from: simple_audio_model_numpy.sav
2024-07-30 21:04:52.287957: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:354] MLIR V1 optimization pass is not enabled
2024-07-30 21:04:52.289493: I tensorflow/cc/saved_model/loader.cc:228] Restoring SavedModel bundle.
2024-07-30 21:04:52.357881: I tensorflow/cc/saved_model/loader.cc:212] Running initialization op on SavedModel bundle at path: simple_audio_model_numpy.sav
2024-07-30 21:04:52.375648: I tensorflow/cc/saved_model/loader.cc:301] SavedModel load for tags { serve }; Status: success: OK. Took 94042 microseconds.
2024-07-30 21:04:52.408613: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:263] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.

Evaluate test set performance

Let's run the model on the test set and check performance.

test_audio = []
test_labels = []

for audio, label in test_ds:
  test_audio.append(audio.numpy())
  test_labels.append(label.numpy())

test_audio = np.array(test_audio)
test_labels = np.array(test_labels)
y_pred = np.argmax(model.predict(test_audio), axis=1)
y_true = test_labels

test_acc = sum(y_pred == y_true) / len(y_true)
print(f'Test set accuracy: {test_acc:.0%}')
25/25 [==============================] - 0s 7ms/step
Test set accuracy: 98%

Display a confusion matrix

A confusion matrix is helpful to see how well the model did on each of the commands in the test set.

confusion_mtx = tf.math.confusion_matrix(y_true, y_pred) 
plt.figure(figsize=(10, 8))
sns.heatmap(confusion_mtx, xticklabels=commands, yticklabels=commands, 
            annot=True, fmt='g')
plt.xlabel('Prediction')
plt.ylabel('Label')
plt.show()

Run inference on an audio file

Finally, verify the model's prediction output using an input audio file of someone saying "no." How well does your model perform?

sample_file = data_dir/'no/01bb6a2a_nohash_0.wav'

sample_ds = preprocess_dataset([str(sample_file)])

for spectrogram, label in sample_ds.batch(1):
  prediction = model(spectrogram)
  plt.bar(commands, tf.nn.softmax(prediction[0]))
  plt.title(f'Predictions for "{commands[label[0]]}"')
  plt.show()
spectrogram: Tensor("EagerPyFunc:0", shape=(129, 124), dtype=float32, device=/job:localhost/replica:0/task:0)

Convert to TensorFlow Lite Model

You can see that your model very clearly recognized the audio command as "no."

Next steps

This tutorial showed how you could do simple audio classification using a convolutional neural network with TensorFlow and Python.

  • To learn how to use transfer learning for audio classification, check out the Sound classification with YAMNet tutorial.

  • To build your own interactive web app for audio classification, consider taking the TensorFlow.js - Audio recognition using transfer learning codelab.

  • TensorFlow also has additional support for audio data preparation and augmentation to help with your own audio-based projects.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1961799.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

贝锐蒲公英视频监控方案:无需专线,实现连锁酒店摄像头集中监看

公司业务主要围绕连锁品牌酒店经营&#xff0c;从2019年成立至今&#xff0c;已经在北京、上海、杭州、武汉、南京、深圳等地设立了多家门店。为了实现集中管理&#xff0c;北京总部需要实时监看或调取各地酒店内的监控摄像头&#xff0c;并将监控数据集中存储至监控中心的服务…

智云-一个抓取web流量的轻量级蜜罐

智云-一个抓取web流量的轻量级蜜罐 安装环境要求 apache php7.4 mysql8 github地址 https://github.com/xiaoxiaoranxxx/POT-ZHIYUN 系统演示

【康复学习--LeetCode每日一题】3111. 覆盖所有点的最少矩形数目

题目&#xff1a; 给你一个二维整数数组 point &#xff0c;其中 points[i] [xi, yi] 表示二维平面内的一个点。同时给你一个整数 w 。你需要用矩形 覆盖所有 点。 每个矩形的左下角在某个点 (x1, 0) 处&#xff0c;且右上角在某个点 (x2, y2) 处&#xff0c;其中 x1 < x…

Base64解码时Illegal base64 character 20问题解决

一&#xff0c;问题 在使用Base64解码的时候 // 这里的keyContent是公钥&#xff0c;一般配置到配置中心里&#xff0c;然后注入到容器里 String publicKeyString keyContent .replaceAll("\\n", "") .replace("-----BEGIN PUBLIC KEY-----",…

二进制搭建 Kubernetes v1.20(中)

一、部署 CNI 网络组件 目录 一、部署 CNI 网络组件 1.flannel简介 1&#xff09;UDP模式 2&#xff09;VXLAN 模式 2.部署flannel ​编辑 3.Calico简介 1.flannel简介 K8S 中 Pod 网络通信&#xff1a;●Pod 内容器与容器之间的通信 在同一个 Pod 内的容器&#xff0…

什么是网络安全等级保护测评服务?

等保测评 依据国家网络安全等级保护制度规定&#xff0c;按照有关管理规范和技术标准&#xff0c;对非涉及国家秘密的网络安全等级保护状况进行检测评估。定级协助 根据等级保护对象在国家安全、经济建设、社会生活中的重要程度&#xff0c;以及一旦遭到破坏、丧失功能或者数据…

【Python工具】Python 实现 telnet、loguru 框架下的 DEBUG 分级日志打印

文章目录 1、背景2、轮子2.1、telnet2.2、loguru DEBUG 日志分级 1、背景 最近业务这边需要用 Python 起一个 web 服务器&#xff0c;做 LLM 相关的业务处理。后台选用的是 django 框架做 web 框架&#xff0c;现在也算结项了。初次写 Python&#xff0c;造出来的轮子啥的总结…

FPGA知识基础之--按键控制LED灯项目

文章目录 前言一、按键简介按键:通过按下或者释放来控制电路通断的电子元件按键原理图 二、实验要求三、程序设计3.1思路整理3.2 模型搭建3.3 顶层模块3.4 波形分析 四、代码整理4.1RTL代码4.2 仿真只需在Testbench上增加上述一段代码即可将参数实例化,可达到在Testbench上更改…

随堂测小程序的设计

管理员账户功能包括&#xff1a;系统首页&#xff0c;个人中心&#xff0c;学生管理&#xff0c;教师管理&#xff0c;试题信息管理&#xff0c;标签类型管理&#xff0c;系统管理 微信端账号功能包括&#xff1a;系统首页&#xff0c;考试成绩&#xff0c;试题信息&#xff0…

RDP Microsoft Remote Desktop 优化指南

RDP优化指南 RDP的优化主要从以下几个方面进行&#xff1a; 帧率设置开启硬件加速&#xff08;使用显卡&#xff09;传输协议设置链接用户数量nVidia显卡加速 1. 帧率设置&#xff08;这里我建议可以不去更改&#xff09; 我更信任UFO Test的FPS值&#xff0c;有人说改完之后…

15.2 zookeeper java client

15.2 zookeeper java client 1. Zookeeper官方1.1 依赖1.2 Zookeeper客户端连接测试1.3 Zookeeper节点操作1.3.1 zooKeeper.create创建节点1.3.2 zooKeeper.exists获取节点详情1.3.3 zooKeeper.getData获取节点数据1.3.4 zooKeeper.getChildren获取节点下所有子节点名称1.3.5 …

网络地址转换(NAT)

文章目录 NAT的作用NAT的实现方式NAT静态转换工作过程 NAT的作用 局域网与互联网的通信需求 重叠网段合并互通 隐藏内部网络的细节 NAT的实现方式 静态转换 "一对一"固定转换 动态转换 Basic NAT "一对一”动态转换。需要创建公网地址池 eNAPT 通过“IP地址端口…

c#调用python代码,实现读取npy的数据并显示图像

本例子实现的功能是&#xff1a; 根据stat.npy、ops.npy两个npy文件的内容&#xff0c;显示图形 1. 用python代码实现读取两个文件&#xff0c;文件名为read_npy.py&#xff0c;代码如下&#xff1a; import numpy as npdef read_npy_files(stat_file, ops_file):stat np.lo…

sqli-labs(6-10)关通关讲解

sqli-labs(6-10)关通关讲解 Less-6 方法一&#xff1a;手工注入 1.判断闭合 http://localhost/sqli-labs/Less-6/?id1" //报错 http://localhost/sqli-labs/Less-6/?id1" -- //正常 http://localhost/sqli-labs/Less-6/?id1" and 11 -- http://localhos…

YOLOv10环境搭建、训练自己的目标检测数据集、实际验证和测试

1 环境搭建 1.1 在官方仓库的给定的使用python3.9版本&#xff0c;则使用conda创建对应虚拟环境。 conda create -n yolov10 python3.9 1.2 切换到对应虚拟环境 conda activate yolov10 1.3 在指定目录下克隆yolov10官方仓库代码 git clone https://github.com/THU-MIG/yo…

手摸手教你撕碎西门子S7通讯协议10--S7Write写入float数据

1、S7通讯回顾 - &#xff08;1&#xff09;建立TCP连接 Socket.Connect-》已实现 - &#xff08;2&#xff09;发送访问请求 COTP-》已实现 - &#xff08;3&#xff09;交换通信信息 Setup Communication-》已实现 - &#xff08;4&#xff09;执行相关操作 …

器件学习——磁珠(2024.07.30)

参考链接1: 【器件篇】-25-磁珠的选型 在此感谢各位前辈大佬的总结&#xff0c;写这个只是为了记录学习大佬资料的过程&#xff0c;内容基本都是搬运的大佬博客&#xff0c;觉着有用自己搞过来自己记一下&#xff0c;如果有大佬觉着我搬过来不好&#xff0c;联系我删。 器件学习…

【MyBatis】史上最全的MyBatis执行SQL原理分析

目录 一、前言 二、简介 三、SQL 执行过程分析 3.1 SQL 执行入口分析 3.1.0 获取SqlSession对象 3.1.1 为 Mapper 接口创建代理对象 3.1.2 执行代理逻辑 3.1.2.1 获取 / 创建 MapperMethod 对象 3.1.2.1.1 创建 SqlCommand 对象 3.1.2.1.2 创建 MethodSignature 对象…

华为OD机试 - Wonderland游乐园 - 动态规划(Java 2024 D卷 200分)

华为OD机试 2024D卷题库疯狂收录中&#xff0c;刷题点这里 专栏导读 本专栏收录于《华为OD机试&#xff08;JAVA&#xff09;真题&#xff08;D卷C卷A卷B卷&#xff09;》。 刷的越多&#xff0c;抽中的概率越大&#xff0c;每一题都有详细的答题思路、详细的代码注释、3个测…

答应我,在量化策略回测里,避开未来函数这4个坑

由于社群的原因&#xff0c;看过不少策略&#xff0c;今天就腆着脸唠唠&#xff0c;量化新手期经常碰到未来函数的4个坑&#xff0c;希望量化萌新们少掉点儿头发。新手向文章&#xff0c;大神请绕行~ 情景1:使用前复权价格数据。 由于股票会存在分红送股的情形&#xff0c;价格…