Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import tensorflow as tf
- import matplotlib.pyplot as plt
- import networkx as nx
- from queue import Queue
- import numpy as np
- def construct_model(G, start_node_name):
- # BFS traversal
- visited=set()
- q = Queue()
- q.put((start_node_name, None))
- visited.add(start_node_name)
- model_inputs=G.nodes[start_node_name]['data'].input_shape
- while not q.empty():
- current_node_name, parent = q.get()
- final_output=parent
- for neighbor_name in G.neighbors(current_node_name):
- if neighbor_name not in visited:
- current_node = G.nodes[current_node_name]['data']
- neighbor = G.nodes[neighbor_name]['data']
- if current_node.type=="input":
- if neighbor.type=="model":
- q.put((neighbor_name, None))
- elif current_node.type=="model":
- if neighbor.type=="join":
- if parent is not None:
- q.put((neighbor_name, None))
- neighbor.model_outputs.append(parent)
- else:
- neighbor.model_outputs.append(current_node.model.output)
- q.put((neighbor_name, None))
- neighbor.nodes_connected= neighbor.nodes_connected - 1
- elif neighbor.type == "model":
- if parent is not None:
- q.put((neighbor_name,neighbor.model(parent)))
- else:
- q.put((neighbor_name,neighbor.model(current_node.model.output)))
- elif current_node.type=="join" and neighbor.type=="model":
- q.put((neighbor_name,neighbor.model(current_node.model_outputs)))
- if neighbor.type=="join":
- if neighbor.nodes_connected==0:
- visited.add(neighbor_name)
- else:
- visited.add(neighbor_name)
- model = tf.keras.Model(inputs=model_inputs, outputs=final_output)
- model.save('final_model.h5')
- return model
- class Node:
- def __init__(self, type, name):
- self.type = type
- self.name = name
- class InputNode(Node):
- def __init__(self, type, name, input_shape):
- super().__init__(type, name)
- self.input_shape=input_shape
- class JoinNode(Node):
- def __init__(self, type, name, nodes_connected=3, join_method='concatenate'):
- super().__init__(type, name)
- self.nodes_connected=nodes_connected
- self.join_method=join_method
- self.model_outputs=[]
- class ModelNode(Node):
- def __init__(self, type, name, model):
- super().__init__(type, name)
- self.model=model
- def encoder_decoder_model_1(input_shape):
- inputs = tf.keras.Input(shape=input_shape)
- x = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
- x = tf.keras.layers.MaxPooling2D((2, 2), padding='same')(x)
- x = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same')(x)
- encoded = tf.keras.layers.MaxPooling2D((2, 2), padding='same')(x)
- x = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same')(encoded)
- x = tf.keras.layers.UpSampling2D((2, 2))(x)
- x = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', padding='same')(x)
- x = tf.keras.layers.UpSampling2D((2, 2))(x)
- decoded = tf.keras.layers.Conv2D(3, (3, 3), activation='sigmoid', padding='same')(x)
- model = tf.keras.Model(inputs, decoded)
- return model
- def encoder_decoder_model_2(input_shape):
- inputs = tf.keras.Input(shape=input_shape)
- x = tf.keras.layers.Conv2D(16, (3, 3), activation='relu', padding='same')(inputs)
- x = tf.keras.layers.MaxPooling2D((2, 2), padding='same')(x)
- x = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', padding='same')(x)
- encoded = tf.keras.layers.MaxPooling2D((2, 2), padding='same')(x)
- x = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', padding='same')(encoded)
- x = tf.keras.layers.UpSampling2D((2, 2))(x)
- x = tf.keras.layers.Conv2D(16, (3, 3), activation='relu', padding='same')(x)
- x = tf.keras.layers.UpSampling2D((2, 2))(x)
- decoded = tf.keras.layers.Conv2D(3, (3, 3), activation='sigmoid', padding='same')(x)
- model = tf.keras.Model(inputs, decoded)
- return model
- def encoder_decoder_model_3(input_shape):
- inputs = tf.keras.Input(shape=input_shape)
- x = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same')(inputs)
- x = tf.keras.layers.MaxPooling2D((2, 2), padding='same')(x)
- x = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same')(x)
- encoded = tf.keras.layers.MaxPooling2D((2, 2), padding='same')(x)
- x = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same')(encoded)
- x = tf.keras.layers.UpSampling2D((2, 2))(x)
- x = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same')(x)
- x = tf.keras.layers.UpSampling2D((2, 2))(x)
- decoded = tf.keras.layers.Conv2D(3, (3, 3), activation='sigmoid', padding='same')(x)
- model = tf.keras.Model(inputs, decoded)
- return model
- def regression_model(input_shapes):
- inputs1 = tf.keras.Input(shape=input_shapes[0])
- inputs2 = tf.keras.Input(shape=input_shapes[1])
- inputs3 = tf.keras.Input(shape=input_shapes[2])
- x1 = tf.keras.layers.Flatten()(inputs1)
- x2 = tf.keras.layers.Flatten()(inputs2)
- x3 = tf.keras.layers.Flatten()(inputs3)
- concatenated = tf.keras.layers.Concatenate()([x1, x2, x3])
- x = tf.keras.layers.Dense(64, activation='relu')(concatenated)
- x = tf.keras.layers.Dense(32, activation='relu')(x)
- output = tf.keras.layers.Dense(1)(x)
- model = tf.keras.Model(inputs=[inputs1, inputs2, inputs3], outputs=output)
- return model
- input_shape = (28, 28, 3)
- model1 = encoder_decoder_model_1(input_shape)
- model2 = encoder_decoder_model_2(input_shape)
- model3 = encoder_decoder_model_3(input_shape)
- input_shapes = [input_shape] * 3
- reg_model = regression_model(input_shapes)
- input_node = InputNode(name='input_node', type="input", input_shape=[model1.input,model2.input,model3.input])
- join_node_1 = JoinNode(name='Join Node 1', type='join')
- model_1_node = ModelNode(model=model1, name='Model 1', type='model')
- model_2_node = ModelNode(model=model2, name='Model 2', type='model')
- model_3_node = ModelNode(model=model3, name='Model 3', type='model')
- reg_model_node = ModelNode(model=reg_model, name='Regression Model', type='model')
- # Create a NetworkX graph
- G = nx.Graph()
- # Add nodes with custom labels and data attributes
- G.add_node(input_node.name, data=input_node)
- G.add_node(model_1_node.name, data=model_1_node)
- G.add_node(model_2_node.name, data=model_2_node)
- G.add_node(model_3_node.name, data=model_3_node)
- G.add_node(join_node_1.name, data=join_node_1)
- G.add_node(reg_model_node.name, data=reg_model_node)
- # Add edges between the nodes
- G.add_edge(input_node.name, model_1_node.name)
- G.add_edge(input_node.name, model_2_node.name)
- G.add_edge(input_node.name, model_3_node.name)
- G.add_edge(model_1_node.name, join_node_1.name)
- G.add_edge(model_2_node.name, join_node_1.name)
- G.add_edge(model_3_node.name, join_node_1.name)
- G.add_edge(join_node_1.name, reg_model_node.name)
- final_output = construct_model(G, input_node.name)
- image = np.random.random((1, 28, 28, 3)) # Batch size of 1
- print(final_output.predict([image,image,image]))
- # Plot the model
- tf.keras.utils.plot_model(final_output, to_file='model_plot_1.png', show_shapes=True, show_layer_names=True)
- plt.figure(figsize=(10, 10))
- pos = nx.spring_layout(G) # positions for all nodes
- nx.draw(G, pos, with_labels=True, node_size=3000, node_color='lightblue', font_size=10, font_weight='bold')
- plt.title('Model Graph with Custom Labels')
- # plt.show()
- plt.savefig('model_graph_1.png')
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement