Diffusion Models (Image)
Arcus Prompt Enrichment enriches inputs provided to image generation models (e.g. diffusion models) with external data matched by the Arcus Data Exchange. The exchange automatically matches your input to high value external image data that is related to your task at hand, then provides this context to your diffusion model to:
- Prevent hallucinations. By grounding on real-world images, you can prevent common diffusion model errors, such as misspelled text and non-photorealistic images.
- Better performance on specific terms and concepts. The training data your diffusion model has seen may not contain examples of the specific concepts you are interested in. By injecting relevant images into your diffusion model, you can exercise better semantic control of your generated images.
- Better performance on specific image styles. By injecting relevant images of appropriate styles into your diffusion model, you can control the style of your generated images.
Arcus Prompt Enrichment works by using Data Augmented Generation, a technique which first retrieves relevant external data from the exchange and then composes and injects this context into your diffusion model alongside your original prompt (see Retrieval-Augmented Diffusion Models and Re-Imagen). Arcus Prompt Enrichment also provides you a summary of the images that were provided to your model, so you can better understand the context that was injected into your model.
Let’s walk through how to use Arcus Prompt Enrichment for diffusion models using the Arcus Prompt SDK. Before we get started, you should create a prompt enrichment project on the Arcus platform (request early access here) and have your Project ID and API Key ready.
First, you configure your environment to connect the Prompt SDK to your Arcus Project and diffusion model. To do this, wrap your Diffusion Model in an Arcus DiffusionModel
object, which connects to the exchange to discover and compose relevant external images to enrich your prompts while maintaining the same functionality as your original Diffusion Model.
Let’s look at an example using Stable Diffusion in Pytorch with a DPMSampler
to generate images. In a few lines of code, we’ll initialize an Arcus Config
object and use this to wrap our existing model with an Arcus DiffusionModel
. The Config
object takes in your Project ID and API Key, which you can find in the Arcus platform.
# Set the Config object
arcus_config = arcus.prompt.text.Config(
api_key='MY_API_KEY',
project_id='MY_PROJECT_ID',
)
# Initialize your original Diffusion Model
my_diffusion_model = MyDiffusionModel()
# Initialize an Arcus Diffusion Model Object
arcus_diffusion_model = arcus.prompt.image.DiffusionModel(
my_diffusion_model,
arcus_config
)
This new arcus_diffusion_model
object maintains all the same functionality you use to control image generations with Stable Diffusion, but with the added functionality of combining your prompts with external images provided the exchange.
Now that you’ve configured your connection to Arcus, you are ready to enrich your image generations with external images matched by the exchange.
Let’s look at an example below. To generate a single 768 x 768 image using Stable Diffusion and a DPMSampler
with 100 steps, your original code might look like the following:
sampler = DPMSampler(my_diffusion_model)
prompt = "a professional photograph of an astronaut riding a horse"
c = model.get_learned_conditioning([prompt])
generated_image, _ = sampler.sample(
S=100,
conditioning=c,
batch_size=1,
shape=[3, 768, 768]
)
Now, using the wrapped arcus_diffusion_model
object, you can generate an image with the same code, but with the added functionality of enriching your prompt with external images matched by the exchange.
sampler = DPMSampler(arcus_diffusion_model)
prompt = "a professional photograph of an astronaut riding a horse"
c = model.get_learned_conditioning([prompt])
response, _ = sampler.sample(
S=100,
conditioning=c,
batch_size=1,
shape=[3, 768, 768]
)
When you call sampler.sample()
, the arcus_diffusion_model
object performs the following steps:
- Queries the Arcus Data Exchange to discover high value and relevant images for your given prompt. Arcus’ matching algorithms rank external data candidates on the exchange by their inherent quality and their relevance to your given prompt.
- Retrieves this external data and passes it to your diffusion model as additional conditioning.
- Samples an image using your diffusion model with the combined prompt and external images as conditioning.
sampler.sample()
returns an Arcus DiffusionOutput
object which stores the generated response as well as summary information about the images provided by the exchange. You can access the generated response as follows.
>>> import matplotlib.pyplot as plt
>>> plt.imshow(response.get_generation())
The Arcus Prompt SDK also provides summary information about the context that was provided to your LLM. This helps you understand what context the exchange found most valuable and relevant to your prompt.
You can access this information using the get_context_summary()
method on your DiffusionOutput
object.
>>> print(response.get_context_summary())
The context includes images of professional photographs, astronauts, and horses.
In this instance, we see that the exchange provided valuable context in the form of relevant data about the prompt, which can help prevent hallucinations and improve semantic control.