Skip to content
GitLab
Menu
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Pavlo Beylin
MaD Patch Yolov5
Commits
00ae676a
Commit
00ae676a
authored
Oct 05, 2021
by
Pavlo Beylin
Browse files
Add cosine similarity matrix calculation for YOLO predictions for all classes.
parent
acf685f0
Changes
2
Hide whitespace changes
Inline
Side-by-side
CSM.py
0 → 100644
View file @
00ae676a
import
torch
import
torch.nn.functional
as
F
def
calc_yolo_csms
(
imgs_and_preds
:
torch
.
Tensor
,
sign
:
bool
=
True
,
rescale
:
bool
=
True
)
->
torch
.
Tensor
:
'''
computes the cosine similarity map for given input images X
Parameters
---------
model: torch model
imgs_and_preds: torch tensor; shape: (Batch_Size, Channels, Width, Height)
sign: use sign of gradients to calculate cosine similarity maps
rescale: rescale the logits before applying softmax -> solves gradient obfuscation problem of large logits
Returns
---------
return: cosine_similarity_map:
'''
csms
=
[]
# saliency maps w.r.t. all possible output classes
imgs
=
[]
for
tup
in
imgs_and_preds
:
img
,
pred
,
frame
,
x1
,
y1
,
x2
,
y2
=
tup
if
not
img
.
requires_grad
:
img
.
requires_grad_
()
logit
=
pred
[
5
:]
imgs
.
append
(
img
)
# rescale network output to avoid gradient obfuscation
if
rescale
:
logit
=
logit
/
torch
.
max
(
torch
.
abs
(
logit
))
*
10
classes
=
len
(
logit
)
deltas
=
[]
for
c
in
range
(
classes
):
# calculate loss and compute gradient w.r.t. the input of the current class
y
=
torch
.
ones
(
1
,
device
=
"cuda"
,
dtype
=
torch
.
long
)
*
c
loss
=
F
.
cross_entropy
(
logit
.
unsqueeze
(
0
),
y
)
frame_grad
=
torch
.
autograd
.
grad
(
loss
,
frame
,
retain_graph
=
True
)[
0
][:,
5
:]
img_grad
=
frame_grad
[
int
(
y1
):
int
(
y2
),
int
(
x1
):
int
(
x2
),
:]
# take sign of gradient as in the original paper
if
sign
:
img_grad
=
torch
.
sign
(
img_grad
)
deltas
.
append
(
img_grad
.
clone
().
detach
())
deltas
=
torch
.
stack
(
deltas
)
# compute cosine similarity matrices
try
:
deltas
=
torch
.
max
(
deltas
,
dim
=-
3
).
values
# take only the maximum value of all channels to compute the
deltas
=
deltas
.
view
(
classes
,
1
,
-
1
)
norm
=
torch
.
norm
(
deltas
,
p
=
2
,
dim
=
2
,
keepdim
=
True
)
deltas
=
deltas
/
norm
deltas
=
deltas
.
transpose
(
0
,
1
)
csm
=
torch
.
matmul
(
deltas
,
deltas
.
transpose
(
1
,
2
))
except
Exception
as
e
:
print
(
"error"
)
raise
e
# division by zero can lead to NaNs
if
torch
.
isnan
(
csm
).
any
():
# raise Exception("NaNs in CSM!")
print
(
"NaNs in csm"
)
else
:
print
(
f
'
{
deltas
.
mean
()
}
'
)
csms
.
append
(
csm
)
return
imgs
,
csms
def
calc_csm
(
model
:
torch
.
nn
.
Module
,
X
:
torch
.
Tensor
,
sign
:
bool
=
True
,
rescale
:
bool
=
True
)
->
torch
.
Tensor
:
'''
computes the cosine similarity map for given input images X
Parameters
---------
model: torch model
X: torch tensor; shape: (Batch_Size, Channels, Width, Height)
sign: use sign of gradients to calculate cosine similarity maps
rescale: rescale the logits before applying softmax -> solves gradient obfuscation problem of large logits
Returns
---------
return: cosine_similarity_map:
'''
deltas
=
[]
# saliency maps w.r.t. all possible output classes
if
not
X
.
requires_grad
:
X
.
requires_grad_
()
logits
=
model
(
X
)
# network output
# rescale network output to avoid gradient obfuscation
if
rescale
:
logits
=
logits
/
torch
.
max
(
torch
.
abs
(
logits
),
1
,
keepdim
=
True
).
values
*
10
B
=
logits
.
shape
[
0
]
# batch size
classes
=
logits
.
shape
[
-
1
]
# output classes
for
c
in
range
(
classes
):
# calculate loss and compute gradient w.r.t. the input of the current class
y
=
torch
.
ones
(
B
,
device
=
"cuda"
,
dtype
=
torch
.
long
)
*
c
loss
=
F
.
cross_entropy
(
logits
,
y
)
grad
=
torch
.
autograd
.
grad
(
loss
,
X
,
retain_graph
=
True
)[
0
]
# take sign of gradient as in the original paper
if
sign
:
grad
=
torch
.
sign
(
grad
)
deltas
.
append
(
grad
.
detach
().
clone
())
model
.
zero_grad
()
deltas
=
torch
.
stack
(
deltas
,
dim
=
0
)
deltas
=
torch
.
max
(
deltas
,
dim
=-
3
).
values
# take only the maximum value of all channels to compute the cosine similarity
# compute cosine similarity matrices
deltas
=
deltas
.
view
(
classes
,
B
,
-
1
)
norm
=
torch
.
norm
(
deltas
,
p
=
2
,
dim
=
2
,
keepdim
=
True
)
deltas
=
deltas
/
norm
deltas
=
deltas
.
transpose
(
0
,
1
)
csm
=
torch
.
matmul
(
deltas
,
deltas
.
transpose
(
1
,
2
))
# division by zero can lead to NaNs
if
torch
.
isnan
(
csm
).
any
():
raise
Exception
(
"NaNs in CSM!"
)
return
csm
def
calc_csm_partial_network
(
model_first_part
:
torch
.
nn
.
Module
,
model_second_part
:
torch
.
nn
.
Module
,
X
:
torch
.
Tensor
,
sign
:
bool
=
True
,
rescale
:
bool
=
True
,
scalar_product
:
bool
=
False
)
->
torch
.
Tensor
:
'''
computes the cosine similarity map for given input images X
Parameters
---------
model: torch model
X: torch tensor; shape: (Batch_Size, Channels, Width, Height)
sign: use sign of gradients to calculate cosine similarity maps
rescale: rescale the logits before applying softmax -> solves gradient obfuscation problem of large logits
Returns
---------
return: cosine_similarity_map:
'''
deltas
=
[]
# saliency maps w.r.t. all possible output classes
pre_ultimate_output
=
model_first_part
(
X
)
pre_ultimate_output
.
requires_grad_
()
logits
=
model_second_part
(
pre_ultimate_output
)
# network output
# rescale network output to avoid gradient obfuscation
if
rescale
:
logits
=
logits
/
torch
.
max
(
torch
.
abs
(
logits
),
1
,
keepdim
=
True
).
values
*
10
B
=
logits
.
shape
[
0
]
# batch size
classes
=
logits
.
shape
[
-
1
]
# output classes
for
c
in
range
(
classes
):
# calculate loss and compute gradient w.r.t. the input of the current class
y
=
torch
.
ones
(
B
,
device
=
"cuda"
,
dtype
=
torch
.
long
)
*
c
loss
=
F
.
cross_entropy
(
logits
,
y
)
grad
=
torch
.
autograd
.
grad
(
loss
,
pre_ultimate_output
,
retain_graph
=
True
)[
0
]
# take sign of gradient as in the original paper
if
sign
:
grad
=
torch
.
sign
(
grad
)
deltas
.
append
(
grad
.
detach
().
clone
())
deltas
=
torch
.
stack
(
deltas
,
dim
=
0
)
# compute cosine similarity matrices
deltas
=
deltas
.
view
(
classes
,
B
,
-
1
)
norm
=
torch
.
norm
(
deltas
,
p
=
2
,
dim
=
2
,
keepdim
=
True
)
if
not
scalar_product
:
deltas
=
deltas
/
norm
deltas
=
deltas
.
transpose
(
0
,
1
)
csm
=
torch
.
matmul
(
deltas
,
deltas
.
transpose
(
1
,
2
))
# division by zero can lead to NaNs
if
torch
.
isnan
(
csm
).
any
():
raise
Exception
(
"NaNs in CSM!"
)
return
csm
main.py
View file @
00ae676a
...
...
@@ -9,6 +9,7 @@ import math
import
matplotlib
from
torch
import
optim
import
CSM
import
models
from
models.common
import
Detections
from
utils.external
import
TotalVariation
...
...
@@ -129,6 +130,7 @@ def bb_intersection_over_union(boxA, boxB):
# return the intersection over union value
return
iou
def
save_image
(
image
):
print
(
"save image called!"
)
im
=
transforms
.
ToPILImage
(
'RGB'
)(
image
)
...
...
@@ -136,6 +138,7 @@ def save_image(image):
plt
.
show
()
im
.
save
(
f
"saved_patches/
{
time
.
time
()
}
.jpg"
)
def
get_best_prediction
(
true_box
,
res
,
cls_nr
):
min_distance
=
float
(
"inf"
)
max_iou
=
float
(
0
)
...
...
@@ -149,10 +152,32 @@ def get_best_prediction(true_box, res, cls_nr):
max_iou
=
pred_iou
best_prediction
=
pred
[
cls_nr
+
5
]
print
(
f
"max found iou:
{
max_iou
}
"
)
# print(f"max found iou: {max_iou}")
return
max_iou
,
best_prediction
def
calculate_csms
(
frame
,
predictions
):
imgs_and_preds
=
[]
for
pred
in
predictions
:
x1
,
y1
,
x2
,
y2
,
conf
=
pred
[:
5
].
float
()
pred_img_section
=
frame
.
flip
(
2
)[
int
(
y1
):
int
(
y2
),
int
(
x1
):
int
(
x2
),
:]
tup
=
(
pred_img_section
,
pred
,
frame
,
x1
,
y1
,
x2
,
y2
)
# print(tup)
imgs_and_preds
.
append
(
tup
)
# if conf > 0.8:
# cls = classes[int(pred[5:].argmax())]
# print(f"{cls}: {conf} - {pred[:5].float()}")
# show(frame.flip(2)[int(y1):int(y2), int(x1):int(x2), :] / 255.)
# print("done")
imgs
,
csms
=
CSM
.
calc_yolo_csms
(
imgs_and_preds
)
if
__name__
==
"__main__"
:
# init
patch_transformer
=
PatchTransformer
().
cuda
()
...
...
@@ -209,8 +234,8 @@ if __name__ == "__main__":
pred
=
-
1
frame_read
=
False
fix_frame
=
False
patch_transformer
.
maxangle
=
5
/
180
*
math
.
pi
patch_transformer
.
minangle
=
-
5
/
180
*
math
.
pi
patch_transformer
.
maxangle
=
5
/
180
*
math
.
pi
patch_transformer
.
minangle
=
-
5
/
180
*
math
.
pi
loss
=
None
while
True
:
if
not
(
fix_frame
and
frame_read
):
...
...
@@ -257,6 +282,12 @@ if __name__ == "__main__":
# debug_preds()
pass
# calculate Cosine Similarity Matrix
imgs
,
csms
=
calculate_csms
(
frame
,
raw_results
)
for
i
in
range
(
len
(
imgs
)):
show
(
imgs
[
i
])
show
(
csms
[
i
])
iou
,
pred
=
get_best_prediction
(
bounding_box
,
raw_results
,
15
)
# get cat
# iou, pred = get_best_prediction(bounding_box, raw_results, 0) # get personal
# iou, pred = get_best_prediction(bounding_box, raw_results, 12) # get parking meter
...
...
@@ -296,7 +327,7 @@ if __name__ == "__main__":
# sgn_grads = torch.sign(optimizer.param_groups[0]['params'][0].grad)
# optimizer.param_groups[0]['params'][0].grad = sgn_grads
# optimizer.step()
patch
.
data
-=
torch
.
sign
(
gradient_sum
)
*
0.001
# * 0 # TODO reactivate
patch
.
data
-=
torch
.
sign
(
gradient_sum
)
*
0.001
# * 0 # TODO reactivate
patch
.
data
=
patch
.
detach
().
clone
().
clamp
(
MIN_THRESHOLD
,
0.99999
).
data
gradient_sum
=
0
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment