PyTorch Frontend
The PyTorch frontend enables the Proxy Base Agent (PBA) to utilize Large Language Models loaded via the popular PyTorch framework and Hugging Face transformers.
Overview
This frontend uses the agent.llm.frontend.torch.TorchInference class, which integrates with standard PyTorch models (like LlamaForCausalLM or other transformers models) and the PSE StructuringEngine.
Key Features:
- Broad Compatibility: Works with a wide range of Hugging Face
transformersmodels compatible with PyTorch. - Hardware Flexibility: Runs on CPUs or GPUs (NVIDIA, AMD) supported by PyTorch.
- PSE Integration: Uses the
PSETorchMixinfrom thepselibrary to easily integrate theStructuringEngine'sprocess_logitsandsamplemethods into the standardtransformersgenerate()workflow.
Usage
-
Installation: Ensure you have PyTorch and the necessary
transformersdependencies installed. This is typically handled by installing PBA with the[torch]extra:You may need to install a specific version of PyTorch separately depending on your hardware (CPU/CUDA/ROCm). See the PyTorch installation guide.pip install proxy-base-agent[torch] # or uv pip install proxy-base-agent[torch] -
Model Selection: During the PBA setup wizard (
python -m agent), choose a model compatible with PyTorch (most standard Hugging Face models). Select "PyTorch" as the inference backend when prompted. -
Configuration: The
LocalInferenceclass will automatically instantiateTorchInferencewhen a PyTorch-compatible model path and the PyTorch frontend are selected. Relevantinference_kwargs(liketemp,seed,max_tokens,top_k,top_p) passed to theAgentconstructor will be used by themodel.generate()method.
How it Works
- Loading:
TorchInferenceloads the model usingtransformers.AutoModelForCausalLM.from_pretrained(specifically via thePSE_Torchclass which incorporates thePSETorchMixin) and the tokenizer usingagent.llm.tokenizer.Tokenizer.load. - Mixin Integration: The
PSETorchMixinmodifies the model's_samplemethod (used bygeneratewhendo_sample=True) to:- Include
engine.process_logitsin thelogits_processorlist. - Use
engine.sample(wrapping a basic multinomial sampler or argmax) for token selection.
- Include
- Inference Loop: The
inference()method sets up aTextIteratorStreamerand runsmodel.generate()in a separate thread, yielding tokens as they become available from the streamer. - Caching: Currently, the PyTorch frontend in PBA does not implement persistent KV cache saving/loading to disk like the MLX frontend (
supports_reusing_prompt_cache()returnsFalse). Standardtransformersin-memory KV caching during generation is used if enabled (use_cache=True).
The PyTorch frontend offers broad model compatibility for running PBA on various hardware configurations.