Commit 64288be5 authored by zhengyaoqiu's avatar zhengyaoqiu

优化

parent 3e9d8678
from flask import jsonify, request from flask import jsonify, request
from app.api import api_bp from app.api import api_bp
from app.services.feature_extractor import FeatureExtractor from app.services.feature_extractor import get_feature_extractor
from app.services.image_search import ImageSearch from app.services.image_search import ImageSearch
from app.services.image_upload import ImageUpload from app.services.image_upload import ImageUpload
from app.services.milvus import MilvusClient from app.services.milvus import MilvusClient
...@@ -30,9 +30,8 @@ def upload(): ...@@ -30,9 +30,8 @@ def upload():
bucket = data.get('bucket') bucket = data.get('bucket')
image2keys = data.get('image2keys') image2keys = data.get('image2keys')
feature_extractor = FeatureExtractor()
milvus = MilvusClient().connect() milvus = MilvusClient().connect()
image_upload = ImageUpload(feature_extractor, milvus) image_upload = ImageUpload(get_feature_extractor(), milvus)
image_upload.upload_many(bucket, image2keys) image_upload.upload_many(bucket, image2keys)
return jsonify({ return jsonify({
'code': 0, 'code': 0,
...@@ -45,10 +44,9 @@ def search(): ...@@ -45,10 +44,9 @@ def search():
top_k = request.args.get("top_k", type=int) top_k = request.args.get("top_k", type=int)
bucket = request.args.get("bucket") bucket = request.args.get("bucket")
feature_extractor = FeatureExtractor()
milvus = MilvusClient().connect() milvus = MilvusClient().connect()
result = ImageSearch(feature_extractor, milvus).image_to_image_search(bucket, image, top_k) result = ImageSearch(get_feature_extractor(), milvus).image_to_image_search(bucket, image, top_k)
return jsonify({ return jsonify({
'code': 0, 'code': 0,
......
import threading import numpy as np
import torch import torch
import clip import clip
import requests import requests
from PIL import Image from PIL import Image
from io import BytesIO from io import BytesIO
from typing import Optional, Tuple
import numpy as np
import logging import logging
class FeatureExtractor: class FeatureExtractor:
# __model = None
# __preprocess = None
# __device = "ViT-B/32"
__instance = None
__lock = threading.Lock()
__logger = logging.getLogger(__name__) __logger = logging.getLogger(__name__)
def __new__(cls, device = "cuda" if torch.cuda.is_available() else "cpu", model_name = "ViT-B/32"): device = "cuda" if torch.cuda.is_available() else "cpu"
# 第一次检查 - 不带锁 def __init__(self, device = "xpu" if torch.xpu.is_available() else "cpu", model_name = "ViT-B/32"):
if cls.__instance is None: device = "cpu"
# 只有在可能需要创建实例时才获取锁 self.model, self.preprocess = self.init_model(device, model_name)
with cls.__lock: self.device = device
# 第二次检查 - 带锁
if cls.__instance is None: @staticmethod
def init_model(device="xpu" if torch.xpu.is_available() else "cpu", model_name="ViT-B/32"):
torch.xpu.empty_cache()
print(f"创建并初始化 CLIP 模型: {model_name} 在设备: {device}") print(f"创建并初始化 CLIP 模型: {model_name} 在设备: {device}")
# 创建实例 model, preprocess = clip.load(model_name, device=device)
cls.__instance = super().__new__(cls) return model, preprocess
# 在这里直接完成初始化
cls.__instance.__model, cls.__instance.__preprocess = clip.load(model_name, device=device)
cls.__instance.__device = device
return cls.__instance
@staticmethod @staticmethod
def resize_with_padding(img, target_size = (224, 224)): def resize_with_padding(img, target_size = (224, 224)):
...@@ -103,20 +92,35 @@ class FeatureExtractor: ...@@ -103,20 +92,35 @@ class FeatureExtractor:
特征向量,如果提取失败则返回None 特征向量,如果提取失败则返回None
""" """
device = self.device
model = self.model
preprocess = self.preprocess
# device = "xpu" if torch.xpu.is_available() else "cpu"
# device = "cpu"
# model_name = "ViT-B/32"
# model, preprocess = self.init_model(device, model_name)
try: try:
# 调整图像大小并添加填充 # 调整图像大小并添加填充
image = self.resize_with_padding(img) image = self.resize_with_padding(img)
# 预处理并提取特征 # 预处理并提取特征
image_tensor = self.__preprocess(image).unsqueeze(0).to(self.__device) image_tensor = preprocess(image).unsqueeze(0).to(device)
with torch.no_grad(): with torch.no_grad():
image_features = self.__model.encode_image(image_tensor) image_features = model.encode_image(image_tensor)
# 归一化特征向量 # 归一化特征向量
image_features /= image_features.norm(dim=-1, keepdim=True) image_features /= image_features.norm(dim=-1, keepdim=True)
return image_features.cpu().numpy().flatten() return image_features.cpu().numpy().astype(np.float32).flatten()
except Exception as e: except Exception as e:
self.__logger.error(f"Error extracting features from image: {e}") self.__logger.error(f"Error extracting features from image: {e}")
return None return None
def get_feature_extractor():
return feature_extractor
feature_extractor = FeatureExtractor()
\ No newline at end of file
...@@ -11,11 +11,32 @@ class ImageUpload: ...@@ -11,11 +11,32 @@ class ImageUpload:
self.upload_many(bucket, {image: key}) self.upload_many(bucket, {image: key})
def upload_many(self, bucket, image2keys): def upload_many(self, bucket, image2keys):
images = []
keys = [] keys = []
vectors = [] vectors = []
for image2key in image2keys: images = [image2key["image"] for image2key in image2keys]
new_images, exist_images = self.milvus.filter_new_urls(bucket, images)
images = []
print(f"总图片数: {len(images)}")
print(f"新图片数: {len(new_images)}")
print(f"已存在图片数: {len(exist_images)}")
print(f"新图片: {new_images}")
print(f"已存在图片: {exist_images}")
# 将已存在的图片转换为集合,提高查找效率
exist_images_set = set(exist_images)
# 过滤掉已存在的图片
filtered_image2keys = [
image2key for image2key in image2keys
if image2key["image"] not in exist_images_set
]
if len(filtered_image2keys) == 0:
return
for image2key in filtered_image2keys:
image = image2key["image"] image = image2key["image"]
key = image2key["key"] key = image2key["key"]
vector = self.feature_extractor.extract_from_url(image) vector = self.feature_extractor.extract_from_url(image)
......
from typing import Dict, List, Any, Optional, Union from typing import Dict, List, Any, Optional, Union, Tuple
import numpy as np import numpy as np
from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType, utility from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType, utility
...@@ -62,8 +62,41 @@ class MilvusClient: ...@@ -62,8 +62,41 @@ class MilvusClient:
# [item["embedding"] for item in data] # [item["embedding"] for item in data]
# ] # ]
def insert(self, collection_name, entities): def insert(self, collection_name, entities):
self.get_collection(collection_name).insert(entities) collection = self.get_collection(collection_name)
collection.insert(entities)
def filter_new_urls(self, collection_name: str, urls: List[str]) -> Tuple[List[str], List[str]]:
if not urls:
return [], []
collection = self.get_collection(collection_name)
# 构建查询表达式,检查image字段是否在URL列表中
# 使用 in 操作符进行批量查询
url_str_list = [f'"{url}"' for url in urls] # 为每个URL添加引号
expr = f"image in [{','.join(url_str_list)}]"
try:
# 执行查询
results = collection.query(
expr=expr,
output_fields=["image"]
)
# 提取已存在的URL
existing_urls = [result["image"] for result in results]
# 计算不存在的URL
new_urls = [url for url in urls if url not in existing_urls]
return new_urls, existing_urls
except Exception as e:
print(f"查询时发生错误: {e}")
# 如果查询失败,返回所有URL为新URL
return urls, []
@staticmethod @staticmethod
def close(alias ="default") -> None: def close(alias ="default") -> None:
connections.disconnect(alias) connections.disconnect(alias)
import torch
import gc
import sys
def print_xpu_memory_stats():
"""打印 XPU 内存分配统计信息"""
try:
if not hasattr(torch, 'xpu') or not torch.xpu.is_available():
print("XPU 不可用")
return
# 清理缓存以获取准确的内存使用情况
gc.collect()
torch.xpu.empty_cache()
# 获取当前内存统计信息
if hasattr(torch.xpu, 'memory_stats'):
stats = torch.xpu.memory_stats()
print("\n===== XPU 内存统计信息 =====")
print(f"分配的内存: {stats.get('allocated_bytes.all', 0) / (1024 ** 3):.2f} GB")
print(f"缓存的内存: {stats.get('reserved_bytes.all', 0) / (1024 ** 3):.2f} GB")
print(f"活跃的内存块: {stats.get('active_bytes.all', 0) / (1024 ** 3):.2f} GB")
print(f"内存分配次数: {stats.get('allocation.all', 0)}")
else:
print("torch.xpu.memory_stats() 不可用")
# 获取当前设备内存信息
if hasattr(torch.xpu, 'get_device_properties'):
device = torch.xpu.current_device()
props = torch.xpu.get_device_properties(device)
print("\n===== XPU 设备属性 =====")
print(f"设备名称: {props.name}")
print(f"总内存: {props.total_memory / (1024 ** 3):.2f} GB")
# 获取当前内存使用情况
if hasattr(torch.xpu, 'memory_allocated') and hasattr(torch.xpu, 'memory_reserved'):
print("\n===== XPU 当前内存使用 =====")
print(f"已分配内存: {torch.xpu.memory_allocated() / (1024 ** 3):.2f} GB")
print(f"已保留内存: {torch.xpu.memory_reserved() / (1024 ** 3):.2f} GB")
print(f"可用内存: {(props.total_memory - torch.xpu.memory_reserved()) / (1024 ** 3):.2f} GB")
else:
print("torch.xpu.get_device_properties() 不可用")
# 检查内存分配比例设置
if hasattr(torch.xpu, 'get_memory_fraction'):
print("\n===== XPU 内存分配比例 =====")
fraction = torch.xpu.get_memory_fraction()
print(f"当前内存分配比例: {fraction:.2f}")
elif hasattr(torch.xpu, 'get_allocator_backend'):
print("\n===== XPU 分配器后端 =====")
backend = torch.xpu.get_allocator_backend()
print(f"当前分配器后端: {backend}")
# 尝试获取最大内存
try:
# 分配测试张量,逐步增加大小直到失败
max_gb = 0
step = 1 # 每次增加 1GB
print("\n===== XPU 最大可分配内存测试 =====")
print("正在测试最大可分配内存...")
while True:
try:
size_bytes = int(max_gb * 1024 ** 3 / 4) # float32 是 4 字节
if size_bytes <= 0:
max_gb += step
continue
test_tensor = torch.zeros(size_bytes, dtype=torch.float32, device='xpu')
del test_tensor
torch.xpu.empty_cache()
print(f"成功分配 {max_gb} GB")
max_gb += step
except Exception as e:
print(f"在尝试分配 {max_gb} GB 时失败")
print(f"最大可分配内存约为: {max_gb - step} GB")
break
# 防止无限循环
if max_gb > 32: # 设置上限为 32GB
print("达到测试上限 (32GB)")
break
except Exception as e:
print(f"内存测试失败: {e}")
except Exception as e:
print(f"获取 XPU 内存统计信息时出错: {e}")
import traceback
print(traceback.format_exc())
# 打印 PyTorch 和系统信息
print(f"PyTorch 版本: {torch.__version__}")
print(f"Python 版本: {sys.version}")
# 检查 IPEX 是否安装
try:
import intel_extension_for_pytorch as ipex
print(f"IPEX 版本: {ipex.__version__}")
except ImportError:
print("IPEX 未安装")
# 打印 XPU 是否可用
if hasattr(torch, 'xpu'):
print(f"XPU 可用: {torch.xpu.is_available()}")
if torch.xpu.is_available():
print(f"XPU 设备数量: {torch.xpu.device_count()}")
print(f"当前 XPU 设备: {torch.xpu.current_device()}")
else:
print("XPU 不可用 (torch.xpu 不存在)")
# 打印内存统计信息
print_xpu_memory_stats()
# 测试设置内存分配比例
if hasattr(torch, 'xpu') and torch.xpu.is_available() and hasattr(torch.xpu, 'set_per_process_memory_fraction'):
print("\n===== 测试设置内存分配比例 =====")
current_fraction = 0.3 # 默认值
try:
# 尝试获取当前值
if hasattr(torch.xpu, 'get_memory_fraction'):
current_fraction = torch.xpu.get_memory_fraction()
print(f"当前内存分配比例: {current_fraction:.2f}")
except:
pass
# 设置新值
new_fraction = 0.8
print(f"设置内存分配比例为: {new_fraction:.2f}")
torch.xpu.set_per_process_memory_fraction(new_fraction)
# 验证设置是否生效
try:
if hasattr(torch.xpu, 'get_memory_fraction'):
updated_fraction = torch.xpu.get_memory_fraction()
print(f"更新后的内存分配比例: {updated_fraction:.2f}")
if abs(updated_fraction - new_fraction) < 0.01:
print("✓ 内存分配比例设置成功!")
else:
print("✗ 内存分配比例设置失败!")
except:
print("无法验证内存分配比例设置")
import unittest import unittest
from app.services.feature_extractor import FeatureExtractor from app.services.feature_extractor import get_feature_extractor
class TestFeatureExtractorFunction(unittest.TestCase): class TestFeatureExtractorFunction(unittest.TestCase):
def test_feature_extractor(self): def test_feature_extractor(self):
url = "https://pc3oscdn.chillcy.com/3359847025/QSIiPR0XExYACM/00f9bdfa63158ec9477e4f7fe70f5989.jpg" url = "https://pc3oscdn.chillcy.com/3359847025/QSIiPR0XExYACM/00f9bdfa63158ec9477e4f7fe70f5989.jpg"
feature = FeatureExtractor().extract_from_url(url) feature = get_feature_extractor().extract_from_url(url)
print(feature) print(feature)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment