| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214 |
- # file: src/query_hybrid.py
- import os
- from dotenv import load_dotenv
- import psycopg2
- from weaviate import WeaviateClient
- from weaviate.connect import ConnectionParams
- from src.embed_utils import get_embedding
- import json
- load_dotenv()
- DB_HOST = os.getenv("DB_HOST", "localhost")
- DB_PORT = os.getenv("DB_PORT", "5432")
- DB_NAME = os.getenv("DB_NAME", "pdf_research")
- DB_USER = os.getenv("DB_USER", "pdf_user")
- DB_PASSWORD = os.getenv("DB_PASSWORD")
- WEAVIATE_HOST = os.getenv("WEAVIATE_HOST", "localhost")
- WEAVIATE_HTTP_PORT = int(os.getenv("WEAVIATE_HTTP_PORT", "8080"))
- WEAVIATE_CLASS = os.getenv("WEAVIATE_CLASS", "ScientificArticle")
- def get_db_conn():
- return psycopg2.connect(
- host=DB_HOST,
- port=DB_PORT,
- database=DB_NAME,
- user=DB_USER,
- password=DB_PASSWORD,
- )
- def get_weaviate_client() -> WeaviateClient:
- """Weaviate v4 klient - HTTP ainult"""
- client = WeaviateClient(connection_params=ConnectionParams.from_params(
- http_host=WEAVIATE_HOST,
- http_port=WEAVIATE_HTTP_PORT,
- http_secure=False,
- grpc_host=WEAVIATE_HOST,
- grpc_port=50051,
- grpc_secure=False,
- ))
- client.connect()
- return client
- def search_weaviate_articles_all(limit: int = 10000) -> list[dict]:
- """
- Loeb KÕIK artiklid Weaviate'st (fetch_objects),
- ei kasuta near_text (mis vajavad Ollama embedding'uid).
- """
- client = get_weaviate_client()
- try:
- collection = client.collections.get(WEAVIATE_CLASS)
-
- # Loeme kõik objektid ilma otsinguta
- results = collection.query.fetch_objects(
- limit=limit,
- return_properties=["title", "source_file", "summary_et", "key_concepts",
- "authors", "transport_context", "relevance_score", "abstract_en"]
- )
- articles = []
- for obj in results.objects:
- props = obj.properties
- articles.append({
- "title": props.get("title", "N/A"),
- "article_id": str(obj.uuid),
- "summary_et": (props.get("summary_et", "") or "")[:500],
- "key_concepts": props.get("key_concepts", []),
- "authors": props.get("authors", []),
- "transport_context": props.get("transport_context", ""),
- "relevance_score": props.get("relevance_score", 0),
- "abstract_en": (props.get("abstract_en", "") or "")[:500],
- "source_file": props.get("source_file", ""),
- })
- return articles
- except Exception as e:
- print(f"❌ Weaviate päringu viga: {e}")
- import traceback
- traceback.print_exc()
- return []
- finally:
- client.close()
- def search_articles_by_query(query: str, all_articles: list[dict], limit: int = 10) -> list[dict]:
- """
- Otsib artikleid query abil, kasutades lokaalseid embeddings'eid (bge-small).
- Arvutab sarnasuse vektoritega.
- """
- query_emb = get_embedding(query)
-
- # Otsime artiklite summary_et ja abstract_en seast
- scored_articles = []
-
- for article in all_articles:
- # Võta tekst (summary_et või abstract_en)
- text = article.get("summary_et") or article.get("abstract_en") or ""
-
- if not text:
- continue
-
- # Arvuta embedding
- text_emb = get_embedding(text)
-
- # Arvuta cosine similarity (dot product normaliseeritud vektorite jaoks)
- similarity = sum(a * b for a, b in zip(query_emb, text_emb))
-
- scored_articles.append({
- **article,
- "score": similarity
- })
-
- # Sorteeri ja võta top N
- scored_articles.sort(key=lambda x: x["score"], reverse=True)
- return scored_articles[:limit]
- def search_pgvector_chunks(query: str, article_ids: list[str], limit: int = 20) -> list[dict]:
- """
- Otsib pgvectorist chunkid, mis kuuluvad antud artiklitele.
- """
- if not article_ids:
- return []
- conn = get_db_conn()
- cur = conn.cursor()
- query_emb = get_embedding(query)
- cur.execute("""
- SELECT
- c.id,
- c.text,
- c.page,
- c.chunk_index,
- r.filename,
- r.weaviate_article_id,
- 1 - (c.embedding <=> %s::vector) AS similarity
- FROM chunks c
- JOIN raw_documents r ON c.raw_doc_id = r.id
- WHERE r.weaviate_article_id = ANY(%s::UUID[])
- ORDER BY c.embedding <=> %s::vector
- LIMIT %s
- """, (query_emb, article_ids, query_emb, limit))
- chunks = []
- for row in cur.fetchall():
- chunks.append({
- "chunk_id": row[0],
- "text": row[1],
- "page": row[2],
- "chunk_index": row[3],
- "filename": row[4],
- "weaviate_article_id": str(row[5]) if row[5] else None,
- "similarity": float(row[6]),
- })
- cur.close()
- conn.close()
- return chunks
- def hybrid_search(query: str, top_articles: int = 10, top_chunks: int = 20) -> dict:
- """
- Kaheastmeline otsing:
- 1. Leia top artikleid Weaviate'st (lokaalsete embeddings'idega)
- 2. Leia nende artiklite chunkide seast parimad pgvectorist
- """
- print(f"\n🔍 Hübriidotsing: '{query}'\n")
- # Samm 1: Loe kõik artiklid Weaviate'st
- print("Samm 1: Loeme artikleid Weaviate'st...")
- all_articles = search_weaviate_articles_all(limit=10000)
- print(f"✓ Leitud {len(all_articles)} artiklit kokku")
- # Samm 2: Otsime query abil (lokaalsed embeddings)
- print(f"Samm 2: Otsing ({len(all_articles)} artikli seast)...")
- articles = search_articles_by_query(query, all_articles, limit=top_articles)
- print(f"✓ Leitud {len(articles)} asjakohast artiklit")
- if articles:
- for i, art in enumerate(articles, 1):
- print(f" {i}. {art['title'][:60]}...")
- print(f" Relevants: {art['relevance_score']:.1f} | Score: {art['score']:.3f}")
- # Samm 3: pgvector chunk otsing
- print(f"\nSamm 3: Otsing pgvector'ist (chunk tase)...")
- article_ids = [a["article_id"] for a in articles if a["article_id"]]
-
- if article_ids:
- chunks = search_pgvector_chunks(query, article_ids, limit=top_chunks)
- print(f"✓ Leitud {len(chunks)} chunk'i")
- if chunks:
- for i, chunk in enumerate(chunks[:5], 1):
- print(f" {i}. {chunk['filename']} (leht {chunk['page']}, sim: {chunk['similarity']:.3f})")
- print(f" \"{chunk['text'][:100]}...\"")
- else:
- chunks = []
- print(f"⚠️ Ühtegi artiklit ei leitud, chunkide otsingut ei saa teha")
- return {
- "query": query,
- "articles": articles,
- "chunks": chunks,
- }
- if __name__ == "__main__":
- #results = hybrid_search("young driver accident risk", top_articles=10, top_chunks=20)
- results = hybrid_search("modelling traffic volume in rural areas", top_articles=10, top_chunks=20)
- print("\n" + "=" * 60)
- print("TÄIELIK TULEMUS:")
- print(json.dumps(results, default=str, indent=2, ensure_ascii=False)[:3000])
|