Commit 285b001c authored by zhengyaoqiu's avatar zhengyaoqiu

模块重构

parent ef587af1
...@@ -11,12 +11,12 @@ import logging ...@@ -11,12 +11,12 @@ import logging
class FeatureExtractor: class FeatureExtractor:
__model = None # __model = None
__preprocess = None # __preprocess = None
__logger = logging.getLogger(__name__) # __device = "ViT-B/32"
__instance = None __instance = None
__lock = threading.Lock() __lock = threading.Lock()
__device = "ViT-B/32" __logger = logging.getLogger(__name__)
def __new__(cls, device = "cuda" if torch.cuda.is_available() else "cpu", model_name = "ViT-B/32"): def __new__(cls, device = "cuda" if torch.cuda.is_available() else "cpu", model_name = "ViT-B/32"):
# 第一次检查 - 不带锁 # 第一次检查 - 不带锁
......
from typing import List, Tuple, Optional, Union
import logging import logging
import numpy as np
from PIL import Image
from pymilvus import Collection
# 引入特征提取器和MilvusClient
from feature_extractor import FeatureExtractor
from milvus import MilvusClient # 假设MilvusClient在milvus_client.py文件中
class ImageSearch: class ImageSearch:
""" __logger = logging.getLogger(__name__)
图像搜索类,提供基于图像相似度的搜索功能
"""
_logger = logging.getLogger(__name__)
_feature_extractor_initialized = False
_milvus_client = None
@classmethod
def _initialize_feature_extractor(cls) -> None:
"""初始化特征提取器"""
if not cls._feature_extractor_initialized:
try:
FeatureExtractor.initialize()
cls._feature_extractor_initialized = True
except Exception as e:
cls._logger.error(f"Failed to initialize feature extractor: {e}")
raise RuntimeError(f"Failed to initialize feature extractor: {e}")
@classmethod
def get_milvus_client(cls, host: str = "localhost", port: str = "19530",
collection_name: str = "image_collection") -> MilvusClient:
"""
获取或创建Milvus客户端
Args:
host: Milvus服务器地址
port: Milvus服务器端口
collection_name: 集合名称
Returns:
MilvusClient: Milvus客户端实例
"""
if cls._milvus_client is None:
cls._milvus_client = MilvusClient(host, port, collection_name).connect()
return cls._milvus_client
@classmethod
def get_collection(cls, host: str = "localhost", port: str = "19530",
collection_name: str = "image_collection") -> Collection:
"""
获取Milvus集合
Args:
host: Milvus服务器地址
port: Milvus服务器端口
collection_name: 集合名称
Returns:
Collection: Milvus集合对象
Raises: def __init__(self, feature_extractor, milvus):
RuntimeError: 如果无法获取集合 self.feature_extractor = feature_extractor
""" self.milvus = milvus
try:
client = cls.get_milvus_client(host, port, collection_name)
return client.get_collection()
except Exception as e:
cls._logger.error(f"Failed to get collection: {e}")
raise RuntimeError(f"Failed to get collection: {e}")
@classmethod
def extract_features(cls, image: Union[str, Image.Image]) -> np.ndarray:
"""
从图像提取特征向量
Args:
image: 图像URL或PIL图像对象
Returns:
np.ndarray: 特征向量
Raises:
ValueError: 如果特征提取失败
"""
# 确保特征提取器已初始化
cls._initialize_feature_extractor()
try: # id = product_id
# 根据图像类型调用相应的提取方法 def image_to_image_search(self, image, key_name, top_k = 100):
if isinstance(image, str):
features = FeatureExtractor.extract_from_url(image)
elif isinstance(image, Image.Image):
features = FeatureExtractor.extract_from_image(image)
else:
raise ValueError(f"Unsupported image type: {type(image)}")
if features is None:
raise ValueError("Feature extraction returned None")
return features
except Exception as e:
cls._logger.error(f"Feature extraction failed: {e}")
raise ValueError(f"Failed to extract features: {e}")
@classmethod
def image_to_image_search(
cls,
image: Union[str, Image.Image],
top_k: int = 100,
host: str = "localhost",
port: str = "19530",
collection_name: str = "image_collection"
) -> List[Tuple[str, float]]:
"""
使用图像查询相似图像
Args:
image: 查询图像的URL或PIL图像对象
top_k: 返回的最相似结果数量
host: Milvus服务器地址
port: Milvus服务器端口
collection_name: 集合名称
Returns:
List[Tuple[str, float]]: 产品ID和相似度分数的列表,按相似度降序排列
Raises:
ValueError: 如果图像处理或搜索过程中出错
"""
try: try:
# 提取查询图像的特征 # 提取查询图像的特征
query_embedding = cls.extract_features(image) query_embedding = self.feature_extractor.extract_features(image)
# 获取Milvus客户端并搜索 results = self.milvus.search(query_embedding, limit=top_k)
client = cls.get_milvus_client(host, port, collection_name)
results = client.search(query_embedding, limit=top_k)
# 处理结果 # 处理结果
if not results or len(results) == 0: if not results or len(results) == 0:
return [] return []
# 返回结果 # 返回结果
product_ids = [hit.entity.get('product_id') for hit in results[0]] keys = [hit.entity.get(key_name) for hit in results[0]]
scores = [hit.score for hit in results[0]] scores = [hit.score for hit in results[0]]
return list(zip(product_ids, scores)) return list(zip(keys, scores))
except Exception as e: except Exception as e:
cls._logger.error(f"Image search failed: {e}") self.__logger.error(f"Image search failed: {e}")
raise ValueError(f"Image search failed: {e}") raise ValueError(f"Image search failed: {e}")
\ No newline at end of file
@classmethod
def batch_image_search(
cls,
images: List[Union[str, Image.Image]],
top_k: int = 100,
host: str = "localhost",
port: str = "19530",
collection_name: str = "image_collection"
) -> List[List[Tuple[str, float]]]:
"""
批量图像搜索
Args:
images: 查询图像URL或PIL图像对象的列表
top_k: 每个查询返回的最相似结果数量
host: Milvus服务器地址
port: Milvus服务器端口
collection_name: 集合名称
Returns:
List[List[Tuple[str, float]]]: 每个查询图像对应的结果列表
"""
# 确保特征提取器已初始化
cls._initialize_feature_extractor()
# 批量提取特征
batch_embeddings = []
valid_indices = []
for i, image in enumerate(images):
try:
features = cls.extract_features(image)
batch_embeddings.append(features)
valid_indices.append(i)
except Exception as e:
cls._logger.warning(f"Feature extraction failed for image at index {i}: {e}")
# 如果没有有效的特征,返回空列表
if not batch_embeddings:
return [[] for _ in range(len(images))]
try:
# 获取集合
collection = cls.get_collection(host, port, collection_name)
collection.load()
# 批量搜索
search_params = {"metric_type": "IP", "params": {"ef": 100}}
batch_results = collection.search(
data=batch_embeddings,
anns_field="embedding",
param=search_params,
limit=top_k,
output_fields=["image", "product_id"]
)
# 处理结果
all_results = [[] for _ in range(len(images))]
for i, results in enumerate(batch_results):
original_idx = valid_indices[i]
product_ids = [hit.entity.get('product_id') for hit in results]
scores = [hit.score for hit in results]
all_results[original_idx] = list(zip(product_ids, scores))
return all_results
except Exception as e:
cls._logger.error(f"Batch image search failed: {e}")
raise ValueError(f"Batch image search failed: {e}")
...@@ -2,111 +2,70 @@ from typing import Dict, List, Any, Optional, Union ...@@ -2,111 +2,70 @@ from typing import Dict, List, Any, Optional, Union
import numpy as np import numpy as np
from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType, utility from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType, utility
# , collection_name: str = "image_collection"
class MilvusClient: class MilvusClient:
def __init__(self, host: str = "localhost", port: str = "19530", collection_name: str = "image_collection") -> None: def __init__(self, host = "localhost", port = "19530"):
"""初始化Milvus客户端
Args:
host: Milvus服务器地址
port: Milvus服务器端口
collection_name: 集合名称
"""
self.host: str = host self.host: str = host
self.port: str = port self.port: str = port
self.collection_name: str = collection_name
self.collection: Optional[Collection] = None
def connect(self) -> 'MilvusClient': def connect(self, alias = "default") -> 'MilvusClient':
"""连接到Milvus服务器 connections.connect(alias, host=self.host, port=self.port)
Returns:
MilvusClient: 当前客户端实例,支持链式调用
"""
connections.connect("default", host=self.host, port=self.port)
return self return self
def get_collection(self) -> Collection: # # 定义集合结构
"""获取或创建集合 # fields: List[FieldSchema] = [
# FieldSchema(name="image", dtype=DataType.VARCHAR, max_length=256, is_primary=True, auto_id=False),
Returns: # FieldSchema(name="product_id", dtype=DataType.VARCHAR, max_length=256),
Collection: Milvus集合对象 # FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=512) # CLIP ViT-B/32的特征维度为512
""" # ]
# 定义集合结构 @staticmethod
fields: List[FieldSchema] = [ def create_new_collection(collection_name, fields, description) -> Collection:
FieldSchema(name="image", dtype=DataType.VARCHAR, max_length=256, is_primary=True, auto_id=False), schema: CollectionSchema = CollectionSchema(fields, description)
FieldSchema(name="product_id", dtype=DataType.VARCHAR, max_length=256),
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=512) # CLIP ViT-B/32的特征维度为512 # 创建集合
] collection = Collection(name=collection_name, schema=schema)
schema: CollectionSchema = CollectionSchema(fields, "图像特征集合")
return collection
# 创建或获取集合
if utility.has_collection(self.collection_name): @staticmethod
self.collection = Collection(name=self.collection_name) def get_collection(collection_name):
else: if not utility.has_collection(collection_name):
self.collection = Collection(name=self.collection_name, schema=schema) raise RuntimeError(f"集合 '{collection_name}' 不存在")
self._create_index() return Collection(name=collection_name)
return self.collection # index_params: Dict[str, Any] = {
# "index_type": "HNSW",
def _create_index(self) -> None: # "metric_type": "IP", # 内积相似度
"""创建索引""" # "params": {"M": 16, "efConstruction": 200}
index_params: Dict[str, Any] = { # }
"index_type": "HNSW", # def create_index(self, index_params) -> None:
"metric_type": "IP", # 内积相似度 # if self.collection is not None:
"params": {"M": 16, "efConstruction": 200} # self.collection.create_index(field_name="embedding", index_params=index_params)
}
if self.collection is not None: # anns_field = embedding
self.collection.create_index(field_name="embedding", index_params=index_params) # output_fields=["image", "product_id"]
# search_params: Dict[str, Any] = {"metric_type": "IP", "params": {"ef": 100}}
def search(self, vector: Union[List[float], np.ndarray], limit: int = 10) -> Any:
"""搜索相似向量 def search(self, collection_name, vector, anns_field, search_params, output_fields, top_k = 10) -> Any:
collection = self.get_collection(collection_name)
Args:
vector: 查询向量,可以是列表或numpy数组 results = collection.search(
limit: 返回结果数量
Returns:
查询结果
"""
if self.collection is None:
self.get_collection()
self.collection.load()
search_params: Dict[str, Any] = {"metric_type": "IP", "params": {"ef": 100}}
results = self.collection.search(
data=[vector], data=[vector],
anns_field="embedding", anns_field=anns_field,
param=search_params, param=search_params,
limit=limit, limit=top_k,
output_fields=["image", "product_id"] output_fields=output_fields
) )
return results return results
def insert(self, data: List[Dict[str, Any]]) -> None: # entities: List[List[Any]] = [
"""插入数据 # [item["image"] for item in data],
# [item["product_id"] for item in data],
Args: # [item["embedding"] for item in data]
data: 包含image、product_id和embedding的字典列表 # ]
每个字典应包含键: "image", "product_id", "embedding" def insert(self, collection_name, entities):
""" self.get_collection(collection_name).insert(entities)
if self.collection is None:
self.get_collection() @staticmethod
def close(alias ="default") -> None:
entities: List[List[Any]] = [ connections.disconnect(alias)
[item["image"] for item in data],
[item["product_id"] for item in data],
[item["embedding"] for item in data]
]
self.collection.insert(entities)
def drop_collection(self) -> None:
"""删除集合"""
if utility.has_collection(self.collection_name):
utility.drop_collection(self.collection_name)
self.collection = None
def close(self) -> None:
"""关闭连接"""
connections.disconnect("default")
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment