Forced Alignment with Wav2Vec2 — Torchaudio 0.10.0 documentation

Excerpt

Author Moto Hira


Author Moto Hira

This tutorial shows how to align transcript to speech with torchaudio, using CTC segmentation algorithm described in CTC-Segmentation of Large Corpora for German End-to-end Speech Recognition.

Overview

The process of alignment looks like the following.

  1. Estimate the frame-wise label probability from audio waveform

  2. Generate the trellis matrix which represents the probability of labels aligned at time step.

  3. Find the most likely path from the trellis matrix.

In this example, we use torchaudio’s Wav2Vec2 model for acoustic feature extraction.

Preparation

First we import the necessary packages, and fetch data that we work on.

Out:

<span></span>1.10.0+cpu
0.10.0+cpu
cpu

Generate frame-wise label probability

The first step is to generate the label class porbability of each aduio frame. We can use a Wav2Vec2 model that is trained for ASR. Here we use torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H().

torchaudio provides easy access to pretrained models with associated labels.

Note

In the subsequent sections, we will compute the probability in log-domain to avoid numerical instability. For this purpose, we normalize the emission with torch.log_softmax().

Visualization

<span></span><span>print</span><span>(</span><a href="https://docs.python.org/3/library/stdtypes.html#tuple" title="builtins.tuple"><span>labels</span></a><span>)</span>
<span>plt</span><span>.</span><span>imshow</span><span>(</span><a href="https://pytorch.org/docs/stable/tensors.html#torch.Tensor" title="torch.Tensor"><span>emission</span><span>.</span><span>T</span></a><span>)</span>
<span>plt</span><span>.</span><span>colorbar</span><span>()</span>
<span>plt</span><span>.</span><span>title</span><span>(</span><span>"Frame-wise class probability"</span><span>)</span>
<span>plt</span><span>.</span><span>xlabel</span><span>(</span><span>"Time"</span><span>)</span>
<span>plt</span><span>.</span><span>ylabel</span><span>(</span><span>"Labels"</span><span>)</span>
<span>plt</span><span>.</span><span>show</span><span>()</span>

Frame-wise class probability

Out:

<span></span>('&lt;s&gt;', '&lt;pad&gt;', '&lt;/s&gt;', '&lt;unk&gt;', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')

Generate alignment probability (trellis)

From the emission matrix, next we generate the trellis which represents the probability of transcript labels occur at each time frame.

Trellis is 2D matrix with time axis and label axis. The label axis represents the transcript that we are aligning. In the following, we use to denote the index in time axis and to denote the index in label axis. represents the label at label index .

To generate, the probability of time step , we look at the trellis from time step and emission at time step . There are two path to reach to time step with label . The first one is the case where the label was at and there was no label change from to . The other case is where the label was at and it transitioned to the next label at .

The follwoing diagram illustrates this transition.

https://download.pytorch.org/torchaudio/tutorial-assets/ctc-forward.png

Since we are looking for the most likely transitions, we take the more likely path for the value of , that is

where represents is trellis matrix, and represents the probability of label at time step . represents the blank token from CTC formulation. (For the detail of CTC algorithm, please refer to the Sequence Modeling with CTC [distill.pub])

<span></span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str"><span>transcript</span></a> <span>=</span> <span>'I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT'</span>
<a href="https://docs.python.org/3/library/stdtypes.html#dict" title="builtins.dict"><span>dictionary</span></a>  <span>=</span> <span>{</span><span>c</span><span>:</span> <span>i</span> <span>for</span> <span>i</span><span>,</span> <span>c</span> <span>in</span> <span>enumerate</span><span>(</span><a href="https://docs.python.org/3/library/stdtypes.html#tuple" title="builtins.tuple"><span>labels</span></a><span>)}</span>

<a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list"><span>tokens</span></a> <span>=</span> <span>[</span><a href="https://docs.python.org/3/library/stdtypes.html#dict" title="builtins.dict"><span>dictionary</span></a><span>[</span><span>c</span><span>]</span> <span>for</span> <span>c</span> <span>in</span> <a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str"><span>transcript</span></a><span>]</span>
<span>print</span><span>(</span><span>list</span><span>(</span><span>zip</span><span>(</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str"><span>transcript</span></a><span>,</span> <a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list"><span>tokens</span></a><span>)))</span>

<span>def</span> <span>get_trellis</span><span>(</span><a href="https://pytorch.org/docs/stable/tensors.html#torch.Tensor" title="torch.Tensor"><span>emission</span></a><span>,</span> <a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list"><span>tokens</span></a><span>,</span> <span>blank_id</span><span>=</span><span>0</span><span>):</span>
  <span>num_frame</span> <span>=</span> <a href="https://pytorch.org/docs/stable/tensors.html#torch.Tensor" title="torch.Tensor"><span>emission</span></a><span>.</span><span>size</span><span>(</span><span>0</span><span>)</span>
  <span>num_tokens</span> <span>=</span> <span>len</span><span>(</span><a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list"><span>tokens</span></a><span>)</span>

  <span># Trellis has extra diemsions for both time axis and tokens.</span>
  <span># The extra dim for tokens represents &lt;SoS&gt; (start-of-sentence)</span>
  <span># The extra dim for time axis is for simplification of the code.</span>
  <a href="https://pytorch.org/docs/stable/tensors.html#torch.Tensor" title="torch.Tensor"><span>trellis</span></a> <span>=</span> <a href="https://pytorch.org/docs/stable/generated/torch.full.html#torch.full" title="torch.full"><span>torch</span><span>.</span><span>full</span></a><span>((</span><span>num_frame</span><span>+</span><span>1</span><span>,</span> <span>num_tokens</span><span>+</span><span>1</span><span>),</span> <span>-</span><span>float</span><span>(</span><span>'inf'</span><span>))</span>
  <a href="https://pytorch.org/docs/stable/tensors.html#torch.Tensor" title="torch.Tensor"><span>trellis</span></a><span>[:,</span> <span>0</span><span>]</span> <span>=</span> <span>0</span>
  <span>for</span> <span>t</span> <span>in</span> <span>range</span><span>(</span><span>num_frame</span><span>):</span>
    <a href="https://pytorch.org/docs/stable/tensors.html#torch.Tensor" title="torch.Tensor"><span>trellis</span></a><span>[</span><span>t</span><span>+</span><span>1</span><span>,</span> <span>1</span><span>:]</span> <span>=</span> <a href="https://pytorch.org/docs/stable/generated/torch.maximum.html#torch.maximum" title="torch.maximum"><span>torch</span><span>.</span><span>maximum</span></a><span>(</span>
        <span># Score for staying at the same token</span>
        <a href="https://pytorch.org/docs/stable/tensors.html#torch.Tensor" title="torch.Tensor"><span>trellis</span></a><span>[</span><span>t</span><span>,</span> <span>1</span><span>:]</span> <span>+</span> <a href="https://pytorch.org/docs/stable/tensors.html#torch.Tensor" title="torch.Tensor"><span>emission</span></a><span>[</span><span>t</span><span>,</span> <span>blank_id</span><span>],</span>
        <span># Score for changing to the next token</span>
        <a href="https://pytorch.org/docs/stable/tensors.html#torch.Tensor" title="torch.Tensor"><span>trellis</span></a><span>[</span><span>t</span><span>,</span> <span>:</span><span>-</span><span>1</span><span>]</span> <span>+</span> <a href="https://pytorch.org/docs/stable/tensors.html#torch.Tensor" title="torch.Tensor"><span>emission</span></a><span>[</span><span>t</span><span>,</span> <a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list"><span>tokens</span></a><span>],</span>
    <span>)</span>
  <span>return</span> <a href="https://pytorch.org/docs/stable/tensors.html#torch.Tensor" title="torch.Tensor"><span>trellis</span></a>

<a href="https://pytorch.org/docs/stable/tensors.html#torch.Tensor" title="torch.Tensor"><span>trellis</span></a> <span>=</span> <span>get_trellis</span><span>(</span><a href="https://pytorch.org/docs/stable/tensors.html#torch.Tensor" title="torch.Tensor"><span>emission</span></a><span>,</span> <a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list"><span>tokens</span></a><span>)</span>

Out:

<span></span>[('I', 10), ('|', 4), ('H', 11), ('A', 7), ('D', 14), ('|', 4), ('T', 6), ('H', 11), ('A', 7), ('T', 6), ('|', 4), ('C', 19), ('U', 16), ('R', 13), ('I', 10), ('O', 8), ('S', 12), ('I', 10), ('T', 6), ('Y', 22), ('|', 4), ('B', 24), ('E', 5), ('S', 12), ('I', 10), ('D', 14), ('E', 5), ('|', 4), ('M', 17), ('E', 5), ('|', 4), ('A', 7), ('T', 6), ('|', 4), ('T', 6), ('H', 11), ('I', 10), ('S', 12), ('|', 4), ('M', 17), ('O', 8), ('M', 17), ('E', 5), ('N', 9), ('T', 6)]

Visualization

<span></span><span>plt</span><span>.</span><span>imshow</span><span>(</span><a href="https://pytorch.org/docs/stable/tensors.html#torch.Tensor" title="torch.Tensor"><span>trellis</span></a><span>[</span><span>1</span><span>:,</span> <span>1</span><span>:]</span><span>.</span><span>T</span><span>,</span> <span>origin</span><span>=</span><span>'lower'</span><span>)</span>
<span>plt</span><span>.</span><span>annotate</span><span>(</span><span>"- Inf"</span><span>,</span> <span>(</span><a href="https://pytorch.org/docs/stable/tensors.html#torch.Tensor" title="torch.Tensor"><span>trellis</span></a><span>.</span><span>size</span><span>(</span><span>1</span><span>)</span> <span>/</span> <span>5</span><span>,</span> <a href="https://pytorch.org/docs/stable/tensors.html#torch.Tensor" title="torch.Tensor"><span>trellis</span></a><span>.</span><span>size</span><span>(</span><span>1</span><span>)</span> <span>/</span> <span>1.5</span><span>))</span>
<span>plt</span><span>.</span><span>colorbar</span><span>()</span>
<span>plt</span><span>.</span><span>show</span><span>()</span>

forced alignment tutorial

In the above visualization, we can see that there is a trace of high probability crossing the matrix diagonally.

Find the most likely path (backtracking)

Once the trellis is generated, we will traverse it following the elements with high probability.

We will start from the last label index with the time step of highest probability, then, we traverse back in time, picking stay () or transition (), based on the post-transition probability or .

Transition is done once the label reaches the beginning.

The trellis matrix is used for path-finding, but for the final probability of each segment, we take the frame-wise probability from emission matrix.

<span></span><span>@dataclass</span>
<span>class</span> <span>Point</span><span>:</span>
  <span>token_index</span><span>:</span> <span>int</span>
  <span>time_index</span><span>:</span> <span>int</span>
  <span>score</span><span>:</span> <span>float</span>


<span>def</span> <span>backtrack</span><span>(</span><a href="https://pytorch.org/docs/stable/tensors.html#torch.Tensor" title="torch.Tensor"><span>trellis</span></a><span>,</span> <a href="https://pytorch.org/docs/stable/tensors.html#torch.Tensor" title="torch.Tensor"><span>emission</span></a><span>,</span> <a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list"><span>tokens</span></a><span>,</span> <span>blank_id</span><span>=</span><span>0</span><span>):</span>
  <span># Note:</span>
  <span># j and t are indices for trellis, which has extra dimensions</span>
  <span># for time and tokens at the beginning.</span>
  <span># When refering to time frame index `T` in trellis,</span>
  <span># the corresponding index in emission is `T-1`.</span>
  <span># Similarly, when refering to token index `J` in trellis,</span>
  <span># the corresponding index in transcript is `J-1`.</span>
  <span>j</span> <span>=</span> <a href="https://pytorch.org/docs/stable/tensors.html#torch.Tensor" title="torch.Tensor"><span>trellis</span></a><span>.</span><span>size</span><span>(</span><span>1</span><span>)</span> <span>-</span> <span>1</span>
  <span>t_start</span> <span>=</span> <a href="https://pytorch.org/docs/stable/generated/torch.argmax.html#torch.argmax" title="torch.argmax"><span>torch</span><span>.</span><span>argmax</span></a><span>(</span><a href="https://pytorch.org/docs/stable/tensors.html#torch.Tensor" title="torch.Tensor"><span>trellis</span></a><span>[:,</span> <span>j</span><span>])</span><span>.</span><span>item</span><span>()</span>

  <a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list"><span>path</span></a> <span>=</span> <span>[]</span>
  <span>for</span> <span>t</span> <span>in</span> <span>range</span><span>(</span><span>t_start</span><span>,</span> <span>0</span><span>,</span> <span>-</span><span>1</span><span>):</span>
    <span># 1. Figure out if the current position was stay or change</span>
    <span># Note (again):</span>
    <span># `emission[J-1]` is the emission at time frame `J` of trellis dimension.</span>
    <span># Score for token staying the same from time frame J-1 to T.</span>
    <span>stayed</span> <span>=</span> <a href="https://pytorch.org/docs/stable/tensors.html#torch.Tensor" title="torch.Tensor"><span>trellis</span></a><span>[</span><span>t</span><span>-</span><span>1</span><span>,</span> <span>j</span><span>]</span> <span>+</span> <a href="https://pytorch.org/docs/stable/tensors.html#torch.Tensor" title="torch.Tensor"><span>emission</span></a><span>[</span><span>t</span><span>-</span><span>1</span><span>,</span> <span>blank_id</span><span>]</span>
    <span># Score for token changing from C-1 at T-1 to J at T.</span>
    <span>changed</span> <span>=</span> <a href="https://pytorch.org/docs/stable/tensors.html#torch.Tensor" title="torch.Tensor"><span>trellis</span></a><span>[</span><span>t</span><span>-</span><span>1</span><span>,</span> <span>j</span><span>-</span><span>1</span><span>]</span> <span>+</span> <a href="https://pytorch.org/docs/stable/tensors.html#torch.Tensor" title="torch.Tensor"><span>emission</span></a><span>[</span><span>t</span><span>-</span><span>1</span><span>,</span> <a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list"><span>tokens</span></a><span>[</span><span>j</span><span>-</span><span>1</span><span>]]</span>

    <span># 2. Store the path with frame-wise probability.</span>
    <span>prob</span> <span>=</span> <a href="https://pytorch.org/docs/stable/tensors.html#torch.Tensor" title="torch.Tensor"><span>emission</span></a><span>[</span><span>t</span><span>-</span><span>1</span><span>,</span> <a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list"><span>tokens</span></a><span>[</span><span>j</span><span>-</span><span>1</span><span>]</span> <span>if</span> <span>changed</span> <span>&gt;</span> <span>stayed</span> <span>else</span> <span>0</span><span>]</span><span>.</span><span>exp</span><span>()</span><span>.</span><span>item</span><span>()</span>
    <span># Return token index and time index in non-trellis coordinate.</span>
    <a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list"><span>path</span></a><span>.</span><span>append</span><span>(</span><span>Point</span><span>(</span><span>j</span><span>-</span><span>1</span><span>,</span> <span>t</span><span>-</span><span>1</span><span>,</span> <span>prob</span><span>))</span>

    <span># 3. Update the token</span>
    <span>if</span> <span>changed</span> <span>&gt;</span> <span>stayed</span><span>:</span>
      <span>j</span> <span>-=</span> <span>1</span>
      <span>if</span> <span>j</span> <span>==</span> <span>0</span><span>:</span>
        <span>break</span>
  <span>else</span><span>:</span>
    <span>raise</span> <span>ValueError</span><span>(</span><span>'Failed to align'</span><span>)</span>
  <span>return</span> <a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list"><span>path</span></a><span>[::</span><span>-</span><span>1</span><span>]</span>

<a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list"><span>path</span></a> <span>=</span> <span>backtrack</span><span>(</span><a href="https://pytorch.org/docs/stable/tensors.html#torch.Tensor" title="torch.Tensor"><span>trellis</span></a><span>,</span> <a href="https://pytorch.org/docs/stable/tensors.html#torch.Tensor" title="torch.Tensor"><span>emission</span></a><span>,</span> <a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list"><span>tokens</span></a><span>)</span>
<span>print</span><span>(</span><a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list"><span>path</span></a><span>)</span>

Out:

<span></span>[Point(token_index=0, time_index=30, score=0.9999842643737793), Point(token_index=0, time_index=31, score=0.9846950173377991), Point(token_index=0, time_index=32, score=0.9999707937240601), Point(token_index=0, time_index=33, score=0.1540004163980484), Point(token_index=1, time_index=34, score=0.9999173879623413), Point(token_index=1, time_index=35, score=0.6080269813537598), Point(token_index=2, time_index=36, score=0.9997720122337341), Point(token_index=2, time_index=37, score=0.9997130036354065), Point(token_index=3, time_index=38, score=0.9999357461929321), Point(token_index=3, time_index=39, score=0.9861581325531006), Point(token_index=4, time_index=40, score=0.9238582253456116), Point(token_index=4, time_index=41, score=0.9257349967956543), Point(token_index=4, time_index=42, score=0.015662744641304016), Point(token_index=5, time_index=43, score=0.9998378753662109), Point(token_index=6, time_index=44, score=0.9988442659378052), Point(token_index=6, time_index=45, score=0.10144233703613281), Point(token_index=7, time_index=46, score=0.9999426603317261), Point(token_index=7, time_index=47, score=0.9999946355819702), Point(token_index=8, time_index=48, score=0.9979603290557861), Point(token_index=8, time_index=49, score=0.036036331206560135), Point(token_index=8, time_index=50, score=0.06162545830011368), Point(token_index=9, time_index=51, score=4.3326534068910405e-05), Point(token_index=10, time_index=52, score=0.999980092048645), Point(token_index=10, time_index=53, score=0.9967095851898193), Point(token_index=10, time_index=54, score=0.9999257326126099), Point(token_index=11, time_index=55, score=0.9999982118606567), Point(token_index=11, time_index=56, score=0.9990689158439636), Point(token_index=11, time_index=57, score=0.9999996423721313), Point(token_index=11, time_index=58, score=0.9999996423721313), Point(token_index=11, time_index=59, score=0.8457557559013367), Point(token_index=12, time_index=60, score=0.9999995231628418), Point(token_index=12, time_index=61, score=0.999601423740387), Point(token_index=13, time_index=62, score=0.999998927116394), Point(token_index=13, time_index=63, score=0.0035246757324784994), Point(token_index=13, time_index=64, score=1.0), Point(token_index=13, time_index=65, score=1.0), Point(token_index=14, time_index=66, score=0.9999916553497314), Point(token_index=14, time_index=67, score=0.9971591234207153), Point(token_index=14, time_index=68, score=0.9999990463256836), Point(token_index=14, time_index=69, score=0.9999991655349731), Point(token_index=14, time_index=70, score=0.9999998807907104), Point(token_index=14, time_index=71, score=0.9999998807907104), Point(token_index=14, time_index=72, score=0.9999881982803345), Point(token_index=14, time_index=73, score=0.011426654644310474), Point(token_index=15, time_index=74, score=0.9999978542327881), Point(token_index=15, time_index=75, score=0.9996134042739868), Point(token_index=15, time_index=76, score=0.999998927116394), Point(token_index=15, time_index=77, score=0.9727553129196167), Point(token_index=16, time_index=78, score=0.999998927116394), Point(token_index=16, time_index=79, score=0.9949328303337097), Point(token_index=16, time_index=80, score=0.999998927116394), Point(token_index=16, time_index=81, score=0.9999121427536011), Point(token_index=17, time_index=82, score=0.9999775886535645), Point(token_index=17, time_index=83, score=0.6576985716819763), Point(token_index=17, time_index=84, score=0.9984292387962341), Point(token_index=18, time_index=85, score=0.9999874830245972), Point(token_index=18, time_index=86, score=0.9993745684623718), Point(token_index=18, time_index=87, score=0.9999988079071045), Point(token_index=18, time_index=88, score=0.10424679517745972), Point(token_index=19, time_index=89, score=0.9999969005584717), Point(token_index=19, time_index=90, score=0.3978584110736847), Point(token_index=20, time_index=91, score=0.9999933242797852), Point(token_index=20, time_index=92, score=1.6990968561003683e-06), Point(token_index=20, time_index=93, score=0.9861307740211487), Point(token_index=21, time_index=94, score=0.9999960660934448), Point(token_index=21, time_index=95, score=0.9992727637290955), Point(token_index=21, time_index=96, score=0.9993411898612976), Point(token_index=22, time_index=97, score=0.9999983310699463), Point(token_index=22, time_index=98, score=0.9999971389770508), Point(token_index=22, time_index=99, score=0.9999997615814209), Point(token_index=22, time_index=100, score=0.9999995231628418), Point(token_index=23, time_index=101, score=0.9999732971191406), Point(token_index=23, time_index=102, score=0.9983227849006653), Point(token_index=23, time_index=103, score=0.9999992847442627), Point(token_index=23, time_index=104, score=0.9999997615814209), Point(token_index=23, time_index=105, score=1.0), Point(token_index=23, time_index=106, score=1.0), Point(token_index=23, time_index=107, score=0.9998630285263062), Point(token_index=24, time_index=108, score=0.9999982118606567), Point(token_index=24, time_index=109, score=0.9988579750061035), Point(token_index=25, time_index=110, score=0.9999798536300659), Point(token_index=25, time_index=111, score=0.8572984933853149), Point(token_index=26, time_index=112, score=0.9999847412109375), Point(token_index=26, time_index=113, score=0.9870278835296631), Point(token_index=26, time_index=114, score=1.904349664982874e-05), Point(token_index=27, time_index=115, score=0.9999794960021973), Point(token_index=27, time_index=116, score=0.9998253583908081), Point(token_index=28, time_index=117, score=0.9999991655349731), Point(token_index=28, time_index=118, score=0.9999734163284302), Point(token_index=28, time_index=119, score=0.0009004566818475723), Point(token_index=29, time_index=120, score=0.9993478655815125), Point(token_index=29, time_index=121, score=0.9975456595420837), Point(token_index=29, time_index=122, score=0.00030501981382258236), Point(token_index=30, time_index=123, score=0.9999344348907471), Point(token_index=30, time_index=124, score=6.0791257965320256e-06), Point(token_index=31, time_index=125, score=0.9833147525787354), Point(token_index=32, time_index=126, score=0.9974580407142639), Point(token_index=32, time_index=127, score=0.0008236187277361751), Point(token_index=33, time_index=128, score=0.9965153932571411), Point(token_index=33, time_index=129, score=0.0174646507948637), Point(token_index=34, time_index=130, score=0.9989169836044312), Point(token_index=35, time_index=131, score=0.9999698400497437), Point(token_index=35, time_index=132, score=0.9999842643737793), Point(token_index=36, time_index=133, score=0.9997640252113342), Point(token_index=36, time_index=134, score=0.5096558928489685), Point(token_index=37, time_index=135, score=0.9998302459716797), Point(token_index=37, time_index=136, score=0.08524401485919952), Point(token_index=37, time_index=137, score=0.0040728189051151276), Point(token_index=38, time_index=138, score=0.9999814033508301), Point(token_index=38, time_index=139, score=0.012057782150804996), Point(token_index=38, time_index=140, score=0.9999979734420776), Point(token_index=38, time_index=141, score=0.0005778099875897169), Point(token_index=39, time_index=142, score=0.9999068975448608), Point(token_index=39, time_index=143, score=0.9999960660934448), Point(token_index=39, time_index=144, score=0.9999980926513672), Point(token_index=40, time_index=145, score=0.9999915361404419), Point(token_index=40, time_index=146, score=0.9971170425415039), Point(token_index=40, time_index=147, score=0.9981803894042969), Point(token_index=41, time_index=148, score=0.9999310970306396), Point(token_index=41, time_index=149, score=0.9879505634307861), Point(token_index=41, time_index=150, score=0.9997628331184387), Point(token_index=42, time_index=151, score=0.9999535083770752), Point(token_index=43, time_index=152, score=0.9999716281890869), Point(token_index=44, time_index=153, score=0.6811745762825012)]

Visualization

<span></span><span>def</span> <span>plot_trellis_with_path</span><span>(</span><a href="https://pytorch.org/docs/stable/tensors.html#torch.Tensor" title="torch.Tensor"><span>trellis</span></a><span>,</span> <a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list"><span>path</span></a><span>):</span>
  <span># To plot trellis with path, we take advantage of 'nan' value</span>
  <span>trellis_with_path</span> <span>=</span> <a href="https://pytorch.org/docs/stable/tensors.html#torch.Tensor" title="torch.Tensor"><span>trellis</span></a><span>.</span><span>clone</span><span>()</span>
  <span>for</span> <span>i</span><span>,</span> <span>p</span> <span>in</span> <span>enumerate</span><span>(</span><a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list"><span>path</span></a><span>):</span>
    <span>trellis_with_path</span><span>[</span><span>p</span><span>.</span><span>time_index</span><span>,</span> <span>p</span><span>.</span><span>token_index</span><span>]</span> <span>=</span> <span>float</span><span>(</span><span>'nan'</span><span>)</span>
  <span>plt</span><span>.</span><span>imshow</span><span>(</span><span>trellis_with_path</span><span>[</span><span>1</span><span>:,</span> <span>1</span><span>:]</span><span>.</span><span>T</span><span>,</span> <span>origin</span><span>=</span><span>'lower'</span><span>)</span>

<span>plot_trellis_with_path</span><span>(</span><a href="https://pytorch.org/docs/stable/tensors.html#torch.Tensor" title="torch.Tensor"><span>trellis</span></a><span>,</span> <a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list"><span>path</span></a><span>)</span>
<span>plt</span><span>.</span><span>title</span><span>(</span><span>"The path found by backtracking"</span><span>)</span>
<span>plt</span><span>.</span><span>show</span><span>()</span>

The path found by backtracking

Looking good. Now this path contains repetations for the same labels, so let’s merge them to make it close to the original transcript.

When merging the multiple path points, we simply take the average probability for the merged segments.

<span></span><span># Merge the labels</span>
<span>@dataclass</span>
<span>class</span> <span>Segment</span><span>:</span>
  <span>label</span><span>:</span> <span>str</span>
  <span>start</span><span>:</span> <span>int</span>
  <span>end</span><span>:</span> <span>int</span>
  <span>score</span><span>:</span> <span>float</span>

  <span>def</span> <span>__repr__</span><span>(</span><span>self</span><span>):</span>
    <span>return</span> <span>f</span><span>"</span><span>{</span><span>self</span><span>.</span><span>label</span><span>}</span><span>\t</span><span>(</span><span>{</span><span>self</span><span>.</span><span>score</span><span>:</span><span>4.2f</span><span>}</span><span>): [</span><span>{</span><span>self</span><span>.</span><span>start</span><span>:</span><span>5d</span><span>}</span><span>, </span><span>{</span><span>self</span><span>.</span><span>end</span><span>:</span><span>5d</span><span>}</span><span>)"</span>

  <span>@property</span>
  <span>def</span> <span>length</span><span>(</span><span>self</span><span>):</span>
    <span>return</span> <span>self</span><span>.</span><span>end</span> <span>-</span> <span>self</span><span>.</span><span>start</span>

<span>def</span> <span>merge_repeats</span><span>(</span><a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list"><span>path</span></a><span>):</span>
  <span>i1</span><span>,</span> <span>i2</span> <span>=</span> <span>0</span><span>,</span> <span>0</span>
  <a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list"><span>segments</span></a> <span>=</span> <span>[]</span>
  <span>while</span> <span>i1</span> <span>&lt;</span> <span>len</span><span>(</span><a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list"><span>path</span></a><span>):</span>
    <span>while</span> <span>i2</span> <span>&lt;</span> <span>len</span><span>(</span><a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list"><span>path</span></a><span>)</span> <span>and</span> <a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list"><span>path</span></a><span>[</span><span>i1</span><span>]</span><span>.</span><span>token_index</span> <span>==</span> <a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list"><span>path</span></a><span>[</span><span>i2</span><span>]</span><span>.</span><span>token_index</span><span>:</span>
      <span>i2</span> <span>+=</span> <span>1</span>
    <span>score</span> <span>=</span> <span>sum</span><span>(</span><a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list"><span>path</span></a><span>[</span><span>k</span><span>]</span><span>.</span><span>score</span> <span>for</span> <span>k</span> <span>in</span> <span>range</span><span>(</span><span>i1</span><span>,</span> <span>i2</span><span>))</span> <span>/</span> <span>(</span><span>i2</span> <span>-</span> <span>i1</span><span>)</span>
    <a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list"><span>segments</span></a><span>.</span><span>append</span><span>(</span><span>Segment</span><span>(</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str"><span>transcript</span></a><span>[</span><a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list"><span>path</span></a><span>[</span><span>i1</span><span>]</span><span>.</span><span>token_index</span><span>],</span> <a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list"><span>path</span></a><span>[</span><span>i1</span><span>]</span><span>.</span><span>time_index</span><span>,</span> <a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list"><span>path</span></a><span>[</span><span>i2</span><span>-</span><span>1</span><span>]</span><span>.</span><span>time_index</span> <span>+</span> <span>1</span><span>,</span> <span>score</span><span>))</span>
    <span>i1</span> <span>=</span> <span>i2</span>
  <span>return</span> <a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list"><span>segments</span></a>

<a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list"><span>segments</span></a> <span>=</span> <span>merge_repeats</span><span>(</span><a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list"><span>path</span></a><span>)</span>
<span>for</span> <span>seg</span> <span>in</span> <a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list"><span>segments</span></a><span>:</span>
  <span>print</span><span>(</span><span>seg</span><span>)</span>

Out:

<span></span>I       (0.78): [   30,    34)
|       (0.80): [   34,    36)
H       (1.00): [   36,    38)
A       (0.99): [   38,    40)
D       (0.62): [   40,    43)
|       (1.00): [   43,    44)
T       (0.55): [   44,    46)
H       (1.00): [   46,    48)
A       (0.37): [   48,    51)
T       (0.00): [   51,    52)
|       (1.00): [   52,    55)
C       (0.97): [   55,    60)
U       (1.00): [   60,    62)
R       (0.75): [   62,    66)
I       (0.88): [   66,    74)
O       (0.99): [   74,    78)
S       (1.00): [   78,    82)
I       (0.89): [   82,    85)
T       (0.78): [   85,    89)
Y       (0.70): [   89,    91)
|       (0.66): [   91,    94)
B       (1.00): [   94,    97)
E       (1.00): [   97,   101)
S       (1.00): [  101,   108)
I       (1.00): [  108,   110)
D       (0.93): [  110,   112)
E       (0.66): [  112,   115)
|       (1.00): [  115,   117)
M       (0.67): [  117,   120)
E       (0.67): [  120,   123)
|       (0.50): [  123,   125)
A       (0.98): [  125,   126)
T       (0.50): [  126,   128)
|       (0.51): [  128,   130)
T       (1.00): [  130,   131)
H       (1.00): [  131,   133)
I       (0.75): [  133,   135)
S       (0.36): [  135,   138)
|       (0.50): [  138,   142)
M       (1.00): [  142,   145)
O       (1.00): [  145,   148)
M       (1.00): [  148,   151)
E       (1.00): [  151,   152)
N       (1.00): [  152,   153)
T       (0.68): [  153,   154)

Visualization

<span></span><span>def</span> <span>plot_trellis_with_segments</span><span>(</span><a href="https://pytorch.org/docs/stable/tensors.html#torch.Tensor" title="torch.Tensor"><span>trellis</span></a><span>,</span> <a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list"><span>segments</span></a><span>,</span> <a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str"><span>transcript</span></a><span>):</span>
  <span># To plot trellis with path, we take advantage of 'nan' value</span>
  <span>trellis_with_path</span> <span>=</span> <a href="https://pytorch.org/docs/stable/tensors.html#torch.Tensor" title="torch.Tensor"><span>trellis</span></a><span>.</span><span>clone</span><span>()</span>
  <span>for</span> <span>i</span><span>,</span> <span>seg</span> <span>in</span> <span>enumerate</span><span>(</span><a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list"><span>segments</span></a><span>):</span>
    <span>if</span> <a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str"><span>seg</span><span>.</span><span>label</span></a> <span>!=</span> <span>'|'</span><span>:</span>
      <span>trellis_with_path</span><span>[</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int"><span>seg</span><span>.</span><span>start</span></a><span>+</span><span>1</span><span>:</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int"><span>seg</span><span>.</span><span>end</span></a><span>+</span><span>1</span><span>,</span> <span>i</span><span>+</span><span>1</span><span>]</span> <span>=</span> <span>float</span><span>(</span><span>'nan'</span><span>)</span>

  <span>fig</span><span>,</span> <span>[</span><span>ax1</span><span>,</span> <span>ax2</span><span>]</span> <span>=</span> <span>plt</span><span>.</span><span>subplots</span><span>(</span><span>2</span><span>,</span> <span>1</span><span>,</span> <span>figsize</span><span>=</span><span>(</span><span>16</span><span>,</span> <span>9.5</span><span>))</span>
  <span>ax1</span><span>.</span><span>set_title</span><span>(</span><span>"Path, label and probability for each label"</span><span>)</span>
  <span>ax1</span><span>.</span><span>imshow</span><span>(</span><span>trellis_with_path</span><span>.</span><span>T</span><span>,</span> <span>origin</span><span>=</span><span>'lower'</span><span>)</span>
  <span>ax1</span><span>.</span><span>set_xticks</span><span>([])</span>

  <span>for</span> <span>i</span><span>,</span> <span>seg</span> <span>in</span> <span>enumerate</span><span>(</span><a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list"><span>segments</span></a><span>):</span>
    <span>if</span> <a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str"><span>seg</span><span>.</span><span>label</span></a> <span>!=</span> <span>'|'</span><span>:</span>
      <span>ax1</span><span>.</span><span>annotate</span><span>(</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str"><span>seg</span><span>.</span><span>label</span></a><span>,</span> <span>(</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int"><span>seg</span><span>.</span><span>start</span></a> <span>+</span> <span>.7</span><span>,</span> <span>i</span> <span>+</span> <span>0.3</span><span>),</span> <span>weight</span><span>=</span><span>'bold'</span><span>)</span>
      <span>ax1</span><span>.</span><span>annotate</span><span>(</span><span>f</span><span>'</span><span>{</span><a href="https://docs.python.org/3/library/functions.html#float" title="builtins.float"><span>seg</span><span>.</span><span>score</span></a><span>:</span><span>.2f</span><span>}</span><span>'</span><span>,</span> <span>(</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int"><span>seg</span><span>.</span><span>start</span></a> <span>-</span> <span>.3</span><span>,</span> <span>i</span> <span>+</span> <span>4.3</span><span>))</span>

  <span>ax2</span><span>.</span><span>set_title</span><span>(</span><span>"Label probability with and without repetation"</span><span>)</span>
  <span>xs</span><span>,</span> <span>hs</span><span>,</span> <span>ws</span> <span>=</span> <span>[],</span> <span>[],</span> <span>[]</span>
  <span>for</span> <span>seg</span> <span>in</span> <a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list"><span>segments</span></a><span>:</span>
    <span>if</span> <a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str"><span>seg</span><span>.</span><span>label</span></a> <span>!=</span> <span>'|'</span><span>:</span>
      <span>xs</span><span>.</span><span>append</span><span>((</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int"><span>seg</span><span>.</span><span>end</span></a> <span>+</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int"><span>seg</span><span>.</span><span>start</span></a><span>)</span> <span>/</span> <span>2</span> <span>+</span> <span>.4</span><span>)</span>
      <span>hs</span><span>.</span><span>append</span><span>(</span><a href="https://docs.python.org/3/library/functions.html#float" title="builtins.float"><span>seg</span><span>.</span><span>score</span></a><span>)</span>
      <span>ws</span><span>.</span><span>append</span><span>(</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int"><span>seg</span><span>.</span><span>end</span></a> <span>-</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int"><span>seg</span><span>.</span><span>start</span></a><span>)</span>
      <span>ax2</span><span>.</span><span>annotate</span><span>(</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str"><span>seg</span><span>.</span><span>label</span></a><span>,</span> <span>(</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int"><span>seg</span><span>.</span><span>start</span></a> <span>+</span> <span>.8</span><span>,</span> <span>-</span><span>0.07</span><span>),</span> <span>weight</span><span>=</span><span>'bold'</span><span>)</span>
  <span>ax2</span><span>.</span><span>bar</span><span>(</span><span>xs</span><span>,</span> <span>hs</span><span>,</span> <span>width</span><span>=</span><span>ws</span><span>,</span> <span>color</span><span>=</span><span>'gray'</span><span>,</span> <span>alpha</span><span>=</span><span>0.5</span><span>,</span> <span>edgecolor</span><span>=</span><span>'black'</span><span>)</span>

  <span>xs</span><span>,</span> <span>hs</span> <span>=</span> <span>[],</span> <span>[]</span>
  <span>for</span> <span>p</span> <span>in</span> <a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list"><span>path</span></a><span>:</span>
    <span>label</span> <span>=</span> <a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str"><span>transcript</span></a><span>[</span><span>p</span><span>.</span><span>token_index</span><span>]</span>
    <span>if</span> <span>label</span> <span>!=</span> <span>'|'</span><span>:</span>
      <span>xs</span><span>.</span><span>append</span><span>(</span><span>p</span><span>.</span><span>time_index</span> <span>+</span> <span>1</span><span>)</span>
      <span>hs</span><span>.</span><span>append</span><span>(</span><span>p</span><span>.</span><span>score</span><span>)</span>

  <span>ax2</span><span>.</span><span>bar</span><span>(</span><span>xs</span><span>,</span> <span>hs</span><span>,</span> <span>width</span><span>=</span><span>0.5</span><span>,</span> <span>alpha</span><span>=</span><span>0.5</span><span>)</span>
  <span>ax2</span><span>.</span><span>axhline</span><span>(</span><span>0</span><span>,</span> <span>color</span><span>=</span><span>'black'</span><span>)</span>
  <span>ax2</span><span>.</span><span>set_xlim</span><span>(</span><span>ax1</span><span>.</span><span>get_xlim</span><span>())</span>
  <span>ax2</span><span>.</span><span>set_ylim</span><span>(</span><span>-</span><span>0.1</span><span>,</span> <span>1.1</span><span>)</span>

<span>plot_trellis_with_segments</span><span>(</span><a href="https://pytorch.org/docs/stable/tensors.html#torch.Tensor" title="torch.Tensor"><span>trellis</span></a><span>,</span> <a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list"><span>segments</span></a><span>,</span> <a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str"><span>transcript</span></a><span>)</span>
<span>plt</span><span>.</span><span>tight_layout</span><span>()</span>
<span>plt</span><span>.</span><span>show</span><span>()</span>

Path, label and probability for each label, Label probability with and without repetation

Looks good. Now let’s merge the words. The Wav2Vec2 model uses '|' as the word boundary, so we merge the segments before each occurance of '|'.

Then, finally, we segment the original audio into segmented audio and listen to them to see if the segmentation is correct.

Out:

<span></span>I       (0.78): [   30,    34)
HAD     (0.84): [   36,    43)
THAT    (0.52): [   44,    52)
CURIOSITY       (0.89): [   55,    91)
BESIDE  (0.94): [   94,   115)
ME      (0.67): [  117,   123)
AT      (0.66): [  125,   128)
THIS    (0.70): [  130,   138)
MOMENT  (0.97): [  142,   154)

Visualization

<span></span><span>def</span> <span>plot_alignments</span><span>(</span><a href="https://pytorch.org/docs/stable/tensors.html#torch.Tensor" title="torch.Tensor"><span>trellis</span></a><span>,</span> <a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list"><span>segments</span></a><span>,</span> <a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list"><span>word_segments</span></a><span>,</span> <a href="https://pytorch.org/docs/stable/tensors.html#torch.Tensor" title="torch.Tensor"><span>waveform</span></a><span>):</span>
  <span>trellis_with_path</span> <span>=</span> <a href="https://pytorch.org/docs/stable/tensors.html#torch.Tensor" title="torch.Tensor"><span>trellis</span></a><span>.</span><span>clone</span><span>()</span>
  <span>for</span> <span>i</span><span>,</span> <span>seg</span> <span>in</span> <span>enumerate</span><span>(</span><a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list"><span>segments</span></a><span>):</span>
    <span>if</span> <a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str"><span>seg</span><span>.</span><span>label</span></a> <span>!=</span> <span>'|'</span><span>:</span>
      <span>trellis_with_path</span><span>[</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int"><span>seg</span><span>.</span><span>start</span></a><span>+</span><span>1</span><span>:</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int"><span>seg</span><span>.</span><span>end</span></a><span>+</span><span>1</span><span>,</span> <span>i</span><span>+</span><span>1</span><span>]</span> <span>=</span> <span>float</span><span>(</span><span>'nan'</span><span>)</span>

  <span>fig</span><span>,</span> <span>[</span><span>ax1</span><span>,</span> <span>ax2</span><span>]</span> <span>=</span> <span>plt</span><span>.</span><span>subplots</span><span>(</span><span>2</span><span>,</span> <span>1</span><span>,</span> <span>figsize</span><span>=</span><span>(</span><span>16</span><span>,</span> <span>9.5</span><span>))</span>

  <span>ax1</span><span>.</span><span>imshow</span><span>(</span><span>trellis_with_path</span><span>[</span><span>1</span><span>:,</span> <span>1</span><span>:]</span><span>.</span><span>T</span><span>,</span> <span>origin</span><span>=</span><span>'lower'</span><span>)</span>
  <span>ax1</span><span>.</span><span>set_xticks</span><span>([])</span>
  <span>ax1</span><span>.</span><span>set_yticks</span><span>([])</span>

  <span>for</span> <span>word</span> <span>in</span> <a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list"><span>word_segments</span></a><span>:</span>
    <span>ax1</span><span>.</span><span>axvline</span><span>(</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int"><span>word</span><span>.</span><span>start</span></a> <span>-</span> <span>0.5</span><span>)</span>
    <span>ax1</span><span>.</span><span>axvline</span><span>(</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int"><span>word</span><span>.</span><span>end</span></a> <span>-</span> <span>0.5</span><span>)</span>

  <span>for</span> <span>i</span><span>,</span> <span>seg</span> <span>in</span> <span>enumerate</span><span>(</span><a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list"><span>segments</span></a><span>):</span>
    <span>if</span> <a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str"><span>seg</span><span>.</span><span>label</span></a> <span>!=</span> <span>'|'</span><span>:</span>
      <span>ax1</span><span>.</span><span>annotate</span><span>(</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str"><span>seg</span><span>.</span><span>label</span></a><span>,</span> <span>(</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int"><span>seg</span><span>.</span><span>start</span></a><span>,</span> <span>i</span> <span>+</span> <span>0.3</span><span>))</span>
      <span>ax1</span><span>.</span><span>annotate</span><span>(</span><span>f</span><span>'</span><span>{</span><a href="https://docs.python.org/3/library/functions.html#float" title="builtins.float"><span>seg</span><span>.</span><span>score</span></a><span>:</span><span>.2f</span><span>}</span><span>'</span><span>,</span> <span>(</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int"><span>seg</span><span>.</span><span>start</span></a> <span>,</span> <span>i</span> <span>+</span> <span>4</span><span>),</span> <span>fontsize</span><span>=</span><span>8</span><span>)</span>

  <span># The original waveform</span>
  <span>ratio</span> <span>=</span> <a href="https://pytorch.org/docs/stable/tensors.html#torch.Tensor" title="torch.Tensor"><span>waveform</span></a><span>.</span><span>size</span><span>(</span><span>0</span><span>)</span> <span>/</span> <span>(</span><a href="https://pytorch.org/docs/stable/tensors.html#torch.Tensor" title="torch.Tensor"><span>trellis</span></a><span>.</span><span>size</span><span>(</span><span>0</span><span>)</span> <span>-</span> <span>1</span><span>)</span>
  <span>ax2</span><span>.</span><span>plot</span><span>(</span><a href="https://pytorch.org/docs/stable/tensors.html#torch.Tensor" title="torch.Tensor"><span>waveform</span></a><span>)</span>
  <span>for</span> <span>word</span> <span>in</span> <a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list"><span>word_segments</span></a><span>:</span>
    <span>x0</span> <span>=</span> <span>ratio</span> <span>*</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int"><span>word</span><span>.</span><span>start</span></a>
    <span>x1</span> <span>=</span> <span>ratio</span> <span>*</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int"><span>word</span><span>.</span><span>end</span></a>
    <span>ax2</span><span>.</span><span>axvspan</span><span>(</span><span>x0</span><span>,</span> <span>x1</span><span>,</span> <span>alpha</span><span>=</span><span>0.1</span><span>,</span> <span>color</span><span>=</span><span>'red'</span><span>)</span>
    <span>ax2</span><span>.</span><span>annotate</span><span>(</span><span>f</span><span>'</span><span>{</span><a href="https://docs.python.org/3/library/functions.html#float" title="builtins.float"><span>word</span><span>.</span><span>score</span></a><span>:</span><span>.2f</span><span>}</span><span>'</span><span>,</span> <span>(</span><span>x0</span><span>,</span> <span>0.8</span><span>))</span>

  <span>for</span> <span>seg</span> <span>in</span> <a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list"><span>segments</span></a><span>:</span>
    <span>if</span> <a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str"><span>seg</span><span>.</span><span>label</span></a> <span>!=</span> <span>'|'</span><span>:</span>
      <span>ax2</span><span>.</span><span>annotate</span><span>(</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str"><span>seg</span><span>.</span><span>label</span></a><span>,</span> <span>(</span><a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int"><span>seg</span><span>.</span><span>start</span></a> <span>*</span> <span>ratio</span><span>,</span> <span>0.9</span><span>))</span>
  <span>xticks</span> <span>=</span> <span>ax2</span><span>.</span><span>get_xticks</span><span>()</span>
  <span>plt</span><span>.</span><span>xticks</span><span>(</span><span>xticks</span><span>,</span> <span>xticks</span> <span>/</span> <span>bundle</span><span>.</span><span>sample_rate</span><span>)</span>
  <span>ax2</span><span>.</span><span>set_xlabel</span><span>(</span><span>'time [second]'</span><span>)</span>
  <span>ax2</span><span>.</span><span>set_yticks</span><span>([])</span>
  <span>ax2</span><span>.</span><span>set_ylim</span><span>(</span><span>-</span><span>1.0</span><span>,</span> <span>1.0</span><span>)</span>
  <span>ax2</span><span>.</span><span>set_xlim</span><span>(</span><span>0</span><span>,</span> <a href="https://pytorch.org/docs/stable/tensors.html#torch.Tensor" title="torch.Tensor"><span>waveform</span></a><span>.</span><span>size</span><span>(</span><span>-</span><span>1</span><span>))</span>

<span>plot_alignments</span><span>(</span><a href="https://pytorch.org/docs/stable/tensors.html#torch.Tensor" title="torch.Tensor"><span>trellis</span></a><span>,</span> <a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list"><span>segments</span></a><span>,</span> <a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list"><span>word_segments</span></a><span>,</span> <a href="https://pytorch.org/docs/stable/tensors.html#torch.Tensor" title="torch.Tensor"><span>waveform</span></a><span>[</span><span>0</span><span>],)</span>
<span>plt</span><span>.</span><span>show</span><span>()</span>

<span># A trick to embed the resulting audio to the generated file.</span>
<span># `IPython.display.Audio` has to be the last call in a cell,</span>
<span># and there should be only one call par cell.</span>
<span>def</span> <span>display_segment</span><span>(</span><span>i</span><span>):</span>
  <span>ratio</span> <span>=</span> <a href="https://pytorch.org/docs/stable/tensors.html#torch.Tensor" title="torch.Tensor"><span>waveform</span></a><span>.</span><span>size</span><span>(</span><span>1</span><span>)</span> <span>/</span> <span>(</span><a href="https://pytorch.org/docs/stable/tensors.html#torch.Tensor" title="torch.Tensor"><span>trellis</span></a><span>.</span><span>size</span><span>(</span><span>0</span><span>)</span> <span>-</span> <span>1</span><span>)</span>
  <span>word</span> <span>=</span> <a href="https://docs.python.org/3/library/stdtypes.html#list" title="builtins.list"><span>word_segments</span></a><span>[</span><span>i</span><span>]</span>
  <span>x0</span> <span>=</span> <span>int</span><span>(</span><span>ratio</span> <span>*</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int"><span>word</span><span>.</span><span>start</span></a><span>)</span>
  <span>x1</span> <span>=</span> <span>int</span><span>(</span><span>ratio</span> <span>*</span> <a href="https://docs.python.org/3/library/functions.html#int" title="builtins.int"><span>word</span><span>.</span><span>end</span></a><span>)</span>
  <span>filename</span> <span>=</span> <span>f</span><span>"_assets/</span><span>{</span><span>i</span><span>}</span><span>_</span><span>{</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str"><span>word</span><span>.</span><span>label</span></a><span>}</span><span>.wav"</span>
  <span>torchaudio</span><span>.</span><span>save</span><span>(</span><span>filename</span><span>,</span> <a href="https://pytorch.org/docs/stable/tensors.html#torch.Tensor" title="torch.Tensor"><span>waveform</span></a><span>[:,</span> <span>x0</span><span>:</span><span>x1</span><span>],</span> <span>bundle</span><span>.</span><span>sample_rate</span><span>)</span>
  <span>print</span><span>(</span><span>f</span><span>"</span><span>{</span><a href="https://docs.python.org/3/library/stdtypes.html#str" title="builtins.str"><span>word</span><span>.</span><span>label</span></a><span>}</span><span> (</span><span>{</span><a href="https://docs.python.org/3/library/functions.html#float" title="builtins.float"><span>word</span><span>.</span><span>score</span></a><span>:</span><span>.2f</span><span>}</span><span>): </span><span>{</span><span>x0</span> <span>/</span> <span>bundle</span><span>.</span><span>sample_rate</span><span>:</span><span>.3f</span><span>}</span><span> - </span><span>{</span><span>x1</span> <span>/</span> <span>bundle</span><span>.</span><span>sample_rate</span><span>:</span><span>.3f</span><span>}</span><span> sec"</span><span>)</span>
  <span>return</span> <span>IPython</span><span>.</span><span>display</span><span>.</span><span>Audio</span><span>(</span><span>filename</span><span>)</span>

forced alignment tutorial

Out:

<span></span>I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT

Your browser does not support the audio element.

Out:

<span></span>I (0.78): 0.604 - 0.684 sec

Your browser does not support the audio element.

Out:

<span></span>HAD (0.84): 0.724 - 0.865 sec

Your browser does not support the audio element.

Out:

<span></span>THAT (0.52): 0.885 - 1.046 sec

Your browser does not support the audio element.

Out:

<span></span>CURIOSITY (0.89): 1.107 - 1.831 sec

Your browser does not support the audio element.

Out:

<span></span>BESIDE (0.94): 1.891 - 2.314 sec

Your browser does not support the audio element.

Out:

<span></span>ME (0.67): 2.354 - 2.474 sec

Your browser does not support the audio element.

Out:

<span></span>AT (0.66): 2.515 - 2.575 sec

Your browser does not support the audio element.

Out:

<span></span>THIS (0.70): 2.615 - 2.776 sec

Your browser does not support the audio element.

Out:

<span></span>MOMENT (0.97): 2.857 - 3.098 sec

Your browser does not support the audio element.

Conclusion

In this tutorial, we looked how to use torchaudio’s Wav2Vec2 model to perform CTC segmentation for forced alignment.

Total running time of the script: ( 0 minutes 2.415 seconds)