# Copyright (c) 2022 The BayesFlow Developers
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import tensorflow as tf
[docs]
class SpectralNormalization(tf.keras.layers.Wrapper):
"""Performs spectral normalization on neural network weights. Adapted from:
https://www.tensorflow.org/addons/api_docs/python/tfa/layers/SpectralNormalization
This wrapper controls the Lipschitz constant of a layer by
constraining its spectral norm, which can stabilize the training of generative networks.
See Spectral Normalization for Generative Adversarial Networks](https://arxiv.org/abs/1802.05957).
"""
[docs]
def __init__(self, layer, power_iterations=1, **kwargs):
super(SpectralNormalization, self).__init__(layer, **kwargs)
if power_iterations <= 0:
raise ValueError(
"`power_iterations` should be greater than zero, got " "`power_iterations={}`".format(power_iterations)
)
self.power_iterations = power_iterations
self._initialized = False
[docs]
def build(self, input_shape):
"""Build `Layer`"""
# Register input shape
super().build(input_shape)
# Store reference to weights
if hasattr(self.layer, "kernel"):
self.w = self.layer.kernel
elif hasattr(self.layer, "embeddings"):
self.w = self.layer.embeddings
else:
raise AttributeError(
"{} object has no attribute 'kernel' nor " "'embeddings'".format(type(self.layer).__name__)
)
self.w_shape = self.w.shape.as_list()
self.u = self.add_weight(
shape=(1, self.w_shape[-1]),
initializer=tf.initializers.TruncatedNormal(stddev=0.02),
trainable=False,
name="sn_u",
dtype=self.w.dtype,
)
[docs]
def call(self, inputs, training=False):
"""Call `Layer`
Parameters
----------
inputs : tf.Tensor of shape (None,...,condition_dim + target_dim)
The inputs to the corresponding layer.
"""
if training:
self.normalize_weights()
output = self.layer(inputs)
return output
[docs]
def normalize_weights(self):
"""Generate spectral normalized weights.
This method will update the value of `self.w` with the
spectral normalized value, so that the layer is ready for `call()`.
"""
w = tf.reshape(self.w, [-1, self.w_shape[-1]])
u = self.u
with tf.name_scope("spectral_normalize"):
for _ in range(self.power_iterations):
v = tf.math.l2_normalize(tf.matmul(u, w, transpose_b=True))
u = tf.math.l2_normalize(tf.matmul(v, w))
u = tf.stop_gradient(u)
v = tf.stop_gradient(v)
sigma = tf.matmul(tf.matmul(v, w), u, transpose_b=True)
self.u.assign(tf.cast(u, self.u.dtype))
self.w.assign(tf.cast(tf.reshape(self.w / sigma, self.w_shape), self.w.dtype))
[docs]
def get_config(self):
config = {"power_iterations": self.power_iterations}
base_config = super().get_config()
return {**base_config, **config}