pipelogic.infer
Helpers for wiring ML models into a Pipelogic component.
from pipelogic.infer import (
select_device, select_dtype, ort_providers, pipeline_device_index,
is_offline, ensure_local_dir, find_model_file, load_state_dict,
)
These were sitting unused for a long time. Most worker components reinvented these functions — use the helpers instead.
Device selection
select_device(prefer_cuda=True, prefer_mps=False, gpu_id=0)
Returns a torch.device. Falls back through CUDA → MPS → CPU based on availability and flags.
from pipelogic.infer import select_device
device = select_device() # CUDA if available, else CPU
device = select_device(prefer_mps=True) # CUDA → MPS → CPU
device = select_device(gpu_id=1) # second GPU when CUDA available
select_dtype(device, check_capability=False)
Returns the recommended torch.dtype for the device. CUDA → float16; CPU/MPS → float32. With check_capability=True, only Volta-class (compute capability ≥ 7) GPUs return float16:
device = select_device()
dtype = select_dtype(device, check_capability=True)
model = model.to(device).to(dtype)
pipeline_device_index(device)
Converts a torch.device to the integer index HuggingFace pipeline() expects (0 for first CUDA GPU, -1 for CPU).
from transformers import pipeline
from pipelogic.infer import select_device, pipeline_device_index
device = select_device()
pipe = pipeline("image-classification", device=pipeline_device_index(device))
ort_providers(device=None)
Returns an ordered list of ONNX Runtime execution providers. With device=None, auto-detects what's installed. With a torch.device, picks providers matching the device type.
import onnxruntime as ort
from pipelogic.infer import select_device, ort_providers
device = select_device()
session = ort.InferenceSession("model.onnx", providers=ort_providers(device))
Model loading
is_offline()
Returns True when HF_HUB_OFFLINE=1 or TRANSFORMERS_OFFLINE=1. Use it to short-circuit network calls in airgapped deployments.
ensure_local_dir(model_or_repo, token=None, revision=None)
If model_or_repo is a local directory, returns it as-is. Otherwise downloads the snapshot from HuggingFace Hub. Honors is_offline() (uses local_files_only=True when offline).
from pipelogic.infer import ensure_local_dir
local = ensure_local_dir("microsoft/resnet-50")
local = ensure_local_dir("/path/to/already/local/model")
local = ensure_local_dir("org/private-model", token=os.environ["HF_TOKEN"])
local = ensure_local_dir("microsoft/resnet-50", revision="main")
Don't add numpy or huggingface_hub to your requirements.txt — they're already provided by pipelogic.
find_model_file(directory, extensions=(".pt", ".onnx", ".safetensors", ".bin"))
Searches directory recursively for exactly one file with one of the given extensions. Raises FileNotFoundError when none match, ValueError when more than one matches.
from pipelogic.infer import ensure_local_dir, find_model_file
local = ensure_local_dir("yolo-org/yolov8")
weights = find_model_file(local, extensions=(".pt",))
This replaces the boilerplate Path(...).rglob("*.pt") pattern — and includes the "exactly one file" invariant most callers want.
load_state_dict(path)
Load a model state dict from .safetensors, .pt, .pth, .bin, or .ckpt. Automatically:
- uses
safetensors.torch.load_filefor.safetensors, - calls
torch.load(..., weights_only=True)for the rest (with a graceful fallback for older torch), - unwraps nested checkpoints (
{"state_dict": ...},{"model": ...},{"ema_state_dict": ...}), - strips the
module.prefix from DistributedDataParallel checkpoints.
from pipelogic.infer import load_state_dict
state = load_state_dict("/path/to/checkpoint.pt")
model.load_state_dict(state)
Putting it together
The canonical "load a model" recipe for a Python worker:
from pipelogic.worker import config
from pipelogic.infer import (
select_device, select_dtype, ensure_local_dir,
find_model_file, load_state_dict,
)
import torch
device = select_device()
dtype = select_dtype(device)
local = ensure_local_dir(str(config.model_path))
weights = find_model_file(local, extensions=(".pt",))
model = MyArchitecture()
model.load_state_dict(load_state_dict(weights))
model = model.to(device).to(dtype).eval()
What's next
- API: batch —
@accumulateand@sliding_windowdecorators that work well with infer. - Recipes — full worker examples that use these helpers.