Advertisement
HawkeyeHS

Composer

Sep 2nd, 2024 (edited)
242
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 8.11 KB | None | 0 0
  1. import tensorflow as tf
  2. import matplotlib.pyplot as plt
  3. import networkx as nx
  4. from queue import Queue
  5. import numpy as np
  6.  
  7. def construct_model(G, start_node_name):
  8.     # BFS traversal
  9.     visited=set()
  10.     q = Queue()
  11.     q.put((start_node_name, None))
  12.     visited.add(start_node_name)
  13.     model_inputs=G.nodes[start_node_name]['data'].input_shape
  14.    
  15.     while not q.empty():
  16.         current_node_name, parent = q.get()
  17.         final_output=parent
  18.         for neighbor_name in G.neighbors(current_node_name):
  19.             if neighbor_name not in visited:
  20.                 current_node = G.nodes[current_node_name]['data']
  21.                 neighbor = G.nodes[neighbor_name]['data']
  22.                
  23.                 if current_node.type=="input":
  24.                     if neighbor.type=="model":
  25.                         q.put((neighbor_name, None))
  26.                        
  27.                 elif current_node.type=="model":
  28.                     if neighbor.type=="join":
  29.                         if parent is not None:
  30.                             q.put((neighbor_name, None))
  31.                             neighbor.model_outputs.append(parent)
  32.                         else:
  33.                             neighbor.model_outputs.append(current_node.model.output)
  34.                             q.put((neighbor_name, None))
  35.                            
  36.                         neighbor.nodes_connected= neighbor.nodes_connected - 1
  37.                        
  38.                     elif neighbor.type == "model":
  39.                         if parent is not None:
  40.                             q.put((neighbor_name,neighbor.model(parent)))
  41.                         else:
  42.                             q.put((neighbor_name,neighbor.model(current_node.model.output)))
  43.                            
  44.                 elif current_node.type=="join" and neighbor.type=="model":
  45.                     q.put((neighbor_name,neighbor.model(current_node.model_outputs)))
  46.                    
  47.                    
  48.                 if neighbor.type=="join":
  49.                     if neighbor.nodes_connected==0:
  50.                         visited.add(neighbor_name)
  51.                 else:
  52.                     visited.add(neighbor_name)
  53.    
  54.     model = tf.keras.Model(inputs=model_inputs, outputs=final_output)
  55.     model.save('final_model.h5')
  56.     return model        
  57.          
  58. class Node:
  59.     def __init__(self, type, name):
  60.         self.type = type
  61.         self.name = name
  62.        
  63. class InputNode(Node):
  64.     def __init__(self, type, name, input_shape):
  65.         super().__init__(type, name)      
  66.         self.input_shape=input_shape
  67.        
  68. class JoinNode(Node):
  69.     def __init__(self, type, name, nodes_connected=3, join_method='concatenate'):
  70.         super().__init__(type, name)
  71.         self.nodes_connected=nodes_connected
  72.         self.join_method=join_method
  73.         self.model_outputs=[]
  74.        
  75. class ModelNode(Node):
  76.     def __init__(self, type, name, model):
  77.         super().__init__(type, name)
  78.         self.model=model
  79.  
  80. def encoder_decoder_model_1(input_shape):
  81.     inputs = tf.keras.Input(shape=input_shape)
  82.     x = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
  83.     x = tf.keras.layers.MaxPooling2D((2, 2), padding='same')(x)
  84.     x = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same')(x)
  85.     encoded = tf.keras.layers.MaxPooling2D((2, 2), padding='same')(x)
  86.    
  87.     x = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same')(encoded)
  88.     x = tf.keras.layers.UpSampling2D((2, 2))(x)
  89.     x = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', padding='same')(x)
  90.     x = tf.keras.layers.UpSampling2D((2, 2))(x)
  91.     decoded = tf.keras.layers.Conv2D(3, (3, 3), activation='sigmoid', padding='same')(x)
  92.    
  93.     model = tf.keras.Model(inputs, decoded)
  94.     return model
  95.  
  96.  
  97. def encoder_decoder_model_2(input_shape):
  98.     inputs = tf.keras.Input(shape=input_shape)
  99.     x = tf.keras.layers.Conv2D(16, (3, 3), activation='relu', padding='same')(inputs)
  100.     x = tf.keras.layers.MaxPooling2D((2, 2), padding='same')(x)
  101.     x = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', padding='same')(x)
  102.     encoded = tf.keras.layers.MaxPooling2D((2, 2), padding='same')(x)
  103.    
  104.     x = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', padding='same')(encoded)
  105.     x = tf.keras.layers.UpSampling2D((2, 2))(x)
  106.     x = tf.keras.layers.Conv2D(16, (3, 3), activation='relu', padding='same')(x)
  107.     x = tf.keras.layers.UpSampling2D((2, 2))(x)
  108.     decoded = tf.keras.layers.Conv2D(3, (3, 3), activation='sigmoid', padding='same')(x)
  109.    
  110.     model = tf.keras.Model(inputs, decoded)
  111.     return model
  112.  
  113.  
  114. def encoder_decoder_model_3(input_shape):
  115.     inputs = tf.keras.Input(shape=input_shape)
  116.     x = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same')(inputs)
  117.     x = tf.keras.layers.MaxPooling2D((2, 2), padding='same')(x)
  118.     x = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same')(x)
  119.     encoded = tf.keras.layers.MaxPooling2D((2, 2), padding='same')(x)
  120.    
  121.     x = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same')(encoded)
  122.     x = tf.keras.layers.UpSampling2D((2, 2))(x)
  123.     x = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same')(x)
  124.     x = tf.keras.layers.UpSampling2D((2, 2))(x)
  125.     decoded = tf.keras.layers.Conv2D(3, (3, 3), activation='sigmoid', padding='same')(x)
  126.    
  127.     model = tf.keras.Model(inputs, decoded)
  128.     return model
  129.  
  130. def regression_model(input_shapes):
  131.     inputs1 = tf.keras.Input(shape=input_shapes[0])
  132.     inputs2 = tf.keras.Input(shape=input_shapes[1])
  133.     inputs3 = tf.keras.Input(shape=input_shapes[2])
  134.    
  135.     x1 = tf.keras.layers.Flatten()(inputs1)
  136.     x2 = tf.keras.layers.Flatten()(inputs2)
  137.     x3 = tf.keras.layers.Flatten()(inputs3)
  138.    
  139.     concatenated = tf.keras.layers.Concatenate()([x1, x2, x3])
  140.     x = tf.keras.layers.Dense(64, activation='relu')(concatenated)
  141.     x = tf.keras.layers.Dense(32, activation='relu')(x)
  142.     output = tf.keras.layers.Dense(1)(x)
  143.    
  144.     model = tf.keras.Model(inputs=[inputs1, inputs2, inputs3], outputs=output)
  145.     return model
  146.  
  147.  
  148. input_shape = (28, 28, 3)
  149. model1 = encoder_decoder_model_1(input_shape)
  150. model2 = encoder_decoder_model_2(input_shape)
  151. model3 = encoder_decoder_model_3(input_shape)
  152.  
  153.  
  154. input_shapes = [input_shape] * 3
  155. reg_model = regression_model(input_shapes)
  156.  
  157. input_node = InputNode(name='input_node', type="input", input_shape=[model1.input,model2.input,model3.input])
  158. join_node_1 = JoinNode(name='Join Node 1', type='join')
  159. model_1_node = ModelNode(model=model1, name='Model 1', type='model')
  160. model_2_node = ModelNode(model=model2, name='Model 2', type='model')
  161. model_3_node = ModelNode(model=model3, name='Model 3', type='model')
  162. reg_model_node = ModelNode(model=reg_model, name='Regression Model', type='model')
  163.  
  164. # Create a NetworkX graph
  165. G = nx.Graph()
  166.  
  167. # Add nodes with custom labels and data attributes
  168. G.add_node(input_node.name, data=input_node)
  169. G.add_node(model_1_node.name, data=model_1_node)
  170. G.add_node(model_2_node.name, data=model_2_node)
  171. G.add_node(model_3_node.name, data=model_3_node)
  172. G.add_node(join_node_1.name, data=join_node_1)
  173. G.add_node(reg_model_node.name, data=reg_model_node)
  174.  
  175. # Add edges between the nodes
  176. G.add_edge(input_node.name, model_1_node.name)
  177. G.add_edge(input_node.name, model_2_node.name)
  178. G.add_edge(input_node.name, model_3_node.name)
  179. G.add_edge(model_1_node.name, join_node_1.name)
  180. G.add_edge(model_2_node.name, join_node_1.name)
  181. G.add_edge(model_3_node.name, join_node_1.name)
  182. G.add_edge(join_node_1.name, reg_model_node.name)
  183.  
  184. final_output = construct_model(G, input_node.name)
  185. image = np.random.random((1, 28, 28, 3))  # Batch size of 1
  186. print(final_output.predict([image,image,image]))
  187. # Plot the model
  188. tf.keras.utils.plot_model(final_output, to_file='model_plot_1.png', show_shapes=True, show_layer_names=True)
  189.  
  190. plt.figure(figsize=(10, 10))
  191. pos = nx.spring_layout(G)  # positions for all nodes
  192. nx.draw(G, pos, with_labels=True, node_size=3000, node_color='lightblue', font_size=10, font_weight='bold')
  193. plt.title('Model Graph with Custom Labels')
  194. # plt.show()
  195. plt.savefig('model_graph_1.png')
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement