2import matplotlib.pyplot
as plt
3import matplotlib.colors
as mcolors
5import plotly.graph_objs
as go
6from typing
import Dict, Any, Optional
8def drawCellLineage(data:
'Dict[str, Any]', cell_attribute: str, use_plotly: bool =
False, color_scale: str =
'Spectral_r', discrete_colors: Optional[Dict[Any, str]] =
None) ->
None:
10 Draws the lineage of cells in the simulation based on a specified cell attribute.
15 The dictionary containing cell lineage data
and enumeration map. It should have the keys
'data' and 'enumMap'.
16 - cell_attribute : str
17 The name of the cell attribute to visualize. This can be a discrete
or continuous attribute.
18 - use_plotly : bool, optional
19 If
True, use Plotly
for visualization. If
False, use Matplotlib. Default
is False.
20 - color_scale : str, optional
21 The color scale to use
for visualization. Default
is 'Spectral_r'.
22 - discrete_colors : dict, optional
23 A dictionary mapping discrete attribute values to specific colors. If provided, this will be used
for coloring discrete attributes.
28 This function does
not return any value. It generates
and displays a lineage plot.
34 lineage_data = simu.recordLineage()
35 discrete_colors = {
'A':
'#1f77b4',
'B':
'#ff7f0e',
'C':
'#2ca02c'}
36 drawCellLineage(lineage_data,
'state', use_plotly=
True, color_scale=
'Viridis', discrete_colors=discrete_colors)
41 - If the `cell_attribute`
is an Enum, the corresponding values are replaced
and the data type
is set to
'category'.
42 - For continuous attributes, a color map
is used to represent the range of values.
45 cell_attribute=cell_attribute.lower()
46 if cell_attribute
in data[
'enumMap']:
47 charac_mapping = data[
'enumMap'][cell_attribute]
53 G.add_node(-1, birth_step=-1, death_step=-1, characteristic=[])
57 G.add_node(cell[
'id'], birth_step=cell[
'birth_step'], death_step=cell[
'death_step'], characteristic=cell[cell_attribute])
58 if cell[
'mother_id'] != -1:
59 G.add_edge(cell[
'mother_id'], cell[
'id'])
61 G.add_edge(-1, cell[
'id'])
64 def get_tree_positions(G, y_spacing=1):
69 def calculate_widths(node):
70 if node
not in widths:
71 if G.out_degree(node) == 0:
74 widths[node] = sum(calculate_widths(child)
for child
in G.successors(node))
77 def assign_position(node, x, y):
86 for child
in G.successors(node):
87 child_width = widths[child]
88 child_x = G.nodes[child][
'birth_step']
89 y = dfs(child, child_x, start_y)
90 start_y += child_width * y_spacing
91 if G.out_degree(node) > 0:
92 y = (pos[list(G.successors(node))[0]][1] + pos[list(G.successors(node))[-1]][1]) / 2
93 y = assign_position(node, x, y)
97 for node
in G.nodes():
98 calculate_widths(node)
101 y_offset = dfs(-1, -1, y_offset) + y_spacing
106 pos = get_tree_positions(G, y_spacing=2)
109 num_leaves = sum(1
for node
in G.nodes
if G.out_degree(node) == 0
and node != -1)
112 plot_height = num_leaves * 30
118 horizontal_lines = []
122 if charac_mapping
is not None:
124 color_mapping = {v:discrete_colors[k]
for v,k
in charac_mapping.items()}
125 norm = mcolors.BoundaryNorm(boundaries=np.arange(len(color_mapping) + 1) - 0.5, ncolors=len(color_mapping))
127 unique_chars = sorted(set(val
for cell
in data
for val
in cell[cell_attribute]))
128 colors = plt.cm.tab10(np.linspace(0, 1, len(unique_chars)))
129 color_mapping = {val: mcolors.rgb2hex(colors[i])
for i, val
in enumerate(unique_chars)}
130 norm = mcolors.BoundaryNorm(boundaries=np.arange(len(unique_chars) + 1) - 0.5, ncolors=len(unique_chars))
132 cmap = plt.get_cmap(color_scale)
136 vmin = min(np.min(cell[cell_attribute]), vmin)
137 vmax = max(np.max(cell[cell_attribute]), vmax)
138 norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
140 for node
in G.nodes():
146 birth_step = G.nodes[node][
'birth_step']
147 death_step = G.nodes[node][
'death_step']
149 characteristic = G.nodes[node][
'characteristic']
150 steps = list(range(birth_step, death_step))
153 if charac_mapping
is not None:
154 current_char = characteristic[0]
155 current_color = color_mapping[current_char]
157 current_color = mcolors.rgb2hex(cmap(norm(characteristic[0])))
158 start_step = birth_step
159 for i
in range(1, len(steps)):
160 if charac_mapping
is not None:
161 new_char = characteristic[i]
162 new_color = color_mapping[new_char]
163 if new_char != current_char:
164 vertical_lines.append(((start_step, y), (steps[i], y), current_color))
165 current_char = new_char
166 current_color = new_color
167 start_step = steps[i]
169 new_color = mcolors.rgb2hex(cmap(norm(characteristic[i])))
170 if new_color != current_color:
171 vertical_lines.append(((start_step, y), (steps[i], y), current_color))
172 current_color = new_color
173 start_step = steps[i]
174 vertical_lines.append(((start_step, y), (death_step, y), current_color))
176 for edge
in G.edges():
181 parent_x, parent_y = pos[parent]
182 child_x, child_y = pos[child]
183 division_step = G.nodes[child][
'birth_step']
184 horizontal_lines.append(((division_step, parent_y), (division_step, child_y)))
188 for line
in vertical_lines:
189 (x0, y0), (x1, y1), color = line
190 life_traces.append(go.Scatter(
194 line=dict(width=4, color=color),
203 for line
in horizontal_lines:
204 (x0, y0), (x1, y1) = line
205 edge_x.extend([x0, x1,
None])
206 edge_y.extend([y0, y1,
None])
208 edge_trace = go.Scatter(
210 line=dict(width=2, color=
'#888'),
217 if charac_mapping
is not None:
218 for char_val, char_name
in charac_mapping.items():
219 life_traces.append(go.Scatter(
223 line=dict(width=2, color=color_mapping[char_val]),
230 colorbar_trace = go.Scatter(
235 colorscale=color_scale,
240 title=cell_attribute,
247 life_traces.append(colorbar_trace)
250 fig = go.Figure(data=life_traces + [edge_trace],
252 title=
'Lineage tree of Cells',
256 margin=dict(b=20, l=5, r=5, t=40),
257 xaxis=dict(showgrid=
True, zeroline=
True, showticklabels=
True, range=[min(node_x), max(node_x)]),
258 yaxis=dict(showgrid=
True, zeroline=
True, showticklabels=
False, range=[min(node_y) - 1, max(node_y) + 1]),
265 fig, ax = plt.subplots(figsize=(10, plot_height / 100))
268 for line
in vertical_lines:
269 (x0, y0), (x1, y1), color = line
270 ax.plot([x0, x1], [y0, y0], color=color, linewidth=4, zorder=3)
273 for line
in horizontal_lines:
274 (x0, y0), (x1, y1) = line
275 ax.plot([x0, x1], [y0, y1], color=
'#888', linewidth=2, zorder=1)
280 if charac_mapping
is not None:
282 handles = [plt.Line2D([0], [0], color=color_mapping[char_val], lw=2, label=char_name)
283 for char_val, char_name
in charac_mapping.items()]
284 ax.legend(handles=handles, title=cell_attribute, loc=
'center left', bbox_to_anchor=(1, 0.5))
287 sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
289 cbar = plt.colorbar(sm, ax=ax, fraction=0.02, pad=0.04)
290 cbar.set_label(cell_attribute)
292 ax.set_title(
'Lineage tree of Cells')
293 ax.set_xlabel(
'Step')
294 ax.set_ylabel(
'Cells')
298 ax.set_xlim(min(node_x), max(node_x))
299 ax.set_ylim(min(node_y) - 1, max(node_y) + 1)
None drawCellLineage('Dict[str, Any]' data, str cell_attribute, bool use_plotly=False, str color_scale='Spectral_r', Optional[Dict[Any, str]] discrete_colors=None)