pipelogic.infer

Helpers for wiring ML models into a Pipelogic component.

from pipelogic.infer import (
    select_device, select_dtype, ort_providers, pipeline_device_index,
    is_offline, ensure_local_dir, find_model_file, load_state_dict,
)

These were sitting unused for a long time. Most worker components reinvented these functions — use the helpers instead.

Device selection

select_device(prefer_cuda=True, prefer_mps=False, gpu_id=0)

Returns a torch.device. Falls back through CUDA → MPS → CPU based on availability and flags.

from pipelogic.infer import select_device

device = select_device()              # CUDA if available, else CPU
device = select_device(prefer_mps=True)   # CUDA → MPS → CPU
device = select_device(gpu_id=1)      # second GPU when CUDA available

select_dtype(device, check_capability=False)

Returns the recommended torch.dtype for the device. CUDA → float16; CPU/MPS → float32. With check_capability=True, only Volta-class (compute capability ≥ 7) GPUs return float16:

device = select_device()
dtype = select_dtype(device, check_capability=True)
model = model.to(device).to(dtype)

pipeline_device_index(device)

Converts a torch.device to the integer index HuggingFace pipeline() expects (0 for first CUDA GPU, -1 for CPU).

from transformers import pipeline
from pipelogic.infer import select_device, pipeline_device_index

device = select_device()
pipe = pipeline("image-classification", device=pipeline_device_index(device))

ort_providers(device=None)

Returns an ordered list of ONNX Runtime execution providers. With device=None, auto-detects what's installed. With a torch.device, picks providers matching the device type.

import onnxruntime as ort
from pipelogic.infer import select_device, ort_providers

device = select_device()
session = ort.InferenceSession("model.onnx", providers=ort_providers(device))

Model loading

is_offline()

Returns True when HF_HUB_OFFLINE=1 or TRANSFORMERS_OFFLINE=1. Use it to short-circuit network calls in airgapped deployments.

ensure_local_dir(model_or_repo, token=None, revision=None)

If model_or_repo is a local directory, returns it as-is. Otherwise downloads the snapshot from HuggingFace Hub. Honors is_offline() (uses local_files_only=True when offline).

from pipelogic.infer import ensure_local_dir

local = ensure_local_dir("microsoft/resnet-50")
local = ensure_local_dir("/path/to/already/local/model")
local = ensure_local_dir("org/private-model", token=os.environ["HF_TOKEN"])
local = ensure_local_dir("microsoft/resnet-50", revision="main")

find_model_file(directory, extensions=(".pt", ".onnx", ".safetensors", ".bin"))

Searches directory recursively for exactly one file with one of the given extensions. Raises FileNotFoundError when none match, ValueError when more than one matches.

from pipelogic.infer import ensure_local_dir, find_model_file

local = ensure_local_dir("yolo-org/yolov8")
weights = find_model_file(local, extensions=(".pt",))

This replaces the boilerplate Path(...).rglob("*.pt") pattern — and includes the "exactly one file" invariant most callers want.

load_state_dict(path)

Load a model state dict from .safetensors, .pt, .pth, .bin, or .ckpt. Automatically:

  • uses safetensors.torch.load_file for .safetensors,
  • calls torch.load(..., weights_only=True) for the rest (with a graceful fallback for older torch),
  • unwraps nested checkpoints ({"state_dict": ...}, {"model": ...}, {"ema_state_dict": ...}),
  • strips the module. prefix from DistributedDataParallel checkpoints.
from pipelogic.infer import load_state_dict

state = load_state_dict("/path/to/checkpoint.pt")
model.load_state_dict(state)

Putting it together

The canonical "load a model" recipe for a Python worker:

from pipelogic.worker import config
from pipelogic.infer import (
    select_device, select_dtype, ensure_local_dir,
    find_model_file, load_state_dict,
)
import torch

device = select_device()
dtype = select_dtype(device)

local = ensure_local_dir(str(config.model_path))
weights = find_model_file(local, extensions=(".pt",))

model = MyArchitecture()
model.load_state_dict(load_state_dict(weights))
model = model.to(device).to(dtype).eval()

What's next

  • API: batch@accumulate and @sliding_window decorators that work well with infer.
  • Recipes — full worker examples that use these helpers.

Was this page helpful?