# PopulationTransformer
**Repository Path**: whs075/PopulationTransformer
## Basic Information
- **Project Name**: PopulationTransformer
- **Description**: No description available
- **Primary Language**: Unknown
- **License**: Not specified
- **Default Branch**: main
- **Homepage**: None
- **GVP Project**: No
## Statistics
- **Stars**: 0
- **Forks**: 0
- **Created**: 2025-07-28
- **Last Updated**: 2025-07-28
## Categories & Tags
**Categories**: Uncategorized
**Tags**: None
## README
# PopulationTransformer
This is the code for the [PopulationTransformer](https://arxiv.org/abs/2406.03044).
[](https://arxiv.org/abs/2406.03044)
[](https://glchau.github.io/population-transformer/)
[](https://huggingface.co/PopulationTransformer/popt_brainbert_stft)
[](https://braintreebank.dev/)
Quick Start
| Prerequisites
| Data
| Pretraining
| Fine-tuning
| Citation
## Quick Start
With the [requirements](#prerequisites) and [data](#data) installed, run these scripts for their desired outcomes:
| Script | Outcome |
| -------- | ------- |
| `0_write_pretraining_data.sh` | Write pretraining data |
| `1_create_pretraining_manifest.sh` | Create the pretraining manifest file |
| `2_run_pretraining.sh` | Run pretraining |
| `3_write_finetuning_data.sh` | Write fine-tuning data |
| `4_create_finetuning_manifest.sh` | Create fine-tuning manifest file |
| `5_run_finetuning.sh` | Run fine-tuning |
More details about each script in [pretraining](#pretraining) and [fine-tuning](#fine-tuning) sections below.
## Prerequisites
Requirements:
- pytorch >= 1.13.1
```bash
pip install -r requirements.txt
```
## Data
If using the data from the Brain Treebank, the data can be downloaded from [braintreebank.dev](https://braintreebank.dev).
The following commands expects the Brain Treebank data to have the following structure:
```
/braintreebank_data
|_electrode_labels
|_subject_metadata
|_localization
|_all_subject_data
|_sub_*_trial*.h5
```
## Pretraining
Details
The below commands assume that you will be using the PopulationTransformer in conjunction with BrainBERT.
If you are, you will need to download the [BrainBERT](https://github.com/czlwang/BrainBERT) weights.
First, we write the BrainBERT features for pre-training. This command takes a list of brain recordings (see below) and creates a training dataset from their BrainBERT representations:
```bash
REPO_DIR="/path/to/PopulationTransformer"
BRAINTREEBANK_DIR="/path/to/braintreebank_data"
python3 -m data.write_nsp_pretraining_data \
+preprocessor=multi_elec_spec_pretrained \
++preprocessor.upstream_ckpt=${REPO_DIR}/pretrained_weights/stft_large_pretrained.pth \
+data_prep=pretrain_multi_subj_multi_chan_template \
++data_prep.task_name=nsp_pretraining \
++data_prep.brain_runs=${REPO_DIR}/trial_selections/pretrain_split_trials.json \
++data_prep.electrodes=${REPO_DIR}/electrode_selections/clean_laplacian.json \
++data_prep.output_directory=${REPO_DIR}/saved_examples/cr_pretrain_examples \
+data=pretraining_subject_data_template \
++data.cached_transcript_aligns=${REPO_DIR}/semantics/saved_aligns \
++data.cached_data_array=${REPO_DIR}/cached_data_arrays/ \
++data.raw_brain_data_dir=${BRAINTREEBANK_DIR}
```
Salient arguments:
- Input:
- `preprocessor.upstream_ckpt` is the path to the BrainBERT weights
- `preprocessor.brain_runs` is the path to a json file of the following format: `{: [trial_name]}`. This specifies the brain recording files that will be used.
- `data_prep.electrodes` is the path to a json file of the following format `{: [electrode_name]}`. Similar to the above, this specifies which channels will be used.
- `data.raw_brain_data_dir` is the path to the root of the Brain Treebank data (see the Data section above)
- Output:
- `data_prep.output_directory` is the path where the output will be written
- `data.cached_data_array` is the path to an (optional) cache where intermediate outputs can be written for faster processing
Next, we need to create a manifest for all the training examples we've just created.
```bash
REPO_DIR="/path/to/PopulationTransformer"
python3 -m data.make_pretrain_replace_manifest +data_prep=combine_nsp_datasets \
++data_prep.source_dir=${REPO_DIR}/saved_examples/cr_pretrain_examples \
++data_prep.output_dir=${REPO_DIR}/saved_examples/nsp_replace_task-0_5s \
++data_prep.task="nsp_negative_any"
```
Salient arguments:
- Input:
- `data_prep.source_dir` should match `data_prep.output_dir` from above.
- Output:
- `data_prep.output_dir` is the path where the output will be written.
Now, we can run the pretraining
```bash
REPO_DIR="/path/to/PopulationTransformer"
python3 run_train.py \
+exp=multi_elec_pretrain \
++exp.runner.device=cuda \
+data=nsp_replace_only_pretrain \
++data.data_path=${REPO_DIR}/saved_examples/nsp_replace_task-0_5s \
++data.saved_data_split=${REPO_DIR}/saved_data_splits/pretrain_split \
++data.test_data_cfg.name=nsp_replace_only_deterministic \
++data.test_data_cfg.data_path=${REPO_DIR}/saved_examples/nsp_replace_task-0_5s \
+model=pt_custom_model \
+task=nsp_replace_only_pretrain \
+criterion=nsp_replace_only_pretrain \
+preprocessor=empty_preprocessor
```
Salient arguments:
- Input:
- `data.data_path` should match the `data_prep.output_dir` from the manifest creation step above.
- Output:
- The final weights will be saved in an automatically created directory under `outputs`.
## Fine-tuning
Details
Now, let's write the BrainBERT features for a finetuning task. For this example, let's decode volume (rms) from one electrode over the course of one trial.
```bash
REPO_DIR="/path/to/PopulationTransformer"
BRAINTREEBANK_DIR="/path/to/braintreebank_data"
python3 -m data.write_multi_subject_multi_channel \
+data_prep=pretrain_multi_subj_multi_chan_template \
++data_prep.task_name=rms \
++data_prep.brain_runs=${REPO_DIR}/trial_selections/test_trials.json \
++data_prep.electrodes=${REPO_DIR}/electrode_selections/clean_laplacian.json \
++data_prep.output_directory=${REPO_DIR}/saved_examples/all_test_rms \
+preprocessor=multi_elec_spec_pretrained \
++preprocessor.upstream_ckpt=${REPO_DIR}/pretrained_weights/stft_large_pretrained.pth \
+data=subject_data_template \
++data.cached_transcript_aligns=${REPO_DIR}/semantics/saved_aligns \
++data.cached_data_array=${REPO_DIR}/cached_data_arrays/ \
++data.raw_brain_data_dir=${BRAINTREEBANK_DIR}/ \
++data.movie_transcripts_dir=${BRAINTREEBANK_DIR}/transcripts
```
- Inputs:
- `data_prep.electrodes` and `data_prep.brain_runs` as in Pretraining, these files specify the trials and channels that will be used to create the dataset.
- Outputs:
- `data_prep.output_directory` is the path to where the BrainBERT embeddings will be written.
Let's write the manifest for this decoding task.
```bash
REPO_DIR="/path/to/PopulationTransformer"
SUBJECT=sub_1; TASK=rms; python3 -m data.make_subject_specific_manifest \
+data_prep=subject_specific_manifest \
++data_prep.data_path=${REPO_DIR}/saved_examples/all_test_${TASK} \
++data_prep.subj=${SUBJECT} \
++data_prep.out_path=${REPO_DIR}/saved_examples/${SUBJECT}_${TASK}_cr
```
- Inputs:
- `data_prep.data_path` should match the `output_directory` given above
Now, we are ready to run the finetuning. You an either fine-tune a model that you have pre-trained yourself, or use a model from [our huggingface repo](https://huggingface.co/PopulationTransformer).
```bash
REPO_DIR="/path/to/PopulationTransformer"
SUBJECT=sub_1; TASK=rms; N=1; NAME=popt_brainbert_stft; WEIGHTS=pretrained_popt_brainbert_stft;
python3 run_train.py \
+exp=multi_elec_feature_extract \
++exp.runner.results_dir=${REPO_DIR}/outputs/${SUBJECT}_${TASK}_top${N}_${NAME} \
++exp.runner.save_checkpoints=False \
++model.frozen_upstream=False \
+task=pt_feature_extract_coords \
+criterion=pt_feature_extract_coords_criterion \
+preprocessor=empty_preprocessor \
+data=pt_supervised_task_coords \
++data.data_path=${REPO_DIR}/saved_examples/${SUBJECT}_${TASK}_cr \
++data.saved_data_split=${REPO_DIR}/saved_data_splits/${SUBJECT}_${TASK}_fine_tuning \
++data.sub_sample_electrodes=${REPO_DIR}/electrode_selections/debug_electrodes.json \
+model=pt_downstream_model \
++model.upstream_path=${REPO_DIR}/pretrained_weights/${WEIGHTS}.pth
```
- Inputs:
- `data.data_path` should match the `out_path` of the manifest creation step above.
- `model.upstream_path` should be a path to the weights from pretraining --- either from the steps above, or from [the huggingface repo](https://huggingface.co/PopulationTransformer).
- `data.saved_data_split` is a path to where the indices for train/val/test splits will be written. You can use this to ensure that splits are consistent between runs.
- Outputs:
- `exp.runner.results_dir` will contain performance metrics (f1, ROC-AUC) on the test set.
## Citation
```bibtex
@inproceedings{
chau_wang_2025_population,
title={Population Transformer: Learning Population-level Representations of Neural Activity},
url={https://openreview.net/forum?id=FVuqJt3c4L},
booktitle={The Thirteenth International Conference on Learning Representations},
author={Chau, Geeling and Wang, Christopher and Talukder, Sabera J and Subramaniam, Vighnesh and Soedarmadji, Saraswati and Yue, Yisong and Katz, Boris and Barbu, Andrei},
year={2025}
}
```