Spaces:
Sleeping
Sleeping
| # Streamlit App: PyTorch Geometric Structure Visualization | |
| import streamlit as st | |
| import torch | |
| from torch_geometric.data import Data | |
| import numpy as np | |
| import plotly.graph_objs as go | |
| import math # Moved import to the top | |
| # Function Definitions | |
| def generate_sierpinski_triangle(depth): | |
| # Generate the vertices of the initial triangle | |
| vertices = np.array([ | |
| [0, 0, 0], | |
| [1, 0, 0], | |
| [0.5, np.sqrt(3)/2, 0] | |
| ]) | |
| # Function to recursively generate points | |
| def recurse_triangle(v1, v2, v3, depth): | |
| if depth == 0: | |
| return [v1, v2, v3] | |
| else: | |
| # Calculate midpoints | |
| m12 = (v1 + v2) / 2 | |
| m23 = (v2 + v3) / 2 | |
| m31 = (v3 + v1) / 2 | |
| # Recursively subdivide | |
| return (recurse_triangle(v1, m12, m31, depth - 1) + | |
| recurse_triangle(m12, v2, m23, depth - 1) + | |
| recurse_triangle(m31, m23, v3, depth - 1)) | |
| points = recurse_triangle(vertices[0], vertices[1], vertices[2], depth) | |
| pos = np.array(points) | |
| # Remove duplicate points | |
| pos = np.unique(pos, axis=0) | |
| # Create edges between points | |
| edge_index = [] | |
| for i in range(0, len(pos), 3): | |
| idx = i % len(pos) | |
| edge_index.extend([ | |
| [idx, (idx+1)%len(pos)], | |
| [(idx+1)%len(pos), (idx+2)%len(pos)], | |
| [(idx+2)%len(pos), idx] | |
| ]) | |
| edge_index = np.array(edge_index).T | |
| return pos, edge_index | |
| def generate_spiral(turns, points_per_turn): | |
| total_points = turns * points_per_turn | |
| theta_max = 2 * np.pi * turns | |
| theta = np.linspace(0, theta_max, total_points) | |
| z = np.linspace(0, 1, total_points) | |
| r = z # Spiral expanding in radius | |
| x = r * np.cos(theta) | |
| y = r * np.sin(theta) | |
| pos = np.vstack((x, y, z)).T | |
| # Edges connect sequential points | |
| edge_index = np.array([np.arange(total_points - 1), np.arange(1, total_points)]) | |
| return pos, edge_index | |
| def generate_plant(iterations, angle): | |
| axiom = "F" | |
| rules = {"F": "F[+F]F[-F]F"} | |
| def expand_axiom(axiom, rules, iterations): | |
| for _ in range(iterations): | |
| new_axiom = "" | |
| for ch in axiom: | |
| new_axiom += rules.get(ch, ch) | |
| axiom = new_axiom | |
| return axiom | |
| final_axiom = expand_axiom(axiom, rules, iterations) | |
| stack = [] | |
| pos_list = [] | |
| edge_list = [] | |
| current_pos = np.array([0.0, 0.0, 0.0]) | |
| pos_list.append(current_pos.copy()) | |
| idx = 0 | |
| direction = np.array([0.0, 1.0, 0.0]) | |
| for command in final_axiom: | |
| if command == 'F': | |
| next_pos = current_pos + direction | |
| pos_list.append(next_pos.copy()) | |
| edge_list.append([idx, idx + 1]) | |
| current_pos = next_pos | |
| idx += 1 | |
| elif command == '+': | |
| theta = np.radians(angle) | |
| rotation_matrix = rotation_matrix_3d(np.array([0, 0, 1]), theta) | |
| direction = rotation_matrix @ direction | |
| elif command == '-': | |
| theta = np.radians(-angle) | |
| rotation_matrix = rotation_matrix_3d(np.array([0, 0, 1]), theta) | |
| direction = rotation_matrix @ direction | |
| elif command == '[': | |
| stack.append((current_pos.copy(), direction.copy(), idx)) | |
| elif command == ']': | |
| current_pos, direction, idx = stack.pop() | |
| pos = np.array(pos_list) | |
| edge_index = np.array(edge_list).T | |
| return pos, edge_index | |
| def rotation_matrix_3d(axis, theta): | |
| # Return the rotation matrix associated with rotation about the given axis by theta radians. | |
| axis = axis / np.linalg.norm(axis) | |
| a = np.cos(theta) | |
| b, c, d = axis * np.sin(theta) | |
| return np.array([ | |
| [a + (1 - a) * axis[0] * axis[0], | |
| (1 - a) * axis[0] * axis[1] - axis[2] * np.sin(theta), | |
| (1 - a) * axis[0] * axis[2] + axis[1] * np.sin(theta)], | |
| [(1 - a) * axis[1] * axis[0] + axis[2] * np.sin(theta), | |
| a + (1 - a) * axis[1] * axis[1], | |
| (1 - a) * axis[1] * axis[2] - axis[0] * np.sin(theta)], | |
| [(1 - a) * axis[2] * axis[0] - axis[1] * np.sin(theta), | |
| (1 - a) * axis[2] * axis[1] + axis[0] * np.sin(theta), | |
| a + (1 - a) * axis[2] * axis[2]] | |
| ]) | |
| def plot_graph_3d(pos, edge_index): | |
| x, y, z = pos[:, 0], pos[:, 1], pos[:, 2] | |
| edge_x = [] | |
| edge_y = [] | |
| edge_z = [] | |
| for i in range(edge_index.shape[1]): | |
| src = edge_index[0, i] | |
| dst = edge_index[1, i] | |
| edge_x += [x[src], x[dst], None] | |
| edge_y += [y[src], y[dst], None] | |
| edge_z += [z[src], z[dst], None] | |
| edge_trace = go.Scatter3d( | |
| x=edge_x, y=edge_y, z=edge_z, | |
| line=dict(width=2, color='gray'), | |
| hoverinfo='none', | |
| mode='lines') | |
| node_trace = go.Scatter3d( | |
| x=x, y=y, z=z, | |
| mode='markers', | |
| marker=dict( | |
| size=4, | |
| color='red', | |
| ), | |
| hoverinfo='none' | |
| ) | |
| fig = go.Figure(data=[edge_trace, node_trace]) | |
| fig.update_layout( | |
| scene=dict( | |
| xaxis_title='X', | |
| yaxis_title='Y', | |
| zaxis_title='Z', | |
| aspectmode='data' | |
| ), | |
| showlegend=False, | |
| margin=dict(l=0, r=0, b=0, t=0) # Optional: adjust margins | |
| ) | |
| return fig | |
| # Main App Code | |
| def main(): | |
| st.title("PyTorch Geometric Structure Visualization") | |
| structure_type = st.sidebar.selectbox( | |
| "Select Structure Type", | |
| ("Sierpinski Triangle", "Spiral", "Plant Structure") | |
| ) | |
| if structure_type == "Sierpinski Triangle": | |
| depth = st.sidebar.slider("Recursion Depth", 0, 5, 3) | |
| pos, edge_index = generate_sierpinski_triangle(depth) | |
| data = Data(pos=torch.tensor(pos, dtype=torch.float), edge_index=torch.tensor(edge_index, dtype=torch.long)) | |
| fig = plot_graph_3d(pos, edge_index) | |
| st.plotly_chart(fig) | |
| elif structure_type == "Spiral": | |
| turns = st.sidebar.slider("Number of Turns", 1, 20, 5) | |
| points_per_turn = st.sidebar.slider("Points per Turn", 10, 100, 50) | |
| pos, edge_index = generate_spiral(turns, points_per_turn) | |
| data = Data(pos=torch.tensor(pos, dtype=torch.float), edge_index=torch.tensor(edge_index, dtype=torch.long)) | |
| fig = plot_graph_3d(pos, edge_index) | |
| st.plotly_chart(fig) | |
| elif structure_type == "Plant Structure": | |
| iterations = st.sidebar.slider("L-system Iterations", 1, 5, 3) | |
| angle = st.sidebar.slider("Branching Angle", 15, 45, 25) | |
| pos, edge_index = generate_plant(iterations, angle) | |
| data = Data(pos=torch.tensor(pos, dtype=torch.float), edge_index=torch.tensor(edge_index, dtype=torch.long)) | |
| fig = plot_graph_3d(pos, edge_index) | |
| st.plotly_chart(fig) | |
| if __name__ == "__main__": | |
| main() | |