weaviate_export_import_clean.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. """
  2. Weaviate Collection Export/Import Utility
  3. Korduvkasutatav tööriist Weaviate kollektsioonide eksportimiseks ja importimiseks JSON backup failide kaudu.
  4. Toetab UUID normaliseerimist, int/float tüüpe, doc_hash, vigu ja batch operatsioone.
  5. """
  6. import datetime
  7. import json
  8. import uuid
  9. import logging
  10. from pathlib import Path
  11. from typing import Dict, List, Any, Optional, Union
  12. from weaviate import WeaviateClient, ConnectionParams
  13. from weaviate.classes.config import Property, DataType
  14. from decimal import Decimal
  15. import ijson
  16. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
  17. logger = logging.getLogger(__name__)
  18. class WeaviateExportImport:
  19. """Korduvkasutatav klass Weaviate kollektsioonide eksportimiseks ja importimiseks."""
  20. def __init__(self, src_client: Optional[WeaviateClient] = None, dst_client: Optional[WeaviateClient] = None):
  21. self.src_client = src_client
  22. self.dst_client = dst_client
  23. @staticmethod
  24. def create_client(host: str, http_port: int = 9020, grpc_port: int = 50051, secure: bool = False) -> WeaviateClient:
  25. client = WeaviateClient(connection_params=ConnectionParams.from_params(
  26. http_host=host,
  27. http_port=http_port,
  28. http_secure=secure,
  29. grpc_host=host,
  30. grpc_port=grpc_port,
  31. grpc_secure=secure,
  32. ))
  33. client.connect()
  34. logger.info(f"Ühendatud Weaviate'ga: {host}:{http_port}")
  35. return client
  36. def normalize_int_fields(self, props: Dict[str, Any], int_fields: List[str] = None) -> Dict[str, Any]:
  37. if int_fields is None:
  38. int_fields = ["page_start", "page_end", "chunk"]
  39. for field in int_fields:
  40. if field in props:
  41. value = props[field]
  42. # Kui väärtus on float, Decimal või int-tüüp
  43. if isinstance(value, float) and value.is_integer():
  44. props[field] = int(value)
  45. elif isinstance(value, Decimal):
  46. # kasuta kas int() või float(), olenevalt kontekstist
  47. props[field] = int(value) if value % 1 == 0 else float(value)
  48. return props
  49. def normalize_doc_hash(self, doc_hash: Any) -> str:
  50. if isinstance(doc_hash, uuid.UUID):
  51. return doc_hash.hex
  52. if isinstance(doc_hash, str) and len(doc_hash) == 36 and "-" in doc_hash:
  53. return doc_hash.replace("-", "")
  54. return str(doc_hash)
  55. def clean_uuid(self, obj: Any) -> Any:
  56. if isinstance(obj, dict):
  57. return {k: self.clean_uuid(v) for k, v in obj.items()}
  58. if isinstance(obj, (list, tuple)):
  59. return [self.clean_uuid(x) for x in obj]
  60. if isinstance(obj, uuid.UUID):
  61. return str(obj)
  62. if hasattr(obj, "__str__") and obj.__class__.__name__.lower().startswith("uuid"):
  63. return str(obj)
  64. return obj
  65. def process_properties(self, props: Dict[str, Any], int_fields: List[str] = None, hash_fields: List[str] = None) -> Dict[str, Any]:
  66. if hash_fields is None:
  67. hash_fields = ["doc_hash"]
  68. props = self.clean_uuid(props)
  69. props = self.normalize_int_fields(props, int_fields)
  70. for field in hash_fields:
  71. if field in props:
  72. props[field] = self.normalize_doc_hash(props[field])
  73. return props
  74. def export_collection(self, collection_name: str, output_file: Union[str, Path],
  75. int_fields: List[str] = None, hash_fields: List[str] = None,
  76. include_vectors: bool = True) -> int:
  77. if not self.src_client:
  78. raise ValueError("Source client pole määratud")
  79. logger.info(f"Alustan kollektsiooni '{collection_name}' streaming eksporti...")
  80. collection = self.src_client.collections.get(collection_name)
  81. output_path = Path(output_file)
  82. count = 0
  83. def custom_json_encoder(obj):
  84. if isinstance(obj, datetime.datetime):
  85. return obj.isoformat()
  86. return str(obj)
  87. # Kirjuta otse faili, mitte mällu
  88. with open(output_path, "w", encoding="utf-8") as f:
  89. f.write("[\n") # alusta JSON array
  90. first = True
  91. for item in collection.iterator(include_vector=include_vectors):
  92. props = self.process_properties(dict(item.properties),
  93. int_fields=int_fields,
  94. hash_fields=hash_fields)
  95. export_obj = {
  96. 'uuid': str(item.uuid),
  97. 'properties': props,
  98. }
  99. if include_vectors:
  100. export_obj['vector'] = item.vector
  101. # Kirjuta objekt otse faili
  102. if not first:
  103. f.write(",\n")
  104. json.dump(export_obj, f, ensure_ascii=False, default=custom_json_encoder)
  105. first = False
  106. count += 1
  107. # Progress log iga 1000 objekti järel
  108. if count % 1000 == 0:
  109. logger.info(f"Eksporditud: {count} objekti...")
  110. f.write("\n]") # lõpeta JSON array
  111. logger.info(f"Eksport valmis: {count} objekti")
  112. return count
  113. def clean_decimals(self, obj: Any) -> Any:
  114. '''Teisenda kõik Decimal objektid float-ideks'''
  115. if isinstance(obj, Decimal):
  116. return float(obj)
  117. if isinstance(obj, dict):
  118. return {k: self.clean_decimals(v) for k, v in obj.items()}
  119. if isinstance(obj, list):
  120. return [self.clean_decimals(item) for item in obj]
  121. return obj
  122. def import_collection(self, collection_name: str, input_file: Union[str, Path],
  123. int_fields: List[str] = None, batch_size: int = 100,
  124. recreate_collection: bool = False) -> int:
  125. if not self.dst_client:
  126. raise ValueError("Destination client pole määratud")
  127. logger.info(f"Alustan kollektsiooni '{collection_name}' streaming importi...")
  128. collection = self.dst_client.collections.get(collection_name)
  129. input_path = Path(input_file)
  130. imported_count = 0
  131. batch = []
  132. # ijson.items loeb faili osade kaupa, mitte kogu faili mällu
  133. with open(input_path, 'rb') as f:
  134. # ✅ PARANDUS 1: Lisa use_decimal=False
  135. for obj in ijson.items(f, 'item', use_float=True):
  136. try:
  137. props = obj["properties"]
  138. props = self.clean_decimals(props)
  139. if int_fields:
  140. props = self.normalize_int_fields(props, int_fields)
  141. # ✅ PARANDUS 2: clean_decimals ka vectorile
  142. vector = obj.get("vector")
  143. if vector is not None:
  144. vector = self.clean_decimals(vector)
  145. batch.append({
  146. 'uuid': str(obj["uuid"]),
  147. 'properties': props,
  148. 'vector': vector
  149. })
  150. # Kui batch täis, importi
  151. if len(batch) >= batch_size:
  152. self._import_batch(collection, batch)
  153. imported_count += len(batch)
  154. logger.info(f"Imporditud: {imported_count} objekti...")
  155. batch = []
  156. except Exception as e:
  157. logger.warning(f"Import error: {e}")
  158. # Importi viimane batch
  159. if batch:
  160. self._import_batch(collection, batch)
  161. imported_count += len(batch)
  162. logger.info(f"Import lõpetatud: {imported_count} objekti")
  163. return imported_count
  164. def _import_batch(self, collection, batch):
  165. '''Batch import helper'''
  166. for item in batch:
  167. try:
  168. collection.data.insert(
  169. properties=item['properties'],
  170. uuid=item['uuid'],
  171. vector=item.get('vector')
  172. )
  173. except Exception as e:
  174. if "already exists" not in str(e):
  175. logger.warning(f"Insert error: {e}")
  176. def close_clients(self):
  177. if self.src_client:
  178. self.src_client.close()
  179. logger.info("Source client suletud")
  180. if self.dst_client:
  181. self.dst_client.close()
  182. logger.info("Destination client suletud")