mirror of
https://github.com/marvinscham/masterthesis-playground.git
synced 2026-03-22 00:12:42 +01:00
New RAFT approach
This commit is contained in:
162
raft/make_raft_data.py
Normal file
162
raft/make_raft_data.py
Normal file
@@ -0,0 +1,162 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
|
||||
import faiss
|
||||
import numpy as np
|
||||
import torch
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
## Usage: python make_raft_data.py --out_dir out --n_examples 5000
|
||||
|
||||
|
||||
SYSTEM_PERSONA = """You are 'BaliTwin', a culturally versed Bali traveler.
|
||||
You give your opinions nand guidance with local etiquette and context.
|
||||
You avoid stereotypes. You explain local etiquette, customs, and context.
|
||||
When given CONTEXT, you must ground your answer in it and include 1-2 short direct quotes from CONTEXT as evidence.
|
||||
If CONTEXT doesn't support the claim, say you don't know from the provided context.
|
||||
"""
|
||||
|
||||
|
||||
def load_docstore(path):
|
||||
docs = []
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
docs.append(json.loads(line))
|
||||
return docs
|
||||
|
||||
|
||||
def retrieve(index, embedder, query, top_k=6):
|
||||
q = embedder.encode([query], normalize_embeddings=True).astype(np.float32)
|
||||
scores, ids = index.search(q, top_k)
|
||||
return ids[0].tolist(), scores[0].tolist()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_text(model, tok, messages, max_new_tokens=220, temperature=0.7):
|
||||
# Using tokenizer chat template where available
|
||||
input_ids = tok.apply_chat_template(messages, return_tensors="pt").to(model.device)
|
||||
out = model.generate(
|
||||
input_ids=input_ids,
|
||||
max_new_tokens=max_new_tokens,
|
||||
do_sample=True,
|
||||
temperature=temperature,
|
||||
top_p=0.9,
|
||||
eos_token_id=tok.eos_token_id,
|
||||
)
|
||||
return tok.decode(out[0][input_ids.shape[1] :], skip_special_tokens=True).strip()
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--out_dir", default="out")
|
||||
ap.add_argument(
|
||||
"--embedding_model", default="sentence-transformers/all-MiniLM-L6-v2"
|
||||
)
|
||||
ap.add_argument("--teacher_model", default="mistralai/Mistral-7B-Instruct-v0.2")
|
||||
ap.add_argument("--n_examples", type=int, default=5000)
|
||||
ap.add_argument("--top_k", type=int, default=6)
|
||||
ap.add_argument("--n_distractors", type=int, default=3)
|
||||
ap.add_argument("--seed", type=int, default=7)
|
||||
args = ap.parse_args()
|
||||
|
||||
random.seed(args.seed)
|
||||
|
||||
faiss_path = os.path.join(args.out_dir, "faiss.index")
|
||||
docstore_path = os.path.join(args.out_dir, "docstore.jsonl")
|
||||
|
||||
index = faiss.read_index(faiss_path)
|
||||
docstore = load_docstore(docstore_path)
|
||||
|
||||
embedder = SentenceTransformer(args.embedding_model)
|
||||
|
||||
# Teacher model to synthesize questions & answers from review chunks
|
||||
tok = AutoTokenizer.from_pretrained(args.teacher_model, use_fast=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.teacher_model, torch_dtype=torch.float16, device_map="auto"
|
||||
)
|
||||
model.eval()
|
||||
|
||||
out_path = os.path.join(args.out_dir, "raft_train.jsonl")
|
||||
with open(out_path, "w", encoding="utf-8") as f:
|
||||
for _ in tqdm(range(args.n_examples), desc="Generating RAFT examples"):
|
||||
# pick a "gold" chunk
|
||||
gold = random.choice(docstore)
|
||||
gold_text = gold["text"]
|
||||
|
||||
# 1) generate a question answerable from gold_text
|
||||
q_prompt = [
|
||||
{"role": "system", "content": SYSTEM_PERSONA},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Create ONE realistic traveler question about Bali that can be answered using ONLY the CONTEXT.\n\n"
|
||||
f"CONTEXT:\n{gold_text}\n\n"
|
||||
"Return only the question.",
|
||||
},
|
||||
]
|
||||
question = generate_text(
|
||||
model, tok, q_prompt, max_new_tokens=60, temperature=0.8
|
||||
)
|
||||
question = question.split("\n")[0].strip()
|
||||
|
||||
# 2) retrieve top-k for that question
|
||||
ids, _ = retrieve(index, embedder, question, top_k=args.top_k)
|
||||
retrieved = [docstore[i] for i in ids]
|
||||
|
||||
# 3) add distractors (random docs not in retrieved)
|
||||
retrieved_ids = set(ids)
|
||||
distractors = []
|
||||
attempts = 0
|
||||
while len(distractors) < args.n_distractors and attempts < 50:
|
||||
cand_idx = random.randrange(len(docstore))
|
||||
attempts += 1
|
||||
if cand_idx in retrieved_ids:
|
||||
continue
|
||||
distractors.append(docstore[cand_idx])
|
||||
|
||||
# Mix: retrieved + distractors
|
||||
context_docs = retrieved + distractors
|
||||
random.shuffle(context_docs)
|
||||
|
||||
# 4) generate grounded answer WITH short quotes
|
||||
context_blob = ""
|
||||
for j, d in enumerate(context_docs):
|
||||
context_blob += f"[DOC {j}] {d['text']}\n\n"
|
||||
|
||||
a_prompt = [
|
||||
{"role": "system", "content": SYSTEM_PERSONA},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Answer the question using ONLY the CONTEXT.\n"
|
||||
"Rules:\n"
|
||||
"- Include 1–2 short direct quotes from CONTEXT as evidence.\n"
|
||||
"- If the answer isn't supported, say you can't tell from the context.\n\n"
|
||||
f"QUESTION: {question}\n\nCONTEXT:\n{context_blob}",
|
||||
},
|
||||
]
|
||||
answer = generate_text(
|
||||
model, tok, a_prompt, max_new_tokens=260, temperature=0.6
|
||||
)
|
||||
|
||||
# Final training example (conversational dataset format for TRL)
|
||||
train_ex = {
|
||||
"messages": [
|
||||
{"role": "system", "content": SYSTEM_PERSONA},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"QUESTION: {question}\n\nCONTEXT:\n{context_blob}\n"
|
||||
"Please answer as a culturally versed Bali traveler and include 1-2 short direct quotes from CONTEXT.",
|
||||
},
|
||||
{"role": "assistant", "content": answer},
|
||||
]
|
||||
}
|
||||
f.write(json.dumps(train_ex, ensure_ascii=False) + "\n")
|
||||
|
||||
print(f"Wrote {out_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
122
raft/prepare_corpus.py
Normal file
122
raft/prepare_corpus.py
Normal file
@@ -0,0 +1,122 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
|
||||
import faiss
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from tqdm import tqdm
|
||||
|
||||
## Usage: python prepare_corpus.py --input_tab your_reviews.tab --out_dir out
|
||||
|
||||
|
||||
def simple_clean(text: str) -> str:
|
||||
if not isinstance(text, str):
|
||||
return ""
|
||||
text = text.replace("\u00a0", " ")
|
||||
text = re.sub(r"\s+", " ", text).strip()
|
||||
return text
|
||||
|
||||
|
||||
def chunk_text(text: str, chunk_chars: int = 900, overlap: int = 150):
|
||||
"""
|
||||
Simple char-based chunking (good enough for reviews).
|
||||
For better chunking, split by sentences and cap token length.
|
||||
"""
|
||||
text = simple_clean(text)
|
||||
if len(text) <= chunk_chars:
|
||||
return [text] if text else []
|
||||
chunks = []
|
||||
i = 0
|
||||
while i < len(text):
|
||||
chunk = text[i : i + chunk_chars]
|
||||
if chunk:
|
||||
chunks.append(chunk)
|
||||
i += max(1, chunk_chars - overlap)
|
||||
return chunks
|
||||
|
||||
|
||||
def detect_text_col(df: pd.DataFrame) -> str:
|
||||
# Heuristic: pick the longest average string column
|
||||
best_col, best_score = None, -1
|
||||
for col in df.columns:
|
||||
sample = df[col].dropna().astype(str).head(200)
|
||||
if len(sample) == 0:
|
||||
continue
|
||||
avg_len = sample.map(len).mean()
|
||||
if avg_len > best_score:
|
||||
best_score = avg_len
|
||||
best_col = col
|
||||
if best_col is None:
|
||||
raise ValueError("Could not detect a text column in the .tab file.")
|
||||
return best_col
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--input_tab", required=True, help="Tripadvisor reviews .tab file")
|
||||
ap.add_argument("--out_dir", default="out")
|
||||
ap.add_argument(
|
||||
"--embedding_model", default="sentence-transformers/all-MiniLM-L6-v2"
|
||||
)
|
||||
ap.add_argument("--chunk_chars", type=int, default=900)
|
||||
ap.add_argument("--overlap", type=int, default=150)
|
||||
args = ap.parse_args()
|
||||
|
||||
os.makedirs(args.out_dir, exist_ok=True)
|
||||
|
||||
# Many .tab files are TSV
|
||||
df = pd.read_csv(args.input_tab, sep="\t", dtype=str, on_bad_lines="skip")
|
||||
text_col = detect_text_col(df)
|
||||
|
||||
rows = df[text_col].fillna("").astype(str).tolist()
|
||||
|
||||
corpus_path = os.path.join(args.out_dir, "corpus.jsonl")
|
||||
corpus = []
|
||||
doc_id = 0
|
||||
|
||||
for r in tqdm(rows, desc="Chunking"):
|
||||
r = simple_clean(r)
|
||||
if len(r) < 30:
|
||||
continue
|
||||
chunks = chunk_text(r, chunk_chars=args.chunk_chars, overlap=args.overlap)
|
||||
for ch in chunks:
|
||||
if len(ch) < 30:
|
||||
continue
|
||||
corpus.append({"doc_id": doc_id, "text": ch})
|
||||
doc_id += 1
|
||||
|
||||
with open(corpus_path, "w", encoding="utf-8") as f:
|
||||
for ex in corpus:
|
||||
f.write(json.dumps(ex, ensure_ascii=False) + "\n")
|
||||
|
||||
# Build FAISS index
|
||||
embedder = SentenceTransformer(args.embedding_model)
|
||||
texts = [c["text"] for c in corpus]
|
||||
embs = embedder.encode(
|
||||
texts, batch_size=64, show_progress_bar=True, normalize_embeddings=True
|
||||
)
|
||||
embs = np.asarray(embs, dtype=np.float32)
|
||||
|
||||
dim = embs.shape[1]
|
||||
index = faiss.IndexFlatIP(dim) # cosine if normalized
|
||||
index.add(embs)
|
||||
|
||||
faiss_path = os.path.join(args.out_dir, "faiss.index")
|
||||
faiss.write_index(index, faiss_path)
|
||||
|
||||
# Store mapping doc row -> text
|
||||
mapping_path = os.path.join(args.out_dir, "docstore.jsonl")
|
||||
with open(mapping_path, "w", encoding="utf-8") as f:
|
||||
for i, c in enumerate(corpus):
|
||||
f.write(json.dumps({"faiss_id": i, **c}, ensure_ascii=False) + "\n")
|
||||
|
||||
print(
|
||||
f"Saved:\n- {corpus_path}\n- {faiss_path}\n- {mapping_path}\nText column detected: {text_col}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
89
raft/rag_chat.py
Normal file
89
raft/rag_chat.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
||||
import faiss
|
||||
import numpy as np
|
||||
import torch
|
||||
from peft import PeftModel
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
## Usage: python rag_chat.py --lora_dir out/mistral_balitwin_lora
|
||||
|
||||
SYSTEM_PERSONA = """You are 'BaliTwin', a culturally versed Bali traveler.
|
||||
You give your opinions nand guidance with local etiquette and context.
|
||||
Use the provided CONTEXT; include 1-2 short quotes as evidence.
|
||||
If the context does not support the claim, say so.
|
||||
"""
|
||||
|
||||
|
||||
def load_docstore(path):
|
||||
docs = []
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
docs.append(json.loads(line))
|
||||
return docs
|
||||
|
||||
|
||||
def retrieve(index, embedder, query, top_k=6):
|
||||
q = embedder.encode([query], normalize_embeddings=True).astype(np.float32)
|
||||
scores, ids = index.search(q, top_k)
|
||||
return ids[0].tolist(), scores[0].tolist()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--base_model", default="mistralai/Mistral-7B-Instruct-v0.2")
|
||||
ap.add_argument("--lora_dir", default="out/mistral_balitwin_lora")
|
||||
ap.add_argument("--out_dir", default="out")
|
||||
ap.add_argument(
|
||||
"--embedding_model", default="sentence-transformers/all-MiniLM-L6-v2"
|
||||
)
|
||||
ap.add_argument("--top_k", type=int, default=6)
|
||||
args = ap.parse_args()
|
||||
|
||||
index = faiss.read_index(os.path.join(args.out_dir, "faiss.index"))
|
||||
docstore = load_docstore(os.path.join(args.out_dir, "docstore.jsonl"))
|
||||
embedder = SentenceTransformer(args.embedding_model)
|
||||
|
||||
tok = AutoTokenizer.from_pretrained(args.base_model, use_fast=True)
|
||||
base = AutoModelForCausalLM.from_pretrained(
|
||||
args.base_model, device_map="auto", torch_dtype=torch.float16
|
||||
)
|
||||
model = PeftModel.from_pretrained(base, args.lora_dir)
|
||||
model.eval()
|
||||
|
||||
print("Type your question (Ctrl+C to exit).")
|
||||
while True:
|
||||
q = input("\nYou: ").strip()
|
||||
if not q:
|
||||
continue
|
||||
|
||||
ids, _ = retrieve(index, embedder, q, top_k=args.top_k)
|
||||
context_docs = [docstore[i]["text"] for i in ids]
|
||||
context_blob = "\n\n".join(
|
||||
[f"[DOC {i}] {t}" for i, t in enumerate(context_docs)]
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": SYSTEM_PERSONA},
|
||||
{"role": "user", "content": f"QUESTION: {q}\n\nCONTEXT:\n{context_blob}"},
|
||||
]
|
||||
inp = tok.apply_chat_template(messages, return_tensors="pt").to(model.device)
|
||||
|
||||
out = model.generate(
|
||||
inp,
|
||||
max_new_tokens=320,
|
||||
do_sample=True,
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
eos_token_id=tok.eos_token_id,
|
||||
)
|
||||
ans = tok.decode(out[0][inp.shape[1] :], skip_special_tokens=True).strip()
|
||||
print(f"\nBaliTwin: {ans}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
118
raft/rag_chat_merged.py
Normal file
118
raft/rag_chat_merged.py
Normal file
@@ -0,0 +1,118 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
||||
import faiss
|
||||
import numpy as np
|
||||
import torch
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
## Usage: python rag_chat_merged.py --model_dir /path/to/model_folder --out_dir out
|
||||
|
||||
SYSTEM_PERSONA = """You are 'BaliTwin', a culturally versed Bali traveler.
|
||||
You give your opinions nand guidance with local etiquette and context.
|
||||
Use the provided CONTEXT; include 1-2 short quotes as evidence.
|
||||
If the context does not support the claim, say so.
|
||||
"""
|
||||
|
||||
|
||||
def load_docstore(path):
|
||||
docs = []
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
docs.append(json.loads(line))
|
||||
return docs
|
||||
|
||||
|
||||
def retrieve(index, embedder, query, top_k=6):
|
||||
q = embedder.encode([query], normalize_embeddings=True).astype(np.float32)
|
||||
scores, ids = index.search(q, top_k)
|
||||
return ids[0].tolist(), scores[0].tolist()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument(
|
||||
"--model_dir", required=True, help="Path to your finetuned model folder"
|
||||
)
|
||||
ap.add_argument(
|
||||
"--out_dir", default="out", help="Where faiss.index and docstore.jsonl live"
|
||||
)
|
||||
ap.add_argument(
|
||||
"--embedding_model", default="sentence-transformers/all-MiniLM-L6-v2"
|
||||
)
|
||||
ap.add_argument("--top_k", type=int, default=6)
|
||||
ap.add_argument("--max_new_tokens", type=int, default=320)
|
||||
args = ap.parse_args()
|
||||
|
||||
index = faiss.read_index(os.path.join(args.out_dir, "faiss.index"))
|
||||
docstore = load_docstore(os.path.join(args.out_dir, "docstore.jsonl"))
|
||||
embedder = SentenceTransformer(args.embedding_model)
|
||||
|
||||
# Load your externally finetuned model directly from disk
|
||||
tok = AutoTokenizer.from_pretrained(args.model_dir, use_fast=True)
|
||||
|
||||
# Important: ensure pad token exists for generation; Mistral often uses eos as pad
|
||||
if tok.pad_token is None:
|
||||
tok.pad_token = tok.eos_token
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.model_dir,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
model.eval()
|
||||
|
||||
print("Type your question (Ctrl+C to exit).")
|
||||
while True:
|
||||
q = input("\nYou: ").strip()
|
||||
if not q:
|
||||
continue
|
||||
|
||||
ids, _ = retrieve(index, embedder, q, top_k=args.top_k)
|
||||
context_docs = [docstore[i]["text"] for i in ids]
|
||||
context_blob = "\n\n".join(
|
||||
[f"[DOC {i}] {t}" for i, t in enumerate(context_docs)]
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": SYSTEM_PERSONA},
|
||||
{"role": "user", "content": f"QUESTION: {q}\n\nCONTEXT:\n{context_blob}"},
|
||||
]
|
||||
|
||||
# Use chat template from your folder (you have chat_template.jinja)
|
||||
inp = tok.apply_chat_template(messages, return_tensors="pt").to(model.device)
|
||||
|
||||
enc = tok.apply_chat_template(
|
||||
messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
|
||||
)
|
||||
|
||||
if isinstance(enc, torch.Tensor):
|
||||
input_ids = enc.to(model.device)
|
||||
attention_mask = torch.ones_like(input_ids, device=model.device)
|
||||
else:
|
||||
input_ids = enc["input_ids"].to(model.device)
|
||||
attention_mask = enc.get("attention_mask")
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
attention_mask = attention_mask.to(model.device)
|
||||
|
||||
out = model.generate(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
max_new_tokens=args.max_new_tokens,
|
||||
do_sample=True,
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
eos_token_id=tok.eos_token_id,
|
||||
pad_token_id=tok.pad_token_id,
|
||||
)
|
||||
|
||||
ans = tok.decode(out[0][input_ids.shape[1] :], skip_special_tokens=True).strip()
|
||||
print(f"\nBaliTwin: {ans}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
100
raft/train_mistral_raft.py
Normal file
100
raft/train_mistral_raft.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from peft import LoraConfig
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
||||
from trl import SFTConfig, SFTTrainer
|
||||
|
||||
## Usage: python train_mistral_raft.py --train_jsonl out/raft_train.jsonl --out_dir out/mistral_balitwin_lora
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--train_jsonl", default="out/raft_train.jsonl")
|
||||
ap.add_argument("--base_model", default="mistralai/Mistral-7B-Instruct-v0.3")
|
||||
ap.add_argument("--out_dir", default="out/mistral_balitwin_lora")
|
||||
ap.add_argument("--max_seq_len", type=int, default=2048)
|
||||
ap.add_argument("--batch_size", type=int, default=1)
|
||||
ap.add_argument("--grad_accum", type=int, default=16)
|
||||
ap.add_argument("--lr", type=float, default=2e-4)
|
||||
ap.add_argument("--epochs", type=int, default=1)
|
||||
args = ap.parse_args()
|
||||
|
||||
os.makedirs(args.out_dir, exist_ok=True)
|
||||
|
||||
# QLoRA (4-bit) config (good default for 7B on limited VRAM)
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=(
|
||||
torch.bfloat16 if torch.cuda.is_available() else torch.float16
|
||||
),
|
||||
bnb_4bit_use_double_quant=True,
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.base_model, use_fast=True)
|
||||
# Mistral usually has a valid chat template; keep it intact. :contentReference[oaicite:9]{index=9}
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.base_model,
|
||||
device_map="auto",
|
||||
quantization_config=bnb_config,
|
||||
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16,
|
||||
)
|
||||
|
||||
# LoRA adapter config (tweak r/alpha if needed)
|
||||
peft_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.05,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
target_modules=[
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
"o_proj",
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
"down_proj",
|
||||
],
|
||||
)
|
||||
|
||||
dataset = load_dataset("json", data_files=args.train_jsonl, split="train")
|
||||
|
||||
training_args = SFTConfig(
|
||||
output_dir=args.out_dir,
|
||||
num_train_epochs=args.epochs,
|
||||
per_device_train_batch_size=args.batch_size,
|
||||
gradient_accumulation_steps=args.grad_accum,
|
||||
learning_rate=args.lr,
|
||||
logging_steps=10,
|
||||
save_steps=200,
|
||||
save_total_limit=2,
|
||||
max_length=args.max_seq_len,
|
||||
bf16=torch.cuda.is_available(),
|
||||
fp16=not torch.cuda.is_available(),
|
||||
assistant_only_loss=True, # only learn from assistant turns in messages :contentReference[oaicite:10]{index=10}
|
||||
report_to=[],
|
||||
)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
processing_class=tokenizer,
|
||||
peft_config=peft_config,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
trainer.save_model(args.out_dir)
|
||||
tokenizer.save_pretrained(args.out_dir)
|
||||
|
||||
print(f"Saved LoRA adapter to: {args.out_dir}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user