Image Generation Using Text - Stable Diffusion
Introduction
Stable Diffusion is a text-to-image model trained on 512x512 images from a subset of the LAION-5B dataset.
The goal of this notebook is to demonstrate how easily you can implement text-to-image generation using the 🤗 Diffusers library, which is the go-to library for state-of-the-art pre-trained diffusion models for generating images, audio, and 3D structures.
Before jumping into the coding, however, we need to know what exactly is Stable Diffusion.
What is Stable Diffusion?
Architecture of the Stable Diffusion
Stable Diffusion is based on a type of diffusion model that is called Latent Diffusion, which details can be seen in the paper High-Resolution Image Synthesis with Latent Diffusion Models.
Diffusion models are a type of generative model that is trained to denoise an object, such as an image, to obtain a sample of interest. The model is trained to slightly denoise the image in each step, until a sample is obtained. This process can be seen below:
Image denoising process
These diffusion models have gained popularity in recent years, specially for their ability to achieve state-of-the-art results in generating image data. However, diffusion models can consume a lot of memory and be computationally expensive to work with.
Latent Diffusion, on the other side, reduces complexity and memory usage by applying the diffusion process over a lower dimensional latent space. In latent diffusion, the model is trained to generate compressed representations of images.
There are three main components in latent diffusion.
1. Autoencoder (VAE).
2. U-Net.
3. A text-encoder.
Autoencoder (VAE)
The Variational Autoencoder architecture
The Variational Autoencoder (VAE) is a model that consists of both an encoder and a decoder. While the encoder is used to convert the image into a low dimensional latent representation to serve an input to the U-Net model, the decoder transforms the latent representation back into an image.
U-Net
The U-Net architecture
The U-Net is a convolutional neural network that is widely used in image segmentation tasks. It also has an encoder and a decoder, both comprised of ResNet blocks. The encoder compresses an image into a lower resolution image, while the decoder decodes this lower resolution image back to the original higher resolution, which is supposed to be less noisy.
Text Encoder
How does a text encoder work?
The text encoder is responsible for transforming the text input prompt into an embedding space that can be understood by the U-Net. It is usually a simple transformer-based encoder that maps a sequence of input tokens to a sequence of latent text embeddings.
Stable Diffusion Pipeline with 🤗 Diffusers
The StableDiffusionPipeline
is a pipeline created by the 🤗 Diffusers library that allows us to generate images from text with just a few lines of code in Python. It has many versions and checkpoints available, which you can take a look at by visiting the Text-to-Image Generation page of the library documentation.
For this notebook, we are going to use the Stable Diffusion version 1.4 (CompVis/stable-diffusion-v1–4). We are also using torch_dtype = torch.float16
to load the fp16 weights, this is helpful to reduce the cost of memory used.
Let’s start our project by installing the diffusers library.
# Installing diffusers library
!pip install diffusers
We may now import all relevant libraries
# Library imports
# Importing PyTorch library, for building and training neural networks
import torch# Importing StableDiffusionPipeline to use pre-trained Stable Diffusion models
from diffusers import StableDiffusionPipeline# Image is a class for the PIL module to visualize images in a Python Notebook
from PIL import Image
Let’s create an instance of the pipeline.
The .from_pretrained("CompVis/stable-diffusion-v1-4")
will initialize the diffusion model with pre-trained weights and settings, as well as a pre-trained VAE, U-Net, and text encoder to generate images from text.
The torch_dtype = torch.float16
sets the datatype of the model to float16, a lower-precision floating-point format, to help speed up inference.
# Creating pipeline
pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4",
torch_dtype=torch.float16)
Now, we may define a function that is going to create and display a grid of images generated with Stable Diffusion.
# Defining function for the creation of a grid of images
def image_grid(imgs, rows, cols):
assert len(imgs) == rows*cols
w, h = imgs[0].size
grid = Image.new('RGB', size = (cols*w,
rows * w))
grid_w, grid_h = grid.size
for i, img in enumerate(imgs):
grid.paste(img, box = (i%cols*w, i // cols*h))
return grid
Next, we use PyTorch’s to
method to move our pipeline to the GPU, which speeds up the training and inference of neural networks.
# Moving pipeline to GPU
pipeline = pipeline.to('cuda'
Now, we can finally use Stable Diffusion to generate images from text!
In the code below, n_images
is used to define how many images will be generated. Whereas prompt
is the text that is going to be used to generate the images we would like to generate.
n_images = 6 # Let's generate 6 images based on the prompt below
prompt = ['Sunset on a beach'] * n_images
images = pipeline(prompt).imagesgrid = image_grid(images, rows=2, cols = 3)
grid
Sunset on a beach
n_images = 6
prompt = ['Portrait of Napoleon Bonaparte'] * n_images
images = pipeline(prompt).imagesgrid = image_grid(images, rows=2, cols = 3)
grid
Portrait of Napoleon Bonaparte
n_images = 6
prompt = ['Skyline of a cyberpunk megalopolis'] * n_images
images = pipeline(prompt).imagesgrid = image_grid(images, rows=2, cols = 3)
grid
Skyline of a cyberpunk megalopolis
n_images = 6
prompt = ['Painting of a woman in the style of Van Gogh'] * n_images
images = pipeline(prompt).imagesgrid = image_grid(images, rows=2, cols = 3)
grid
Painting of a woman in the style of Van Gogh
n_images = 6
prompt = ['Picture of an astronaut in space'] * n_images
images = pipeline(prompt).imagesgrid = image_grid(images, rows=2, cols = 3)
grid
Picture of an astronaut in space
n_images = 6
prompt = ['Renaissance marble bust sculpture'] * n_images
images = pipeline(prompt).imagesgrid = image_grid(images, rows=2, cols = 3)
grid
Renaissance marble bust sculpture
Thank you for reading,
linkedIn: Abdul Hadi Ali | LinkedIn
Github Repo:Abdul Hadi Ali | Github
Comments
Post a Comment