0

Understanding transposed convolutions in PyTorch

 1 month ago
source link: https://numbersmithy.com/understanding-transposed-convolutions-in-pytorch/
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

Understanding transposed convolutions in PyTorch

Understanding transposed convolutions in PyTorch

1 The problem

PyTorch’s documentation on the transposed convolution modules (nn.ConvTransposexd, x being 1, 2 or 3) is bloody confusing!

This is to a large part due to their implicit switching of context when using terms like “input” and “output”, and overloads of terms like “stride”.

The animated gifs they pointed to, although well-produced, still need some explanation in words.

Let’s work through a derivation and clarify what’s really happening.

2 Derivation and explanation

2.1 The output length equation

The formula given in the doc of nn.ConvTransposexd modules is:

Hout​=(Hin​−1)×stride[0]−2×padding[0]+dilation[0]×(kernelsize[0]−1)+outputpadding[0]+1

First, let’s introduce some simpler symbols and re-arrange the formula a bit:

o=s(l−1)+d(f−1)+1–2p+po

where:

  • o: output length (in any dimension)
  • s: stride, default to 1
  • l: input length (in any dimension)
  • d: dilation, default to 1
  • f: filter/kernel size
  • p: padding, default to 0
  • po: padding onto the output, default to 0

The definitions given above are rather terse, because some of them require more detailed explanations. Let’s start by re-arranging Eq 1 again:

o=[l+(s−1)(l−1)]+[f+(d−1)(f−1)−1]–[2p]+[po]

This looks more complicated than before, but it will make more sense when we explain how it works.

NOTE that I’m using square brackets [] to create 4 groups of terms:

  • group-1: [l+(s−1)(l−1)], this is “a measure” of/about the input length. By input I mean input to the transposed conv layer.
  • group-2: [f+(d−1)(f−1)−1], this is “a measure” of/about the filter length.
  • group-3: [−2p], this is the most confusing term, mostly due to its negative sign. The way to understand it is to treat it as the extra length padded onto the input to the normal conv layer, NOT the input to the transposed conv layer. More on this later.
  • group-4: [po], this is the extra length padded onto the output from the transposed conv layer. This could be regarded as extra addition to group-1, but it only takes effect when s>1. More on this later.

2.2 Simple case of stride=1, dilation=1, paddings=0

Let’s deal with a simple scenario first.

When s=1,d=1p=0po=0, the size equation becomes:

o=[l]+[f−1]

In this case, group-1 is l, and group-2 f−1.

One way to understand it is: imagine the filter is sliding across the input sequence. Figure 1 below shows a concrete example where l=5f=3.

conv_convtransposed_example.png

Figure 1: Schematic for 1d transposed convolution. Input sequence is shown as squares at the bottom, output as circles at the top. Filter is [1, 1, 1], represented as triangles.

The filter (triangles) starts from leftmost end, with its last element overlapping with the 1st element of input (3) . And it ends with a position where its 1st element overlaps with the last element of input (15).

So, if we focus on the end point of the filter, it steps through the l positions of the input, when it has overlaps with points in the input (this is group-1), plus an extra f−1 points outside of the input when it has no overlap with the input (this is group-2).

This same counting method will be used throughout:

  • group-1 counts the number of steps when the end point of the filter overlaps with the input.
  • group-2 counts the extra steps where there is no overlap between the two.

2.3 When stride > 1

To proceed further to include the remaining s, d, p and po terms, it may be necessary to get some terms and notations straight:

  • inputc: the input into a normal conv layer, we also use it to denote the length of the input sequence. Similarly for the next 3 terms.
  • outputc: the output from a normal conv layer.
  • inputtc: the input into a transposed conv layer.
  • outputtc: the output from a transposed conv layer.
  • sc: the stride in a normal conv layer, i.e. sc=2 means the filter moves 2 cells every time. This is the stride argument you give to the nn.Convxd() module.
  • stc: This is the stride argument you give to the nn.ConvTransposexd() module. BUT: it shouldn’t be understood as the filter step size in the transposed convolution, instead, treat it as the same as the filter step in the normal conv layer.

Below is a concrete example given in Code block 1. Figure 2 is the screenshot of outputs, and Figure 3 the schematic.

Listing 1: Code block
conv = nn.Conv1d(1, 1, kernel_size=3, stride=2, padding=0, dilation=1, bias=False)
with torch.no_grad():
    conv.weight.data.fill_(1)

x = torch.tensor(np.arange(7)).float()
x = x.unsqueeze(0)
x = x.unsqueeze(0)
y = conv(x)

print('\n### Normal conv:\n\t', conv)
print('Input sequence x:\n\t', x)
print('Filter weights:\n\t', conv.weight.data)
print('Output sequence y:\n\t', y)

transconv = nn.ConvTranspose1d(1, 1, kernel_size=3, stride=2, padding=0, bias=False)
with torch.no_grad():
    transconv.weight.data.fill_(1)
x2 = transconv(y)

print('\n### Tranposed conv:\n\t', transconv)
print('Input sequence y:\n\t', y)
print('Filter weights:\n\t', transconv.weight.data)
print('Output sequence x2:\n\t', x2)

Figure 2: Screenshot of Python code output from Code block 1

Figure 3: Schematic for (a) 1d normal convolution and (b) transposed convolution, corresponding to the example given in Code block 1. Filter is [1, 1, 1], represented as triangles. Hallow squares denote empty placeholders.

The top row with solid dots is the inputs to a normal conv layer, therefore inputc=7.

From top to bottom, the inputs are convoled with a filter of f=3sc=2. This gives the bottom row where solid squares are the outputs from the normal conv layer, therefore outputc=3.

Note that I’m adding some hallow squares in the 2nd row to denote the empty slots created by the sc=2 stride.

The way nn.ConvTransposexd is designed in PyTorch is that they try to make Convxd and ConvTransposexd inverses to each other (in terms of shape transformations). I found it very helpful to keep this in mind when understanding transposed convolutions in PyTorch.

So, the “inverse” operation ConvTransposexd should map from the bottom row to the top, with a consistent set of arguments.

That’s to say, the argument stc we give to nn.ConvTransposexd(), is actually the same as sc that we used in its “inverse” function nn.Convxd(), and it DOES NOT describe the filter movement step in the transposed convolution!

Let’s see whether this matches the equation. With stc=2, we “dilute”/”interleave” (I’m deliberately avoiding overloading the term “dilate”) the inputtc with stc−1=1 empty slots (shown as hallow squares). So there will be (stc−1)∗(l−1) such empty placeholders added.

And, the filter in the transposed convolution still moves 1 step at a time, regardless the stc value. This gives the value for group-1: l+(stc−1)∗(l−1)

Also because the filter moves 1 step a time, the term from group-2 is still f–1.

So, for stc>1, d=1, p=0, po=0, the output size is

o=[l+(stc−1)∗(l−1)]+[f−1]

In the example shown in Figure 3, outputtc=7, and we indeed achieve “an inverse” operation.

The 2 key points here:

  1. “stride” should be understood as describing the number of interleaving empty slots inserted into the input into the transposed conv layer, or the filter movement step in the normal conv layer.
  2. Even when “stride” > 1, the filter still moves 1 step at a time.

Both of these are already illustrated in these animated gifs.

2.4 When stride > 1, dilate > 1

This is a relatively easy part: the explanation given by PyTorch’s doc is actually rather to the point: “Spacing between kernel elements”.

This means that for a filter with length f, we add d−1 number of empty slots for each of the f−1 intervals within the filter, giving the new group-2 number: f+(d−1)(f−1)−1. And group-1 is not affected.

Below is the snippet that generates a concrete example using code block 2, and Figure 4 shows the output. Figure 5 gives a schematic.

conv = nn.Conv1d(1, 1, kernel_size=3, stride=2, padding=0, dilation=2, bias=False)
with torch.no_grad():
    conv.weight.data.fill_(1)

x = torch.tensor(np.arange(7)).float()
x = x.unsqueeze(0)
x = x.unsqueeze(0)
y = conv(x)

print('\n### Normal conv:\n\t', conv)
print('Input sequence x:\n\t', x)
print('Filter weights:\n\t', conv.weight.data)
print('Output sequence y:\n\t', y)

transconv = nn.ConvTranspose1d(1, 1, kernel_size=3, stride=2, padding=0, dilation=2, bias=False)
with torch.no_grad():
    transconv.weight.data.fill_(1)
x2 = transconv(y)

print('\n### Tranposed conv:\n\t', transconv)
print('Input sequence y:\n\t', y)
print('Filter weights:\n\t', transconv.weight.data)
print('Output sequence x2:\n\t', x2)

Figure 4: Screenshot of Python code output from Code block 2

Figure 5: Schematic for (a) 1d normal convolution and (b) transposed convolution, corresponding to the example given in Code block 2. Filter is [1, 1, 1], represented as solid triangles, and dilated places in the filter are represented as hallow triangles. Hallow squares denote empty placeholders.

  • from top to bottom is the normal convolution.
  • from bottom to top is the transposed convolution.
  • inputc=7.
  • for the normal convolution: f=3, d=2, sc=2.
  • this gives outputc=2.
  • for the transposed convolution: f=3, d=2, stc=2. Remember: the filter still moves 1 step at a time!
  • this gives outputtc=7. Again, we achieved “an inverse” operation.

2.5 When stride > 1, dilate > 1, padding > 1

I think this the worst part of all. To quote PyTorch’s doc:

padding (int or tuple, optional) – dilation * (kernel_size - 1) - padding zero-padding will be added to both sides of each dimension in the input. Default: 0″

Not sure how you feel about it, this makes NO sense to me.

The extra note helps (only by a little):

“The padding argument effectively adds dilation * (kernel_size - 1) - padding amount of zero padding to both sizes of the input. This is set so that when a Conv2d and a ConvTranspose2d are initialized with same parameters, they are inverses of each other in regard to the input and output shapes.”

The dilation * (kernel_size - 1) - padding part is awfully confusing, I think it would be better off if they just deleted that.

This sentence does shed some light: “This is set so that when a Conv2d and a ConvTranspose2d are initialized with same parameters, they are inverses of each other in regard to the input and output shapes.”

So it’s helpful to look at the paired operations. The Code block 2 below gives an example, and Figure 6 the output, Figure 7 an schematic.

conv = nn.Conv1d(1, 1, kernel_size=3, stride=2, padding=1, dilation=2, bias=False)
with torch.no_grad():
    conv.weight.data.fill_(1)

x = torch.tensor(np.arange(7)).float()
x = x.unsqueeze(0)
x = x.unsqueeze(0)
y = conv(x)

print('\n### Normal conv:\n\t', conv)
print('Input sequence x:\n\t', x)
print('Filter weights:\n\t', conv.weight.data)
print('Output sequence y:\n\t', y)

transconv = nn.ConvTranspose1d(1, 1, kernel_size=3, stride=2, padding=1, dilation=2, bias=False)
with torch.no_grad():
    transconv.weight.data.fill_(1)
x2 = transconv(y)

print('\n### Tranposed conv:\n\t', transconv)
print('Input sequence y:\n\t', y)
print('Filter weights:\n\t', transconv.weight.data)
print('Output sequence x2:\n\t', x2)

Figure 6: Screenshot of Python code output from Code block 2

Figure 7: Schematic for (a) 1d normal convolution and (b) transposed convolution, corresponding to the example given in Code block 2. Filter is [1, 1, 1], represented as solid triangles, dilated places in the filter are represented as hallow triangles. Hallow squares denote empty placeholders. Hallow circles denote padded inputs during the normal convolution, or removed outputs during the transposed convolution.

  • from top to bottom is the normal convolution
  • from bottom to top is the transposed convolution
  • inputc=7
  • for the normal convolution: f=3, d=2, sc=2,p=1
  • this gives outputc=3
  • for the transposed convolution: f=3, d=2, stc=2, p=1. Remember: the filter still moves 1 step at a time!

This gives the outputtc=[l+(s−1)(l−1)]+[f+(d−1)(f−1)−1]+[−2p]=[5]+[4]–[2]=7.

Again, we achieved “an inverse” operation.

Therefore, by “padding”, they actually meant the padding added onto the “forward”/”normal” convolution of nn.Convxd(), and you need to copy that same number into nn.ConvTransposexd(), such that these 2 operations are “inverses” to each other.

Let’s walk through the computations in more details:

we still start the transposed convolution from the 1st dot product:

[1,0,1,0,1]⋅[nan,nan,nan,nan,4]T=4

where nan denotes out-of-bound placeholders in the inputtc.

But, that output is NOT included, contributing a −p to the total count. And we move to the next window position (Remember: the filter still moves 1 step at a time! )

[1,0,1,0,1]⋅[nan,nan0,4,0]T=0

Then the next step:

[1,0,1,0,1]⋅[nan,0,4,0,9]T=13

And next:

[1,0,1,0,1]⋅[0,4,0,9,0]T=0

On the right most end, we should have

[1,0,1,0,1]⋅[8,nan,nan,nan,nan]T=8

But this output 8 is also not included, and is the remaining part of the [−2p] term.

2.6 When stride > 1, dilate > 1, padding > 1, output padding > 1

The extra po term, as our group-4 is added to the previous 3 groups, completing our Eq 1:

o=[l+(s−1)(l−1)]+[f+(d−1)(f−1)−1]–[2p]+[po]

PyTorch’s doc describes it as the “additional size added to one side of each dimension in the output shape”.

NOTE that is NOT padding 0s to the “diluted”/”interleaved” inputtc, otherwise the layer output will always has a rim of 0s around the edges.

This is again for the purpose of making normal and transposed convolutions “inverse” operations. During the normal convolution, the output size is computed as:

outputc=[inputc+2p–fs]+1

where [] is the floor function. So, it is possible that different inputc values get mapped onto a same outputc, e.g.

[7+2×1–32]+1=[8+2×1–32]+1=4

In such cases, output_padding allows one to add the extra few elements such that inputc=outputtc. Therefore, output_padding only works when stc>1.

Using our last example (Code block 2) but with the extra output_padding=1 parameter added, you could see that the trailing value 8 that was previously removed as a part of the −2p term, is now reserved, giving outputtc=8.

conv = nn.Conv1d(1, 1, kernel_size=3, stride=2, padding=1, dilation=2, bias=False)
with torch.no_grad():
    conv.weight.data.fill_(1)

x = torch.tensor(np.arange(7)).float()
x = x.unsqueeze(0)
x = x.unsqueeze(0)
y = conv(x)

print('\n### Normal conv:\n\t', conv)
print('Input sequence x:\n\t', x)
print('Filter weights:\n\t', conv.weight.data)
print('Output sequence y:\n\t', y)

transconv = nn.ConvTranspose1d(1, 1, kernel_size=3, stride=2, padding=1, dilation=2,
    output_padding=1,bias=False)
with torch.no_grad():
    transconv.weight.data.fill_(1)
x2 = transconv(y)

print('\n### Tranposed conv:\n\t', transconv)
print('Input sequence y:\n\t', y)
print('Filter weights:\n\t', transconv.weight.data)
print('Output sequence x2:\n\t', x2)

Figure 8: Screenshot of Python code output from Code block 2

Figure 9: Schematic for (a) 1d normal convolution and (b) transposed convolution, corresponding to the example given in Code block 2. Filter is [1, 1, 1], represented as solid triangles, dilated places in the filter are represented as hallow triangles. Hallow squares denote empty placeholders. Hallow circles denote padded inputs during the normal convolution, or removed outputs during the transposed convolution.

3 Summary

We walked through derivations of the computations in transposed convolutions in PyTorch, and clarified some confusions in their documentation, much of which stem from implicit changes of context and overloads of terms.

It is helpful to keep in mind PyTorch’s design choice that normal conv layers and transposed conv layers are “inverse” operations to each other, in that they revert the shape of a tensor.

In fact, for some input arguments to a nn.ConvTransposexd module, it is easier to mentally read them as the input arguments to nn.Convxd, and think about them as:

“what arguments would a forward convolution use to get the current tensor at hand, that I am now feeding into a transposed convolution”.

These arguments include:

  • stride
  • padding

Despite the same names, these arguments mean rather different things in nn.Convxd and nn.ConvTransposexd, creating great confusion to the output size formula. The overloading of argument names helps maintain consistency in the code API (maybe?), but the explanations could certainly be made better.

With the above confusions cleared, we give a break-down of the formula given in PyTorch’s documentation:

o=s(l−1)+d(f−1)+1–2p+po=[l+(s−1)(l−1)]+[f+(d−1)(f−1)−1]–[2p]+[po]

where:

  • o: output length (in any dimension).
  • s: the stride used in the normal or forward convolution. The input to a transposed conv layer is “diluted”/”interleaved” with s−1 number of 0s. Default to 1.
  • l: length of input to the transposed conv layer.
  • d: dilation of the filter, i.e. the filter is interleaved with d−1 number of 0s. Default to 1.
  • f: filter size (before dilation).
  • p: padding used in the normal convolution. 2p number of elements from both ends of the output from a transposed conv layer are removed, effectively “undo” the padding performed in the normal convolution. Default to 0
  • po: extra length added to the output from the transposed conv layer. Only used when s>1. This is to clarify the size ambiguity created by the floor function in computing the output size in a normal convolution.

The basic idea of the derivation is to count the output elements as 2 parts:

  1. counts the number of steps when the end point of the filter overlaps with the input. This corresponds to our group-1 term: [l+(s−1)(l−1)].
  2. counts the extra steps where there is no overlap between the filter and the input. This is our group-2 term: [f+(d−1)(f−1)−1].

The extra group-3 of −2p, and group-4 of po, are due to a design choice of PyTorch to make the normal and transposed convolutions inverse operations to each other.

Author: guangzhi

Created: 2022-05-04 Wed 20:09

Validate


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK