Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
I
image_search
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
service
image_search
Commits
ef587af1
Commit
ef587af1
authored
May 24, 2025
by
zhengyaoqiu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
特征提取模块重构
parent
9858b763
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
387 additions
and
46 deletions
+387
-46
feature_extractor.py
app/models/feature_extractor.py
+47
-44
image_search.py
app/models/image_search.py
+224
-0
milvus.py
app/models/milvus.py
+112
-0
requirements.txt
requirements.txt
+3
-1
test_feature_extractor.py
tests/test_feature_extractor.py
+1
-1
No files found.
app/models/feature_extractor.py
View file @
ef587af1
import
threading
import
torch
import
clip
import
requests
from
PIL
import
Image
from
io
import
BytesIO
from
typing
import
Optional
,
Tuple
import
numpy
as
np
import
logging
class
FeatureExtractor
:
"""
使用CLIP模型提取图像特征向量的工具类
"""
# 类变量,用于存储模型实例
_model
=
None
_preprocess
=
None
_device
=
None
@
classmethod
def
initialize
(
cls
,
model_name
=
"ViT-B/32"
):
"""
初始化CLIP模型
Args:
model_name (str): CLIP模型名称
"""
cls
.
_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
cls
.
_model
,
cls
.
_preprocess
=
clip
.
load
(
model_name
,
device
=
cls
.
_device
)
print
(
f
"CLIP model {model_name} loaded on {cls._device}"
)
class
FeatureExtractor
:
__model
=
None
__preprocess
=
None
__logger
=
logging
.
getLogger
(
__name__
)
__instance
=
None
__lock
=
threading
.
Lock
()
__device
=
"ViT-B/32"
def
__new__
(
cls
,
device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
,
model_name
=
"ViT-B/32"
):
# 第一次检查 - 不带锁
if
cls
.
__instance
is
None
:
# 只有在可能需要创建实例时才获取锁
with
cls
.
__lock
:
# 第二次检查 - 带锁
if
cls
.
__instance
is
None
:
print
(
f
"创建并初始化 CLIP 模型: {model_name} 在设备: {device}"
)
# 创建实例
cls
.
__instance
=
super
()
.
__new__
(
cls
)
# 在这里直接完成初始化
cls
.
__instance
.
__model
,
cls
.
__instance
.
__preprocess
=
clip
.
load
(
model_name
,
device
=
device
)
cls
.
__instance
.
__device
=
device
return
cls
.
__instance
@
staticmethod
def
resize_with_padding
(
img
):
def
resize_with_padding
(
img
,
target_size
=
(
224
,
224
)
):
"""
调整图像大小,保持纵横比并添加填充
Args:
img (Image.Image): 输入图像
img: 输入图像
target_size: 目标尺寸,默认为(224, 224)
Returns:
PIL.Image:
调整大小后的图像
调整大小后的图像
"""
target_size
=
(
224
,
224
)
# 计算调整大小的比例
ratio
=
min
(
target_size
[
0
]
/
img
.
width
,
target_size
[
1
]
/
img
.
height
)
new_size
=
(
int
(
img
.
width
*
ratio
),
int
(
img
.
height
*
ratio
))
...
...
@@ -58,62 +64,59 @@ class FeatureExtractor:
return
new_img
@
classmethod
def
extract_from_url
(
cls
,
image_url
):
def
extract_from_url
(
self
,
image_url
):
"""
从URL加载图像并提取特征向量
Args:
image_url
(str)
: 图像URL
image_url: 图像URL
Returns:
numpy.ndarray: 特征向量
特征向量,如果提取失败则返回None
"""
if
cls
.
_model
is
None
:
cls
.
initialize
()
try
:
# 下载图片
response
=
requests
.
get
(
image_url
,
stream
=
True
)
response
=
requests
.
get
(
image_url
,
stream
=
True
,
timeout
=
10
)
response
.
raise_for_status
()
# 确保请求成功
# 将图片数据转换为 PIL Image 对象
image
=
Image
.
open
(
BytesIO
(
response
.
content
))
.
convert
(
"RGB"
)
return
cls
.
extract_from_image
(
image
)
return
self
.
extract_from_image
(
image
)
except
requests
.
RequestException
as
e
:
self
.
__logger
.
error
(
f
"Network error when downloading image from {image_url}: {e}"
)
return
None
except
Exception
as
e
:
print
(
f
"Error extracting features from URL
: {e}"
)
self
.
__logger
.
error
(
f
"Error extracting features from URL {image_url}
: {e}"
)
return
None
@
classmethod
def
extract_from_image
(
cls
,
img
):
def
extract_from_image
(
self
,
img
):
"""
从PIL图像对象提取特征向量
Args:
img
(Image.Image)
: 输入图像
img: 输入图像
Returns:
numpy.ndarray: 特征向量
特征向量,如果提取失败则返回None
"""
if
cls
.
_model
is
None
:
cls
.
initialize
()
try
:
# 调整图像大小并添加填充
image
=
cls
.
resize_with_padding
(
img
)
image
=
self
.
resize_with_padding
(
img
)
# 预处理并提取特征
image_tensor
=
cls
.
_preprocess
(
image
)
.
unsqueeze
(
0
)
.
to
(
cls
.
_device
)
image_tensor
=
self
.
__preprocess
(
image
)
.
unsqueeze
(
0
)
.
to
(
self
.
_
_device
)
with
torch
.
no_grad
():
image_features
=
cls
.
_model
.
encode_image
(
image_tensor
)
image_features
=
self
.
_
_model
.
encode_image
(
image_tensor
)
# 归一化特征向量
image_features
/=
image_features
.
norm
(
dim
=-
1
,
keepdim
=
True
)
return
image_features
.
cpu
()
.
numpy
()
.
flatten
()
except
Exception
as
e
:
print
(
f
"Error extracting features from image: {e}"
)
self
.
__logger
.
error
(
f
"Error extracting features from image: {e}"
)
return
None
\ No newline at end of file
app/models/image_search.py
0 → 100644
View file @
ef587af1
from
typing
import
List
,
Tuple
,
Optional
,
Union
import
logging
import
numpy
as
np
from
PIL
import
Image
from
pymilvus
import
Collection
# 引入特征提取器和MilvusClient
from
feature_extractor
import
FeatureExtractor
from
milvus
import
MilvusClient
# 假设MilvusClient在milvus_client.py文件中
class
ImageSearch
:
"""
图像搜索类,提供基于图像相似度的搜索功能
"""
_logger
=
logging
.
getLogger
(
__name__
)
_feature_extractor_initialized
=
False
_milvus_client
=
None
@
classmethod
def
_initialize_feature_extractor
(
cls
)
->
None
:
"""初始化特征提取器"""
if
not
cls
.
_feature_extractor_initialized
:
try
:
FeatureExtractor
.
initialize
()
cls
.
_feature_extractor_initialized
=
True
except
Exception
as
e
:
cls
.
_logger
.
error
(
f
"Failed to initialize feature extractor: {e}"
)
raise
RuntimeError
(
f
"Failed to initialize feature extractor: {e}"
)
@
classmethod
def
get_milvus_client
(
cls
,
host
:
str
=
"localhost"
,
port
:
str
=
"19530"
,
collection_name
:
str
=
"image_collection"
)
->
MilvusClient
:
"""
获取或创建Milvus客户端
Args:
host: Milvus服务器地址
port: Milvus服务器端口
collection_name: 集合名称
Returns:
MilvusClient: Milvus客户端实例
"""
if
cls
.
_milvus_client
is
None
:
cls
.
_milvus_client
=
MilvusClient
(
host
,
port
,
collection_name
)
.
connect
()
return
cls
.
_milvus_client
@
classmethod
def
get_collection
(
cls
,
host
:
str
=
"localhost"
,
port
:
str
=
"19530"
,
collection_name
:
str
=
"image_collection"
)
->
Collection
:
"""
获取Milvus集合
Args:
host: Milvus服务器地址
port: Milvus服务器端口
collection_name: 集合名称
Returns:
Collection: Milvus集合对象
Raises:
RuntimeError: 如果无法获取集合
"""
try
:
client
=
cls
.
get_milvus_client
(
host
,
port
,
collection_name
)
return
client
.
get_collection
()
except
Exception
as
e
:
cls
.
_logger
.
error
(
f
"Failed to get collection: {e}"
)
raise
RuntimeError
(
f
"Failed to get collection: {e}"
)
@
classmethod
def
extract_features
(
cls
,
image
:
Union
[
str
,
Image
.
Image
])
->
np
.
ndarray
:
"""
从图像提取特征向量
Args:
image: 图像URL或PIL图像对象
Returns:
np.ndarray: 特征向量
Raises:
ValueError: 如果特征提取失败
"""
# 确保特征提取器已初始化
cls
.
_initialize_feature_extractor
()
try
:
# 根据图像类型调用相应的提取方法
if
isinstance
(
image
,
str
):
features
=
FeatureExtractor
.
extract_from_url
(
image
)
elif
isinstance
(
image
,
Image
.
Image
):
features
=
FeatureExtractor
.
extract_from_image
(
image
)
else
:
raise
ValueError
(
f
"Unsupported image type: {type(image)}"
)
if
features
is
None
:
raise
ValueError
(
"Feature extraction returned None"
)
return
features
except
Exception
as
e
:
cls
.
_logger
.
error
(
f
"Feature extraction failed: {e}"
)
raise
ValueError
(
f
"Failed to extract features: {e}"
)
@
classmethod
def
image_to_image_search
(
cls
,
image
:
Union
[
str
,
Image
.
Image
],
top_k
:
int
=
100
,
host
:
str
=
"localhost"
,
port
:
str
=
"19530"
,
collection_name
:
str
=
"image_collection"
)
->
List
[
Tuple
[
str
,
float
]]:
"""
使用图像查询相似图像
Args:
image: 查询图像的URL或PIL图像对象
top_k: 返回的最相似结果数量
host: Milvus服务器地址
port: Milvus服务器端口
collection_name: 集合名称
Returns:
List[Tuple[str, float]]: 产品ID和相似度分数的列表,按相似度降序排列
Raises:
ValueError: 如果图像处理或搜索过程中出错
"""
try
:
# 提取查询图像的特征
query_embedding
=
cls
.
extract_features
(
image
)
# 获取Milvus客户端并搜索
client
=
cls
.
get_milvus_client
(
host
,
port
,
collection_name
)
results
=
client
.
search
(
query_embedding
,
limit
=
top_k
)
# 处理结果
if
not
results
or
len
(
results
)
==
0
:
return
[]
# 返回结果
product_ids
=
[
hit
.
entity
.
get
(
'product_id'
)
for
hit
in
results
[
0
]]
scores
=
[
hit
.
score
for
hit
in
results
[
0
]]
return
list
(
zip
(
product_ids
,
scores
))
except
Exception
as
e
:
cls
.
_logger
.
error
(
f
"Image search failed: {e}"
)
raise
ValueError
(
f
"Image search failed: {e}"
)
@
classmethod
def
batch_image_search
(
cls
,
images
:
List
[
Union
[
str
,
Image
.
Image
]],
top_k
:
int
=
100
,
host
:
str
=
"localhost"
,
port
:
str
=
"19530"
,
collection_name
:
str
=
"image_collection"
)
->
List
[
List
[
Tuple
[
str
,
float
]]]:
"""
批量图像搜索
Args:
images: 查询图像URL或PIL图像对象的列表
top_k: 每个查询返回的最相似结果数量
host: Milvus服务器地址
port: Milvus服务器端口
collection_name: 集合名称
Returns:
List[List[Tuple[str, float]]]: 每个查询图像对应的结果列表
"""
# 确保特征提取器已初始化
cls
.
_initialize_feature_extractor
()
# 批量提取特征
batch_embeddings
=
[]
valid_indices
=
[]
for
i
,
image
in
enumerate
(
images
):
try
:
features
=
cls
.
extract_features
(
image
)
batch_embeddings
.
append
(
features
)
valid_indices
.
append
(
i
)
except
Exception
as
e
:
cls
.
_logger
.
warning
(
f
"Feature extraction failed for image at index {i}: {e}"
)
# 如果没有有效的特征,返回空列表
if
not
batch_embeddings
:
return
[[]
for
_
in
range
(
len
(
images
))]
try
:
# 获取集合
collection
=
cls
.
get_collection
(
host
,
port
,
collection_name
)
collection
.
load
()
# 批量搜索
search_params
=
{
"metric_type"
:
"IP"
,
"params"
:
{
"ef"
:
100
}}
batch_results
=
collection
.
search
(
data
=
batch_embeddings
,
anns_field
=
"embedding"
,
param
=
search_params
,
limit
=
top_k
,
output_fields
=
[
"image"
,
"product_id"
]
)
# 处理结果
all_results
=
[[]
for
_
in
range
(
len
(
images
))]
for
i
,
results
in
enumerate
(
batch_results
):
original_idx
=
valid_indices
[
i
]
product_ids
=
[
hit
.
entity
.
get
(
'product_id'
)
for
hit
in
results
]
scores
=
[
hit
.
score
for
hit
in
results
]
all_results
[
original_idx
]
=
list
(
zip
(
product_ids
,
scores
))
return
all_results
except
Exception
as
e
:
cls
.
_logger
.
error
(
f
"Batch image search failed: {e}"
)
raise
ValueError
(
f
"Batch image search failed: {e}"
)
app/models/milvus.py
0 → 100644
View file @
ef587af1
from
typing
import
Dict
,
List
,
Any
,
Optional
,
Union
import
numpy
as
np
from
pymilvus
import
connections
,
Collection
,
FieldSchema
,
CollectionSchema
,
DataType
,
utility
class
MilvusClient
:
def
__init__
(
self
,
host
:
str
=
"localhost"
,
port
:
str
=
"19530"
,
collection_name
:
str
=
"image_collection"
)
->
None
:
"""初始化Milvus客户端
Args:
host: Milvus服务器地址
port: Milvus服务器端口
collection_name: 集合名称
"""
self
.
host
:
str
=
host
self
.
port
:
str
=
port
self
.
collection_name
:
str
=
collection_name
self
.
collection
:
Optional
[
Collection
]
=
None
def
connect
(
self
)
->
'MilvusClient'
:
"""连接到Milvus服务器
Returns:
MilvusClient: 当前客户端实例,支持链式调用
"""
connections
.
connect
(
"default"
,
host
=
self
.
host
,
port
=
self
.
port
)
return
self
def
get_collection
(
self
)
->
Collection
:
"""获取或创建集合
Returns:
Collection: Milvus集合对象
"""
# 定义集合结构
fields
:
List
[
FieldSchema
]
=
[
FieldSchema
(
name
=
"image"
,
dtype
=
DataType
.
VARCHAR
,
max_length
=
256
,
is_primary
=
True
,
auto_id
=
False
),
FieldSchema
(
name
=
"product_id"
,
dtype
=
DataType
.
VARCHAR
,
max_length
=
256
),
FieldSchema
(
name
=
"embedding"
,
dtype
=
DataType
.
FLOAT_VECTOR
,
dim
=
512
)
# CLIP ViT-B/32的特征维度为512
]
schema
:
CollectionSchema
=
CollectionSchema
(
fields
,
"图像特征集合"
)
# 创建或获取集合
if
utility
.
has_collection
(
self
.
collection_name
):
self
.
collection
=
Collection
(
name
=
self
.
collection_name
)
else
:
self
.
collection
=
Collection
(
name
=
self
.
collection_name
,
schema
=
schema
)
self
.
_create_index
()
return
self
.
collection
def
_create_index
(
self
)
->
None
:
"""创建索引"""
index_params
:
Dict
[
str
,
Any
]
=
{
"index_type"
:
"HNSW"
,
"metric_type"
:
"IP"
,
# 内积相似度
"params"
:
{
"M"
:
16
,
"efConstruction"
:
200
}
}
if
self
.
collection
is
not
None
:
self
.
collection
.
create_index
(
field_name
=
"embedding"
,
index_params
=
index_params
)
def
search
(
self
,
vector
:
Union
[
List
[
float
],
np
.
ndarray
],
limit
:
int
=
10
)
->
Any
:
"""搜索相似向量
Args:
vector: 查询向量,可以是列表或numpy数组
limit: 返回结果数量
Returns:
查询结果
"""
if
self
.
collection
is
None
:
self
.
get_collection
()
self
.
collection
.
load
()
search_params
:
Dict
[
str
,
Any
]
=
{
"metric_type"
:
"IP"
,
"params"
:
{
"ef"
:
100
}}
results
=
self
.
collection
.
search
(
data
=
[
vector
],
anns_field
=
"embedding"
,
param
=
search_params
,
limit
=
limit
,
output_fields
=
[
"image"
,
"product_id"
]
)
return
results
def
insert
(
self
,
data
:
List
[
Dict
[
str
,
Any
]])
->
None
:
"""插入数据
Args:
data: 包含image、product_id和embedding的字典列表
每个字典应包含键: "image", "product_id", "embedding"
"""
if
self
.
collection
is
None
:
self
.
get_collection
()
entities
:
List
[
List
[
Any
]]
=
[
[
item
[
"image"
]
for
item
in
data
],
[
item
[
"product_id"
]
for
item
in
data
],
[
item
[
"embedding"
]
for
item
in
data
]
]
self
.
collection
.
insert
(
entities
)
def
drop_collection
(
self
)
->
None
:
"""删除集合"""
if
utility
.
has_collection
(
self
.
collection_name
):
utility
.
drop_collection
(
self
.
collection_name
)
self
.
collection
=
None
def
close
(
self
)
->
None
:
"""关闭连接"""
connections
.
disconnect
(
"default"
)
requirements.txt
View file @
ef587af1
...
...
@@ -32,4 +32,6 @@ networkx~=3.4.2
urllib3
~=2.4.0
fsspec
~=2025.3.2
ujson
~=5.10.0
pandas
~=2.2.3
\ No newline at end of file
pandas
~=2.2.3
pymilvus
~=2.5.9
typing_extensions
~=4.13.2
\ No newline at end of file
tests/test_feature_extractor.py
View file @
ef587af1
...
...
@@ -6,7 +6,7 @@ from app.models.feature_extractor import FeatureExtractor
class
TestFeatureExtractorFunction
(
unittest
.
TestCase
):
def
test_feature_extractor
(
self
):
url
=
"https://pc3oscdn.chillcy.com/3359847025/QSIiPR0XExYACM/00f9bdfa63158ec9477e4f7fe70f5989.jpg"
feature
=
FeatureExtractor
.
extract_from_url
(
url
)
feature
=
FeatureExtractor
()
.
extract_from_url
(
url
)
print
(
feature
)
if
__name__
==
'__main__'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment