mirror of
https://github.com/marvinscham/masterthesis-playground.git
synced 2025-12-06 18:20:53 +01:00
144 lines
4.2 KiB
Python
144 lines
4.2 KiB
Python
import concurrent.futures
|
|
import json
|
|
import os
|
|
from pathlib import Path
|
|
from threading import Lock
|
|
|
|
from dotenv import load_dotenv
|
|
from openai import OpenAI
|
|
|
|
# Initialize a thread-safe lock for file writing
|
|
load_dotenv()
|
|
file_lock = Lock()
|
|
|
|
client = OpenAI(
|
|
api_key=os.getenv("DEEPSEEK_API_KEY"),
|
|
base_url="https://api.deepseek.com",
|
|
)
|
|
|
|
system_prompt = """
|
|
The user will provide a tourist review. Please categorize them according to the following categories, provide a short reasoning for the decision (max 8 words) and output them in JSON format.
|
|
The categories are: adventurer, business, family, backpacker, luxury, or none if no category fits.
|
|
|
|
EXAMPLE INPUT:
|
|
Perfect for families! The hotel had a kids' club, a shallow pool, and spacious rooms. Nearby attractions were child-friendly, and the staff went out of their way to accommodate us. Will definitely return!
|
|
|
|
EXAMPLE JSON OUTPUT:
|
|
{
|
|
"category": "family",
|
|
"reason": "child-friendly amenities and staff"
|
|
}
|
|
"""
|
|
|
|
|
|
def query_deepseek(review):
|
|
"""Query DeepSeek API for categorization"""
|
|
try:
|
|
response = client.chat.completions.create(
|
|
model="deepseek-chat",
|
|
messages=[
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": review},
|
|
],
|
|
temperature=0.2,
|
|
response_format={"type": "json_object"},
|
|
)
|
|
content = response.choices[0].message.content
|
|
return content
|
|
except Exception as e:
|
|
print(f"Error querying DeepSeek API: {e}")
|
|
return None
|
|
|
|
|
|
def read_reviews(file_path):
|
|
"""Read reviews from tab-separated file, assuming one review per line"""
|
|
with open(file_path, "r", encoding="utf-8") as f:
|
|
return [line.strip() for line in f if line.strip()]
|
|
|
|
|
|
def validate_response(response):
|
|
"""Validate if response matches expected JSON format"""
|
|
try:
|
|
data = json.loads(response)
|
|
if not all(key in data for key in ["category", "reason"]):
|
|
return None
|
|
if len(data["reason"].split()) > 8:
|
|
return None
|
|
return data
|
|
except json.JSONDecodeError:
|
|
return None
|
|
|
|
|
|
def process_review(i, review, output_file):
|
|
"""Process a single review and save results"""
|
|
print(f"Processing review {i}")
|
|
|
|
deepseek_response = query_deepseek(review)
|
|
deepseek_result = process_response(deepseek_response, i, "deepseek")
|
|
|
|
result = {
|
|
"id": i,
|
|
"review": review.strip('"'),
|
|
"deepseek": deepseek_result,
|
|
}
|
|
|
|
# Thread-safe file writing
|
|
with file_lock:
|
|
with open(output_file, "r+", encoding="utf-8") as f:
|
|
try:
|
|
data = json.load(f)
|
|
except json.JSONDecodeError:
|
|
data = []
|
|
data.append(result)
|
|
f.seek(0)
|
|
json.dump(data, f, indent=2)
|
|
f.truncate()
|
|
|
|
|
|
def process_response(response, i, model_name):
|
|
"""Helper function to validate and format responses"""
|
|
if not response:
|
|
return {"error": "query failed"}
|
|
|
|
validated = validate_response(response)
|
|
if validated:
|
|
return validated
|
|
else:
|
|
print(f"Format mismatch for {model_name} response {i}: {response}")
|
|
return {"error": "format mismatch"}
|
|
|
|
|
|
def main():
|
|
input_file = "data.tab"
|
|
output_file = "labels.json"
|
|
|
|
# Initialize output file
|
|
if not Path(output_file).exists():
|
|
with open(output_file, "w") as f:
|
|
json.dump([], f)
|
|
|
|
reviews = read_reviews(input_file)
|
|
|
|
# Skip header and limit to 20,000 reviews
|
|
reviews_to_process = [
|
|
(i, review) for i, review in enumerate(reviews[1:20001], start=1)
|
|
]
|
|
|
|
# Use ThreadPoolExecutor for parallel processing
|
|
# Adjust max_workers based on your API rate limits and system capabilities
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
|
futures = []
|
|
for i, review in reviews_to_process:
|
|
futures.append(executor.submit(process_review, i, review, output_file))
|
|
|
|
# Wait for all futures to complete
|
|
for future in concurrent.futures.as_completed(futures):
|
|
try:
|
|
future.result()
|
|
except Exception as e:
|
|
print(f"Error processing review: {e}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|