diff --git a/loss_functions.py b/loss_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..a6e52dac14f829afe1ff25cfc31ff96eb175f8a9 --- /dev/null +++ b/loss_functions.py @@ -0,0 +1,14 @@ +import tensorflow as tf + +def wgan_d_loss(y_true, y_pred, EPS=1e-12): + return tf.reduce_mean(y_pred) - tf.reduce_mean(y_true) + +def wgan_g_loss(y_true, y_pred, EPS=1e-12): + return tf.reduce_mean(y_pred + EPS) + +def gan_d_loss(y_true, y_pred, EPS=1e-12): + return tf.reduce_mean(tf.math.log(y_true + EPS) + tf.math.log(1 - y_pred + EPS)) + +def gan_g_loss(y_true, y_pred, EPS=1e-12): + return tf.reduce_mean(-tf.math.log(y_pred + EPS)) +