Commit 1e69ea27 authored by zhengyaoqiu's avatar zhengyaoqiu

上传 & 检索

parent 285b001c
from flask import jsonify from flask import jsonify, request
from app.api import api_bp from app.api import api_bp
@api_bp.route('/hello', methods=['GET']) @api_bp.route('/hello', methods=['GET'])
...@@ -16,3 +16,18 @@ def hello_name(name): ...@@ -16,3 +16,18 @@ def hello_name(name):
'message': f'Hello, {name}!', 'message': f'Hello, {name}!',
'status': 'success' 'status': 'success'
}) })
@api_bp.route('/upload', methods=['PUT'])
def upload(name):
# 获取 JSON 格式的请求体数据
data = request.get_json()
# 访问具体字段
bucket = data.get('bucket')
image = data.get('image')
key = data.get('key')
return jsonify({
'message': f'Hello, {name}!',
'status': 'success'
})
...@@ -8,19 +8,23 @@ class ImageSearch: ...@@ -8,19 +8,23 @@ class ImageSearch:
self.milvus = milvus self.milvus = milvus
# id = product_id # id = product_id
def image_to_image_search(self, image, key_name, top_k = 100): def image_to_image_search(self, bucket, image, top_k = 100):
try: try:
# 提取查询图像的特征 # 提取查询图像的特征
query_embedding = self.feature_extractor.extract_features(image) vector = self.feature_extractor.extract_features(image)
results = self.milvus.search(query_embedding, limit=top_k) # anns_field = embedding
# output_fields=["image", "product_id"]
# search_params: Dict[str, Any] = {"metric_type": "IP", "params": {"ef": 100}}
results = self.milvus.search(bucket, vector, top_k)
# 处理结果 # 处理结果
if not results or len(results) == 0: if not results or len(results) == 0:
return [] return []
# 返回结果 # 返回结果
keys = [hit.entity.get(key_name) for hit in results[0]] keys = [hit.entity.get("key") for hit in results[0]]
scores = [hit.score for hit in results[0]] scores = [hit.score for hit in results[0]]
return list(zip(keys, scores)) return list(zip(keys, scores))
......
import logging
class Upload:
__logger = logging.getLogger(__name__)
def __init__(self, feature_extractor, milvus):
self.feature_extractor = feature_extractor
self.milvus = milvus
def upload_one(self, bucket, image, key):
self.upload_many(bucket, {image: key})
def upload_many(self, bucket, image2key):
images = []
keys = []
vectors = []
for image, key in image2key.items():
vector = self.feature_extractor.extract_from_url(image)
images.append(image)
keys.append(key)
vectors.append(vector)
entities = [
images,
keys,
vectors
]
self.milvus.insert(bucket, entities)
...@@ -46,15 +46,17 @@ class MilvusClient: ...@@ -46,15 +46,17 @@ class MilvusClient:
# output_fields=["image", "product_id"] # output_fields=["image", "product_id"]
# search_params: Dict[str, Any] = {"metric_type": "IP", "params": {"ef": 100}} # search_params: Dict[str, Any] = {"metric_type": "IP", "params": {"ef": 100}}
def search(self, collection_name, vector, anns_field, search_params, output_fields, top_k = 10) -> Any: def search(self, collection_name, vector, top_k = 10) -> Any:
collection = self.get_collection(collection_name) collection = self.get_collection(collection_name)
search_params: Dict[str, Any] = {"metric_type": "IP", "params": {"ef": 100}}
results = collection.search( results = collection.search(
data=[vector], data=[vector],
anns_field=anns_field, anns_field="vector",
param=search_params, param=search_params,
limit=top_k, limit=top_k,
output_fields=output_fields output_fields=["image", "key"]
) )
return results return results
......
import unittest import unittest
from app.models.feature_extractor import FeatureExtractor from app.services.feature_extractor import FeatureExtractor
class TestFeatureExtractorFunction(unittest.TestCase): class TestFeatureExtractorFunction(unittest.TestCase):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment