Thread
Are you interested in learning JAX with Flax? We have translated our popular Deep Learning tutorials on CNNs, GNNs, (Vision) Transformers, and more from PyTorch to JAX+Flax, with considerable speedups for smaller models! Check them out here: uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial2/Introduction_to_JAX.html
🧵 1/12
2/12 With the same setup, we found small CNNs on CIFAR10 to train almost 3x faster in JAX! While the notebooks are not meant as benchmarks, further optimizations can be done and larger models have much smaller speedups, it still surprised me. Curious to hear others' opinions!
3/12 How is JAX able to do that? JAX allows just-in-time compilation, which optimizes the joint forward and backward pass for the available accelerator (GPU/TPU). Want to learn more or see how it works in practice? Check out our tutorials! 👇
4/12 "Tutorial 2 (JAX): Introduction to JAX+Flax" - We give a simple introduction on how JAX works, and how you can start training your own small networks with Flax and Optax. uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial2/Introduction_to_JAX.html
5/12 "Tutorial 5 (JAX): Inception, ResNet and DenseNet" - We explain the basics of the most popular CNN architectures, and train a small GoogleNet, (Pre-Activation) ResNet, and DenseNet on CIFAR10. uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial5/Inception_ResNet_DenseNet....
6/12 "Tutorial 6 (JAX): Transformers and Multi-Head Attention" - We implement a Transformer from scratch and train it for set anomaly prediction on CIFAR100, where we use a pretrained ResNet as feature extractor.uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial6/Transformers_and_MHAttenti...
7/12 "Tutorial 7 (JAX): Graph Neural Networks" - We review GNNs by implementing the Graph Convolutional Network layer and Graph Attention. Examples of training full GNNs coming soon! uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial7/GNN_overview.html
8/12 "Tutorial 9 (JAX): Deep Autoencoders" - We show how to implement and train simple autoencoders on CIFAR10, and how you can use an autoencoder as an image retrieval model! uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial9/AE_CIFAR10.html
9/12 "Tutorial 11 (JAX): Normalizing Flows for image modeling" - As a generative model, we implement a normalizing flow on MNIST, and discuss the challenges coming with discrete data in flows! uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial11/NF_image_modeling.html
10/12 "Tutorial 15 (JAX): Vision Transformers" - Following-up Tutorial 6, we implement a small Vision Transformer and train it on CIFAR10. Spoiler: without pre-training, a Vision Transformer does not necessarily outperform all CNNs on CIFAR10. uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial15/Vision_Transformer.html
11/12 We have also added the PyTorch tutorials from the DL2 course on advanced topics, like Geometric DL, Neural ODEs, and more! Thanks to all the TAs! @_gabrielecesa_ @davidknigge @BryanEikema @oneapra @AmandaIlze @LeonardBereska @RValperga @MiltosKofinas Adeel Pervez, Cyril Hsu
12/12 If you find our tutorials helpful, consider ⭐-ing our repository: github.com/phlippe/uvadlc_notebooks
Similarly, if you found a bug or have a question/comment on a notebook, feel free to open a GitHub issue!
Mentions
There are no mentions of this content so far.