258 أسطر
9.2 KiB
Python
258 أسطر
9.2 KiB
Python
# doc_rag_app.py
|
|
import os
|
|
import json
|
|
import uvicorn
|
|
import requests
|
|
from dotenv import load_dotenv
|
|
from typing import Optional
|
|
from openai import OpenAI
|
|
from fastapi import FastAPI, HTTPException
|
|
from fastapi.responses import JSONResponse
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from pydantic import BaseModel
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
from langchain_community.embeddings import HuggingFaceEmbeddings
|
|
|
|
# Load .env
|
|
load_dotenv()
|
|
|
|
# -----------------------
|
|
# Configuration
|
|
# -----------------------
|
|
GITPASHA_HOST = os.getenv(
|
|
"GITPASHA_HOST",
|
|
"https://app1-f06df021060b.hosted.ghaymah.systems"
|
|
) # remote GitPasha endpoint you provided
|
|
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") # used only for final LLM summarization if needed
|
|
DOC_FILE = os.getenv("DOC_FILE", "full_ghaymah_docs.txt")
|
|
|
|
# -----------------------
|
|
# FastAPI + client
|
|
# -----------------------
|
|
app = FastAPI(title="Ghaymah Docs RAG API (Restarted)", version="1.0")
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"], # or ["http://127.0.0.1:5500"] if serving HTML with Live Server
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# optional remote LLM client (only required if you want final answer generation)
|
|
client = None
|
|
if OPENAI_API_KEY:
|
|
client = OpenAI(api_key=OPENAI_API_KEY, base_url="https://genai.ghaymah.systems")
|
|
|
|
# -----------------------
|
|
# Embedding model (512 dims)
|
|
# -----------------------
|
|
print("Initializing local embedding model (sentence-transformers/distiluse-base-multilingual-cased)...")
|
|
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/distiluse-base-multilingual-cased")
|
|
print("Embedding model loaded.")
|
|
|
|
# -----------------------
|
|
# Request Models
|
|
# -----------------------
|
|
class QueryRequest(BaseModel):
|
|
query: str
|
|
k: Optional[int] = 10 # allow overriding k
|
|
|
|
class IngestRequest(BaseModel):
|
|
# keep for future if want dynamic file name content ingestion
|
|
filename: Optional[str] = None
|
|
|
|
# -----------------------
|
|
# Helpers
|
|
# -----------------------
|
|
def _embed_texts(texts):
|
|
"""Return list of embeddings for given texts."""
|
|
return embeddings.embed_documents(texts)
|
|
|
|
def _embed_query(text):
|
|
"""Return single embedding for query (list)."""
|
|
return embeddings.embed_query(text)
|
|
|
|
def store_text_chunks_remote(text: str) -> bool:
|
|
"""Split text, embed chunks, and insert to remote GitPasha."""
|
|
if not text:
|
|
print("No text provided to store.")
|
|
return False
|
|
|
|
# Split
|
|
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
|
|
chunks = splitter.split_text(text)
|
|
print(f"[store] Split into {len(chunks)} chunks.")
|
|
|
|
# Create embeddings
|
|
try:
|
|
chunk_vectors = _embed_texts(chunks)
|
|
except Exception as e:
|
|
print(f"[store] Embedding creation error: {e}")
|
|
raise HTTPException(status_code=500, detail=f"Failed to create embeddings: {e}")
|
|
|
|
# Log embedding dimension sanity check
|
|
if chunk_vectors and isinstance(chunk_vectors[0], list):
|
|
print(f"[store] Embedding vector dimension: {len(chunk_vectors[0])}")
|
|
else:
|
|
print(f"[store] Unexpected embedding format. First vector: {type(chunk_vectors[0])}")
|
|
|
|
payloads = [{"text_chunk": chunk} for chunk in chunks]
|
|
|
|
# Send to GitPasha
|
|
try:
|
|
resp = requests.post(
|
|
f"{GITPASHA_HOST.rstrip('/')}/insert",
|
|
json={"vectors": chunk_vectors, "payloads": payloads},
|
|
headers={"Content-Type": "application/json"},
|
|
timeout=60
|
|
)
|
|
resp.raise_for_status()
|
|
print(f"[store] Remote insert status: {resp.status_code}")
|
|
return True
|
|
except requests.exceptions.RequestException as e:
|
|
print(f"[store] Error calling remote insert: {e} / Response: {getattr(e, 'response', None)}")
|
|
raise HTTPException(status_code=500, detail=f"Failed to insert to remote vector store: {e}")
|
|
|
|
def search_remote_by_vector(vector, k=10):
|
|
"""Call remote /search with given vector and return parsed JSON (raw)."""
|
|
try:
|
|
resp = requests.post(
|
|
f"{GITPASHA_HOST.rstrip('/')}/search",
|
|
json={"vector": vector, "k": k},
|
|
headers={"Content-Type": "application/json"},
|
|
timeout=30
|
|
)
|
|
resp.raise_for_status()
|
|
return resp.json()
|
|
except requests.exceptions.RequestException as e:
|
|
print(f"[search] Error calling remote search: {e}")
|
|
raise HTTPException(status_code=500, detail=f"Remote search failed: {e}")
|
|
|
|
def build_context_from_search_results(search_results, min_score: Optional[float] = None):
|
|
"""Given remote search results, optionally filter by min_score and return context text and metadata."""
|
|
if not search_results or "results" not in search_results:
|
|
return "", []
|
|
|
|
items = []
|
|
for r in search_results["results"]:
|
|
score = r.get("score", None)
|
|
payload = r.get("payload", {})
|
|
text_chunk = payload.get("text_chunk", "")
|
|
if min_score is None or (score is not None and score >= min_score):
|
|
items.append({"score": score, "text": text_chunk})
|
|
context = "\n\n".join([it["text"] for it in items])
|
|
return context, items
|
|
|
|
# -----------------------
|
|
# Startup: optionally auto-ingest file on startup
|
|
# -----------------------
|
|
@app.on_event("startup")
|
|
def startup_ingest():
|
|
"""On startup, attempt to ingest DOC_FILE automatically (non-fatal)."""
|
|
print(f"[startup] Attempting to ingest '{DOC_FILE}' on startup (if present).")
|
|
if not os.path.exists(DOC_FILE):
|
|
print(f"[startup] File '{DOC_FILE}' not found; skipping automatic ingestion.")
|
|
return
|
|
try:
|
|
with open(DOC_FILE, "r", encoding="utf-8") as f:
|
|
text = f.read()
|
|
ok = store_text_chunks_remote(text)
|
|
if ok:
|
|
print(f"[startup] Ingested '{DOC_FILE}' successfully.")
|
|
except Exception as e:
|
|
# do not prevent server from starting
|
|
print(f"[startup] Ingest error (non-fatal): {e}")
|
|
|
|
# -----------------------
|
|
# Endpoints
|
|
# -----------------------
|
|
@app.post("/ingest-docs/")
|
|
async def ingest_docs(req: IngestRequest = None):
|
|
"""Read full_ghaymah_docs.txt and store it remotely. Returns success message."""
|
|
filename = DOC_FILE
|
|
try:
|
|
with open(filename, "r", encoding="utf-8") as f:
|
|
text = f.read()
|
|
except FileNotFoundError:
|
|
raise HTTPException(status_code=404, detail=f"{filename} not found in working folder.")
|
|
|
|
ok = store_text_chunks_remote(text)
|
|
if ok:
|
|
return JSONResponse(content={"message": f"Successfully ingested '{filename}' into vector store."})
|
|
raise HTTPException(status_code=500, detail="Ingestion failed.")
|
|
@app.post("/query/")
|
|
async def query_docs(request: QueryRequest):
|
|
query = request.query
|
|
k = request.k or 10
|
|
print(f"[query] Received query: {query} (k={k})")
|
|
|
|
# Embed query
|
|
qvec = _embed_query(query)
|
|
|
|
# Remote vector search
|
|
search_results = search_remote_by_vector(qvec, k=k)
|
|
payloads = [p["text_chunk"] for p in search_results.get("payloads", [])]
|
|
|
|
if not payloads:
|
|
return {"answer": "No relevant chunks found.", "search_results": search_results}
|
|
|
|
# Deduplicate chunks (keep first occurrence)
|
|
seen = set()
|
|
context_chunks = []
|
|
for chunk in payloads:
|
|
if chunk not in seen:
|
|
context_chunks.append(chunk)
|
|
seen.add(chunk)
|
|
|
|
context = "\n\n".join(context_chunks)
|
|
|
|
# Use LLM if available
|
|
if client:
|
|
try:
|
|
completion = client.chat.completions.create(
|
|
model="DeepSeek-V3-0324",
|
|
messages=[
|
|
{"role": "system", "content": "You are a helpful assistant for Ghaymah Cloud. Answer the question using the context provided."},
|
|
{"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}"}
|
|
],
|
|
temperature=0.0,
|
|
)
|
|
answer = completion.choices[0].message.content
|
|
return {"answer": answer, "context": context_chunks, "scores": search_results.get("scores", [])}
|
|
except Exception as e:
|
|
print(f"[query] LLM failed: {e}")
|
|
return {"answer": context, "context": context_chunks, "scores": search_results.get("scores", [])}
|
|
else:
|
|
return {"answer": context, "context": context_chunks, "scores": search_results.get("scores", [])}
|
|
|
|
|
|
@app.post("/debug-search/")
|
|
async def debug_search(request: QueryRequest):
|
|
"""
|
|
Debug endpoint: returns raw search response from remote vector store for the provided query.
|
|
Use this to inspect exact 'results' and scores returned remotely.
|
|
"""
|
|
query = request.query
|
|
k = request.k or 10
|
|
print(f"[debug-search] Query: {query} (k={k})")
|
|
|
|
try:
|
|
qvec = _embed_query(query)
|
|
print(f"[debug-search] Query embedding length: {len(qvec)}")
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Embedding failed: {e}")
|
|
|
|
raw = search_remote_by_vector(qvec, k=k)
|
|
return JSONResponse(content={"search_response": raw})
|
|
|
|
@app.get("/")
|
|
def read_root():
|
|
return {"message": "Ghaymah Docs RAG API. Use /docs for interactive UI."}
|
|
|
|
# -----------------------
|
|
# Run
|
|
# -----------------------
|
|
if __name__ == "__main__":
|
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|