Diffusers pipelines were not designed for concurrency. For example, calling pipe() simultaneously causes race conditions in schedulers, "Already borrowed" errors in Rust tokenizers, or duplicates entire models in memory.
To solve these problems, I created RequestScopedPipeline along with BaseAsyncScheduler. The solution works as follows:
Shallow copy per request: Instead of duplicating the entire pipeline, a shallow copy is created that maintains references to the heavy components.
Selective cloning: Only small, mutable components are cloned such as:
Shared components: Heavy models (UNet, VAE, Text Encoder/Transformer) remain shared in GPU memory and are not duplicated per request.
This strategy allows running multiple inferences in parallel without exploding memory.
What does this code do?
DEFAULT_MUTABLE_ATTRS)BaseAsyncScheduler if wrap_scheduler=Trueclass RequestScopedPipeline: DEFAULT_MUTABLE_ATTRS = [ "_all_hooks", "_offload_device", "_progress_bar_config", "_progress_bar", "_rng_state", "_last_seed", "latents", ]
def __init__( self, pipeline: Any, mutable_attrs: Optional[Iterable[str]] = None, auto_detect_mutables: bool = True, tensor_numel_threshold: int = 1_000_000, tokenizer_lock: Optional[threading.Lock] = None, wrap_scheduler: bool = True, ): self._base = pipeline self.unet = getattr(pipeline, "unet", None) self.vae = getattr(pipeline, "vae", None) self.text_encoder = getattr(pipeline, "text_encoder", None) self.components = getattr(pipeline, "components", None)
if wrap_scheduler and hasattr(pipeline, "scheduler") and pipeline.scheduler is not None: if not isinstance(pipeline.scheduler, BaseAsyncScheduler): pipeline.scheduler = BaseAsyncScheduler(pipeline.scheduler)
self._mutable_attrs = list(mutable_attrs) if mutable_attrs is not None else list(self.DEFAULT_MUTABLE_ATTRS) self._tokenizer_lock = tokenizer_lock if tokenizer_lock is not None else threading.Lock()
self._auto_detect_mutables = bool(auto_detect_mutables) self._tensor_numel_threshold = int(tensor_numel_threshold)
self._auto_detected_attrs: List[str] = []While the initialization above sets the stage, the critical component for safe concurrency is the _make_local_scheduler method. This method solves the main problem: preventing multiple requests from modifying the same scheduler simultaneously.
When multiple requests call pipe() at the same time:
scheduler.timesteps with set_timesteps()The _make_local_scheduler method creates a completely isolated scheduler for each request:
def _make_local_scheduler(self, num_inference_steps: int, device: Optional[str] = None, **clone_kwargs): base_sched = getattr(self._base, "scheduler", None) if base_sched is None: return None
# Ensure the scheduler is wrapped in BaseAsyncScheduler if not isinstance(base_sched, BaseAsyncScheduler): wrapped_scheduler = BaseAsyncScheduler(base_sched) else: wrapped_scheduler = base_sched
try: # Clone the scheduler for this specific request return wrapped_scheduler.clone_for_request( num_inference_steps=num_inference_steps, device=device, **clone_kwargs ) except Exception as e: # Fallback if cloning the scheduler for this specific request with the wrapper doesn't work logger.debug(f"clone_for_request failed: {e}; falling back to deepcopy()") try: return copy.deepcopy(wrapped_scheduler) except Exception as e: # If it fails again, return the original scheduler (risky) logger.warning(f"Deepcopy of scheduler failed: {e}. Returning original scheduler (*risky*).") return wrapped_schedulerBaseAsyncScheduler acts as a proxy that adds safe cloning capability:
class BaseAsyncScheduler: def __init__(self, scheduler: Any): self.scheduler = scheduler def clone_for_request(self, num_inference_steps: int, device: Union[str, torch.device, None] = None, **kwargs): # Create a deep copy of the scheduler local = copy.deepcopy(self.scheduler) # Configure timesteps IN THE LOCAL COPY (not in the shared scheduler) local.set_timesteps(num_inference_steps=num_inference_steps, device=device, **kwargs) # Wrap the copy in a new BaseAsyncScheduler cloned = self.__class__(local) return clonedWhy does it work?
deepcopy creates a completely independent copy of the schedulerset_timesteps() is called on the local copy, not on the shared schedulerComplete flow:
Request 1 → _make_local_scheduler() → deepcopy(scheduler) → scheduler_local_1Request 2 → _make_local_scheduler() → deepcopy(scheduler) → scheduler_local_2Request 3 → _make_local_scheduler() → deepcopy(scheduler) → scheduler_local_3 ↓ ↓ Inference with scheduler_local_1 All isolated, no race conditionsThe _autodetect_mutables(...) method automatically detects mutable components such as lists, dicts, sets, mutable tuples, or small tensors and marks them for cloning when creating a per-request instance.
The _is_readonly_property(...) method detects if attr_name in base_obj is a descriptor-type property without a setter, allowing us to avoid read-only properties and prevent an AttributeError when trying to clone them, thus not breaking the process.
The _clone_mutable_attrs(...) method clones relevant mutable attributes from the base object to local so that local has its own isolated state per request.
The _is_tokenizer_component(...) detects if the component is a legitimate tokenizer using heuristics by detecting the 3 typical tokenizer methods such as encode, decode, and tokenize.
Given everything stated above, safe concurrent inference can now be performed without breaking any component of the model pipeline, using the method generate(...).
def generate(self, *args, num_inference_steps: int = 50, device: Optional[str] = None, **kwargs): local_scheduler = self._make_local_scheduler(num_inference_steps=num_inference_steps, device=device) ...local_pipe.scheduler, etc.) without cloning all the content.try: local_pipe = copy.copy(self._base) # Try to do a shallow copy; # copies the attribute structure, but references to complex objects remain shared. except Exception as e: logger.warning(f"copy.copy(self._base) failed: {e}. Falling back to deepcopy (may increase memory).") local_pipe = copy.deepcopy(self._base) # If it fails, use deepcopy (heavier on memory and time)scheduler if local_scheduler existsif local_scheduler is not None: try: timesteps, num_steps, configured_scheduler = async_retrieve_timesteps( local_scheduler.scheduler, num_inference_steps=num_inference_steps, device=device, return_scheduler=True, **{k: v for k, v in kwargs.items() if k in ['timesteps', 'sigmas']} )
final_scheduler = BaseAsyncScheduler(configured_scheduler) setattr(local_pipe, "scheduler", final_scheduler) except Exception: logger.warning("Could not set scheduler on local pipe; proceeding without replacing scheduler.")Attributes listed in _mutable_attrs and those detected by _autodetect_mutables are cloned; the tokenizer component is obtained and calls to the tokenizer are serialized using self._tokenizer_lock to avoid problems with non-thread-safe tokenizers.
self._clone_mutable_attrs(self._base, local_pipe)
# Wrap tokenizers in the local pipeline with the locking wrapper tokenizer_wrappers = {} # name -> original_tokenizer try: # a) Wrap direct tokenizer attributes (tokenizer, tokenizer_2, ...) for name in dir(local_pipe): if "tokenizer" in name and not name.startswith("_"): tok = getattr(local_pipe, name, None) if tok is not None and self._is_tokenizer_component(tok): tokenizer_wrappers[name] = tok setattr( local_pipe, name, lambda *args, tok=tok, **kwargs: safe_tokenize(tok, *args, lock=self._tokenizer_lock, **kwargs) )
# b) Wrap tokenizers in the components dictionary if hasattr(local_pipe, "components") and isinstance(local_pipe.components, dict): for key, val in local_pipe.components.items(): if val is None: continue if self._is_tokenizer_component(val): tokenizer_wrappers[f"components[{key}]"] = val local_pipe.components[key] = lambda *args, tokenizer=val, **kwargs: safe_tokenize( tokenizer, *args, lock=self._tokenizer_lock, **kwargs )
except Exception as e: logger.debug(f"Tokenizer wrapping step encountered an error: {e}")model_cpu_offload_context can be:
- A function that returns a context manager (call cm()), or an already instantiated context manager (use with cm).
Code handles both cases (catches TypeError if cm() fails and tries with cm), if everything fails, falls back to executing local_pipe without offload context, the finally guarantees restoring tokenizers even if inference raises an exception.
result = None cm = getattr(local_pipe, "model_cpu_offload_context", None) try: if callable(cm): try: with cm(): result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs) except TypeError: # cm could be a context manager instance instead of a callable try: with cm: result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs) except Exception as e: logger.debug(f"model_cpu_offload_context usage failed: {e}. Proceeding without it.") result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs) else: # No offload context available: call directly result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
return result
finally: try: for name, tok in tokenizer_wrappers.items(): if name.startswith("components["): key = name[len("components["):-1] local_pipe.components[key] = tok else: setattr(local_pipe, name, tok) except Exception as e: logger.debug(f"Error restoring wrapped tokenizers: {e}")# All imports, lifespan, etc.
app = FastAPI(lifespan=lifespan)
logger = logging.getLogger("DiffusersServer.Pipelines")
# Wrapper to initialize and launch the model pipeline initializer = ModelPipelineInitializer( model=server_config.model, type_models=server_config.type_models,)model_pipeline = initializer.initialize_pipeline()model_pipeline.start()
# Pass the base pipeline to RequestScopedPipelinerequest_pipe = RequestScopedPipeline(model_pipeline.pipeline)pipeline_lock = threading.Lock()
logger.info(f"Pipeline initialized and ready to receive requests (model ={server_config.model})")
# Save component states to make them easily accessible from FastAPIapp.state.MODEL_INITIALIZER = initializerapp.state.MODEL_PIPELINE = model_pipelineapp.state.REQUEST_PIPE = request_pipeapp.state.PIPELINE_LOCK = pipeline_lock
# Basic Pydantic model to request everything needed for inferenceclass JSONBodyQueryAPI(BaseModel): model : str | None = None prompt : str negative_prompt : str | None = None num_inference_steps : int = 28 num_images_per_prompt : int = 1
@app.post("/api/diffusers/inference")async def api(json: JSONBodyQueryAPI): prompt = json.prompt negative_prompt = json.negative_prompt or "" num_steps = json.num_inference_steps num_images_per_prompt = json.num_images_per_prompt
wrapper = app.state.MODEL_PIPELINE initializer = app.state.MODEL_INITIALIZER
utils_app = app.state.utils_app
if not wrapper or not wrapper.pipeline: raise HTTPException(500, "Model not initialized correctly") if not prompt.strip(): raise HTTPException(400, "No prompt provided")
def make_generator(): g = torch.Generator(device=initializer.device) return g.manual_seed(random.randint(0, 10_000_000))
req_pipe = app.state.REQUEST_PIPE
# Function to perform safe per-request inference using app.state.REQUEST_PIPE def infer(): gen = make_generator() return req_pipe.generate( prompt=prompt, negative_prompt=negative_prompt, generator=gen, num_inference_steps=num_steps, num_images_per_prompt=num_images_per_prompt, device=initializer.device, output_type="pil", )
try: async with app.state.metrics_lock: app.state.active_inferences += 1
# Execute inference in a non-blocking manner output = await run_in_threadpool(infer)
async with app.state.metrics_lock: app.state.active_inferences = max(0, app.state.active_inferences - 1)
# When inference finishes, generate a list of # generated images (type ["/images/image_123.png", "/images/image_1234.png"]) # so they can be downloaded urls = [utils_app.save_image(img) for img in output.images] return {"response": urls}
except Exception as e: async with app.state.metrics_lock: app.state.active_inferences = max(0, app.state.active_inferences - 1) logger.error(f"Error during inference: {e}") raise HTTPException(500, f"Error in processing: {e}")
finally: if torch.cuda.is_available(): torch.cuda.synchronize() torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() torch.cuda.ipc_collect() gc.collect()
# This serves to allow downloading of generated images@app.get("/images/{filename}")async def serve_image(filename: str): utils_app = app.state.utils_app file_path = os.path.join(utils_app.image_dir, filename) if not os.path.isfile(file_path): raise HTTPException(status_code=404, detail="Image not found") return FileResponse(file_path, media_type="image/png")Questions or suggestions? Contact me at fredyriveraacevedo13@gmail.com