Commit 285b001c authored by zhengyaoqiu's avatar zhengyaoqiu

模块重构

parent ef587af1
......@@ -11,12 +11,12 @@ import logging
class FeatureExtractor:
__model = None
__preprocess = None
__logger = logging.getLogger(__name__)
# __model = None
# __preprocess = None
# __device = "ViT-B/32"
__instance = None
__lock = threading.Lock()
__device = "ViT-B/32"
__logger = logging.getLogger(__name__)
def __new__(cls, device = "cuda" if torch.cuda.is_available() else "cpu", model_name = "ViT-B/32"):
# 第一次检查 - 不带锁
......
from typing import List, Tuple, Optional, Union
import logging
import numpy as np
from PIL import Image
from pymilvus import Collection
# 引入特征提取器和MilvusClient
from feature_extractor import FeatureExtractor
from milvus import MilvusClient # 假设MilvusClient在milvus_client.py文件中
class ImageSearch:
"""
图像搜索类,提供基于图像相似度的搜索功能
"""
_logger = logging.getLogger(__name__)
_feature_extractor_initialized = False
_milvus_client = None
@classmethod
def _initialize_feature_extractor(cls) -> None:
"""初始化特征提取器"""
if not cls._feature_extractor_initialized:
try:
FeatureExtractor.initialize()
cls._feature_extractor_initialized = True
except Exception as e:
cls._logger.error(f"Failed to initialize feature extractor: {e}")
raise RuntimeError(f"Failed to initialize feature extractor: {e}")
@classmethod
def get_milvus_client(cls, host: str = "localhost", port: str = "19530",
collection_name: str = "image_collection") -> MilvusClient:
"""
获取或创建Milvus客户端
Args:
host: Milvus服务器地址
port: Milvus服务器端口
collection_name: 集合名称
Returns:
MilvusClient: Milvus客户端实例
"""
if cls._milvus_client is None:
cls._milvus_client = MilvusClient(host, port, collection_name).connect()
return cls._milvus_client
@classmethod
def get_collection(cls, host: str = "localhost", port: str = "19530",
collection_name: str = "image_collection") -> Collection:
"""
获取Milvus集合
Args:
host: Milvus服务器地址
port: Milvus服务器端口
collection_name: 集合名称
Returns:
Collection: Milvus集合对象
__logger = logging.getLogger(__name__)
Raises:
RuntimeError: 如果无法获取集合
"""
try:
client = cls.get_milvus_client(host, port, collection_name)
return client.get_collection()
except Exception as e:
cls._logger.error(f"Failed to get collection: {e}")
raise RuntimeError(f"Failed to get collection: {e}")
@classmethod
def extract_features(cls, image: Union[str, Image.Image]) -> np.ndarray:
"""
从图像提取特征向量
Args:
image: 图像URL或PIL图像对象
Returns:
np.ndarray: 特征向量
Raises:
ValueError: 如果特征提取失败
"""
# 确保特征提取器已初始化
cls._initialize_feature_extractor()
def __init__(self, feature_extractor, milvus):
self.feature_extractor = feature_extractor
self.milvus = milvus
try:
# 根据图像类型调用相应的提取方法
if isinstance(image, str):
features = FeatureExtractor.extract_from_url(image)
elif isinstance(image, Image.Image):
features = FeatureExtractor.extract_from_image(image)
else:
raise ValueError(f"Unsupported image type: {type(image)}")
if features is None:
raise ValueError("Feature extraction returned None")
return features
except Exception as e:
cls._logger.error(f"Feature extraction failed: {e}")
raise ValueError(f"Failed to extract features: {e}")
@classmethod
def image_to_image_search(
cls,
image: Union[str, Image.Image],
top_k: int = 100,
host: str = "localhost",
port: str = "19530",
collection_name: str = "image_collection"
) -> List[Tuple[str, float]]:
"""
使用图像查询相似图像
Args:
image: 查询图像的URL或PIL图像对象
top_k: 返回的最相似结果数量
host: Milvus服务器地址
port: Milvus服务器端口
collection_name: 集合名称
Returns:
List[Tuple[str, float]]: 产品ID和相似度分数的列表,按相似度降序排列
Raises:
ValueError: 如果图像处理或搜索过程中出错
"""
# id = product_id
def image_to_image_search(self, image, key_name, top_k = 100):
try:
# 提取查询图像的特征
query_embedding = cls.extract_features(image)
query_embedding = self.feature_extractor.extract_features(image)
# 获取Milvus客户端并搜索
client = cls.get_milvus_client(host, port, collection_name)
results = client.search(query_embedding, limit=top_k)
results = self.milvus.search(query_embedding, limit=top_k)
# 处理结果
if not results or len(results) == 0:
return []
# 返回结果
product_ids = [hit.entity.get('product_id') for hit in results[0]]
keys = [hit.entity.get(key_name) for hit in results[0]]
scores = [hit.score for hit in results[0]]
return list(zip(product_ids, scores))
return list(zip(keys, scores))
except Exception as e:
cls._logger.error(f"Image search failed: {e}")
self.__logger.error(f"Image search failed: {e}")
raise ValueError(f"Image search failed: {e}")
\ No newline at end of file
@classmethod
def batch_image_search(
cls,
images: List[Union[str, Image.Image]],
top_k: int = 100,
host: str = "localhost",
port: str = "19530",
collection_name: str = "image_collection"
) -> List[List[Tuple[str, float]]]:
"""
批量图像搜索
Args:
images: 查询图像URL或PIL图像对象的列表
top_k: 每个查询返回的最相似结果数量
host: Milvus服务器地址
port: Milvus服务器端口
collection_name: 集合名称
Returns:
List[List[Tuple[str, float]]]: 每个查询图像对应的结果列表
"""
# 确保特征提取器已初始化
cls._initialize_feature_extractor()
# 批量提取特征
batch_embeddings = []
valid_indices = []
for i, image in enumerate(images):
try:
features = cls.extract_features(image)
batch_embeddings.append(features)
valid_indices.append(i)
except Exception as e:
cls._logger.warning(f"Feature extraction failed for image at index {i}: {e}")
# 如果没有有效的特征,返回空列表
if not batch_embeddings:
return [[] for _ in range(len(images))]
try:
# 获取集合
collection = cls.get_collection(host, port, collection_name)
collection.load()
# 批量搜索
search_params = {"metric_type": "IP", "params": {"ef": 100}}
batch_results = collection.search(
data=batch_embeddings,
anns_field="embedding",
param=search_params,
limit=top_k,
output_fields=["image", "product_id"]
)
# 处理结果
all_results = [[] for _ in range(len(images))]
for i, results in enumerate(batch_results):
original_idx = valid_indices[i]
product_ids = [hit.entity.get('product_id') for hit in results]
scores = [hit.score for hit in results]
all_results[original_idx] = list(zip(product_ids, scores))
return all_results
except Exception as e:
cls._logger.error(f"Batch image search failed: {e}")
raise ValueError(f"Batch image search failed: {e}")
......@@ -2,111 +2,70 @@ from typing import Dict, List, Any, Optional, Union
import numpy as np
from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType, utility
# , collection_name: str = "image_collection"
class MilvusClient:
def __init__(self, host: str = "localhost", port: str = "19530", collection_name: str = "image_collection") -> None:
"""初始化Milvus客户端
Args:
host: Milvus服务器地址
port: Milvus服务器端口
collection_name: 集合名称
"""
def __init__(self, host = "localhost", port = "19530"):
self.host: str = host
self.port: str = port
self.collection_name: str = collection_name
self.collection: Optional[Collection] = None
def connect(self) -> 'MilvusClient':
"""连接到Milvus服务器
Returns:
MilvusClient: 当前客户端实例,支持链式调用
"""
connections.connect("default", host=self.host, port=self.port)
def connect(self, alias = "default") -> 'MilvusClient':
connections.connect(alias, host=self.host, port=self.port)
return self
def get_collection(self) -> Collection:
"""获取或创建集合
Returns:
Collection: Milvus集合对象
"""
# 定义集合结构
fields: List[FieldSchema] = [
FieldSchema(name="image", dtype=DataType.VARCHAR, max_length=256, is_primary=True, auto_id=False),
FieldSchema(name="product_id", dtype=DataType.VARCHAR, max_length=256),
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=512) # CLIP ViT-B/32的特征维度为512
]
schema: CollectionSchema = CollectionSchema(fields, "图像特征集合")
# 创建或获取集合
if utility.has_collection(self.collection_name):
self.collection = Collection(name=self.collection_name)
else:
self.collection = Collection(name=self.collection_name, schema=schema)
self._create_index()
return self.collection
def _create_index(self) -> None:
"""创建索引"""
index_params: Dict[str, Any] = {
"index_type": "HNSW",
"metric_type": "IP", # 内积相似度
"params": {"M": 16, "efConstruction": 200}
}
if self.collection is not None:
self.collection.create_index(field_name="embedding", index_params=index_params)
def search(self, vector: Union[List[float], np.ndarray], limit: int = 10) -> Any:
"""搜索相似向量
Args:
vector: 查询向量,可以是列表或numpy数组
limit: 返回结果数量
Returns:
查询结果
"""
if self.collection is None:
self.get_collection()
self.collection.load()
search_params: Dict[str, Any] = {"metric_type": "IP", "params": {"ef": 100}}
results = self.collection.search(
# # 定义集合结构
# fields: List[FieldSchema] = [
# FieldSchema(name="image", dtype=DataType.VARCHAR, max_length=256, is_primary=True, auto_id=False),
# FieldSchema(name="product_id", dtype=DataType.VARCHAR, max_length=256),
# FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=512) # CLIP ViT-B/32的特征维度为512
# ]
@staticmethod
def create_new_collection(collection_name, fields, description) -> Collection:
schema: CollectionSchema = CollectionSchema(fields, description)
# 创建集合
collection = Collection(name=collection_name, schema=schema)
return collection
@staticmethod
def get_collection(collection_name):
if not utility.has_collection(collection_name):
raise RuntimeError(f"集合 '{collection_name}' 不存在")
return Collection(name=collection_name)
# index_params: Dict[str, Any] = {
# "index_type": "HNSW",
# "metric_type": "IP", # 内积相似度
# "params": {"M": 16, "efConstruction": 200}
# }
# def create_index(self, index_params) -> None:
# if self.collection is not None:
# self.collection.create_index(field_name="embedding", index_params=index_params)
# anns_field = embedding
# output_fields=["image", "product_id"]
# search_params: Dict[str, Any] = {"metric_type": "IP", "params": {"ef": 100}}
def search(self, collection_name, vector, anns_field, search_params, output_fields, top_k = 10) -> Any:
collection = self.get_collection(collection_name)
results = collection.search(
data=[vector],
anns_field="embedding",
anns_field=anns_field,
param=search_params,
limit=limit,
output_fields=["image", "product_id"]
limit=top_k,
output_fields=output_fields
)
return results
def insert(self, data: List[Dict[str, Any]]) -> None:
"""插入数据
Args:
data: 包含image、product_id和embedding的字典列表
每个字典应包含键: "image", "product_id", "embedding"
"""
if self.collection is None:
self.get_collection()
entities: List[List[Any]] = [
[item["image"] for item in data],
[item["product_id"] for item in data],
[item["embedding"] for item in data]
]
self.collection.insert(entities)
def drop_collection(self) -> None:
"""删除集合"""
if utility.has_collection(self.collection_name):
utility.drop_collection(self.collection_name)
self.collection = None
def close(self) -> None:
"""关闭连接"""
connections.disconnect("default")
# entities: List[List[Any]] = [
# [item["image"] for item in data],
# [item["product_id"] for item in data],
# [item["embedding"] for item in data]
# ]
def insert(self, collection_name, entities):
self.get_collection(collection_name).insert(entities)
@staticmethod
def close(alias ="default") -> None:
connections.disconnect(alias)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment