CellModules
visualization.py
Go to the documentation of this file.
1import networkx as nx
2import matplotlib.pyplot as plt
3import matplotlib.colors as mcolors
4import numpy as np
5import plotly.graph_objs as go
6from typing import Dict, Any, Optional
7
8def drawCellLineage(data: 'Dict[str, Any]', cell_attribute: str, use_plotly: bool = False, color_scale: str = 'Spectral_r', discrete_colors: Optional[Dict[Any, str]] = None) -> None:
9 """
10 Draws the lineage of cells in the simulation based on a specified cell attribute.
11
12 Parameters
13 ----------
14 - data : dict
15 The dictionary containing cell lineage data and enumeration map. It should have the keys 'data' and 'enumMap'.
16 - cell_attribute : str
17 The name of the cell attribute to visualize. This can be a discrete or continuous attribute.
18 - use_plotly : bool, optional
19 If True, use Plotly for visualization. If False, use Matplotlib. Default is False.
20 - color_scale : str, optional
21 The color scale to use for visualization. Default is 'Spectral_r'.
22 - discrete_colors : dict, optional
23 A dictionary mapping discrete attribute values to specific colors. If provided, this will be used for coloring discrete attributes.
24
25 Returns
26 -------
27 - None
28 This function does not return any value. It generates and displays a lineage plot.
29
30 Example
31 -------
32 ```python
33 simu = Simu(params)
34 lineage_data = simu.recordLineage()
35 discrete_colors = {'A': '#1f77b4', 'B': '#ff7f0e', 'C': '#2ca02c'}
36 drawCellLineage(lineage_data, 'state', use_plotly=True, color_scale='Viridis', discrete_colors=discrete_colors)
37 ```
38
39 Notes
40 -----
41 - If the `cell_attribute` is an Enum, the corresponding values are replaced and the data type is set to 'category'.
42 - For continuous attributes, a color map is used to represent the range of values.
43 """
44 charac_mapping = None
45 cell_attribute=cell_attribute.lower()
46 if cell_attribute in data['enumMap']:
47 charac_mapping = data['enumMap'][cell_attribute]
48 data = data['data']
49 # Initialiser le graphe
50 G = nx.DiGraph()
51
52 # Ajouter un nœud fantôme
53 G.add_node(-1, birth_step=-1, death_step=-1, characteristic=[])
54
55 # Ajouter des nœuds et des arêtes
56 for cell in data:
57 G.add_node(cell['id'], birth_step=cell['birth_step'], death_step=cell['death_step'], characteristic=cell[cell_attribute])
58 if cell['mother_id'] != -1:
59 G.add_edge(cell['mother_id'], cell['id'])
60 else:
61 G.add_edge(-1, cell['id'])
62
63 # Fonction pour identifier les arbres et calculer leurs hauteurs
64 def get_tree_positions(G, y_spacing=1):
65 pos = {}
66 widths = {}
67 y_offset = 0
68
69 def calculate_widths(node):
70 if node not in widths:
71 if G.out_degree(node) == 0:
72 widths[node] = 1
73 else:
74 widths[node] = sum(calculate_widths(child) for child in G.successors(node))
75 return widths[node]
76
77 def assign_position(node, x, y):
78 pos[node] = (x, y)
79 return y + y_spacing
80
81 def dfs(node, x, y):
82 if node in pos:
83 return y
84 width = widths[node]
85 start_y = y
86 for child in G.successors(node):
87 child_width = widths[child]
88 child_x = G.nodes[child]['birth_step']
89 y = dfs(child, child_x, start_y)
90 start_y += child_width * y_spacing
91 if G.out_degree(node) > 0:
92 y = (pos[list(G.successors(node))[0]][1] + pos[list(G.successors(node))[-1]][1]) / 2
93 y = assign_position(node, x, y)
94 return y
95
96 # Calculer les largeurs des sous-arbres
97 for node in G.nodes():
98 calculate_widths(node)
99
100 # Positionner les nœuds sans parents en premier
101 y_offset = dfs(-1, -1, y_offset) + y_spacing
102
103 return pos
104
105 # Utiliser un espacement vertical plus grand pour éviter les superpositions
106 pos = get_tree_positions(G, y_spacing=2)
107
108 # Calculer le nombre de feuilles
109 num_leaves = sum(1 for node in G.nodes if G.out_degree(node) == 0 and node != -1)
110
111 # Déterminer la hauteur du plot en fonction du nombre de feuilles
112 plot_height = num_leaves * 30 # Ajustez le facteur de multiplication selon vos besoins
113
114 # Extraire les positions pour les nœuds
115 node_x = []
116 node_y = []
117 vertical_lines = []
118 horizontal_lines = []
119 life_traces = []
120
121 # Déterminer l'échelle de couleurs pour les caractéristiques discrètes
122 if charac_mapping is not None:
123 if discrete_colors:
124 color_mapping = {v:discrete_colors[k] for v,k in charac_mapping.items()}
125 norm = mcolors.BoundaryNorm(boundaries=np.arange(len(color_mapping) + 1) - 0.5, ncolors=len(color_mapping))
126 else:
127 unique_chars = sorted(set(val for cell in data for val in cell[cell_attribute]))
128 colors = plt.cm.tab10(np.linspace(0, 1, len(unique_chars)))
129 color_mapping = {val: mcolors.rgb2hex(colors[i]) for i, val in enumerate(unique_chars)}
130 norm = mcolors.BoundaryNorm(boundaries=np.arange(len(unique_chars) + 1) - 0.5, ncolors=len(unique_chars))
131 else:
132 cmap = plt.get_cmap(color_scale)
133 vmin = np.inf
134 vmax = -np.inf
135 for cell in data:
136 vmin = min(np.min(cell[cell_attribute]), vmin)
137 vmax = max(np.max(cell[cell_attribute]), vmax)
138 norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
139
140 for node in G.nodes():
141 if node == -1:
142 continue
143 x, y = pos[node]
144 node_x.append(x)
145 node_y.append(y)
146 birth_step = G.nodes[node]['birth_step']
147 death_step = G.nodes[node]['death_step']
148
149 characteristic = G.nodes[node]['characteristic']
150 steps = list(range(birth_step, death_step))
151
152 # Fractionner les lignes seulement si la caractéristique change
153 if charac_mapping is not None:
154 current_char = characteristic[0]
155 current_color = color_mapping[current_char]
156 else:
157 current_color = mcolors.rgb2hex(cmap(norm(characteristic[0])))
158 start_step = birth_step
159 for i in range(1, len(steps)):
160 if charac_mapping is not None:
161 new_char = characteristic[i]
162 new_color = color_mapping[new_char]
163 if new_char != current_char:
164 vertical_lines.append(((start_step, y), (steps[i], y), current_color))
165 current_char = new_char
166 current_color = new_color
167 start_step = steps[i]
168 else:
169 new_color = mcolors.rgb2hex(cmap(norm(characteristic[i])))
170 if new_color != current_color:
171 vertical_lines.append(((start_step, y), (steps[i], y), current_color))
172 current_color = new_color
173 start_step = steps[i]
174 vertical_lines.append(((start_step, y), (death_step, y), current_color))
175
176 for edge in G.edges():
177 parent = edge[0]
178 child = edge[1]
179 if parent == -1:
180 continue
181 parent_x, parent_y = pos[parent]
182 child_x, child_y = pos[child]
183 division_step = G.nodes[child]['birth_step']
184 horizontal_lines.append(((division_step, parent_y), (division_step, child_y)))
185
186 if use_plotly:
187 # Créer des traces pour les lignes de vie (lignes verticales) en premier
188 for line in vertical_lines:
189 (x0, y0), (x1, y1), color = line
190 life_traces.append(go.Scatter(
191 x=[x0, x1],
192 y=[y0, y0],
193 mode='lines',
194 line=dict(width=4, color=color),
195 hoverinfo='none',
196 showlegend=False
197 ))
198
199 # Créer des traces pour les arêtes (lignes horizontales pour les divisions)
200 edge_x = []
201 edge_y = []
202
203 for line in horizontal_lines:
204 (x0, y0), (x1, y1) = line
205 edge_x.extend([x0, x1, None])
206 edge_y.extend([y0, y1, None])
207
208 edge_trace = go.Scatter(
209 x=edge_x, y=edge_y,
210 line=dict(width=2, color='#888'),
211 hoverinfo='none',
212 showlegend=False,
213 mode='lines'
214 )
215
216 # Ajouter des traces pour les légendes des caractéristiques discrètes
217 if charac_mapping is not None:
218 for char_val, char_name in charac_mapping.items():
219 life_traces.append(go.Scatter(
220 x=[None],
221 y=[None],
222 mode='lines',
223 line=dict(width=2, color=color_mapping[char_val]),
224 hoverinfo='none',
225 showlegend=True,
226 name=char_name
227 ))
228 else:
229 # Ajouter une colorbar pour les lignes verticales
230 colorbar_trace = go.Scatter(
231 x=[None],
232 y=[None],
233 mode='markers',
234 marker=dict(
235 colorscale=color_scale,
236 cmin=vmin,
237 cmax=vmax,
238 colorbar=dict(
239 thickness=15,
240 title=cell_attribute,
241 xanchor='left',
242 titleside='right'
243 )
244 ),
245 showlegend=False
246 )
247 life_traces.append(colorbar_trace)
248
249 # Créer la figure finale
250 fig = go.Figure(data=life_traces + [edge_trace],
251 layout=go.Layout(
252 title='Lineage tree of Cells',
253 titlefont_size=16,
254 showlegend=True,
255 hovermode='closest',
256 margin=dict(b=20, l=5, r=5, t=40),
257 xaxis=dict(showgrid=True, zeroline=True, showticklabels=True, range=[min(node_x), max(node_x)]),
258 yaxis=dict(showgrid=True, zeroline=True, showticklabels=False, range=[min(node_y) - 1, max(node_y) + 1]),
259 height=plot_height # Ajuster la hauteur du plot
260 )
261 )
262
263 fig.show()
264 else:
265 fig, ax = plt.subplots(figsize=(10, plot_height / 100)) # Ajuster la hauteur du plot
266
267 # Tracer les lignes de vie (lignes verticales) en premier
268 for line in vertical_lines:
269 (x0, y0), (x1, y1), color = line
270 ax.plot([x0, x1], [y0, y0], color=color, linewidth=4, zorder=3)
271
272 # Tracer les lignes horizontales
273 for line in horizontal_lines:
274 (x0, y0), (x1, y1) = line
275 ax.plot([x0, x1], [y0, y1], color='#888', linewidth=2, zorder=1)
276
277 # Tracer les nœuds
278 # ax.scatter(node_x, node_y, color='black', s=10, zorder=2)
279
280 if charac_mapping is not None:
281 # Ajouter des légendes pour les caractéristiques discrètes
282 handles = [plt.Line2D([0], [0], color=color_mapping[char_val], lw=2, label=char_name)
283 for char_val, char_name in charac_mapping.items()]
284 ax.legend(handles=handles, title=cell_attribute, loc='center left', bbox_to_anchor=(1, 0.5))
285 else:
286 # Ajouter une colorbar
287 sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
288 sm.set_array([])
289 cbar = plt.colorbar(sm, ax=ax, fraction=0.02, pad=0.04) # Ajuster fraction et pad
290 cbar.set_label(cell_attribute)
291
292 ax.set_title('Lineage tree of Cells')
293 ax.set_xlabel('Step')
294 ax.set_ylabel('Cells')
295 ax.grid(True)
296
297 # Ajuster les limites des axes pour s'adapter parfaitement au contenu
298 ax.set_xlim(min(node_x), max(node_x))
299 ax.set_ylim(min(node_y) - 1, max(node_y) + 1)
300
301 plt.tight_layout()
302 plt.show()
None drawCellLineage('Dict[str, Any]' data, str cell_attribute, bool use_plotly=False, str color_scale='Spectral_r', Optional[Dict[Any, str]] discrete_colors=None)
Definition: visualization.py:8