PyTorch Frontend
The PyTorch frontend enables the Proxy Base Agent (PBA) to utilize Large Language Models loaded via the popular PyTorch framework and Hugging Face transformers
.
Overview
This frontend uses the agent.llm.frontend.torch.TorchInference
class, which integrates with standard PyTorch models (like LlamaForCausalLM
or other transformers
models) and the PSE StructuringEngine
.
Key Features:
- Broad Compatibility: Works with a wide range of Hugging Face
transformers
models compatible with PyTorch. - Hardware Flexibility: Runs on CPUs or GPUs (NVIDIA, AMD) supported by PyTorch.
- PSE Integration: Uses the
PSETorchMixin
from thepse
library to easily integrate theStructuringEngine
'sprocess_logits
andsample
methods into the standardtransformers
generate()
workflow.
Usage
-
Installation: Ensure you have PyTorch and the necessary
transformers
dependencies installed. This is typically handled by installing PBA with the[torch]
extra:You may need to install a specific version of PyTorch separately depending on your hardware (CPU/CUDA/ROCm). See the PyTorch installation guide.pip install proxy-base-agent[torch] # or uv pip install proxy-base-agent[torch]
-
Model Selection: During the PBA setup wizard (
python -m agent
), choose a model compatible with PyTorch (most standard Hugging Face models). Select "PyTorch" as the inference backend when prompted. -
Configuration: The
LocalInference
class will automatically instantiateTorchInference
when a PyTorch-compatible model path and the PyTorch frontend are selected. Relevantinference_kwargs
(liketemp
,seed
,max_tokens
,top_k
,top_p
) passed to theAgent
constructor will be used by themodel.generate()
method.
How it Works
- Loading:
TorchInference
loads the model usingtransformers.AutoModelForCausalLM.from_pretrained
(specifically via thePSE_Torch
class which incorporates thePSETorchMixin
) and the tokenizer usingagent.llm.tokenizer.Tokenizer.load
. - Mixin Integration: The
PSETorchMixin
modifies the model's_sample
method (used bygenerate
whendo_sample=True
) to:- Include
engine.process_logits
in thelogits_processor
list. - Use
engine.sample
(wrapping a basic multinomial sampler or argmax) for token selection.
- Include
- Inference Loop: The
inference()
method sets up aTextIteratorStreamer
and runsmodel.generate()
in a separate thread, yielding tokens as they become available from the streamer. - Caching: Currently, the PyTorch frontend in PBA does not implement persistent KV cache saving/loading to disk like the MLX frontend (
supports_reusing_prompt_cache()
returnsFalse
). Standardtransformers
in-memory KV caching during generation is used if enabled (use_cache=True
).
The PyTorch frontend offers broad model compatibility for running PBA on various hardware configurations.