Commit ef587af1 authored by zhengyaoqiu's avatar zhengyaoqiu

特征提取模块重构

parent 9858b763
import threading
import torch
import clip
import requests
from PIL import Image
from io import BytesIO
from typing import Optional, Tuple
import numpy as np
import logging
class FeatureExtractor:
"""
使用CLIP模型提取图像特征向量的工具类
"""
# 类变量,用于存储模型实例
_model = None
_preprocess = None
_device = None
@classmethod
def initialize(cls, model_name="ViT-B/32"):
"""
初始化CLIP模型
Args:
model_name (str): CLIP模型名称
"""
cls._device = "cuda" if torch.cuda.is_available() else "cpu"
cls._model, cls._preprocess = clip.load(model_name, device=cls._device)
print(f"CLIP model {model_name} loaded on {cls._device}")
class FeatureExtractor:
__model = None
__preprocess = None
__logger = logging.getLogger(__name__)
__instance = None
__lock = threading.Lock()
__device = "ViT-B/32"
def __new__(cls, device = "cuda" if torch.cuda.is_available() else "cpu", model_name = "ViT-B/32"):
# 第一次检查 - 不带锁
if cls.__instance is None:
# 只有在可能需要创建实例时才获取锁
with cls.__lock:
# 第二次检查 - 带锁
if cls.__instance is None:
print(f"创建并初始化 CLIP 模型: {model_name} 在设备: {device}")
# 创建实例
cls.__instance = super().__new__(cls)
# 在这里直接完成初始化
cls.__instance.__model, cls.__instance.__preprocess = clip.load(model_name, device=device)
cls.__instance.__device = device
return cls.__instance
@staticmethod
def resize_with_padding(img):
def resize_with_padding(img, target_size = (224, 224)):
"""
调整图像大小,保持纵横比并添加填充
Args:
img (Image.Image): 输入图像
img: 输入图像
target_size: 目标尺寸,默认为(224, 224)
Returns:
PIL.Image: 调整大小后的图像
调整大小后的图像
"""
target_size = (224, 224)
# 计算调整大小的比例
ratio = min(target_size[0] / img.width, target_size[1] / img.height)
new_size = (int(img.width * ratio), int(img.height * ratio))
......@@ -58,62 +64,59 @@ class FeatureExtractor:
return new_img
@classmethod
def extract_from_url(cls, image_url):
def extract_from_url(self, image_url):
"""
从URL加载图像并提取特征向量
Args:
image_url (str): 图像URL
image_url: 图像URL
Returns:
numpy.ndarray: 特征向量
特征向量,如果提取失败则返回None
"""
if cls._model is None:
cls.initialize()
try:
# 下载图片
response = requests.get(image_url, stream=True)
response = requests.get(image_url, stream=True, timeout=10)
response.raise_for_status() # 确保请求成功
# 将图片数据转换为 PIL Image 对象
image = Image.open(BytesIO(response.content)).convert("RGB")
return cls.extract_from_image(image)
return self.extract_from_image(image)
except requests.RequestException as e:
self.__logger.error(f"Network error when downloading image from {image_url}: {e}")
return None
except Exception as e:
print(f"Error extracting features from URL: {e}")
self.__logger.error(f"Error extracting features from URL {image_url}: {e}")
return None
@classmethod
def extract_from_image(cls, img):
def extract_from_image(self, img):
"""
从PIL图像对象提取特征向量
Args:
img (Image.Image): 输入图像
img: 输入图像
Returns:
numpy.ndarray: 特征向量
特征向量,如果提取失败则返回None
"""
if cls._model is None:
cls.initialize()
try:
# 调整图像大小并添加填充
image = cls.resize_with_padding(img)
image = self.resize_with_padding(img)
# 预处理并提取特征
image_tensor = cls._preprocess(image).unsqueeze(0).to(cls._device)
image_tensor = self.__preprocess(image).unsqueeze(0).to(self.__device)
with torch.no_grad():
image_features = cls._model.encode_image(image_tensor)
image_features = self.__model.encode_image(image_tensor)
# 归一化特征向量
image_features /= image_features.norm(dim=-1, keepdim=True)
return image_features.cpu().numpy().flatten()
except Exception as e:
print(f"Error extracting features from image: {e}")
self.__logger.error(f"Error extracting features from image: {e}")
return None
\ No newline at end of file
from typing import List, Tuple, Optional, Union
import logging
import numpy as np
from PIL import Image
from pymilvus import Collection
# 引入特征提取器和MilvusClient
from feature_extractor import FeatureExtractor
from milvus import MilvusClient # 假设MilvusClient在milvus_client.py文件中
class ImageSearch:
"""
图像搜索类,提供基于图像相似度的搜索功能
"""
_logger = logging.getLogger(__name__)
_feature_extractor_initialized = False
_milvus_client = None
@classmethod
def _initialize_feature_extractor(cls) -> None:
"""初始化特征提取器"""
if not cls._feature_extractor_initialized:
try:
FeatureExtractor.initialize()
cls._feature_extractor_initialized = True
except Exception as e:
cls._logger.error(f"Failed to initialize feature extractor: {e}")
raise RuntimeError(f"Failed to initialize feature extractor: {e}")
@classmethod
def get_milvus_client(cls, host: str = "localhost", port: str = "19530",
collection_name: str = "image_collection") -> MilvusClient:
"""
获取或创建Milvus客户端
Args:
host: Milvus服务器地址
port: Milvus服务器端口
collection_name: 集合名称
Returns:
MilvusClient: Milvus客户端实例
"""
if cls._milvus_client is None:
cls._milvus_client = MilvusClient(host, port, collection_name).connect()
return cls._milvus_client
@classmethod
def get_collection(cls, host: str = "localhost", port: str = "19530",
collection_name: str = "image_collection") -> Collection:
"""
获取Milvus集合
Args:
host: Milvus服务器地址
port: Milvus服务器端口
collection_name: 集合名称
Returns:
Collection: Milvus集合对象
Raises:
RuntimeError: 如果无法获取集合
"""
try:
client = cls.get_milvus_client(host, port, collection_name)
return client.get_collection()
except Exception as e:
cls._logger.error(f"Failed to get collection: {e}")
raise RuntimeError(f"Failed to get collection: {e}")
@classmethod
def extract_features(cls, image: Union[str, Image.Image]) -> np.ndarray:
"""
从图像提取特征向量
Args:
image: 图像URL或PIL图像对象
Returns:
np.ndarray: 特征向量
Raises:
ValueError: 如果特征提取失败
"""
# 确保特征提取器已初始化
cls._initialize_feature_extractor()
try:
# 根据图像类型调用相应的提取方法
if isinstance(image, str):
features = FeatureExtractor.extract_from_url(image)
elif isinstance(image, Image.Image):
features = FeatureExtractor.extract_from_image(image)
else:
raise ValueError(f"Unsupported image type: {type(image)}")
if features is None:
raise ValueError("Feature extraction returned None")
return features
except Exception as e:
cls._logger.error(f"Feature extraction failed: {e}")
raise ValueError(f"Failed to extract features: {e}")
@classmethod
def image_to_image_search(
cls,
image: Union[str, Image.Image],
top_k: int = 100,
host: str = "localhost",
port: str = "19530",
collection_name: str = "image_collection"
) -> List[Tuple[str, float]]:
"""
使用图像查询相似图像
Args:
image: 查询图像的URL或PIL图像对象
top_k: 返回的最相似结果数量
host: Milvus服务器地址
port: Milvus服务器端口
collection_name: 集合名称
Returns:
List[Tuple[str, float]]: 产品ID和相似度分数的列表,按相似度降序排列
Raises:
ValueError: 如果图像处理或搜索过程中出错
"""
try:
# 提取查询图像的特征
query_embedding = cls.extract_features(image)
# 获取Milvus客户端并搜索
client = cls.get_milvus_client(host, port, collection_name)
results = client.search(query_embedding, limit=top_k)
# 处理结果
if not results or len(results) == 0:
return []
# 返回结果
product_ids = [hit.entity.get('product_id') for hit in results[0]]
scores = [hit.score for hit in results[0]]
return list(zip(product_ids, scores))
except Exception as e:
cls._logger.error(f"Image search failed: {e}")
raise ValueError(f"Image search failed: {e}")
@classmethod
def batch_image_search(
cls,
images: List[Union[str, Image.Image]],
top_k: int = 100,
host: str = "localhost",
port: str = "19530",
collection_name: str = "image_collection"
) -> List[List[Tuple[str, float]]]:
"""
批量图像搜索
Args:
images: 查询图像URL或PIL图像对象的列表
top_k: 每个查询返回的最相似结果数量
host: Milvus服务器地址
port: Milvus服务器端口
collection_name: 集合名称
Returns:
List[List[Tuple[str, float]]]: 每个查询图像对应的结果列表
"""
# 确保特征提取器已初始化
cls._initialize_feature_extractor()
# 批量提取特征
batch_embeddings = []
valid_indices = []
for i, image in enumerate(images):
try:
features = cls.extract_features(image)
batch_embeddings.append(features)
valid_indices.append(i)
except Exception as e:
cls._logger.warning(f"Feature extraction failed for image at index {i}: {e}")
# 如果没有有效的特征,返回空列表
if not batch_embeddings:
return [[] for _ in range(len(images))]
try:
# 获取集合
collection = cls.get_collection(host, port, collection_name)
collection.load()
# 批量搜索
search_params = {"metric_type": "IP", "params": {"ef": 100}}
batch_results = collection.search(
data=batch_embeddings,
anns_field="embedding",
param=search_params,
limit=top_k,
output_fields=["image", "product_id"]
)
# 处理结果
all_results = [[] for _ in range(len(images))]
for i, results in enumerate(batch_results):
original_idx = valid_indices[i]
product_ids = [hit.entity.get('product_id') for hit in results]
scores = [hit.score for hit in results]
all_results[original_idx] = list(zip(product_ids, scores))
return all_results
except Exception as e:
cls._logger.error(f"Batch image search failed: {e}")
raise ValueError(f"Batch image search failed: {e}")
from typing import Dict, List, Any, Optional, Union
import numpy as np
from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType, utility
class MilvusClient:
def __init__(self, host: str = "localhost", port: str = "19530", collection_name: str = "image_collection") -> None:
"""初始化Milvus客户端
Args:
host: Milvus服务器地址
port: Milvus服务器端口
collection_name: 集合名称
"""
self.host: str = host
self.port: str = port
self.collection_name: str = collection_name
self.collection: Optional[Collection] = None
def connect(self) -> 'MilvusClient':
"""连接到Milvus服务器
Returns:
MilvusClient: 当前客户端实例,支持链式调用
"""
connections.connect("default", host=self.host, port=self.port)
return self
def get_collection(self) -> Collection:
"""获取或创建集合
Returns:
Collection: Milvus集合对象
"""
# 定义集合结构
fields: List[FieldSchema] = [
FieldSchema(name="image", dtype=DataType.VARCHAR, max_length=256, is_primary=True, auto_id=False),
FieldSchema(name="product_id", dtype=DataType.VARCHAR, max_length=256),
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=512) # CLIP ViT-B/32的特征维度为512
]
schema: CollectionSchema = CollectionSchema(fields, "图像特征集合")
# 创建或获取集合
if utility.has_collection(self.collection_name):
self.collection = Collection(name=self.collection_name)
else:
self.collection = Collection(name=self.collection_name, schema=schema)
self._create_index()
return self.collection
def _create_index(self) -> None:
"""创建索引"""
index_params: Dict[str, Any] = {
"index_type": "HNSW",
"metric_type": "IP", # 内积相似度
"params": {"M": 16, "efConstruction": 200}
}
if self.collection is not None:
self.collection.create_index(field_name="embedding", index_params=index_params)
def search(self, vector: Union[List[float], np.ndarray], limit: int = 10) -> Any:
"""搜索相似向量
Args:
vector: 查询向量,可以是列表或numpy数组
limit: 返回结果数量
Returns:
查询结果
"""
if self.collection is None:
self.get_collection()
self.collection.load()
search_params: Dict[str, Any] = {"metric_type": "IP", "params": {"ef": 100}}
results = self.collection.search(
data=[vector],
anns_field="embedding",
param=search_params,
limit=limit,
output_fields=["image", "product_id"]
)
return results
def insert(self, data: List[Dict[str, Any]]) -> None:
"""插入数据
Args:
data: 包含image、product_id和embedding的字典列表
每个字典应包含键: "image", "product_id", "embedding"
"""
if self.collection is None:
self.get_collection()
entities: List[List[Any]] = [
[item["image"] for item in data],
[item["product_id"] for item in data],
[item["embedding"] for item in data]
]
self.collection.insert(entities)
def drop_collection(self) -> None:
"""删除集合"""
if utility.has_collection(self.collection_name):
utility.drop_collection(self.collection_name)
self.collection = None
def close(self) -> None:
"""关闭连接"""
connections.disconnect("default")
......@@ -33,3 +33,5 @@ urllib3~=2.4.0
fsspec~=2025.3.2
ujson~=5.10.0
pandas~=2.2.3
pymilvus~=2.5.9
typing_extensions~=4.13.2
\ No newline at end of file
......@@ -6,7 +6,7 @@ from app.models.feature_extractor import FeatureExtractor
class TestFeatureExtractorFunction(unittest.TestCase):
def test_feature_extractor(self):
url = "https://pc3oscdn.chillcy.com/3359847025/QSIiPR0XExYACM/00f9bdfa63158ec9477e4f7fe70f5989.jpg"
feature = FeatureExtractor.extract_from_url(url)
feature = FeatureExtractor().extract_from_url(url)
print(feature)
if __name__ == '__main__':
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment