# agent_prm
**Repository Path**: whs075/agent_prm
## Basic Information
- **Project Name**: agent_prm
- **Description**: No description available
- **Primary Language**: Unknown
- **License**: MIT
- **Default Branch**: main
- **Homepage**: None
- **GVP Project**: No
## Statistics
- **Stars**: 0
- **Forks**: 0
- **Created**: 2025-06-28
- **Last Updated**: 2025-06-28
## Categories & Tags
**Categories**: Uncategorized
**Tags**: None
## README
#
Process Reward Models for LLM Agents: Practical Framework and Directions
Paper link: [https://arxiv.org/pdf/2502.10325](https://arxiv.org/pdf/2502.10325)
## Installation
### Create Conda environment
To set up the project, clone the repository and create a Conda environment:
```bash
cd agent_prm
conda env create -f environment.yml
conda activate agent_prm
pip install -e .
```
### Optional: Set up OpenAI / Gemini / Anthropic environment keys
Ensure you have a `.env` file with the requisite keys:
```bash
OPENAI_API_KEY=your_openai_api_key
OPENAI_ORGANIZATION=your_openai_organization_id
GEMINI_API_KEY=your_gemini_key
ANTHROPIC_API_KEY=your_anthropic_key
```
### Set up external dependencies
We build on [OpenInstruct](https://github.com/allenai/open-instruct) for training, with some minor compatibility fixes so it needs to be installed locally.
```bash
# Clone and install Open-Instruct
git clone --branch fix_vllm https://github.com/sanjibanc/open-instruct.git
cd open-instruct
pip install -e .
cd ..
```
We use [SGLang](https://github.com/sgl-project/sglang) for fast inference, with some minor compatibility fixes with LLama so it needs to be installed locally.
```bash
# Clone and install SGLang
git clone --branch new_llama_model https://github.com/sanjibanc/sglang.git
cd sglang
pip install -e .
cd ..
```
To use slgang server, [got to SGlang instructions](#sglang-instructions)
To set up external environments like AlfWorld, [go to external environment instructions](#external-environment-instructions).
## Agent PRM Training
Agent PRM iterates over 3 stages:
1. Rollout policy and compute PRM targets
2. Train PRM
3. Train policy via RL
Stage 2 and 3 are similar to standard RLHF operations, with stage 1 being the agent specific step.
### Initialize policy with SFT
We collect SFT training data from our prior work [LEAP](https://github.com/sanjibanc/leap_llm) and train a policy via SFT
```bash
bash bash/train-sft-llama3.2-3B.sh
```
For simplicity we provide the model here [rl-llm-agent/Llama-3.2-3B-Instruct-sft-alfworld-iter0](rl-llm-agent/Llama-3.2-3B-Instruct-sft-alfworld-iter0)
### Stage 1: Rollout and Compute Target
We rollout in a batched fashion, and recommend using the SGLangServerAgent for fast inference. See [sglang instructions](#sglang-instructions) to setup the SGLang server, then run the following script to collect rollouts
```bash
python scripts/dataproc/rollout_alfworld.py --config configs/rollout_alfworld.yaml
```
Once you have the rollouts, set the rollout path in `configs/compute_prm_target` and run
```bash
python scripts/dataproc/compute_prm_target.py --config configs/compute_prm_target.yaml
```
This should create a train and test file to train the PRM
### Stage 2: Training the PRM
To train the PRM, run the script that calls open instruct
```bash
bash bash/train-rm-llama3.2-3B.sh
```
Upload the best checkpoint to HF for convenience
```bash
python scripts/utils/upload_model_to_hf.py --input_model --output_model --accelerate
```
### Stage 3: Training the Policy via RL
To train the policy via OnlineDPO to optimize the PRM, run the following script
```bash
bash bash/online-dpo-llama3.2-3B.sh
```
Upload the best checkpoint to HF for convenience
Repeat stages 1 to 3.
## Agent PRM Inference
Configure the agents you want to evaluate in `configs/eval_alfworld.yaml` and run the following script:
```bash
python scripts/eval/eval_alfworld.py --config configs/evaluate_alfworld.yaml
```
It will create a folder in `data/eval/alfworld/` with the current datetime where logs and summary.csv will be saved.
For fast inference, use a SGLang server agent and host the policy in a SGLang server.
To evaluate a Best-of-N policy, host both the policy and the PRM in SGLang, and run the script with best_of_n agent.
## Ablations and Extensions
### Inverse PRM
Stage 1: Given expert demonstrations and policy rollouts, compute inverse PRM target
```bash
python scripts/dataproc/compute_inverse_prm_target.py --config configs/compute_inverse_prm_target.yaml
```
Stage 2: Train PRM
```bash
bash bash/train-inverse-prm-llama3.2-3B.sh
```
Stage 3: Train generator as in agent prm
### Relative Loss
To train the PRM using a relative loss, change the target computation to be a preference dataset
```bash
python scripts/dataproc/compute_prm_preference_target.py --config configs/compute_prm_preference_target.yaml
```
To train the PRM using preference data, use the script
```bash
bash bash/train-rm-pref-llama3.2-3B.sh
```
### Steered Exploration
To train the policy using a steered exploration prompt `prompts/alfworld/alfworld_exploration_template.j2`, run the following script
```bash
python scripts/dataproc/compute_value_target.py --config configs/
bash bash/train-value-model-llama3.2-3B.sh
```
### Process Reward Shaping
Given a reference policy, collect rollouts, compute value targets and train a value estimate
```bash
bash bash/online-dpo-exploration-llama3.2-3B.sh
```
Use the value function to compute shaped PRM targets. This requires running the value function as a critic in a SGLang server
```bash
python scripts/dataproc/compute_shaped_prm_target.py --config configs/compute_shaped_prm_target.yaml
```
Train the shaped PRM
```bash
bash bash/train-shaped-rm-llama3.2-3B.sh
```
Train the policy via online DPO
```bash
bash bash/online-dpo-shaped-prm-llama3.2-3B.sh
```
## SGLang instructions
To use SGLang for inference, grab a node from the same network as your inference scripts so they can communicate over the network.
SGLang has some compatibility issues with agent_prm conda environment, so we recommend using the sglang environment
```bash
conda env create -f sglang_environment.yml
conda activate sglang
```
To host a model, run
```bash
python -m sglang.launch_server --model-path --port
```
When doing inference for Best-of-N with a PRM, you might want to grab two such nodes, one for the generator, and one for the verifier and assign them two different ports 3000 and 30010.
## External environment instructions
### Setup AlfWorld
Clone AlfWorld from [AlfWorld github repository](https://github.com/alfworld/alfworld). Follow the instructions in its README to get the game files.
Create an env_assets folder and copy over data to `env_assets/alfworld`. Set the following environment variable:
```bash
export ALFWORLD_DATA=
```
## Contact
This project is is actively being developed. For any questions or issues, please contact us at sanjibanc@cornell.edu.