Accelerating Stable Diffusion XL Inference with JAX on Cl... | Accelerating Stable Diffusion XL Inference with JAX on Cl...
Accelerating Stable Diffusion XL Inference with JAX on Cloud TPU v5e
Generative AI models, such as Stable Diffusion XL (SDXL), enable the creation of high-quality, realistic content with wide-ranging applications. However, harnessing the power of such models presents significant challenges and computational costs. SDXL is a large image generation model whose UNet component is about three times as large as the one in the previous version of the model. Deploying a model like this in production is challenging due to the increased memory requirements, as well as increased inference times. Today, we are thrilled to announce that Hugging Face Diffusers now supports serving SDXL using JAX on Cloud TPUs, enabling high-performance, cost-efficient inference.

Google Cloud TPUs are custom-designed AI accelerators, which are optimized for training and inference of large AI models, including state-of-the-art LLMs and generative AI models such as SDXL. The new Cloud TPU v5e is purpose-built to bring the cost-efficiency and performance required for large-scale AI training and inference. At less than half the cost of TPU v4, TPU v5e makes it possible for more organizations to train and deploy AI models.

🧨 Diffusers JAX integration offers a convenient way to run SDXL on TPU via XLA, and we built a demo to showcase it. You can try it out in this Space or in the playground embedded below:

<script type="module" src="https://gradio.s3-us-west-2.amazonaws.com/3.45.1/gradio.js"> </script>
Under the hood, this demo runs on several TPU v5e-4 instances (each instance has 4 TPU chips) and takes advantage of parallelization to serve four large 1024×1024 images in about 4 seconds. This time includes format conversions, communications time, and frontend processing; the actual generation time is about 2.3s, as we'll see below!