From 0c2531d514a0ed9cf266ed3f1809e66203f96012 Mon Sep 17 00:00:00 2001 From: Vaibhav Srivastav Date: Tue, 20 Jun 2023 17:50:41 +0200 Subject: [PATCH] Created using Colaboratory --- ...odec_w_\360\237\244\227transformers.ipynb" | 453 ++++++++++++++++++ 1 file changed, 453 insertions(+) create mode 100644 "use_encodec_w_\360\237\244\227transformers.ipynb" diff --git "a/use_encodec_w_\360\237\244\227transformers.ipynb" "b/use_encodec_w_\360\237\244\227transformers.ipynb" new file mode 100644 index 0000000..5cf3ea6 --- /dev/null +++ "b/use_encodec_w_\360\237\244\227transformers.ipynb" @@ -0,0 +1,453 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "authorship_tag": "ABX9TyOnYveho6/q2HOKWBIW6lGR", + "include_colab_link": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Encodec is now in 🤗 Transformers\n", + "\n", + "Want to train your own Bark, MusicGen like models using SoTA audio codebook embeddings powered by Encodec?\n", + "\n", + "Look no more! Try it out today, all in less than 10 lines of code ♥" + ], + "metadata": { + "id": "35pbpwlVWV24" + } + }, + { + "cell_type": "markdown", + "source": [ + "### Set up our colab's development environment!\n", + "\n", + "We only need `datasets` and `transformers`. Since Encodec was merged after the last release, we'll install transformers from source." + ], + "metadata": { + "id": "kIw7dgFPXV9K" + } + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "w2rL5NkIJtco", + "outputId": "a7c0631e-7827-4dfc-c570-3ece2b69a768" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", + " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", + " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n" + ] + } + ], + "source": [ + "!pip install -q datasets git+https://github.com/huggingface/transformers.git@main" + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Import the datasets library and encodec model from transformers" + ], + "metadata": { + "id": "OEG8wdKhX9xZ" + } + }, + { + "cell_type": "code", + "source": [ + "from datasets import load_dataset, Audio\n", + "from transformers import EncodecModel, AutoProcessor" + ], + "metadata": { + "id": "Mkq4fdLWKgUB" + }, + "execution_count": 2, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "For the purposes of this demonstration, we'll leverage a dummy dataset, however, you can swap this for any dataset on the hub or bring your own." + ], + "metadata": { + "id": "-Cy85TM-YJsr" + } + }, + { + "cell_type": "code", + "source": [ + "# dummy dataset, however you can swap this with an dataset on the 🤗 hub or bring your own\n", + "librispeech_dummy = load_dataset(\"hf-internal-testing/librispeech_asr_dummy\", \"clean\", split=\"validation\")\n", + "\n", + "# load the model + processor (for pre-processing the audio)\n", + "model = EncodecModel.from_pretrained(\"facebook/encodec_24khz\")\n", + "processor = AutoProcessor.from_pretrained(\"facebook/encodec_24khz\")" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "RE__a8z5NS4F", + "outputId": "3dd1c165-24ab-4f83-8ddb-2afbb2a8a982" + }, + "execution_count": 14, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "WARNING:datasets.builder:Found cached dataset librispeech_asr_dummy (/root/.cache/huggingface/datasets/hf-internal-testing___librispeech_asr_dummy/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b)\n", + "Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "First, we'd need to pre-process the audio with the sampling rate needed for the model!" + ], + "metadata": { + "id": "r6JAgL5OYxgF" + } + }, + { + "cell_type": "code", + "source": [ + "# cast the audio data to the correct sampling rate for the model\n", + "librispeech_dummy = librispeech_dummy.cast_column(\"audio\", Audio(sampling_rate=processor.sampling_rate))\n", + "audio_sample = librispeech_dummy[0][\"audio\"][\"array\"]" + ], + "metadata": { + "id": "wgE6dua5NYlS" + }, + "execution_count": 4, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Let's give this audio file a listen!" + ], + "metadata": { + "id": "GsYEkaRwcCVy" + } + }, + { + "cell_type": "code", + "source": [ + "import IPython.display as ipd\n", + "\n", + "ipd.Audio(audio_sample, rate=processor.sampling_rate)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 73 + }, + "id": "Bq5dEeKyaX5a", + "outputId": "4bd15178-b316-455e-a03b-8c16457c1225" + }, + "execution_count": 19, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + " \n", + " " + ] + }, + "metadata": {}, + "execution_count": 19 + } + ] + }, + { + "cell_type": "code", + "source": [ + "# pre-process the audio inputs\n", + "inputs = processor(raw_audio=audio_sample, sampling_rate=processor.sampling_rate, return_tensors=\"pt\")" + ], + "metadata": { + "id": "XqJLed54Nbaa" + }, + "execution_count": 5, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "To compress the audio and get the model output waveform we can extract the output from the `model.decode`" + ], + "metadata": { + "id": "2LBmTsW6Z-3g" + } + }, + { + "cell_type": "code", + "source": [ + "# explicitly encode then decode the audio inputs\n", + "encoder_outputs = model.encode(inputs[\"input_values\"], inputs[\"padding_mask\"])\n", + "\n", + "# pass the encoder outputs to the decoder to get the compressed waveform\n", + "audio_values = model.decode(encoder_outputs.audio_codes, encoder_outputs.audio_scales, inputs[\"padding_mask\"])[0]" + ], + "metadata": { + "id": "sFg5cJiMNfcF" + }, + "execution_count": 20, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Since we did not specify a target `bandwidth` by default the model compresses the audio to 1.5KHz (the lowest possible compression rate.\n", + "\n", + "You can change this by passing in `bandwidth` along with the `model.encode` call.\n", + "\n", + "```python\n", + "encoder_outputs = model.encode(inputs[\"input_values\"], inputs[\"padding_mask\"], bandwidth = 24.0)\n", + "```\n", + "\n", + "The model supports: `1.5`, `3.0`, `6.0`, `12.0`, `24.0`Khz bandwidth.\n", + "\n", + "That's it and now let's compare the compressed waveform." + ], + "metadata": { + "id": "n_jgxwexc2DO" + } + }, + { + "cell_type": "code", + "source": [ + "import IPython.display as ipd\n", + "\n", + "ipd.Audio(audio_values.detach().numpy()[0][0], rate=processor.sampling_rate)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 73 + }, + "id": "Lguq4cjBNjrI", + "outputId": "878096e4-b777-45ed-8e42-cb99674c5814" + }, + "execution_count": 30, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + " \n", + " " + ] + }, + "metadata": {}, + "execution_count": 30 + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "Bonus: You can use Encodec to extract discrete codebook representation for your input audio!\n", + "\n", + "These discrete representations can then be used for downstream audio language modeling tasks like Text-to-Speech, Text-to-Music and so on.\n", + "\n", + "P.S. [Bark](https://github.com/suno-ai/bark) and [MusicGen](https://github.com/facebookresearch/audiocraft) also use Encodec under the hood! ⚡️" + ], + "metadata": { + "id": "vLFCqOQydCn_" + } + }, + { + "cell_type": "code", + "source": [ + "audio_codes = model(inputs[\"input_values\"], inputs[\"padding_mask\"]).audio_codes" + ], + "metadata": { + "id": "L1CoRU7SNknU" + }, + "execution_count": 9, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "audio_codes" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "mVJcobjofwBW", + "outputId": "cf11a4bd-5519-40c3-8741-d678e8d08a32" + }, + "execution_count": 31, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "tensor([[[[ 62, 835, 835, 835, 835, 835, 835, 835, 408, 408, 835,\n", + " 835, 835, 835, 835, 408, 408, 408, 408, 408, 408, 408,\n", + " 408, 408, 408, 408, 408, 408, 408, 408, 408, 408, 408,\n", + " 408, 408, 408, 408, 408, 408, 408, 339, 228, 570, 991,\n", + " 681, 972, 969, 303, 38, 463, 738, 106, 855, 602, 142,\n", + " 511, 722, 860, 604, 876, 738, 106, 1014, 405, 405, 488,\n", + " 461, 461, 461, 293, 736, 933, 894, 723, 784, 837, 291,\n", + " 1000, 52, 1019, 488, 854, 872, 585, 991, 784, 723, 11,\n", + " 722, 722, 722, 681, 723, 723, 723, 681, 734, 825, 534,\n", + " 972, 303, 53, 53, 463, 463, 53, 373, 373, 136, 991,\n", + " 840, 765, 407, 303, 840, 916, 613, 393, 430, 876, 738,\n", + " 865, 408, 738, 904, 40, 103, 731, 731, 935, 935, 80,\n", + " 20, 953, 593, 99, 132, 457, 254, 899, 53, 53, 463,\n", + " 53, 373, 627, 73, 395, 80, 534, 807, 690, 871, 164,\n", + " 70, 936, 858, 30, 30, 971, 155, 327, 523, 950, 30,\n", + " 876, 339, 339, 463, 463, 339, 53, 935, 164, 131, 629,\n", + " 957, 131, 310, 811, 900, 311, 681, 394, 534, 373, 63,\n", + " 901, 690, 950, 291, 650, 723, 1014, 373, 408, 738, 106,\n", + " 106, 213, 405, 488, 731, 666, 722, 489, 405, 405, 593,\n", + " 830, 226, 723, 585, 990, 246, 246, 875, 731, 53, 53,\n", + " 53, 148, 53, 53, 373, 293, 293, 1014, 424, 303, 373,\n", + " 957, 976, 433, 53, 53, 53, 53, 53, 53, 53, 103,\n", + " 537, 408, 408, 738, 408, 408, 738, 1017, 95, 811, 879,\n", + " 1022, 310, 900, 722, 583, 310, 900, 604, 224, 1022, 143,\n", + " 722, 605, 868, 944, 311, 901, 926, 701, 291, 936, 131,\n", + " 724, 228, 604, 213, 681, 306, 540, 549, 30, 414, 504,\n", + " 142, 723, 511, 585, 1020, 142, 187, 950, 327, 650, 259,\n", + " 582, 955, 598, 228, 604, 432, 1019, 738, 106, 855, 341,\n", + " 341, 91, 91, 52, 900, 570, 80, 534, 255, 534, 182,\n", + " 534, 405, 327, 128, 327, 640, 112, 216, 291, 656, 628,\n", + " 724, 373, 408, 213, 935, 872, 293, 604, 699, 699, 900,\n", + " 957, 131, 310, 131, 276, 666, 779, 276, 629, 310, 570,\n", + " 726, 570, 991, 628, 373, 53, 53, 53, 53, 53, 106,\n", + " 780, 738, 835, 999, 395, 16, 635, 264, 658, 820, 253,\n", + " 825, 348, 405, 814, 33, 635, 38, 899, 53, 53, 103,\n", + " 1019, 780, 408, 408, 408, 1017, 1019, 194, 983, 656, 253,\n", + " 291, 936, 155, 936, 936, 871, 30, 347, 876, 738, 475,\n", + " 835, 475, 835, 106, 408, 408, 408, 408, 408, 408, 408,\n", + " 408, 408, 408, 408, 408, 408, 408, 62, 408, 408, 106,\n", + " 106, 835, 475, 475, 475, 835, 475, 835, 835, 835, 835],\n", + " [1007, 1007, 1007, 544, 424, 424, 1007, 424, 302, 424, 913,\n", + " 913, 913, 913, 913, 424, 518, 518, 518, 518, 518, 518,\n", + " 424, 302, 424, 518, 518, 302, 518, 424, 518, 518, 302,\n", + " 518, 518, 518, 518, 518, 518, 518, 740, 857, 793, 987,\n", + " 466, 875, 742, 847, 1010, 645, 993, 913, 857, 806, 1023,\n", + " 466, 317, 169, 669, 841, 363, 544, 282, 742, 766, 973,\n", + " 973, 973, 973, 729, 795, 201, 76, 822, 946, 672, 840,\n", + " 161, 99, 835, 973, 268, 654, 429, 413, 413, 349, 380,\n", + " 320, 524, 813, 320, 437, 466, 466, 887, 742, 74, 587,\n", + " 1010, 35, 910, 638, 973, 930, 243, 1023, 1023, 752, 860,\n", + " 678, 29, 645, 429, 973, 982, 326, 35, 572, 859, 937,\n", + " 424, 913, 544, 937, 910, 2, 2, 2, 2, 133, 879,\n", + " 495, 639, 343, 951, 289, 801, 973, 984, 957, 940, 481,\n", + " 957, 928, 497, 648, 144, 147, 671, 990, 944, 318, 750,\n", + " 283, 75, 972, 199, 858, 809, 167, 921, 821, 507, 78,\n", + " 841, 859, 602, 519, 241, 974, 144, 571, 994, 666, 769,\n", + " 896, 752, 127, 404, 672, 908, 186, 403, 195, 742, 154,\n", + " 917, 64, 438, 52, 147, 892, 656, 868, 937, 896, 424,\n", + " 424, 1008, 973, 973, 26, 711, 554, 237, 42, 199, 914,\n", + " 546, 585, 914, 914, 160, 638, 471, 742, 1010, 268, 1010,\n", + " 1010, 973, 984, 742, 875, 890, 354, 788, 492, 639, 429,\n", + " 813, 74, 35, 207, 222, 1023, 1010, 831, 645, 957, 645,\n", + " 961, 518, 937, 544, 424, 424, 518, 544, 581, 602, 236,\n", + " 801, 165, 769, 1002, 693, 924, 868, 363, 959, 556, 729,\n", + " 303, 462, 67, 466, 14, 127, 349, 320, 944, 292, 782,\n", + " 579, 885, 859, 791, 887, 556, 193, 147, 64, 132, 42,\n", + " 45, 148, 169, 867, 615, 1005, 615, 35, 615, 615, 541,\n", + " 354, 954, 12, 355, 917, 1021, 363, 841, 544, 544, 497,\n", + " 320, 492, 969, 859, 214, 186, 16, 931, 917, 673, 857,\n", + " 636, 872, 739, 739, 822, 706, 931, 42, 794, 129, 92,\n", + " 458, 989, 518, 813, 837, 693, 890, 326, 840, 835, 765,\n", + " 829, 1021, 859, 229, 82, 579, 859, 49, 415, 519, 930,\n", + " 177, 675, 1014, 34, 890, 357, 973, 4, 638, 638, 648,\n", + " 518, 518, 424, 646, 859, 914, 195, 492, 639, 461, 701,\n", + " 670, 910, 189, 619, 176, 1021, 292, 1001, 292, 646, 541,\n", + " 857, 913, 913, 913, 518, 544, 740, 789, 593, 365, 892,\n", + " 365, 656, 742, 656, 742, 742, 742, 594, 302, 859, 519,\n", + " 974, 519, 519, 518, 424, 913, 913, 518, 518, 518, 518,\n", + " 518, 518, 518, 544, 518, 544, 544, 424, 740, 518, 544,\n", + " 544, 424, 544, 519, 519, 652, 519, 518, 424, 518, 518]]]])" + ] + }, + "metadata": {}, + "execution_count": 31 + } + ] + }, + { + "cell_type": "code", + "source": [ + "audio_codes.shape" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "sIq7AqoQNqJp", + "outputId": "2b18bb9c-4cad-487c-fef0-00b7fa7b397c" + }, + "execution_count": 11, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "torch.Size([1, 1, 2, 440])" + ] + }, + "metadata": {}, + "execution_count": 11 + } + ] + } + ] +} \ No newline at end of file