-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy path12-pytorch-02-dataloaders.qmd
777 lines (560 loc) · 23.2 KB
/
12-pytorch-02-dataloaders.qmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
---
title: PyTorch 02 - Dataloaders and Transforms
jupyter: python3
---
## Datasets & DataLoaders
[![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/trgardos/ml-549-fa24/blob/main/12-pytorch-02-dataloaders.ipynb)
::: {.callout-note}
Adapted from [PyTorch Quickstart](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html)
to use CIFAR10 dataset instead of FashionMNIST.
:::
* Code for processing data samples can get messy and hard to maintain
* We ideally want our dataset code to be decoupled from our model training
code for better readability and modularity.
## PyTorch Data Primitives
PyTorch provides two data primitives:
1. `torch.utils.data.DataLoader` and
2. `torch.utils.data.Dataset`
that allow you to use pre-loaded datasets as well as your own data.
* `Dataset` stores the samples and their corresponding labels, and
* `DataLoader` wraps an _iterable_ around the `Dataset` to enable easy
access to the samples.
::: {.callout-note}
An **iterable** is a Python object capable of returning its members one at a time.
It must implement the `__iter__` method or the `__getitem__` method. See
[Iterators](https://docs.python.org/3/tutorial/classes.html#iterators)
for more details.
:::
## Pre-loaded Datasets
PyTorch provides a number of pre-loaded datasets (such as FashionMNIST or CIFAR10)
that subclass `torch.utils.data.Dataset` and implement
functions specific to the particular data.
They can be used to prototype and benchmark your model.
You can find them here:
* [Image Datasets](https://pytorch.org/vision/stable/datasets.html)
<details>
<summary>View sublist</summary>
<ul>
<li>image classification</li>
<li>object detection or segmentation</li>
<li>optical flow</li>
<li>stereo matching</li>
<li>image pairs</li>
<li>image captioning</li>
<li>video classification</li>
<li>video prediction</li>
</ul>
</details>
* [Text Datasets](https://pytorch.org/text/stable/datasets.html)
<details>
<summary>View sublist</summary>
<ul>
<li>text classification</li>
<li>language modeling</li>
<li>machine translation</li>
<li>sequence tagging</li>
<li>question answering</li>
<li>unsupervised learning</li>
</ul>
</details>
* [Audio Datasets](https://pytorch.org/audio/stable/datasets.html)
## Loading a Dataset
Here is an example of how to load the
[CIFAR10](https://pytorch.org/vision/stable/generated/torchvision.datasets.CIFAR10.html#torchvision.datasets.CIFAR10)
dataset from TorchVision.
We load the
[CIFAR10](https://pytorch.org/vision/stable/generated/torchvision.datasets.CIFAR10.html#torchvision.datasets.CIFAR10)
with the following parameters:
- `root` is the path where the train/test data is stored,
- `train` if True, specifies training dataset, if False, specifies test dataset,
- `download=True` downloads the data from the internet if it's
not available at `root`.
- `transform` and `target_transform` specify the feature and label
transformations. More on that later.
```{python}
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
training_data = datasets.CIFAR10(
root="data",
train=True,
download=True,
transform=ToTensor()
)
test_data = datasets.CIFAR10(
root="data",
train=False,
download=True,
transform=ToTensor()
)
```
In this case we simply use the
[ToTensor](https://pytorch.org/vision/main/generated/torchvision.transforms.ToTensor.html)
transform which converts the image from
a (H x W x C) shape to a (C x H x W) shape and converts the pixel values from
[0,255] to a torch.FloatTensor in the range [0.0, 1.0].
```{python}
img, label = training_data[0]
print(f"img.shape: {img.shape}")
print(f"img.dtype: {img.dtype}")
```
Let's also look at the `data` directory contents.
```sh
% ls -al data
total 333008
drwxr-xr-x@ 6 tomg staff 192 Oct 1 22:14 .
drwxr-xr-x 62 tomg staff 1984 Oct 17 11:09 ..
drwxr-xr-x@ 10 tomg staff 320 Jun 4 2009 cifar-10-batches-py
-rw-r--r--@ 1 tomg staff 170498071 Sep 4 19:51 cifar-10-python.tar.gz
```
And the contents of the `cifar-10-batches-py` directory.
```sh
% ls -al data/cifar-10-batches-py
total 363752
drwxr-xr-x@ 10 tomg staff 320 Jun 4 2009 .
drwxr-xr-x@ 6 tomg staff 192 Oct 1 22:14 ..
-rw-r--r--@ 1 tomg staff 158 Mar 31 2009 batches.meta
-rw-r--r--@ 1 tomg staff 31035704 Mar 31 2009 data_batch_1
-rw-r--r--@ 1 tomg staff 31035320 Mar 31 2009 data_batch_2
-rw-r--r--@ 1 tomg staff 31035999 Mar 31 2009 data_batch_3
-rw-r--r--@ 1 tomg staff 31035696 Mar 31 2009 data_batch_4
-rw-r--r--@ 1 tomg staff 31035623 Mar 31 2009 data_batch_5
-rw-r--r--@ 1 tomg staff 88 Jun 4 2009 readme.html
-rw-r--r--@ 1 tomg staff 31035526 Mar 31 2009 test_batch
```
You see in this case the images aren't stored individually, but rather
as combined batches. The CIFAR10 `Dataset` and `DataLoader` classes hide
these details from us.
### Iterating and Visualizing the Dataset
It's very important to understand the data you're working with. Visually inspecting
the dataset is a good way to get started.
We can index `Datasets` manually like a list: `training_data[index]`. In this
case we randomly sample images from the dataset.
We use `matplotlib` to visualize some samples in our training data.
```{python}
labels_map = {
0: "plane",
1: "car",
2: "bird",
3: "cat",
4: "deer",
5: "dog",
6: "frog",
7: "horse",
8: "ship",
9: "truck",
}
figure = plt.figure(figsize=(6, 6))
cols, rows = 4, 4
for i in range(1, cols * rows + 1):
# Randomly choose indices
sample_idx = torch.randint(len(training_data), size=(1,)).item()
img, label = training_data[sample_idx]
#print(f"img.shape: {img.shape}")
#print(f"label: {label}")
figure.add_subplot(rows, cols, i)
plt.title(labels_map[label])
plt.axis("off")
plt.imshow(img.permute(1, 2, 0))
plt.show()
```
::: {.callout-tip}
Try re-running the above cell a few times to see different samples from the dataset.
:::
### Collecting Sample Data to Illustrate Custom Dataset
To illustrate creating a custom dataset, we will collect images from the CIFAR10
dataset and save them to a local directory. We will also save the labels to a CSV
file.
```{python}
import os
import pandas as pd
from torchvision import datasets
from torchvision.transforms import ToTensor
from PIL import Image
# Create directories to store images and annotations
os.makedirs('cifar10_images', exist_ok=True)
annotations_file = 'cifar10_annotations.csv'
# Load CIFAR10 dataset
cifar10 = datasets.CIFAR10(root='data', train=True, download=True, transform=ToTensor())
# Number of images to download
n_images = 10
# Store images and their labels
data = []
for i in range(n_images):
img, label = cifar10[i]
img = img.permute(1, 2, 0) # Convert from (C, H, W) to (H, W, C)
img = (img * 255).byte().numpy() # Convert to numpy array and scale to [0, 255]
img = Image.fromarray(img) # Convert to PIL Image
img_filename = f'cifar10_images/img_{i}.png'
img.save(img_filename) # Save image
data.append([img_filename, label]) # Append image path and label to data list
# Write annotations to CSV file
df = pd.DataFrame(data, columns=['image_path', 'label'])
df.to_csv(annotations_file, index=False)
print(f"Saved {n_images} images and their labels to {annotations_file}")
```
``` {python}
# List the directory `cifar_images
import os
os.listdir('cifar10_images')
```
```{python}
import pandas as pd
# Read the annotations file
annotations = pd.read_csv('cifar10_annotations.csv')
# Display the first 10 lines of the annotations file
print(annotations.head(10))
```
## Creating a Custom Dataset Class
A custom Dataset class must implement three functions:
1. `__init__`,
2. `__len__`, and
3. `__getitem__`.
We'll look at an example implementation.
The CIFAR10 images are stored in a directory and their labels are stored
separately in a CSV file.
```{python}
import os
import pandas as pd
import torch
from torch.utils.data import Dataset
from torchvision.io import read_image
class MyCustomImageDataset(Dataset):
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.img_labels)
def __getitem__(self, idx):
img_path = self.img_labels.iloc[idx, 0]
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
```
### `__init__`
The `__init__` function is run once when instantiating the Dataset
object. We initialize the directory containing the images, the
annotations file, and both transforms (covered in more detail in the
next section).
```{.python}
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
```
### `__len__`
The `len()` function returns the number of samples in our dataset.
Example:
```{python}
def __len__(self):
return len(self.img_labels)
```
### `__getitem__`
The `getitem()` function loads and returns a sample from the dataset
at the given index `idx`.
Based on the index, it
* identifies the image's location on disk,
* converts that to a tensor using `read_image`,
* retrieves the corresponding label from the csv data in `self.img_labels`,
* calls the transform functions on them (if applicable), and
* returns the tensor image and corresponding label in a tuple.
```{.python}
def __getitem__(self, idx):
img_path = self.img_labels.iloc[idx, 0]
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
```
<br>
**Let's try to use it.**
``` {python}
from torchvision.transforms import ToTensor
mydataset = MyCustomImageDataset(annotations_file='cifar10_annotations.csv', img_dir='cifar10_images') # , transform=ToTensor())
img, label = mydataset[0]
print(f"img.shape: {img.shape}")
print(f"label: {label}")
```
``` {python}
import matplotlib.pyplot as plt
figure = plt.figure(figsize=(6, 6))
cols, rows = 2, 2
for i in range(1, cols * rows + 1):
# Randomly choose indices
sample_idx = torch.randint(len(mydataset), size=(1,)).item()
img, label = mydataset[sample_idx]
#print(f"img.shape: {img.shape}")
#print(f"label: {label}")
figure.add_subplot(rows, cols, i)
plt.title(labels_map[label])
plt.axis("off")
plt.imshow(img.permute(1, 2, 0))
plt.show()
```
------------------------------------------------------------------------
### Preparing your data for training with DataLoaders
The `Dataset` retrieves your dataset's features and labels one sample at
a time.
While training a model, we typically want to
* pass samples in "batches",
* reshuffle the data at every epoch to reduce model overfitting, and
* use Python's `multiprocessing` to speed up data retrieval.
`DataLoader` is an iterable that abstracts this complexity for us in an
easy API.
Because `MyCustomImageDataset` is a subclass of `Dataset`, we can use it
with a `DataLoader` just as we would use any other `Dataset`.
```{python}
from torch.utils.data import DataLoader
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
```
### Iterate through the DataLoader
We have loaded that dataset into the `DataLoader` and can iterate
through the dataset as needed.
Each iteration below returns a batch of
`train_features` and `train_labels` (containing `batch_size=64` features
and labels respectively).
Because we specified `shuffle=True`, after we
iterate over all batches the data is shuffled
(for finer-grained control over the data loading order, take a look at
[Samplers](https://pytorch.org/docs/stable/data.html#data-loading-order-and-sampler)).
```{python}
# Display image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
```
Note from the size that we got a batch of 64 images.
```{python}
img = train_features[0].squeeze() # squeeze() removes dimension of size 1, e.g. (1, 3, 32, 32) -> (3, 32, 32)
label = train_labels[0]
plt.imshow(img.permute(1, 2, 0))
plt.show()
print(f"Label: {labels_map[label.item()]}")
```
## Transforms and Data Augmentation
Data does not always come in its final processed form that is required
for training machine learning algorithms.
We use **transforms** to perform some manipulation of the data and make it suitable for training as well as augmentation.
All TorchVision dataset classes have two parameters
- `transform` to modify the features and
- `target_transform` to modify the labels
They accept callables containing the transformation logic.
The CIFAR10 images are in PIL Image format, and the labels are
integers.
For training, we need the features as normalized tensors, and
the labels as one-hot encoded tensors.
To make these transformations, we use `ToTensor` and `Lambda`.
```{python}
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
ds = datasets.CIFAR10(
root="data",
train=True,
download=True,
transform=ToTensor(),
target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))
)
```
### ToTensor()
[ToTensor](https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.ToTensor)
converts a PIL image or NumPy `ndarray` into a `FloatTensor` and scales
the image's pixel intensity values in the range $0., 1.$
### Lambda Transforms
Lambda transforms apply any user-defined lambda function.
Here, we define a function to turn the integer into a one-hot encoded tensor.
It
* first creates a zero tensor of size 10 (the number of labels in our dataset) and
* calls [`scatter_`](https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html)
which assigns a `value=1` on the index as given by the label `y`.
```{python}
target_transform = Lambda(lambda y: torch.zeros(
10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))
```
## Transforms V2 and Data Augmentation
The [torchvision.transforms](https://pytorch.org/vision/stable/transforms.html)
module offers many commonly-used transforms for image and video data.
Adapted from [Getting started with transforms v2](https://pytorch.org/vision/stable/auto_examples/transforms/plot_transforms_getting_started.html#sphx-glr-auto-examples-transforms-plot-transforms-getting-started-py)
## Transforms V2 Example
We'll illustrate with an example.
First, a bit of setup
```{python}
from pathlib import Path
import torch
import matplotlib.pyplot as plt
plt.rcParams["savefig.bbox"] = 'tight'
from torchvision.transforms import v2
from torchvision.io import read_image
import os
import urllib.request
import zipfile
# Check if the directory exists
if not os.path.exists('transform_v2_files'):
# Download the zip file
url = 'https://github.com/trgardos/ml-549-fa24/raw/refs/heads/main/transform_v2_files.zip?download='
zip_path = 'transform_v2_files.zip'
urllib.request.urlretrieve(url, zip_path)
# Unzip the file
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall('.')
# Remove the zip file
os.remove(zip_path)
torch.manual_seed(1)
from transform_v2_files.transforms.helpers import plot
img = read_image(str(Path('transform_v2_files/assets') / 'astronaut.jpg'))
print(f"{type(img) = }, {img.dtype = }, {img.shape = }")
```
### The basics
The Torchvision transforms behave like a regular [`torch.nn.Module`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html):
* instantiate a transform,
* pass an input,
* get a transformed output:
```{python}
transform = v2.RandomCrop(size=(224, 224))
out = transform(img)
plot([img, out])
```
This transform takes a random crop of the image of size 224x224. Try
re-running the cell a few times to see different samples.
### Transforms for image classification
Here's a basic image classification transform pipeline:
```{python}
transforms = v2.Compose([
v2.RandomResizedCrop(size=(224, 224), antialias=True),
v2.RandomHorizontalFlip(p=0.5),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
out = transforms(img)
plot([img, out])
```
Again, try re-running the cell a few times to see different samples.
Such transformation pipeline is typically passed as the `transform` argument
to the `Datasets` class, e.g. `ImageNet(...,transform=transforms)`.
See the [Transforms Gallery](https://pytorch.org/vision/stable/auto_examples/transforms/index.html#transforms-gallery) for examples of how to use augmentation transforms like `CutMix` and `MixUp`.
### Detection, Segmentation, Videos
Besides images, transforms can also transform bounding boxes, segmentation / detection masks, or videos.
Let's look at an example with bounding boxes.
```{python}
from torchvision import tv_tensors # we'll describe this a bit later
boxes = tv_tensors.BoundingBoxes(
[
[15, 10, 370, 510],
[275, 340, 510, 510],
[130, 345, 210, 425]
],
format="XYXY", canvas_size=img.shape[-2:])
transforms = v2.Compose([
v2.RandomResizedCrop(size=(224, 224), antialias=True),
v2.RandomPhotometricDistort(p=1),
v2.RandomHorizontalFlip(p=1),
])
out_img, out_boxes = transforms(img, boxes)
print(type(boxes), type(out_boxes))
plot([(img, boxes), (out_img, out_boxes)])
```
This also works for:
* masks (`torchvision.tv_tensors.Mask`) for object segmentation or semantic
segmentation, or
* videos (`torchvision.tv_tensors.Video`)
We could have passed them to the transforms in exactly the same way.
### What are TVTensors?
TVTensors are [`torch.Tensor`](https://pytorch.org/docs/stable/tensors.html#torch.Tensor) subclasses.
The available TVTensors are [`Image`](https://pytorch.org/vision/stable/generated/torchvision.tv_tensors.Image.html#torchvision.tv_tensors.Image),
[`BoundingBoxes`](https://pytorch.org/vision/stable/generated/torchvision.tv_tensors.BoundingBoxes.html#torchvision.tv_tensors.BoundingBoxes),
[`Mask`](https://pytorch.org/vision/stable/generated/torchvision.tv_tensors.Mask.html#torchvision.tv_tensors.Mask), and
[`Video`](https://pytorch.org/vision/stable/generated/torchvision.tv_tensors.Video.html#torchvision.tv_tensors.Video).
TVTensors look and feel just like regular tensors - they **are** tensors.
Everything that is supported on a plain `torch.Tensor` like `.sum()`
or any `torch.*` operator will also work on a TVTensor:
```{python}
img_dp = tv_tensors.Image(torch.randint(0, 256, (3, 256, 256), dtype=torch.uint8))
print(f"{isinstance(img_dp, torch.Tensor) = }")
print(f"{img_dp.dtype = }, {img_dp.shape = }, {img_dp.sum() = }")
```
These TVTensor classes are at the core of the transforms:
in order to transform a given input, the transforms first look at the **class**
of the object, and dispatch to the appropriate implementation accordingly.
For more info, see the [TVTensors FAQ](https://pytorch.org/vision/stable/auto_examples/transforms/plot_tv_tensors.html#sphx-glr-auto-examples-transforms-plot-tv-tensors-py).
### What do I pass as input?
Above, we've seen two examples:
* one where we passed a single image as input i.e. ``out = transforms(img)``, and
* one where we passed both an image and bounding boxes, i.e.
``out_img, out_boxes = transforms(img, boxes)``.
Transforms support **arbitrary input structures**.
The input can be e.g. a single image, a tuple, an arbitrarily nested dictionary...
The same structure will be returned as output.
Below, we use the
same detection transforms, but pass a tuple (image, target_dict) as input and
we're getting the same structure as output:
```{python}
target = {
"boxes": boxes,
"labels": torch.arange(boxes.shape[0]),
"this_is_ignored": ("arbitrary", {"structure": "!"})
}
# Re-using the transforms and definitions from above.
out_img, out_target = transforms(img, target)
plot([(img, target["boxes"]), (out_img, out_target["boxes"])])
print(f"{out_target['this_is_ignored']}")
```
We passed a tuple so we get a tuple back, and the second element is the
tranformed target dict.
Transforms don't really care about the structure of the input; as mentioned
above, they only care about the **type** of the objects and transforms them
accordingly.
*Foreign* objects like strings or ints are simply passed-through. This can be
useful e.g. if you want to associate a path with every single sample when
debugging!
## Transforms and Datasets intercompatibility
Roughly speaking, the output of the datasets must correspond to the input of
the transforms. How to do that depends on whether you're using the torchvision
[built-in datatsets](https://pytorch.org/vision/stable/datasets.html#datasets),
or your own custom datasets.
### Using built-in datasets
If you're just doing image classification, you don't need to do anything. Just
use `transform` argument of the dataset e.g. `ImageNet(..., transform=transforms)`.
#### Compatbility with Older Datasets
Torchvision also supports datasets for object detection or segmentation like
[`torchvision.datasets.CocoDetection`](https://pytorch.org/vision/stable/generated/torchvision.datasets.CocoDetection.html#torchvision.datasets.CocoDetection).
Those datasets predate
the existence of the `transforms.v2` module and of the
TVTensors, so they don't return TVTensors out of the box.
An easy way to force those datasets to return TVTensors and to make them
compatible with v2 transforms is to use the
`wrap_dataset_for_transforms_v2` function:
```python
from torchvision.datasets import CocoDetection, wrap_dataset_for_transforms_v2
dataset = CocoDetection(..., transforms=my_transforms)
dataset = wrap_dataset_for_transforms_v2(dataset)
# Now the dataset returns TVTensors!
```
### Using your own datasets
If you have a custom dataset, then you'll need to convert your objects into
the appropriate TVTensor classes.
To create TVTensor instances refer to
[How do I create TVTensors?](https://pytorch.org/vision/stable/auto_examples/transforms/plot_tv_tensors.html#tv-tensor-creation)
for more details.
There are two main places where you can implement that conversion logic:
- At the end of the datasets's ``__getitem__`` method, before returning the
sample (or by sub-classing the dataset).
- As the very first step of your transforms pipeline
Either way, the logic will depend on your specific dataset.
## V2 Transforms
See [V2 API Snapshot](https://pytorch.org/vision/stable/transforms.html#v2-api-reference-recommended) for a more complete list.
### Further Reading
- [torch.utils.data API](https://pytorch.org/docs/stable/data.html)
- [torchvision.transforms API](https://pytorch.org/vision/stable/transforms.html)