Skip to content

tigerchow89/Learn-from-Committee

 
 

Repository files navigation

Fault-Aware Distillation via Peer-Review (FAIR) approach

Code for the paper: "Learning from Committee: Reasoning Distillation from a Mixture of Teachers with Peer-Review", accepted by ACL 2025. Overview of our FAIR method.

Before You Start

This repository aims to reproduce our paper's work, so the supportive models and datasets are derived from the paper's experimental settings. Since our work is based on public models, datasets, training frameworks on HuggingFace and official APIs, it is easy to scale and adapt. Feel free to clone and customize this repo for your own tasks!

Directory explanation:

  • /local_models: The cache_dir of AutoModelForCausalLM.from_pretrained() for downloaded student models.
  • /local_dataset: The cache_dir of load_dataset() for downloaded datasets.
  • /data: The folder to store processed data, such as wrong answers from student models, rationales and feedback generated by teacher models.
    To support the open-source community, we have provided four JSON files in /data containing the incorrect answers generated by the Llama2-7B-chat model on four benchmark datasets.
  • /checkpoints: The output_dir for trained student models.

Quick Start

1. Requirements

pip install -r requirements.txt

Since Llama-series are gated models, you need to go to Huggingface and apply for access. Then you will have your own user access tokens as hf_token.

2. Inferring Mistakes on Student Model

You can select different student models from Huggingface and datasets from ['gsm8k', 'svamp', 'strategyQA', 'logiQA']. For example:

python infer_student_wrong.py -model meta-llama/Llama-2-7b-chat-hf --dataset svamp --hf_token "Your access token"

After running, the student model will be downloaded under /local_models and datasets will be downloaded under /local_dataset. The student's wrong answers will be stored under /data, like Llama-2-7b-chat-hf_svamp_false_round0.json.

3. Collecting Teacher LLMs with Student's Mistakes

You can set your API key for GPT-3.5-Turbo, Gemini-1.0-Pro, and Mixtral-8x7B-Instructv0.1. For example:

export OPENAI_API_KEY='Your key'  
export GEMINI_API_KEY='Your key'  
export MISTRAL_API_KEY='Your key'

Or, you have to --gpt_api, --mistral_api, and --gemini_api to set the API keys when running the command:

python collect_teacher_res.py --student_wrong "data/Llama-2-7b-chat-hf_svamp_false_round0.json" --gpt_api "Your key" --mistral_api "Your key" --gemini_api "Your key"

After running, the feedback generated by three different teacher LLMs and rationales collected via the peer-review process will be stored under /data. For example: Llama-2-7b-chat-hf_svamp_feedback_round0.json and Llama-2-7b-chat-hf_svamp_rationale_round0.json.

4. Training Student Model

Now, you can start fine-tuning the student model using the previously collected data on each dataset. Specifically, you can select distillation from individual teacher model or mixed teacher models (the peer-reviewed rationale):
train_parser.add_argument('--teacher', default='mixed', choices=['mixed', 'gpt', 'gemini', 'mistral'], help="which model selected as teacher").

python train.py \
  --model_path meta-llama/Llama-2-7b-chat-hf \
  --hf_token "Your access token" \
  --output_path checkpoints/llama2_svamp_round10_mixed \
  --feedback_path data/Llama-2-7b-chat-hf_svamp_feedback_round0.json \
  --rationale_path data/Llama-2-7b-chat-hf_svamp_rationale_round0.json \
  --teacher mixed \
  # Optional: The following hyperparameters can be omitted to use default values
  --learning_rate 1e-5 \
  --max_seq_length 512 \
  --batch_size 8 \
  --epochs 10 \
  --alpha 0.5

4. Testing

Finally, you can test your distilled student model on different datasets:

python test.py --model checkpoints/llama2_svamp_round10_mixed --dataset svamp

For checking, this will generate an accuracy value and a wrong prediction JSON file, like llama2_svamp_round10_mixed_test.json.

Citation

If you find this work helpful, we would appreciate it if you could cite it!

@article{li2024learning,
  title={Learning from Committee: Reasoning Distillation from a Mixture of Teachers with Peer-Review},
  author={Li, Zhuochun and Ji, Yuelyu and Meng, Rui and He, Daqing},
  journal={arXiv preprint arXiv:2410.03663},
  year={2024}
}

Star History Chart

About

The code for paper "Learning from Committee: Reasoning Distillation from a Mixture of Teachers with Peer-Review" accepted by ACL 2025.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages

  • Python 100.0%