Spaces:
Build error
Build error
| import logging, os | |
| logging.disable(logging.WARNING) | |
| os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" | |
| import tensorflow as tf | |
| from basic_ops import * | |
| from resnet_module import * | |
| """This script generates the U-Net architecture according to conf_unet. | |
| """ | |
| class UNet(object): | |
| def __init__(self, conf_unet): | |
| self.depth = conf_unet['depth'] | |
| self.dimension = conf_unet['dimension'] | |
| self.first_output_filters = conf_unet['first_output_filters'] | |
| self.encoding_block_sizes = conf_unet['encoding_block_sizes'] | |
| self.downsampling = conf_unet['downsampling'] | |
| self.decoding_block_sizes = conf_unet['decoding_block_sizes'] | |
| self.skip_method = conf_unet['skip_method'] | |
| def __call__(self, inputs, training): | |
| """Add operations to classify a batch of input images. | |
| Args: | |
| inputs: A Tensor representing a batch of input images. | |
| training: A boolean. Set to True to add operations required only when | |
| training the classifier. | |
| Returns: | |
| A logits Tensor with shape [<batch_size>, self.num_classes]. | |
| """ | |
| return self._build_network(inputs, training) | |
| ################################################################################ | |
| # Composite blocks building the network | |
| ################################################################################ | |
| def _build_network(self, inputs, training): | |
| # first_convolution | |
| if self.dimension == '2D': | |
| convolution = convolution_2D | |
| elif self.dimension == '3D': | |
| convolution = convolution_3D | |
| inputs = convolution(inputs, self.first_output_filters, 3, 1, False, 'first_convolution') | |
| # encoding_block_1 | |
| with tf.variable_scope('encoding_block_1'): | |
| for block_index in range(0, self.encoding_block_sizes[0]): | |
| inputs = res_block(inputs, self.first_output_filters, training, self.dimension, | |
| 'res_%d' % block_index) | |
| # encoding_block_i (down) = downsampling + zero or more res_block, i = 2, 3, ..., depth | |
| skip_inputs = [] # for identity skip connections | |
| for i in range(2, self.depth+1): | |
| skip_inputs.append(inputs) | |
| with tf.variable_scope('encoding_block_%d' % i): | |
| output_filters = self.first_output_filters * (2**(i-1)) | |
| # downsampling | |
| downsampling_func = self._get_downsampling_function(self.downsampling[i-2]) | |
| inputs = downsampling_func(inputs, output_filters, training, self.dimension, | |
| 'downsampling') | |
| for block_index in range(0, self.encoding_block_sizes[i-1]): | |
| inputs = res_block(inputs, output_filters, training, self.dimension, | |
| 'res_%d' % block_index) | |
| # bottom_block = a combination of same_gto and res_block | |
| with tf.variable_scope('bottom_block'): | |
| output_filters = self.first_output_filters * (2**(self.depth-1)) | |
| for block_index in range(0, 1): | |
| current_func = res_block | |
| inputs = current_func(inputs, output_filters, training, self.dimension, | |
| 'block_%d' % block_index) | |
| """ | |
| Note: Identity skip connections are between the output of encoding_block_i and | |
| the output of upsampling in decoding_block_i, i = 1, 2, ..., depth-1. | |
| skip_inputs[i] is the output of encoding_block_i now. | |
| len(skip_inputs) == depth - 1 | |
| skip_inputs[depth-2] should be combined during decoding_block_depth-1 | |
| skip_inputs[0] should be combined during decoding_block_1 | |
| """ | |
| # decoding_block_j (up) = upsampling + zero or more res_block, j = depth-1, depth-2, ..., 1 | |
| for j in range(self.depth-1, 0, -1): | |
| with tf.variable_scope('decoding_block_%d' % j): | |
| output_filters = self.first_output_filters * (2**(j-1)) | |
| # upsampling | |
| upsampling_func = up_transposed_convolution | |
| inputs = upsampling_func(inputs, output_filters, training, self.dimension, | |
| 'upsampling') | |
| # combine with skip connections | |
| if self.skip_method == 'add': | |
| inputs = tf.add(inputs, skip_inputs[j-1]) | |
| elif self.skip_method == 'concat': | |
| inputs = tf.concat([inputs, skip_inputs[j-1]], axis=-1) | |
| for block_index in range(0, self.decoding_block_sizes[self.depth-1-j]): | |
| inputs = res_block(inputs, output_filters, training, self.dimension, | |
| 'res_%d' % block_index) | |
| return inputs | |
| def _get_downsampling_function(self, name): | |
| if name == 'down_res_block': | |
| return down_res_block | |
| elif name == 'convolution': | |
| return down_convolution | |
| else: | |
| raise ValueError("Unsupported function: %s." % (name)) |