+ 1
From a Tensor to an image (Pytorch)
I have a Pytorch Tensor that represent a square image. Let's say tensor.shape = (3, n, n); since I have 3 channels and my image is n×n pixels big. What is the fastest way to plot it? I presume there is something already in Pytorch to do so in one line maybe. What about using just numpy?
2 Réponses
0
Numpy would probably be the fastest, yeah, its incredibly fast, it works well with vectors, scalar, etc, but it won't provide the plotting visuals of matplotlib or seaborn. Scipy might also help, but it all depends on what youre expected output is.
0
Ok I came to this solution, it seems pretty clean:
For CIFAR10 (for example), the images tensors have shape (3, n, n), but with imshow I need to put the channels dimension in the third place. Thus I get:
plt.imshow(tensor.permute(1, 2, 0))
For MNIST I have an input tensor of the form (1, m, m), which needs to be squeezed on the first dimension:
plt.imshow(tensor.squeeze(0))
Hope this may help someone else :)