Initial commit of the RAG API application
هذا الالتزام موجود في:
185
doc_rag_app.py
185
doc_rag_app.py
@@ -4,7 +4,7 @@ import json
|
||||
import uvicorn
|
||||
import requests
|
||||
from dotenv import load_dotenv
|
||||
from typing import Optional
|
||||
from typing import Optional, List
|
||||
from openai import OpenAI
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
@@ -12,13 +12,14 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain_community.embeddings import HuggingFaceEmbeddings
|
||||
from sentence_transformers import CrossEncoder
|
||||
|
||||
# Load .env
|
||||
load_dotenv()
|
||||
|
||||
# -----------------------
|
||||
# -----------------------
|
||||
# Configuration
|
||||
# -----------------------
|
||||
# -----------------------
|
||||
GITPASHA_HOST = os.getenv(
|
||||
"GITPASHA_HOST",
|
||||
"https://app1-f06df021060b.hosted.ghaymah.systems"
|
||||
@@ -26,9 +27,9 @@ GITPASHA_HOST = os.getenv(
|
||||
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") # used only for final LLM summarization if needed
|
||||
DOC_FILE = os.getenv("DOC_FILE", "full_ghaymah_docs.txt")
|
||||
|
||||
# -----------------------
|
||||
# -----------------------
|
||||
# FastAPI + client
|
||||
# -----------------------
|
||||
# -----------------------
|
||||
app = FastAPI(title="Ghaymah Docs RAG API (Restarted)", version="1.0")
|
||||
|
||||
app.add_middleware(
|
||||
@@ -44,33 +45,37 @@ client = None
|
||||
if OPENAI_API_KEY:
|
||||
client = OpenAI(api_key=OPENAI_API_KEY, base_url="https://genai.ghaymah.systems")
|
||||
|
||||
# -----------------------
|
||||
# Embedding model (512 dims)
|
||||
# -----------------------
|
||||
# -----------------------
|
||||
# Models (Embedding + Reranking)
|
||||
# -----------------------
|
||||
print("Initializing local embedding model (sentence-transformers/distiluse-base-multilingual-cased)...")
|
||||
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/distiluse-base-multilingual-cased")
|
||||
print("Embedding model loaded.")
|
||||
|
||||
# -----------------------
|
||||
print("Initializing local CrossEncoder model (ms-marco-MiniLM-L-6-v2)...")
|
||||
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
|
||||
print("CrossEncoder model loaded.")
|
||||
|
||||
# -----------------------
|
||||
# Request Models
|
||||
# -----------------------
|
||||
# -----------------------
|
||||
class QueryRequest(BaseModel):
|
||||
query: str
|
||||
k: Optional[int] = 10 # allow overriding k
|
||||
k: Optional[int] = 5 # final number of chunks to use
|
||||
|
||||
class IngestRequest(BaseModel):
|
||||
# keep for future if want dynamic file name content ingestion
|
||||
filename: Optional[str] = None
|
||||
|
||||
# -----------------------
|
||||
# -----------------------
|
||||
# Helpers
|
||||
# -----------------------
|
||||
def _embed_texts(texts):
|
||||
# -----------------------
|
||||
def _embed_texts(texts: List[str]) -> List[List[float]]:
|
||||
"""Return list of embeddings for given texts."""
|
||||
return embeddings.embed_documents(texts)
|
||||
|
||||
def _embed_query(text):
|
||||
"""Return single embedding for query (list)."""
|
||||
def _embed_query(text: str) -> List[float]:
|
||||
"""Return single embedding for query."""
|
||||
return embeddings.embed_query(text)
|
||||
|
||||
def store_text_chunks_remote(text: str) -> bool:
|
||||
@@ -114,7 +119,7 @@ def store_text_chunks_remote(text: str) -> bool:
|
||||
print(f"[store] Error calling remote insert: {e} / Response: {getattr(e, 'response', None)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to insert to remote vector store: {e}")
|
||||
|
||||
def search_remote_by_vector(vector, k=10):
|
||||
def search_remote_by_vector(vector: List[float], k: int = 10):
|
||||
"""Call remote /search with given vector and return parsed JSON (raw)."""
|
||||
try:
|
||||
resp = requests.post(
|
||||
@@ -129,24 +134,9 @@ def search_remote_by_vector(vector, k=10):
|
||||
print(f"[search] Error calling remote search: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Remote search failed: {e}")
|
||||
|
||||
def build_context_from_search_results(search_results, min_score: Optional[float] = None):
|
||||
"""Given remote search results, optionally filter by min_score and return context text and metadata."""
|
||||
if not search_results or "results" not in search_results:
|
||||
return "", []
|
||||
|
||||
items = []
|
||||
for r in search_results["results"]:
|
||||
score = r.get("score", None)
|
||||
payload = r.get("payload", {})
|
||||
text_chunk = payload.get("text_chunk", "")
|
||||
if min_score is None or (score is not None and score >= min_score):
|
||||
items.append({"score": score, "text": text_chunk})
|
||||
context = "\n\n".join([it["text"] for it in items])
|
||||
return context, items
|
||||
|
||||
# -----------------------
|
||||
# -----------------------
|
||||
# Startup: optionally auto-ingest file on startup
|
||||
# -----------------------
|
||||
# -----------------------
|
||||
@app.on_event("startup")
|
||||
def startup_ingest():
|
||||
"""On startup, attempt to ingest DOC_FILE automatically (non-fatal)."""
|
||||
@@ -164,9 +154,9 @@ def startup_ingest():
|
||||
# do not prevent server from starting
|
||||
print(f"[startup] Ingest error (non-fatal): {e}")
|
||||
|
||||
# -----------------------
|
||||
# -----------------------
|
||||
# Endpoints
|
||||
# -----------------------
|
||||
# -----------------------
|
||||
@app.post("/ingest-docs/")
|
||||
async def ingest_docs(req: IngestRequest = None):
|
||||
"""Read full_ghaymah_docs.txt and store it remotely. Returns success message."""
|
||||
@@ -181,54 +171,135 @@ async def ingest_docs(req: IngestRequest = None):
|
||||
if ok:
|
||||
return JSONResponse(content={"message": f"Successfully ingested '{filename}' into vector store."})
|
||||
raise HTTPException(status_code=500, detail="Ingestion failed.")
|
||||
|
||||
@app.post("/query/")
|
||||
async def query_docs(request: QueryRequest):
|
||||
query = request.query
|
||||
k = request.k or 10
|
||||
print(f"[query] Received query: {query} (k={k})")
|
||||
k_final = request.k or 5 # The final number of documents to use
|
||||
k_initial = 25 # The number of documents to retrieve initially
|
||||
print(f"[query] Received query: '{query}' (k_initial={k_initial}, k_final={k_final})")
|
||||
|
||||
# Embed query
|
||||
# 1. Embed query
|
||||
qvec = _embed_query(query)
|
||||
|
||||
# Remote vector search
|
||||
search_results = search_remote_by_vector(qvec, k=k)
|
||||
payloads = [p["text_chunk"] for p in search_results.get("payloads", [])]
|
||||
# 2. Initial Retrieval from vector store
|
||||
search_results = search_remote_by_vector(qvec, k=k_initial)
|
||||
initial_chunks = [p.get("text_chunk", "") for p in search_results.get("payloads", [])]
|
||||
|
||||
if not payloads:
|
||||
if not initial_chunks:
|
||||
return {"answer": "No relevant chunks found.", "search_results": search_results}
|
||||
|
||||
# Deduplicate chunks (keep first occurrence)
|
||||
|
||||
# Deduplicate initial chunks before re-ranking
|
||||
seen = set()
|
||||
context_chunks = []
|
||||
for chunk in payloads:
|
||||
unique_chunks = []
|
||||
for chunk in initial_chunks:
|
||||
if chunk not in seen:
|
||||
context_chunks.append(chunk)
|
||||
unique_chunks.append(chunk)
|
||||
seen.add(chunk)
|
||||
|
||||
print(f"[query] Retrieved {len(unique_chunks)} unique chunks for re-ranking.")
|
||||
|
||||
context = "\n\n".join(context_chunks)
|
||||
# 3. Re-ranking with CrossEncoder
|
||||
# Create pairs of (query, chunk) for the model
|
||||
rerank_pairs = [(query, chunk) for chunk in unique_chunks]
|
||||
|
||||
# Predict new relevance scores
|
||||
rerank_scores = cross_encoder.predict(rerank_pairs)
|
||||
|
||||
# Combine chunks with their new scores
|
||||
reranked_results = list(zip(rerank_scores, unique_chunks))
|
||||
|
||||
# Sort by the new score in descending order
|
||||
reranked_results.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
# 4. Select top k_final results after re-ranking
|
||||
top_k_chunks = [chunk for score, chunk in reranked_results[:k_final]]
|
||||
top_k_scores = [float(score) for score, chunk in reranked_results[:k_final]]
|
||||
|
||||
# Use LLM if available
|
||||
context = "\n\n".join(top_k_chunks)
|
||||
print(f"[query] Built context with {len(top_k_chunks)} re-ranked chunks.")
|
||||
|
||||
# 5. Use LLM if available to generate a final answer
|
||||
if client:
|
||||
try:
|
||||
completion = client.chat.completions.create(
|
||||
model="DeepSeek-V3-0324",
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant for Ghaymah Cloud. Answer the question using the context provided."},
|
||||
{"role": "system", "content": "You are a helpful assistant for Ghaymah Cloud. Answer the question using the context provided."},
|
||||
{"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}"}
|
||||
],
|
||||
temperature=0.0,
|
||||
)
|
||||
answer = completion.choices[0].message.content
|
||||
return {"answer": answer, "context": context_chunks, "scores": search_results.get("scores", [])}
|
||||
return {"answer": answer, "context": top_k_chunks, "scores": top_k_scores}
|
||||
except Exception as e:
|
||||
print(f"[query] LLM failed: {e}")
|
||||
return {"answer": context, "context": context_chunks, "scores": search_results.get("scores", [])}
|
||||
# Fallback to returning the context directly
|
||||
return {"answer": context, "context": top_k_chunks, "scores": top_k_scores}
|
||||
else:
|
||||
return {"answer": context, "context": context_chunks, "scores": search_results.get("scores", [])}
|
||||
# If no LLM, return the context as the answer
|
||||
return {"answer": context, "context": top_k_chunks, "scores": top_k_scores}
|
||||
|
||||
@app.post("/test-rerank/")
|
||||
async def test_rerank(request: QueryRequest):
|
||||
"""
|
||||
Endpoint for visualization. Returns initial and re-ranked results.
|
||||
"""
|
||||
query = request.query
|
||||
k_final = request.k or 5
|
||||
k_initial = 25
|
||||
print(f"[test-rerank] Received query: '{query}' (k_initial={k_initial}, k_final={k_final})")
|
||||
|
||||
# 1. Embed query
|
||||
qvec = _embed_query(query)
|
||||
|
||||
# 2. Initial Retrieval
|
||||
search_results = search_remote_by_vector(qvec, k=k_initial)
|
||||
|
||||
initial_payloads = search_results.get("payloads", [])
|
||||
initial_scores = search_results.get("scores", [])
|
||||
|
||||
# Ensure we have the same number of scores and payloads
|
||||
min_len = min(len(initial_payloads), len(initial_scores))
|
||||
|
||||
initial_results = [
|
||||
{"text": p.get("text_chunk", ""), "score": s}
|
||||
for p, s in zip(initial_payloads[:min_len], initial_scores[:min_len])
|
||||
]
|
||||
|
||||
# Deduplicate
|
||||
seen_texts = set()
|
||||
unique_initial_results = []
|
||||
for res in initial_results:
|
||||
if res["text"] not in seen_texts:
|
||||
unique_initial_results.append(res)
|
||||
seen_texts.add(res["text"])
|
||||
|
||||
unique_chunks = [res["text"] for res in unique_initial_results]
|
||||
|
||||
if not unique_chunks:
|
||||
return {"initial_results": [], "reranked_results": []}
|
||||
|
||||
# 3. Re-ranking
|
||||
rerank_pairs = [(query, chunk) for chunk in unique_chunks]
|
||||
rerank_scores = cross_encoder.predict(rerank_pairs)
|
||||
|
||||
reranked_results_with_scores = [
|
||||
{"text": chunk, "score": float(score)}
|
||||
for score, chunk in zip(rerank_scores, unique_chunks)
|
||||
]
|
||||
|
||||
# Sort by new score
|
||||
reranked_results_with_scores.sort(key=lambda x: x["score"], reverse=True)
|
||||
|
||||
return {
|
||||
"initial_results": unique_initial_results,
|
||||
"reranked_results": reranked_results_with_scores[:k_final]
|
||||
}
|
||||
|
||||
|
||||
@app.post("/debug-search/")
|
||||
async def debug_search(request: QueryRequest):
|
||||
def debug_search(request: QueryRequest):
|
||||
"""
|
||||
Debug endpoint: returns raw search response from remote vector store for the provided query.
|
||||
Use this to inspect exact 'results' and scores returned remotely.
|
||||
@@ -250,8 +321,8 @@ async def debug_search(request: QueryRequest):
|
||||
def read_root():
|
||||
return {"message": "Ghaymah Docs RAG API. Use /docs for interactive UI."}
|
||||
|
||||
# -----------------------
|
||||
# -----------------------
|
||||
# Run
|
||||
# -----------------------
|
||||
# -----------------------
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
|
المرجع في مشكلة جديدة
حظر مستخدم