mirror of
https://github.com/marvinscham/masterthesis-playground.git
synced 2026-03-23 00:42:43 +01:00
RAFT updates, BERTopic config, cleanup
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
||||
import torch
|
||||
@@ -12,7 +11,7 @@ from trl import SFTConfig, SFTTrainer
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--train_jsonl", default="out/raft_train.jsonl")
|
||||
ap.add_argument("--base_model", default="mistralai/Mistral-7B-Instruct-v0.3")
|
||||
ap.add_argument("--base_model", default="mistralai/Mistral-7B-Instruct-v0.2")
|
||||
ap.add_argument("--out_dir", default="out/mistral_balitwin_lora")
|
||||
ap.add_argument("--max_seq_len", type=int, default=2048)
|
||||
ap.add_argument("--batch_size", type=int, default=1)
|
||||
@@ -23,7 +22,7 @@ def main():
|
||||
|
||||
os.makedirs(args.out_dir, exist_ok=True)
|
||||
|
||||
# QLoRA (4-bit) config (good default for 7B on limited VRAM)
|
||||
# QLoRA (4-bit) config
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
@@ -34,7 +33,6 @@ def main():
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.base_model, use_fast=True)
|
||||
# Mistral usually has a valid chat template; keep it intact. :contentReference[oaicite:9]{index=9}
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.base_model,
|
||||
@@ -43,7 +41,7 @@ def main():
|
||||
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16,
|
||||
)
|
||||
|
||||
# LoRA adapter config (tweak r/alpha if needed)
|
||||
# LoRA adapter config
|
||||
peft_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
@@ -75,7 +73,7 @@ def main():
|
||||
max_length=args.max_seq_len,
|
||||
bf16=torch.cuda.is_available(),
|
||||
fp16=not torch.cuda.is_available(),
|
||||
assistant_only_loss=True, # only learn from assistant turns in messages :contentReference[oaicite:10]{index=10}
|
||||
assistant_only_loss=True, # only learn from assistant turns in messages
|
||||
report_to=[],
|
||||
)
|
||||
|
||||
@@ -91,7 +89,7 @@ def main():
|
||||
trainer.save_model(args.out_dir)
|
||||
tokenizer.save_pretrained(args.out_dir)
|
||||
|
||||
print(f"Saved LoRA adapter to: {args.out_dir}")
|
||||
print(f"Fertig! LoRA-Adapter gespeichert: {args.out_dir}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user