from fastapi import FastAPI, File, UploadFile
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
from transformers import pipeline
import whisper
import os
# import numpy as np
# import json

app = FastAPI()

class EmbedRequest(BaseModel):
    prompt: str
    language_code: str = "en-GB"  # Default to English

class KeywordRequest(BaseModel):
    target: str

# Load the models
print("Loading SentenceTransformer model...")
model = SentenceTransformer('all-MiniLM-L6-v2')
print("Model loaded successfully!")

print("Loading Whisper model...")
whisper_model = whisper.load_model("base")
print("Whisper model loaded successfully!")


@app.post("/embed")
async def generate_embedding(request: EmbedRequest):
    try:
        print(f"Request data: {request}")
        
        prompt = request.prompt
        language_code = request.language_code
        print(f"Processing prompt: '{prompt}' in language: {language_code}")
        
        # Generate embedding (SentenceTransformer supports multiple languages)
        embedding = model.encode(prompt)
        print(f"Embedding generated successfully, shape: {embedding.shape}")
        
        # Convert numpy array to list for JSON serialization
        embedding_list = embedding.tolist()
        
        response = {
            "embedding": embedding_list,
            "model": "all-MiniLM-L6-v2",
            "dimensions": len(embedding_list),
            "language_code": language_code,
            "prompt": prompt
        }
        print(f"Response: {response}")
        return response
        
    except Exception as e:
        print(f"Error: {str(e)}")
        return {"error": str(e)}


@app.post("/transcribe")
async def transcribe_audio(file: UploadFile = File(...)):
    try:
        print(f"File name: {file.filename}")
        print(f"File content type: {file.content_type}")
        
        # Save uploaded file temporarily
        temp_file_path = f"temp_{file.filename}"
        print(f"Saving file to: {temp_file_path}")
        
        with open(temp_file_path, "wb") as buffer:
            content = await file.read()
            buffer.write(content)
        
        print(f"File saved successfully. Size: {os.path.getsize(temp_file_path)} bytes")
        
        # Check if file exists and is readable
        if not os.path.exists(temp_file_path):
            return {"error": "Failed to save uploaded file"}
        
        
        # Check file extension and handle accordingly
        file_extension = os.path.splitext(file.filename)[1].lower()
        print(f"File extension: {file_extension}")
        
        try:
            if file_extension in ['.webm', '.mp4', '.m4a', '.ogg']:
                # For formats that need FFmpeg, return helpful error
                return {
                    "error": "WebM/MP4/M4A files require FFmpeg to be installed on your system."
                }
            else:
                # For WAV and other supported formats
                result = whisper_model.transcribe(temp_file_path, fp16=False)
                print(f"Transcription result: {result['text']}")
            
        except Exception as whisper_error:
            print(f"Whisper error: {str(whisper_error)}")
            return {"error": f"Transcription failed: {str(whisper_error)}"}
        
        # Clean up temporary file
        os.remove(temp_file_path)
        
        return {
            "text": result["text"],
            "language": result.get("language", "unknown"),
            "filename": file.filename
        }
        
    except Exception as e:
        print(f"Error during transcription: {str(e)}")
        # Clean up temporary file if it exists
        if 'temp_file_path' in locals() and os.path.exists(temp_file_path):
            os.remove(temp_file_path)
        return {"error": f"Transcription failed: {str(e)}"}


if __name__ == "__main__":
    import uvicorn
    print("Starting embedding server on http://127.0.0.1:8000")
    uvicorn.run(app, host="127.0.0.1", port=8000)
