前言

最近经常看 Youtube 有很多的钢琴家在分享他们的曲子,有的时候很想将其稍微做一点修改,但是这样的话就要扒谱到 FlStudio 里面,即使有 MIDI 键盘辅助,手工扒谱也是一项非常吃力的活,于是就想着能不能借助 AI,这不,Google MT3 模型就出现在了我的眼前。

环境

出于懒就要到底的原则,我们果断摒弃了本地虚拟机这个选项,毕竟光在虚拟机中打 AMD 深度学习显卡驱动就够我吃一壶了,在此果断投到阿里云怀抱中。

购买 GPU 型实例

这里是我的配置单

  • 地区:日本
  • 架构类型:异构计算
  • 规格:GPU 计算型 gn6i
  • 硬件配置:16 vCPU ( Intel Xeon(Skylake) Platinum 8163 ) / 62 GiB / 1 * NVIDIA T4
  • 购买类型:抢占式实例,享受 10% 折扣率,折合每小时 1.7 元,自动配置单台实例规格上限价 2.5 元,显示历史释放率 3%
  • 系统:Ubuntu 22.04 LTS Server

vps.png

环境基础配置

注意我们这里区域特意选择了日本,也就是说尽量规避了被墙的依赖造成的疑难杂症,这里我们直接使用阿里云的安装脚本配置显卡驱动和 CUDA 环境

$ apt update
$ apt upgrade
$ apt install cmake git wget clang build-essential ninja-build zlib1g-dev libtbb-dev ffmepg
$ ubuntu-drivers devices   # 查看这台机器的 GPU
== /sys/devices/pci0000:00/0000:00:07.0 ==
modalias : pci:v000010DEd00001EB8sv000010DEsd000012A2bc03sc02i00
vendor   : NVIDIA Corporation
model    : TU104GL [Tesla T4]
driver   : nvidia-driver-450-server - distro non-free
driver   : nvidia-driver-470 - distro non-free recommended
driver   : nvidia-driver-418-server - distro non-free
driver   : nvidia-driver-470-server - distro non-free
driver   : xserver-xorg-video-nouveau - distro free builtin
$ bash install.sh
$ reboot

install.sh

#!/bin/sh

IS_INSTALL_RDMA="FALSE"
IS_INSTALL_AIACC_TRAIN="FALSE"
IS_INSTALL_AIACC_INFERENCE="FALSE"
DRIVER_VERSION="470.82.01"
CUDA_VERSION="11.4.1"
CUDNN_VERSION="8.2.4"
IS_INSTALL_RAPIDS="FALSE"

INSTALL_DIR="/root/auto_install"
auto_install_script="auto_install.sh"
script_download_url="https://mirrors.aliyun.com/opsx/ecs/linux/binary/script/"

mkdir $INSTALL_DIR && cd $INSTALL_DIR
wget -t 10 --timeout=10 $script_download_url && sh ${INSTALL_DIR}/${auto_install_script} $DRIVER_VERSION $CUDA_VERSION $CUDNN_VERSION $IS_INSTALL_AIACC_TRAIN $IS_INSTALL_AIACC_INFERENCE $IS_INSTALL_RDMA $IS_INSTALL_RAPIDS

其他

安装 LLVM 指定版本,由于 LLVM 的编译对于 Clang 版本要求过于苛刻,在尝试源码编译多次无果后,我决定采用安装老版本的 APT 源进行安装

PS: 由于 MT3 严重依赖 Numba,错误的版本会导致 LLVMLite 无法安装, RuntimeError: Building llvmlite requires LLVM 10.0.x or 9.0.x, got '12.0.1'. Be sure to set LLVM_CONFIG to the right executable path

$ vim /etc/apt/source.list # 追加下面内容
deb https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal main restricted universe multiverse
deb https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal-updates main restricted universe multiverse
deb https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal-backports main restricted universe multiverse
deb https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal-security main restricted universe multiverse
$ apt update
$ apt insatll llvm-8
$ ln -s $(which llvm-config-8) /usr/bin/llvm-config
$ llvm-config --version
8.0.0

下面继续配置运行环境

$ pip install jupyterlab
$ jupyter-lab --no-browser --ip "*" --allow-root  --NotebookApp.port_retries=0 # 这里是测试机器才使用 ROOT 用户
[I 2022-07-30 22:59:55.666 ServerApp] jupyterlab | extension was successfully linked.
[I 2022-07-30 22:59:55.676 ServerApp] nbclassic | extension was successfully linked.
[I 2022-07-30 22:59:55.847 ServerApp] notebook_shim | extension was successfully linked.
[W 2022-07-30 22:59:55.861 ServerApp] WARNING: The Jupyter server is listening on all IP addresses and not using encryption. This is not recommended.
[I 2022-07-30 22:59:55.863 ServerApp] notebook_shim | extension was successfully loaded.
[I 2022-07-30 22:59:55.864 LabApp] JupyterLab extension loaded from /usr/local/lib/python3.10/dist-packages/jupyterlab
[I 2022-07-30 22:59:55.864 LabApp] JupyterLab application directory is /usr/local/share/jupyter/lab
[I 2022-07-30 22:59:55.867 ServerApp] jupyterlab | extension was successfully loaded.
[I 2022-07-30 22:59:55.870 ServerApp] nbclassic | extension was successfully loaded.
[I 2022-07-30 22:59:55.871 ServerApp] Serving notebooks from local directory: /root
[I 2022-07-30 22:59:55.871 ServerApp] Jupyter Server 1.18.1 is running at:
[I 2022-07-30 22:59:55.871 ServerApp] http://localhost:8888/lab?token=a63ff75619e54268109584a93f4048cd8ed72cfcdb49d603
[I 2022-07-30 22:59:55.871 ServerApp]  or http://127.0.0.1:8888/lab?token=a63ff75619e54268109584a93f4048cd8ed72cfcdb49d603
[I 2022-07-30 22:59:55.871 ServerApp] Use Control-C to stop this server and shut down all kernels (twice to skip confirmation).
[C 2022-07-30 22:59:55.874 ServerApp] 
    
    To access the server, open this file in a browser:
        file:///root/.local/share/jupyter/runtime/jpserver-7962-open.html
    Or copy and paste one of these URLs:
        http://localhost:8888/lab?token=a63ff75619e54268109584a93f4048cd8ed72cfcdb49d603
     or http://127.0.0.1:8888/lab?token=a63ff75619e54268109584a93f4048cd8ed72cfcdb49d603

使用上面的 TOKEN 带入进去页面即可,这样就完成了 Jupyter 的临时搭建

vps.png

正题

配置深度学习环境

这边就需要简单写一下代码了,我将依赖安装等内容都写在了一起,只需要将其分段放入 notebook 中运行即可,如果您使用 CentOS 等发行版请处理一下 YUM 包名字

!apt-get update -qq && apt-get install -qq fluidsynth build-essential libasound2-dev libjack-dev

!pip install nest-asyncio pyfluidsynth pytest-runner pybind11 
# pin CLU for python 3.7 compatibility
!pip install clu==0.0.7
# pin Orbax to use Checkpointer
!pip install orbax==0.0.2

# install t5x
!git clone --branch=main https://github.com/google-research/t5x
# pin T5X for python 3.7 compatibility
!cd t5x; git reset --hard 2e05ad41778c25521738418de805757bf2e41e9e; cd ..
!mv t5x t5x_tmp; mv t5x_tmp/* .; rm -r t5x_tmp
!sed -i 's:jax\[tpu\]:jax:' setup.py
!python3 -m pip install -e .

# install mt3
!git clone --branch=main https://github.com/magenta/mt3
!mv mt3 mt3_tmp; mv mt3_tmp/* .; rm -r mt3_tmp
!python3 -m pip install -e .

!wget http://download.magenta.tensorflow.org/soundfonts/SGM-v2.01-Sal-Guit-Bass-V1.3.sf2 .

导入依赖库

#@title Imports and Definitions

import functools
import os

import numpy as np
import tensorflow.compat.v2 as tf

import functools
import gin
import jax
import librosa
import note_seq
import seqio
import t5
import t5x

from mt3 import metrics_utils
from mt3 import models
from mt3 import network
from mt3 import note_sequences
from mt3 import preprocessors
from mt3 import spectrograms
from mt3 import vocabularies

from google.colab import files

import nest_asyncio
nest_asyncio.apply()

SAMPLE_RATE = 16000
SF2_PATH = 'SGM-v2.01-Sal-Guit-Bass-V1.3.sf2'

def upload_audio(sample_rate):
  data = list(files.upload().values())
  if len(data) > 1:
    print('Multiple files uploaded; using only one.')
  return note_seq.audio_io.wav_data_to_samples_librosa(
    data[0], sample_rate=sample_rate)



class InferenceModel(object):
  """Wrapper of T5X model for music transcription."""

  def __init__(self, checkpoint_path, model_type='mt3'):

    # Model Constants.
    if model_type == 'ismir2021':
      num_velocity_bins = 127
      self.encoding_spec = note_sequences.NoteEncodingSpec
      self.inputs_length = 512
    elif model_type == 'mt3':
      num_velocity_bins = 1
      self.encoding_spec = note_sequences.NoteEncodingWithTiesSpec
      self.inputs_length = 256
    else:
      raise ValueError('unknown model_type: %s' % model_type)

    gin_files = ['/content/mt3/gin/model.gin',
                 f'/content/mt3/gin/{model_type}.gin']

    self.batch_size = 8
    self.outputs_length = 1024
    self.sequence_length = {'inputs': self.inputs_length, 
                            'targets': self.outputs_length}

    self.partitioner = t5x.partitioning.PjitPartitioner(
        num_partitions=1)

    # Build Codecs and Vocabularies.
    self.spectrogram_config = spectrograms.SpectrogramConfig()
    self.codec = vocabularies.build_codec(
        vocab_config=vocabularies.VocabularyConfig(
            num_velocity_bins=num_velocity_bins))
    self.vocabulary = vocabularies.vocabulary_from_codec(self.codec)
    self.output_features = {
        'inputs': seqio.ContinuousFeature(dtype=tf.float32, rank=2),
        'targets': seqio.Feature(vocabulary=self.vocabulary),
    }

    # Create a T5X model.
    self._parse_gin(gin_files)
    self.model = self._load_model()

    # Restore from checkpoint.
    self.restore_from_checkpoint(checkpoint_path)

  @property
  def input_shapes(self):
    return {
          'encoder_input_tokens': (self.batch_size, self.inputs_length),
          'decoder_input_tokens': (self.batch_size, self.outputs_length)
    }

  def _parse_gin(self, gin_files):
    """Parse gin files used to train the model."""
    gin_bindings = [
        'from __gin__ import dynamic_registration',
        'from mt3 import vocabularies',
        'VOCAB_CONFIG=@vocabularies.VocabularyConfig()',
        'vocabularies.VocabularyConfig.num_velocity_bins=%NUM_VELOCITY_BINS'
    ]
    with gin.unlock_config():
      gin.parse_config_files_and_bindings(
          gin_files, gin_bindings, finalize_config=False)

  def _load_model(self):
    """Load up a T5X `Model` after parsing training gin config."""
    model_config = gin.get_configurable(network.T5Config)()
    module = network.Transformer(config=model_config)
    return models.ContinuousInputsEncoderDecoderModel(
        module=module,
        input_vocabulary=self.output_features['inputs'].vocabulary,
        output_vocabulary=self.output_features['targets'].vocabulary,
        optimizer_def=t5x.adafactor.Adafactor(decay_rate=0.8, step_offset=0),
        input_depth=spectrograms.input_depth(self.spectrogram_config))


  def restore_from_checkpoint(self, checkpoint_path):
    """Restore training state from checkpoint, resets self._predict_fn()."""
    train_state_initializer = t5x.utils.TrainStateInitializer(
      optimizer_def=self.model.optimizer_def,
      init_fn=self.model.get_initial_variables,
      input_shapes=self.input_shapes,
      partitioner=self.partitioner)

    restore_checkpoint_cfg = t5x.utils.RestoreCheckpointConfig(
        path=checkpoint_path, mode='specific', dtype='float32')

    train_state_axes = train_state_initializer.train_state_axes
    self._predict_fn = self._get_predict_fn(train_state_axes)
    self._train_state = train_state_initializer.from_checkpoint_or_scratch(
        [restore_checkpoint_cfg], init_rng=jax.random.PRNGKey(0))

  @functools.lru_cache()
  def _get_predict_fn(self, train_state_axes):
    """Generate a partitioned prediction function for decoding."""
    def partial_predict_fn(params, batch, decode_rng):
      return self.model.predict_batch_with_aux(
          params, batch, decoder_params={'decode_rng': None})
    return self.partitioner.partition(
        partial_predict_fn,
        in_axis_resources=(
            train_state_axes.params,
            t5x.partitioning.PartitionSpec('data',), None),
        out_axis_resources=t5x.partitioning.PartitionSpec('data',)
    )

  def predict_tokens(self, batch, seed=0):
    """Predict tokens from preprocessed dataset batch."""
    prediction, _ = self._predict_fn(
        self._train_state.params, batch, jax.random.PRNGKey(seed))
    return self.vocabulary.decode_tf(prediction).numpy()

  def __call__(self, audio):
    """Infer note sequence from audio samples.
    
    Args:
      audio: 1-d numpy array of audio samples (16kHz) for a single example.

    Returns:
      A note_sequence of the transcribed audio.
    """
    ds = self.audio_to_dataset(audio)
    ds = self.preprocess(ds)

    model_ds = self.model.FEATURE_CONVERTER_CLS(pack=False)(
        ds, task_feature_lengths=self.sequence_length)
    model_ds = model_ds.batch(self.batch_size)

    inferences = (tokens for batch in model_ds.as_numpy_iterator()
                  for tokens in self.predict_tokens(batch))

    predictions = []
    for example, tokens in zip(ds.as_numpy_iterator(), inferences):
      predictions.append(self.postprocess(tokens, example))

    result = metrics_utils.event_predictions_to_ns(
        predictions, codec=self.codec, encoding_spec=self.encoding_spec)
    return result['est_ns']

  def audio_to_dataset(self, audio):
    """Create a TF Dataset of spectrograms from input audio."""
    frames, frame_times = self._audio_to_frames(audio)
    return tf.data.Dataset.from_tensors({
        'inputs': frames,
        'input_times': frame_times,
    })

  def _audio_to_frames(self, audio):
    """Compute spectrogram frames from audio."""
    frame_size = self.spectrogram_config.hop_width
    padding = [0, frame_size - len(audio) % frame_size]
    audio = np.pad(audio, padding, mode='constant')
    frames = spectrograms.split_audio(audio, self.spectrogram_config)
    num_frames = len(audio) // frame_size
    times = np.arange(num_frames) / self.spectrogram_config.frames_per_second
    return frames, times

  def preprocess(self, ds):
    pp_chain = [
        functools.partial(
            t5.data.preprocessors.split_tokens_to_inputs_length,
            sequence_length=self.sequence_length,
            output_features=self.output_features,
            feature_key='inputs',
            additional_feature_keys=['input_times']),
        # Cache occurs here during training.
        preprocessors.add_dummy_targets,
        functools.partial(
            preprocessors.compute_spectrograms,
            spectrogram_config=self.spectrogram_config)
    ]
    for pp in pp_chain:
      ds = pp(ds)
    return ds

  def postprocess(self, tokens, example):
    tokens = self._trim_eos(tokens)
    start_time = example['input_times'][0]
    # Round down to nearest symbolic token step.
    start_time -= start_time % (1 / self.codec.steps_per_second)
    return {
        'est_tokens': tokens,
        'start_time': start_time,
        # Internal MT3 code expects raw inputs, not used here.
        'raw_inputs': []
    }

  @staticmethod
  def _trim_eos(tokens):
    tokens = np.array(tokens, np.int32)
    if vocabularies.DECODED_EOS_ID in tokens:
      tokens = tokens[:np.argmax(tokens == vocabularies.DECODED_EOS_ID)]
    return tokens