mirror of
https://github.com/marvinscham/masterthesis-playground.git
synced 2026-03-22 00:12:42 +01:00
96 lines
2.7 KiB
Python
96 lines
2.7 KiB
Python
import argparse
|
|
import os
|
|
|
|
import torch
|
|
from datasets import load_dataset
|
|
from peft import LoraConfig
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
|
from trl import SFTConfig, SFTTrainer
|
|
|
|
|
|
def main():
|
|
ap = argparse.ArgumentParser()
|
|
ap.add_argument("--train_jsonl", default="out/raft_train.jsonl")
|
|
ap.add_argument("--base_model", default="mistralai/Mistral-7B-Instruct-v0.2")
|
|
ap.add_argument("--out_dir", default="out/mistral_balitwin_lora")
|
|
ap.add_argument("--max_seq_len", type=int, default=2048)
|
|
ap.add_argument("--batch_size", type=int, default=1)
|
|
ap.add_argument("--grad_accum", type=int, default=16)
|
|
ap.add_argument("--lr", type=float, default=2e-4)
|
|
ap.add_argument("--epochs", type=int, default=1)
|
|
args = ap.parse_args()
|
|
|
|
os.makedirs(args.out_dir, exist_ok=True)
|
|
|
|
# QLoRA (4-bit) config
|
|
bnb_config = BitsAndBytesConfig(
|
|
load_in_4bit=True,
|
|
bnb_4bit_quant_type="nf4",
|
|
bnb_4bit_compute_dtype=(
|
|
torch.bfloat16 if torch.cuda.is_available() else torch.float16
|
|
),
|
|
bnb_4bit_use_double_quant=True,
|
|
)
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(args.base_model, use_fast=True)
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
args.base_model,
|
|
device_map="auto",
|
|
quantization_config=bnb_config,
|
|
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16,
|
|
)
|
|
|
|
# LoRA adapter config
|
|
peft_config = LoraConfig(
|
|
r=16,
|
|
lora_alpha=32,
|
|
lora_dropout=0.05,
|
|
bias="none",
|
|
task_type="CAUSAL_LM",
|
|
target_modules=[
|
|
"q_proj",
|
|
"k_proj",
|
|
"v_proj",
|
|
"o_proj",
|
|
"gate_proj",
|
|
"up_proj",
|
|
"down_proj",
|
|
],
|
|
)
|
|
|
|
dataset = load_dataset("json", data_files=args.train_jsonl, split="train")
|
|
|
|
training_args = SFTConfig(
|
|
output_dir=args.out_dir,
|
|
num_train_epochs=args.epochs,
|
|
per_device_train_batch_size=args.batch_size,
|
|
gradient_accumulation_steps=args.grad_accum,
|
|
learning_rate=args.lr,
|
|
logging_steps=10,
|
|
save_steps=200,
|
|
save_total_limit=2,
|
|
max_length=args.max_seq_len,
|
|
bf16=torch.cuda.is_available(),
|
|
fp16=not torch.cuda.is_available(),
|
|
report_to=[],
|
|
)
|
|
|
|
trainer = SFTTrainer(
|
|
model=model,
|
|
args=training_args,
|
|
train_dataset=dataset,
|
|
processing_class=tokenizer,
|
|
peft_config=peft_config,
|
|
)
|
|
|
|
trainer.train()
|
|
trainer.save_model(args.out_dir)
|
|
tokenizer.save_pretrained(args.out_dir)
|
|
|
|
print(f"Fertig! LoRA-Adapter gespeichert: {args.out_dir}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|