query_hybrid.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. # file: src/query_hybrid.py
  2. import os
  3. from dotenv import load_dotenv
  4. import psycopg2
  5. from weaviate import WeaviateClient
  6. from weaviate.connect import ConnectionParams
  7. from src.embed_utils import get_embedding
  8. import json
  9. load_dotenv()
  10. DB_HOST = os.getenv("DB_HOST", "localhost")
  11. DB_PORT = os.getenv("DB_PORT", "5432")
  12. DB_NAME = os.getenv("DB_NAME", "pdf_research")
  13. DB_USER = os.getenv("DB_USER", "pdf_user")
  14. DB_PASSWORD = os.getenv("DB_PASSWORD")
  15. WEAVIATE_HOST = os.getenv("WEAVIATE_HOST", "localhost")
  16. WEAVIATE_HTTP_PORT = int(os.getenv("WEAVIATE_HTTP_PORT", "8080"))
  17. WEAVIATE_CLASS = os.getenv("WEAVIATE_CLASS", "ScientificArticle")
  18. def get_db_conn():
  19. return psycopg2.connect(
  20. host=DB_HOST,
  21. port=DB_PORT,
  22. database=DB_NAME,
  23. user=DB_USER,
  24. password=DB_PASSWORD,
  25. )
  26. def get_weaviate_client() -> WeaviateClient:
  27. """Weaviate v4 klient - HTTP ainult"""
  28. client = WeaviateClient(connection_params=ConnectionParams.from_params(
  29. http_host=WEAVIATE_HOST,
  30. http_port=WEAVIATE_HTTP_PORT,
  31. http_secure=False,
  32. grpc_host=WEAVIATE_HOST,
  33. grpc_port=50051,
  34. grpc_secure=False,
  35. ))
  36. client.connect()
  37. return client
  38. def search_weaviate_articles_all(limit: int = 10000) -> list[dict]:
  39. """
  40. Loeb KÕIK artiklid Weaviate'st (fetch_objects),
  41. ei kasuta near_text (mis vajavad Ollama embedding'uid).
  42. """
  43. client = get_weaviate_client()
  44. try:
  45. collection = client.collections.get(WEAVIATE_CLASS)
  46. # Loeme kõik objektid ilma otsinguta
  47. results = collection.query.fetch_objects(
  48. limit=limit,
  49. return_properties=["title", "source_file", "summary_et", "key_concepts",
  50. "authors", "transport_context", "relevance_score", "abstract_en"]
  51. )
  52. articles = []
  53. for obj in results.objects:
  54. props = obj.properties
  55. articles.append({
  56. "title": props.get("title", "N/A"),
  57. "article_id": str(obj.uuid),
  58. "summary_et": (props.get("summary_et", "") or "")[:500],
  59. "key_concepts": props.get("key_concepts", []),
  60. "authors": props.get("authors", []),
  61. "transport_context": props.get("transport_context", ""),
  62. "relevance_score": props.get("relevance_score", 0),
  63. "abstract_en": (props.get("abstract_en", "") or "")[:500],
  64. "source_file": props.get("source_file", ""),
  65. })
  66. return articles
  67. except Exception as e:
  68. print(f"❌ Weaviate päringu viga: {e}")
  69. import traceback
  70. traceback.print_exc()
  71. return []
  72. finally:
  73. client.close()
  74. def search_articles_by_query(query: str, all_articles: list[dict], limit: int = 10) -> list[dict]:
  75. """
  76. Otsib artikleid query abil, kasutades lokaalseid embeddings'eid (bge-small).
  77. Arvutab sarnasuse vektoritega.
  78. """
  79. query_emb = get_embedding(query)
  80. # Otsime artiklite summary_et ja abstract_en seast
  81. scored_articles = []
  82. for article in all_articles:
  83. # Võta tekst (summary_et või abstract_en)
  84. text = article.get("summary_et") or article.get("abstract_en") or ""
  85. if not text:
  86. continue
  87. # Arvuta embedding
  88. text_emb = get_embedding(text)
  89. # Arvuta cosine similarity (dot product normaliseeritud vektorite jaoks)
  90. similarity = sum(a * b for a, b in zip(query_emb, text_emb))
  91. scored_articles.append({
  92. **article,
  93. "score": similarity
  94. })
  95. # Sorteeri ja võta top N
  96. scored_articles.sort(key=lambda x: x["score"], reverse=True)
  97. return scored_articles[:limit]
  98. def search_pgvector_chunks(query: str, article_ids: list[str], limit: int = 20) -> list[dict]:
  99. """
  100. Otsib pgvectorist chunkid, mis kuuluvad antud artiklitele.
  101. """
  102. if not article_ids:
  103. return []
  104. conn = get_db_conn()
  105. cur = conn.cursor()
  106. query_emb = get_embedding(query)
  107. cur.execute("""
  108. SELECT
  109. c.id,
  110. c.text,
  111. c.page,
  112. c.chunk_index,
  113. r.filename,
  114. r.weaviate_article_id,
  115. 1 - (c.embedding <=> %s::vector) AS similarity
  116. FROM chunks c
  117. JOIN raw_documents r ON c.raw_doc_id = r.id
  118. WHERE r.weaviate_article_id = ANY(%s::UUID[])
  119. ORDER BY c.embedding <=> %s::vector
  120. LIMIT %s
  121. """, (query_emb, article_ids, query_emb, limit))
  122. chunks = []
  123. for row in cur.fetchall():
  124. chunks.append({
  125. "chunk_id": row[0],
  126. "text": row[1],
  127. "page": row[2],
  128. "chunk_index": row[3],
  129. "filename": row[4],
  130. "weaviate_article_id": str(row[5]) if row[5] else None,
  131. "similarity": float(row[6]),
  132. })
  133. cur.close()
  134. conn.close()
  135. return chunks
  136. def hybrid_search(query: str, top_articles: int = 10, top_chunks: int = 20) -> dict:
  137. """
  138. Kaheastmeline otsing:
  139. 1. Leia top artikleid Weaviate'st (lokaalsete embeddings'idega)
  140. 2. Leia nende artiklite chunkide seast parimad pgvectorist
  141. """
  142. print(f"\n🔍 Hübriidotsing: '{query}'\n")
  143. # Samm 1: Loe kõik artiklid Weaviate'st
  144. print("Samm 1: Loeme artikleid Weaviate'st...")
  145. all_articles = search_weaviate_articles_all(limit=10000)
  146. print(f"✓ Leitud {len(all_articles)} artiklit kokku")
  147. # Samm 2: Otsime query abil (lokaalsed embeddings)
  148. print(f"Samm 2: Otsing ({len(all_articles)} artikli seast)...")
  149. articles = search_articles_by_query(query, all_articles, limit=top_articles)
  150. print(f"✓ Leitud {len(articles)} asjakohast artiklit")
  151. if articles:
  152. for i, art in enumerate(articles, 1):
  153. print(f" {i}. {art['title'][:60]}...")
  154. print(f" Relevants: {art['relevance_score']:.1f} | Score: {art['score']:.3f}")
  155. # Samm 3: pgvector chunk otsing
  156. print(f"\nSamm 3: Otsing pgvector'ist (chunk tase)...")
  157. article_ids = [a["article_id"] for a in articles if a["article_id"]]
  158. if article_ids:
  159. chunks = search_pgvector_chunks(query, article_ids, limit=top_chunks)
  160. print(f"✓ Leitud {len(chunks)} chunk'i")
  161. if chunks:
  162. for i, chunk in enumerate(chunks[:5], 1):
  163. print(f" {i}. {chunk['filename']} (leht {chunk['page']}, sim: {chunk['similarity']:.3f})")
  164. print(f" \"{chunk['text'][:100]}...\"")
  165. else:
  166. chunks = []
  167. print(f"⚠️ Ühtegi artiklit ei leitud, chunkide otsingut ei saa teha")
  168. return {
  169. "query": query,
  170. "articles": articles,
  171. "chunks": chunks,
  172. }
  173. if __name__ == "__main__":
  174. #results = hybrid_search("young driver accident risk", top_articles=10, top_chunks=20)
  175. results = hybrid_search("modelling traffic volume in rural areas", top_articles=10, top_chunks=20)
  176. print("\n" + "=" * 60)
  177. print("TÄIELIK TULEMUS:")
  178. print(json.dumps(results, default=str, indent=2, ensure_ascii=False)[:3000])