الملفات
rag-app/doc_rag_app.py

329 أسطر
12 KiB
Python

# doc_rag_app.py
import os
import json
import uvicorn
import requests
from dotenv import load_dotenv
from typing import Optional, List
from openai import OpenAI
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from sentence_transformers import CrossEncoder
# Load .env
load_dotenv()
# -----------------------
# Configuration
# -----------------------
GITPASHA_HOST = os.getenv(
"GITPASHA_HOST",
"https://app1-f06df021060b.hosted.ghaymah.systems"
) # remote GitPasha endpoint you provided
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") # used only for final LLM summarization if needed
DOC_FILE = os.getenv("DOC_FILE", "full_ghaymah_docs.txt")
# -----------------------
# FastAPI + client
# -----------------------
app = FastAPI(title="Ghaymah Docs RAG API (Restarted)", version="1.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # or ["http://127.0.0.1:5500"] if serving HTML with Live Server
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# optional remote LLM client (only required if you want final answer generation)
client = None
if OPENAI_API_KEY:
client = OpenAI(api_key=OPENAI_API_KEY, base_url="https://genai.ghaymah.systems")
# -----------------------
# Models (Embedding + Reranking)
# -----------------------
print("Initializing local embedding model (sentence-transformers/distiluse-base-multilingual-cased)...")
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/distiluse-base-multilingual-cased")
print("Embedding model loaded.")
print("Initializing local CrossEncoder model (ms-marco-MiniLM-L-6-v2)...")
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
print("CrossEncoder model loaded.")
# -----------------------
# Request Models
# -----------------------
class QueryRequest(BaseModel):
query: str
k: Optional[int] = 5 # final number of chunks to use
class IngestRequest(BaseModel):
# keep for future if want dynamic file name content ingestion
filename: Optional[str] = None
# -----------------------
# Helpers
# -----------------------
def _embed_texts(texts: List[str]) -> List[List[float]]:
"""Return list of embeddings for given texts."""
return embeddings.embed_documents(texts)
def _embed_query(text: str) -> List[float]:
"""Return single embedding for query."""
return embeddings.embed_query(text)
def store_text_chunks_remote(text: str) -> bool:
"""Split text, embed chunks, and insert to remote GitPasha."""
if not text:
print("No text provided to store.")
return False
# Split
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
chunks = splitter.split_text(text)
print(f"[store] Split into {len(chunks)} chunks.")
# Create embeddings
try:
chunk_vectors = _embed_texts(chunks)
except Exception as e:
print(f"[store] Embedding creation error: {e}")
raise HTTPException(status_code=500, detail=f"Failed to create embeddings: {e}")
# Log embedding dimension sanity check
if chunk_vectors and isinstance(chunk_vectors[0], list):
print(f"[store] Embedding vector dimension: {len(chunk_vectors[0])}")
else:
print(f"[store] Unexpected embedding format. First vector: {type(chunk_vectors[0])}")
payloads = [{"text_chunk": chunk} for chunk in chunks]
# Send to GitPasha
try:
resp = requests.post(
f"{GITPASHA_HOST.rstrip('/')}/insert",
json={"vectors": chunk_vectors, "payloads": payloads},
headers={"Content-Type": "application/json"},
timeout=60
)
resp.raise_for_status()
print(f"[store] Remote insert status: {resp.status_code}")
return True
except requests.exceptions.RequestException as e:
print(f"[store] Error calling remote insert: {e} / Response: {getattr(e, 'response', None)}")
raise HTTPException(status_code=500, detail=f"Failed to insert to remote vector store: {e}")
def search_remote_by_vector(vector: List[float], k: int = 10):
"""Call remote /search with given vector and return parsed JSON (raw)."""
try:
resp = requests.post(
f"{GITPASHA_HOST.rstrip('/')}/search",
json={"vector": vector, "k": k},
headers={"Content-Type": "application/json"},
timeout=30
)
resp.raise_for_status()
return resp.json()
except requests.exceptions.RequestException as e:
print(f"[search] Error calling remote search: {e}")
raise HTTPException(status_code=500, detail=f"Remote search failed: {e}")
# -----------------------
# Startup: optionally auto-ingest file on startup
# -----------------------
@app.on_event("startup")
def startup_ingest():
"""On startup, attempt to ingest DOC_FILE automatically (non-fatal)."""
print(f"[startup] Attempting to ingest '{DOC_FILE}' on startup (if present).")
if not os.path.exists(DOC_FILE):
print(f"[startup] File '{DOC_FILE}' not found; skipping automatic ingestion.")
return
try:
with open(DOC_FILE, "r", encoding="utf-8") as f:
text = f.read()
ok = store_text_chunks_remote(text)
if ok:
print(f"[startup] Ingested '{DOC_FILE}' successfully.")
except Exception as e:
# do not prevent server from starting
print(f"[startup] Ingest error (non-fatal): {e}")
# -----------------------
# Endpoints
# -----------------------
@app.post("/ingest-docs/")
async def ingest_docs(req: IngestRequest = None):
"""Read full_ghaymah_docs.txt and store it remotely. Returns success message."""
filename = DOC_FILE
try:
with open(filename, "r", encoding="utf-8") as f:
text = f.read()
except FileNotFoundError:
raise HTTPException(status_code=404, detail=f"{filename} not found in working folder.")
ok = store_text_chunks_remote(text)
if ok:
return JSONResponse(content={"message": f"Successfully ingested '{filename}' into vector store."})
raise HTTPException(status_code=500, detail="Ingestion failed.")
@app.post("/query/")
async def query_docs(request: QueryRequest):
query = request.query
k_final = request.k or 5 # The final number of documents to use
k_initial = 25 # The number of documents to retrieve initially
print(f"[query] Received query: '{query}' (k_initial={k_initial}, k_final={k_final})")
# 1. Embed query
qvec = _embed_query(query)
# 2. Initial Retrieval from vector store
search_results = search_remote_by_vector(qvec, k=k_initial)
initial_chunks = [p.get("text_chunk", "") for p in search_results.get("payloads", [])]
if not initial_chunks:
return {"answer": "No relevant chunks found.", "search_results": search_results}
# Deduplicate initial chunks before re-ranking
seen = set()
unique_chunks = []
for chunk in initial_chunks:
if chunk not in seen:
unique_chunks.append(chunk)
seen.add(chunk)
print(f"[query] Retrieved {len(unique_chunks)} unique chunks for re-ranking.")
# 3. Re-ranking with CrossEncoder
# Create pairs of (query, chunk) for the model
rerank_pairs = [(query, chunk) for chunk in unique_chunks]
# Predict new relevance scores
rerank_scores = cross_encoder.predict(rerank_pairs)
# Combine chunks with their new scores
reranked_results = list(zip(rerank_scores, unique_chunks))
# Sort by the new score in descending order
reranked_results.sort(key=lambda x: x[0], reverse=True)
# 4. Select top k_final results after re-ranking
top_k_chunks = [chunk for score, chunk in reranked_results[:k_final]]
top_k_scores = [float(score) for score, chunk in reranked_results[:k_final]]
context = "\n\n".join(top_k_chunks)
print(f"[query] Built context with {len(top_k_chunks)} re-ranked chunks.")
# 5. Use LLM if available to generate a final answer
if client:
try:
completion = client.chat.completions.create(
model="DeepSeek-V3-0324",
messages=[
{"role": "system", "content": "You are a helpful assistant for Ghaymah Cloud. Answer the question using the context provided."},
{"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}"}
],
temperature=0.0,
)
answer = completion.choices[0].message.content
return {"answer": answer, "context": top_k_chunks, "scores": top_k_scores}
except Exception as e:
print(f"[query] LLM failed: {e}")
# Fallback to returning the context directly
return {"answer": context, "context": top_k_chunks, "scores": top_k_scores}
else:
# If no LLM, return the context as the answer
return {"answer": context, "context": top_k_chunks, "scores": top_k_scores}
@app.post("/test-rerank/")
async def test_rerank(request: QueryRequest):
"""
Endpoint for visualization. Returns initial and re-ranked results.
"""
query = request.query
k_final = request.k or 5
k_initial = 25
print(f"[test-rerank] Received query: '{query}' (k_initial={k_initial}, k_final={k_final})")
# 1. Embed query
qvec = _embed_query(query)
# 2. Initial Retrieval
search_results = search_remote_by_vector(qvec, k=k_initial)
initial_payloads = search_results.get("payloads", [])
initial_scores = search_results.get("scores", [])
# Ensure we have the same number of scores and payloads
min_len = min(len(initial_payloads), len(initial_scores))
initial_results = [
{"text": p.get("text_chunk", ""), "score": s}
for p, s in zip(initial_payloads[:min_len], initial_scores[:min_len])
]
# Deduplicate
seen_texts = set()
unique_initial_results = []
for res in initial_results:
if res["text"] not in seen_texts:
unique_initial_results.append(res)
seen_texts.add(res["text"])
unique_chunks = [res["text"] for res in unique_initial_results]
if not unique_chunks:
return {"initial_results": [], "reranked_results": []}
# 3. Re-ranking
rerank_pairs = [(query, chunk) for chunk in unique_chunks]
rerank_scores = cross_encoder.predict(rerank_pairs)
reranked_results_with_scores = [
{"text": chunk, "score": float(score)}
for score, chunk in zip(rerank_scores, unique_chunks)
]
# Sort by new score
reranked_results_with_scores.sort(key=lambda x: x["score"], reverse=True)
return {
"initial_results": unique_initial_results,
"reranked_results": reranked_results_with_scores[:k_final]
}
@app.post("/debug-search/")
def debug_search(request: QueryRequest):
"""
Debug endpoint: returns raw search response from remote vector store for the provided query.
Use this to inspect exact 'results' and scores returned remotely.
"""
query = request.query
k = request.k or 10
print(f"[debug-search] Query: {query} (k={k})")
try:
qvec = _embed_query(query)
print(f"[debug-search] Query embedding length: {len(qvec)}")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Embedding failed: {e}")
raw = search_remote_by_vector(qvec, k=k)
return JSONResponse(content={"search_response": raw})
@app.get("/")
def read_root():
return {"message": "Ghaymah Docs RAG API. Use /docs for interactive UI."}
# -----------------------
# Run
# -----------------------
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)