RequestScopedPipeline: Concurrent Inference in Diffusers for Production Servers
Diffusers pipelines were not designed for concurrency. For example, calling pipe()
simultaneously causes race conditions in schedulers, "Already borrowed" errors in Rust tokenizers, or duplicates entire models in memory.
To solve these problems, I created RequestScopedPipeline along with BaseAsyncScheduler. The solution works as follows:
-
Shallow copy per request: Instead of duplicating the entire pipeline, a shallow copy is created that maintains references to the heavy components.
-
Selective cloning: Only small, mutable components are cloned such as:
- Schedulers
- Offloading hooks
- RNG state (random number generator)
- Temporary latents
- Callback configuration
-
Shared components: Heavy models (UNet, VAE, Text Encoder/Transformer) remain shared in GPU memory and are not duplicated per request.
This strategy allows running multiple inferences in parallel without exploding memory.
What does this code do?
- Defines a list of default mutable attributes (
DEFAULT_MUTABLE_ATTRS
) - Stores references to heavy components (UNet, VAE, Text Encoder) that will be shared between requests
- Automatically wraps the scheduler in
BaseAsyncScheduler
ifwrap_scheduler=True
- Configures threading locks for tokenizers (avoids the "Already borrowed" error)
- Allows auto-detection of mutable attributes with a configurable threshold for tensors
class RequestScopedPipeline:
DEFAULT_MUTABLE_ATTRS = [
"_all_hooks",
"_offload_device",
"_progress_bar_config",
"_progress_bar",
"_rng_state",
"_last_seed",
"latents",
]
def __init__(
self,
pipeline: Any,
mutable_attrs: Optional[Iterable[str]] = None,
auto_detect_mutables: bool = True,
tensor_numel_threshold: int = 1_000_000,
tokenizer_lock: Optional[threading.Lock] = None,
wrap_scheduler: bool = True,
):
self._base = pipeline
self.unet = getattr(pipeline, "unet", None)
self.vae = getattr(pipeline, "vae", None)
self.text_encoder = getattr(pipeline, "text_encoder", None)
self.components = getattr(pipeline, "components", None)
if wrap_scheduler and hasattr(pipeline, "scheduler") and pipeline.scheduler is not None:
if not isinstance(pipeline.scheduler, BaseAsyncScheduler):
pipeline.scheduler = BaseAsyncScheduler(pipeline.scheduler)
self._mutable_attrs = list(mutable_attrs) if mutable_attrs is not None else list(self.DEFAULT_MUTABLE_ATTRS)
self._tokenizer_lock = tokenizer_lock if tokenizer_lock is not None else threading.Lock()
self._auto_detect_mutables = bool(auto_detect_mutables)
self._tensor_numel_threshold = int(tensor_numel_threshold)
self._auto_detected_attrs: List[str] = []
Per-Request Scheduler Isolation
While the initialization above sets the stage, the critical component for safe concurrency is the _make_local_scheduler
method. This method solves the main problem: preventing multiple requests from modifying the same scheduler simultaneously.
The Problem with Shared Schedulers
When multiple requests call pipe()
at the same time:
- They all try to modify
scheduler.timesteps
withset_timesteps()
- This causes race conditions where one request overwrites another's timesteps
- The result is corrupted images or inference errors
The Solution: Scheduler Cloning
The _make_local_scheduler
method creates a completely isolated scheduler for each request:
def _make_local_scheduler(self, num_inference_steps: int, device: Optional[str] = None, **clone_kwargs):
base_sched = getattr(self._base, "scheduler", None)
if base_sched is None:
return None
# Ensure the scheduler is wrapped in BaseAsyncScheduler
if not isinstance(base_sched, BaseAsyncScheduler):
wrapped_scheduler = BaseAsyncScheduler(base_sched)
else:
wrapped_scheduler = base_sched
try:
# Clone the scheduler for this specific request
return wrapped_scheduler.clone_for_request(
num_inference_steps=num_inference_steps,
device=device,
**clone_kwargs
)
except Exception as e:
# Fallback if cloning the scheduler for this specific request with the wrapper doesn't work
logger.debug(f"clone_for_request failed: {e}; falling back to deepcopy()")
try:
return copy.deepcopy(wrapped_scheduler)
except Exception as e:
# If it fails again, return the original scheduler (risky)
logger.warning(f"Deepcopy of scheduler failed: {e}. Returning original scheduler (*risky*).")
return wrapped_scheduler
BaseAsyncScheduler: The Cloning Wrapper
BaseAsyncScheduler
acts as a proxy that adds safe cloning capability:
class BaseAsyncScheduler:
def __init__(self, scheduler: Any):
self.scheduler = scheduler
def clone_for_request(self, num_inference_steps: int, device: Union[str, torch.device, None] = None, **kwargs):
# Create a deep copy of the scheduler
local = copy.deepcopy(self.scheduler)
# Configure timesteps IN THE LOCAL COPY (not in the shared scheduler)
local.set_timesteps(num_inference_steps=num_inference_steps, device=device, **kwargs)
# Wrap the copy in a new BaseAsyncScheduler
cloned = self.__class__(local)
return cloned
Why does it work?
deepcopy
creates a completely independent copy of the schedulerset_timesteps()
is called on the local copy, not on the shared scheduler- Each request gets its own scheduler with its own timesteps
- Requests never interfere with each other
Complete flow:
Request 1 → _make_local_scheduler() → deepcopy(scheduler) → scheduler_local_1
Request 2 → _make_local_scheduler() → deepcopy(scheduler) → scheduler_local_2
Request 3 → _make_local_scheduler() → deepcopy(scheduler) → scheduler_local_3
↓ ↓
Inference with scheduler_local_1 All isolated, no race conditions
Helper Methods for Safe Concurrency
Automatic Mutable Detection
The _autodetect_mutables(...)
method automatically detects mutable components such as lists, dicts, sets, mutable tuples, or small tensors and marks them for cloning when creating a per-request instance.
Read-Only Property Verification
The _is_readonly_property(...)
method detects if attr_name
in base_obj
is a descriptor-type property without a setter, allowing us to avoid read-only properties and prevent an AttributeError
when trying to clone them, thus not breaking the process.
Mutable Attribute Cloning
The _clone_mutable_attrs(...)
method clones relevant mutable attributes from the base
object to local
so that local
has its own isolated state per request.
Tokenizer Detection
The _is_tokenizer_component(...)
detects if the component is a legitimate tokenizer using heuristics by detecting the 3 typical tokenizer methods such as encode
, decode
, and tokenize
.
Concurrent Inference
Given everything stated above, safe concurrent inference can now be performed without breaking any component of the model pipeline, using the method generate(...)
.
- We start by cloning the scheduler adapted to this request (for example with specific timesteps/state).
def generate(self, *args, num_inference_steps: int = 50, device: Optional[str] = None, **kwargs):
local_scheduler = self._make_local_scheduler(num_inference_steps=num_inference_steps, device=device)
...
- We clone the base pipeline to have an independent local_pipe object at the attribute level (allowing reassignment of
local_pipe.scheduler
, etc.) without cloning all the content.
try:
local_pipe = copy.copy(self._base) # Try to do a shallow copy; copies the attribute structure, but references to complex objects remain shared.
except Exception as e:
logger.warning(f"copy.copy(self._base) failed: {e}. Falling back to deepcopy (may increase memory).")
local_pipe = copy.deepcopy(self._base) # If it fails, use deepcopy (heavier on memory and time)
- Replace the local
scheduler
iflocal_scheduler
exists
if local_scheduler is not None:
try:
timesteps, num_steps, configured_scheduler = async_retrieve_timesteps(
local_scheduler.scheduler,
num_inference_steps=num_inference_steps,
device=device,
return_scheduler=True,
**{k: v for k, v in kwargs.items() if k in ['timesteps', 'sigmas']}
)
final_scheduler = BaseAsyncScheduler(configured_scheduler)
setattr(local_pipe, "scheduler", final_scheduler)
except Exception:
logger.warning("Could not set scheduler on local pipe; proceeding without replacing scheduler.")
- Cloning of mutable attributes and wrapping of tokenizers for their respective lock
Attributes listed in _mutable_attrs
and those detected by _autodetect_mutables
are cloned; the tokenizer component is obtained and calls to the tokenizer are serialized using self._tokenizer_lock
to avoid problems with non-thread-safe tokenizers.
self._clone_mutable_attrs(self._base, local_pipe)
# Wrap tokenizers in the local pipeline with the locking wrapper
tokenizer_wrappers = {} # name -> original_tokenizer
try:
# a) Wrap direct tokenizer attributes (tokenizer, tokenizer_2, ...)
for name in dir(local_pipe):
if "tokenizer" in name and not name.startswith("_"):
tok = getattr(local_pipe, name, None)
if tok is not None and self._is_tokenizer_component(tok):
tokenizer_wrappers[name] = tok
setattr(
local_pipe,
name,
lambda *args, tok=tok, **kwargs: safe_tokenize(tok, *args, lock=self._tokenizer_lock, **kwargs)
)
# b) Wrap tokenizers in the components dictionary
if hasattr(local_pipe, "components") and isinstance(local_pipe.components, dict):
for key, val in local_pipe.components.items():
if val is None:
continue
if self._is_tokenizer_component(val):
tokenizer_wrappers[f"components[{key}]"] = val
local_pipe.components[key] = lambda *args, tokenizer=val, **kwargs: safe_tokenize(
tokenizer, *args, lock=self._tokenizer_lock, **kwargs
)
except Exception as e:
logger.debug(f"Tokenizer wrapping step encountered an error: {e}")
- Execute inference (with offload context if applicable) and restore tokenizers
model_cpu_offload_context
can be:
- A function that returns a context manager (call cm()
), or an already instantiated context manager (use with cm).
Code handles both cases (catches TypeError
if cm()
fails and tries with cm
), if everything fails, falls back to executing local_pipe
without offload context
, the finally
guarantees restoring tokenizers even if inference raises an exception.
result = None
cm = getattr(local_pipe, "model_cpu_offload_context", None)
try:
if callable(cm):
try:
with cm():
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
except TypeError:
# cm could be a context manager instance instead of a callable
try:
with cm:
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
except Exception as e:
logger.debug(f"model_cpu_offload_context usage failed: {e}. Proceeding without it.")
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
else:
# No offload context available: call directly
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
return result
finally:
try:
for name, tok in tokenizer_wrappers.items():
if name.startswith("components["):
key = name[len("components["):-1]
local_pipe.components[key] = tok
else:
setattr(local_pipe, name, tok)
except Exception as e:
logger.debug(f"Error restoring wrapped tokenizers: {e}")
Example of RequestScopedPipeline usage in an asynchronous FastAPI server
# All imports, lifespan, etc.
app = FastAPI(lifespan=lifespan)
logger = logging.getLogger("DiffusersServer.Pipelines")
# Wrapper to initialize and launch the model pipeline
initializer = ModelPipelineInitializer(
model=server_config.model,
type_models=server_config.type_models,
)
model_pipeline = initializer.initialize_pipeline()
model_pipeline.start()
# Pass the base pipeline to RequestScopedPipeline
request_pipe = RequestScopedPipeline(model_pipeline.pipeline)
pipeline_lock = threading.Lock()
logger.info(f"Pipeline initialized and ready to receive requests (model ={server_config.model})")
# Save component states to make them easily accessible from FastAPI
app.state.MODEL_INITIALIZER = initializer
app.state.MODEL_PIPELINE = model_pipeline
app.state.REQUEST_PIPE = request_pipe
app.state.PIPELINE_LOCK = pipeline_lock
# Basic Pydantic model to request everything needed for inference
class JSONBodyQueryAPI(BaseModel):
model : str | None = None
prompt : str
negative_prompt : str | None = None
num_inference_steps : int = 28
num_images_per_prompt : int = 1
@app.post("/api/diffusers/inference")
async def api(json: JSONBodyQueryAPI):
prompt = json.prompt
negative_prompt = json.negative_prompt or ""
num_steps = json.num_inference_steps
num_images_per_prompt = json.num_images_per_prompt
wrapper = app.state.MODEL_PIPELINE
initializer = app.state.MODEL_INITIALIZER
utils_app = app.state.utils_app
if not wrapper or not wrapper.pipeline:
raise HTTPException(500, "Model not initialized correctly")
if not prompt.strip():
raise HTTPException(400, "No prompt provided")
def make_generator():
g = torch.Generator(device=initializer.device)
return g.manual_seed(random.randint(0, 10_000_000))
req_pipe = app.state.REQUEST_PIPE
# Function to perform safe per-request inference using app.state.REQUEST_PIPE
def infer():
gen = make_generator()
return req_pipe.generate(
prompt=prompt,
negative_prompt=negative_prompt,
generator=gen,
num_inference_steps=num_steps,
num_images_per_prompt=num_images_per_prompt,
device=initializer.device,
output_type="pil",
)
try:
async with app.state.metrics_lock:
app.state.active_inferences += 1
# Execute inference in a non-blocking manner
output = await run_in_threadpool(infer)
async with app.state.metrics_lock:
app.state.active_inferences = max(0, app.state.active_inferences - 1)
# When inference finishes, generate a list of generated images (type ["/images/image_123.png", "/images/image_1234.png"]) so they can be downloaded
urls = [utils_app.save_image(img) for img in output.images]
return {"response": urls}
except Exception as e:
async with app.state.metrics_lock:
app.state.active_inferences = max(0, app.state.active_inferences - 1)
logger.error(f"Error during inference: {e}")
raise HTTPException(500, f"Error in processing: {e}")
finally:
if torch.cuda.is_available():
torch.cuda.synchronize()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
torch.cuda.ipc_collect()
gc.collect()
# This serves to allow downloading of generated images
@app.get("/images/{filename}")
async def serve_image(filename: str):
utils_app = app.state.utils_app
file_path = os.path.join(utils_app.image_dir, filename)
if not os.path.isfile(file_path):
raise HTTPException(status_code=404, detail="Image not found")
return FileResponse(file_path, media_type="image/png")
References
- PR #12328 with all changes: https://github.com/huggingface/diffusers/pull/12328
- Async server code: https://github.com/huggingface/diffusers/tree/main/examples/server-async
Questions or suggestions? Contact me at fredyriveraacevedo13@gmail.com