Using the Compile API
We will start learning the basic usage of the Compile API by applying it to our well-known CNN model and Fashion-MNIST dataset. After that, we will accelerate a heavier model that’s used to classify images from the CIFAR-10 dataset.
Basic usage
Instead of describing the API’s components and explaining a bunch of optional parameters, let’s dive into a simple example to show the basic usage of this capability. The following piece of code uses the Compile API to compile the CNN model presented in previous chapters:
model = CNN()graph_model = torch.compile(model)
Note
The complete code shown in this section is available at https://github.com/PacktPublishing/Accelerate-Model-Training-with-PyTorch-2.X/blob/main/code/chapter03/cnn-graph_mode.ipynb.
To compile a model, we need to call a function named compile
, passing the model as a parameter. Nothing else is necessary for the basic usage of this API. compile
returns an object that...