Generative Adversarial Network (GAN) can be used to generate data such as images and videos. This tutorial will show you how to train a GAN with MMEngine step by step!
> - [Build a Generator Network and a Discriminator Network](#build-a-generator-network-and-a-discriminator-network)
> - [Build a Generative Adversarial Network Model](#build-a-generative-adversarial-network-model)
> - [Build an Optimizer](#building-an-optimizer)
> - [Train with Runner](#training-with-runner)
## Building a DataLoader
### Building a Dataset
First, we will build a dataset class `MNISTDataset` for the MNIST dataset, inheriting from the base dataset class [BaseDataset](mmengine.dataset.BaseDataset), and overwrite the `load_data_list` function of the base dataset class to ensure that the return value is a `list[dict]`, where each `dict` represents a data sample.
The following code implements the basic algorithm of GAN. To implement the algorithm using MMEngine, you need to inherit from the [BaseModel](mmengine.model.BaseModel) and implement the training process in the train_step. GAN requires alternating training of the generator and discriminator, which are implemented by train_discriminator and train_generator and implement disc_loss and gen_loss to calculate the discriminator loss function and generator loss function.
More details about BaseModel, refer to [Model tutorial](../tutorials/model.md).
The function, set_requires_grad, is used to lock the weights of the discriminator when training the generator.
```python
def set_requires_grad(nets, requires_grad=False):
"""Set requires_grad for all the networks.
Args:
nets (nn.Module | list[nn.Module]): A list of networks or a single
network.
requires_grad (bool): Whether the networks require gradients or not.
"""
if not isinstance(nets, list):
nets = [nets]
for net in nets:
if net is not None:
for param in net.parameters():
param.requires_grad = requires_grad
```
```python
model = GAN(generator, discriminator, 100, data_preprocessor)
```
## Building an Optimizer
MMEngine uses [OptimWrapper](mmengine.optim.OptimWrapper) to wrap optimizers. For multiple optimizers, we use [OptimWrapperDict](mmengine.optim.OptimWrapperDict) to further wrap OptimWrapper.
The following code demonstrates how to use Runner for model training.
More details about Runner, please refer to the [Runner tutorial](../tutorials/runner.md).
```python
train_cfg = dict(by_epoch=True, max_epochs=220)
runner = Runner(
model,
work_dir='runs/gan/',
train_dataloader=train_dataloader,
train_cfg=train_cfg,
optim_wrapper=opt_wrapper_dict)
runner.train()
```
Till now, we have completed an example of training a GAN. The following code can be used to view the results generated by the GAN we just trained.

If you want to learn more about using MMEngine to implement GAN and generative models, we highly recommend you try the generative framework [MMGeneration](https://github.com/open-mmlab/mmgeneration/tree/dev-1.x) based on MMEngine.