RAFT updates, BERTopic config, cleanup

This commit is contained in:
2026-02-21 01:57:14 +01:00
parent 8cadcb1f69
commit 1a99b53d44
12 changed files with 10750 additions and 9778 deletions

View File

@@ -1,5 +1,4 @@
import argparse
import json
import os
import torch
@@ -12,7 +11,7 @@ from trl import SFTConfig, SFTTrainer
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--train_jsonl", default="out/raft_train.jsonl")
ap.add_argument("--base_model", default="mistralai/Mistral-7B-Instruct-v0.3")
ap.add_argument("--base_model", default="mistralai/Mistral-7B-Instruct-v0.2")
ap.add_argument("--out_dir", default="out/mistral_balitwin_lora")
ap.add_argument("--max_seq_len", type=int, default=2048)
ap.add_argument("--batch_size", type=int, default=1)
@@ -23,7 +22,7 @@ def main():
os.makedirs(args.out_dir, exist_ok=True)
# QLoRA (4-bit) config (good default for 7B on limited VRAM)
# QLoRA (4-bit) config
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
@@ -34,7 +33,6 @@ def main():
)
tokenizer = AutoTokenizer.from_pretrained(args.base_model, use_fast=True)
# Mistral usually has a valid chat template; keep it intact. :contentReference[oaicite:9]{index=9}
model = AutoModelForCausalLM.from_pretrained(
args.base_model,
@@ -43,7 +41,7 @@ def main():
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16,
)
# LoRA adapter config (tweak r/alpha if needed)
# LoRA adapter config
peft_config = LoraConfig(
r=16,
lora_alpha=32,
@@ -75,7 +73,7 @@ def main():
max_length=args.max_seq_len,
bf16=torch.cuda.is_available(),
fp16=not torch.cuda.is_available(),
assistant_only_loss=True, # only learn from assistant turns in messages :contentReference[oaicite:10]{index=10}
assistant_only_loss=True, # only learn from assistant turns in messages
report_to=[],
)
@@ -91,7 +89,7 @@ def main():
trainer.save_model(args.out_dir)
tokenizer.save_pretrained(args.out_dir)
print(f"Saved LoRA adapter to: {args.out_dir}")
print(f"Fertig! LoRA-Adapter gespeichert: {args.out_dir}")
if __name__ == "__main__":