PitchHut logo
Awesome JAX & Flax LLMs
by sour_amethyst_jobey
Explore cutting-edge LLM implementations with JAX and Flax
Pitch

This repository offers a curated selection of open-source large language model implementations using JAX and Flax, optimized for high-speed TPU/GPU training. Built with a modular and extensible codebase, it supports various architectures and fine-tuning capabilities, suitable for educational and experimental purposes.

Description

Welcome to Awesome JAX & Flax LLMs, a curated collection of open-source large language model (LLM) implementations constructed using JAX and Flax. This repository features modular, efficient, and scalable implementations of transformer-based models that are specifically optimized for high-speed TPU/GPU training and efficient inference.

Note: The implementations provided here are intended for educational purposes only and are not meant for production use. They encompass all model components and can be adapted to fulfill production requirements.

Key Features

  • Multiple LLM architectures implemented in JAX/Flax.
  • Optimization for TPU acceleration utilizing JAX’s XLA compiler.
  • A highly modular and extensible codebase for enhanced adaptability.
  • Efficient training facilitated by Optax optimizers.
  • Support for Hugging Face to enable training on various datasets.
  • Fine-tuning support is forthcoming.

Implemented Models

GPT-2 - JAX/Flax

This implementation of a compact transformer-based language model is developed in pure JAX/Flax, leveraging XLA optimizations for parallelism, thus ensuring efficiency on TPUs and GPUs. It serves as the basis for exploring JAX-based language modeling.

  • Notebook: models/gpt-2/gpt2_in_jax.ipynb
  • Script: models/gpt-2/train.py

Llama 3 - JAX/Flax

An enhanced version of the Llama series, it incorporates state-of-the-art optimizations in JAX that handle longer context windows and achieve a reduced memory footprint through precision tuning.

  • Notebook: models/llama3/llama3_in_jax.ipynb
  • Script: models/llama3/llama3_in_jax.py

DeepSeek-R1 - JAX/Flax (Work In Progress)

This cutting-edge deep learning model is designed for highly efficient semantic search. It employs advanced transformer architectures and JAX optimizations for faster retrieval and lower computational costs.

Mistral - JAX/Flax (Coming Soon)

Anticipated to deliver a high-performance implementation of the Mistral architecture, featuring dense & sparse mixture-of-expert layers. This model will showcase advanced TPU utilization along with optimized autoregressive decoding.

Recommended Environment

The models are best executed in Google Colab, which provides free TPU support for optimal performance.

Exploring the Implementations

Each model comes with its dedicated Jupyter notebook. Users can navigate to the specific directories and open the notebooks in Google Colab for hands-on exploration of the implementations.

Future Enhancements

  • Enable fine-tuning support to facilitate training on custom datasets.
  • Expand the repository with larger model implementations, adding diverse LLMs.

References

Contributions to the repository are welcome, and community members are encouraged to submit issues and pull requests.

0 comments

No comments yet.

Sign in to be the first to comment.