from collections.abc import Sequence
from typing import Any, Literal
import pandas as pd
from pydantic import (
ConfigDict,
Field,
field_validator,
model_validator,
)
from .base import BaseChart
from .enums import (
DateFormat,
GridDisplay,
GridLabelAlign,
GridLabelPosition,
NumberFormat,
PlotHeightMode,
ValueLabelDisplay,
ValueLabelPlacement,
)
from .models import (
AnnotationsMixin,
CustomRangeMixin,
CustomTicksMixin,
GridDisplayMixin,
GridFormatMixin,
RangeAnnotation,
TextAnnotation,
)
from .serializers import (
ColorCategory,
NegativeColor,
PlotHeight,
ValueLabels,
)
class MultipleColumnTextAnnotation(TextAnnotation):
"""Text annotation with additional fields specific to MultipleColumnChart.
This subclass extends TextAnnotation to support multi-panel charts where
annotations can be associated with specific plots/panels.
Attributes:
plot: Which plot/panel this annotation applies to (e.g., "Paris", "London")
showInAllPlots: Whether to show this annotation in all plots (defaults to False)
"""
#: Which plot/panel this annotation applies to
plot: str | None = Field(
default=None,
description="Which plot/panel this annotation applies to",
)
#: Whether to show this annotation in all plots
show_in_all_plots: bool = Field(
default=False,
alias="showInAllPlots",
description="Whether to show this annotation in all plots",
)
def serialize_model(self) -> dict:
"""Serialize the annotation to API format.
Extends the base TextAnnotation serialization to include:
- plot field inside the position object
- showInAllPlots field at the top level
Returns:
Dictionary in Datawrapper API format
"""
result = super().serialize_model()
# Add plot to position object if specified
if self.plot is not None:
result["position"]["plot"] = self.plot
# Always include showInAllPlots at top level
result["showInAllPlots"] = self.show_in_all_plots
return result
@classmethod
def deserialize_model(
cls, api_data: dict[str, dict[Any, Any]] | list[dict[Any, Any]] | None
) -> list[dict]:
"""Parse API response to extract MultipleColumnTextAnnotation data.
Handles the API format where:
- x, y, and plot are inside the position object
- showInAllPlots is at the top level
- Connector lines with enabled=False are converted to None
Args:
api_data: API response data (dict with UUID keys or list format)
Returns:
List of dictionaries that can initialize MultipleColumnTextAnnotation instances
"""
if not api_data:
return []
# First, call parent class to handle base fields including connector line logic
base_result = TextAnnotation.deserialize_model(api_data)
# Now enhance with MultipleColumnChart-specific fields
result = []
for i, anno_dict in enumerate(base_result):
# Extract plot from the original API data
if isinstance(api_data, dict):
# Find the original annotation data by ID (if present)
anno_id = anno_dict.get("id")
original_data = api_data.get(anno_id, {}) if anno_id else {}
else:
# For list format, use index to find original data
original_data = api_data[i] if i < len(api_data) else {}
# Extract plot from position object
position = original_data.get("position", {})
if isinstance(position, dict) and "plot" in position:
anno_dict["plot"] = position["plot"]
# Extract showInAllPlots (defaults to False for text annotations)
anno_dict["show_in_all_plots"] = original_data.get("showInAllPlots", False)
result.append(anno_dict)
return result
class MultipleColumnRangeAnnotation(RangeAnnotation):
"""Range annotation with additional fields specific to MultipleColumnChart.
This subclass extends RangeAnnotation to support multi-panel charts where
annotations can be associated with specific plots/panels.
Attributes:
plot: Which plot/panel this annotation applies to (e.g., "Paris", "London")
showInAllPlots: Whether to show this annotation in all plots (defaults to True)
"""
#: Which plot/panel this annotation applies to
plot: str | None = Field(
default=None,
description="Which plot/panel this annotation applies to",
)
#: Whether to show this annotation in all plots
show_in_all_plots: bool = Field(
default=True,
alias="showInAllPlots",
description="Whether to show this annotation in all plots",
)
def serialize_model(self) -> dict:
"""Serialize the annotation to API format.
Extends the base RangeAnnotation serialization to include:
- plot field inside the position object
- showInAllPlots field at the top level
Returns:
Dictionary in Datawrapper API format
"""
result = super().serialize_model()
# Add plot to position object if specified
if self.plot is not None:
result["position"]["plot"] = self.plot
# Always include showInAllPlots at top level
result["showInAllPlots"] = self.show_in_all_plots
return result
@classmethod
def deserialize_model(
cls, api_data: dict[str, dict[Any, Any]] | list[dict[Any, Any]] | None
) -> list[dict]:
"""Parse API response to extract MultipleColumnRangeAnnotation data.
Handles the API format where:
- plot is inside the position object
- showInAllPlots is at the top level
Args:
api_data: API response data (dict with UUID keys or list format)
Returns:
List of dictionaries that can initialize MultipleColumnRangeAnnotation instances
"""
if not api_data:
return []
result = []
# Handle dict format (UUID keys from API)
if isinstance(api_data, dict):
items_to_process = list(api_data.items())
else:
# Handle list format - generate temporary IDs
items_to_process = [(f"temp-{i}", anno) for i, anno in enumerate(api_data)]
for anno_id, anno_data in items_to_process:
# Extract position data
position = anno_data.get("position", {})
plot = position.get("plot") if isinstance(position, dict) else None
x0 = position.get("x0") if isinstance(position, dict) else None
x1 = position.get("x1") if isinstance(position, dict) else None
y0 = position.get("y0") if isinstance(position, dict) else None
y1 = position.get("y1") if isinstance(position, dict) else None
# Extract showInAllPlots (defaults to True)
show_in_all = anno_data.get("showInAllPlots", True)
# Build annotation dict with id
anno_dict = {**anno_data, "id": anno_id}
# Add position fields
if x0 is not None:
anno_dict["x0"] = x0
if x1 is not None:
anno_dict["x1"] = x1
if y0 is not None:
anno_dict["y0"] = y0
if y1 is not None:
anno_dict["y1"] = y1
# Add MultipleColumnChart-specific fields
if plot is not None:
anno_dict["plot"] = plot
anno_dict["show_in_all_plots"] = show_in_all
result.append(anno_dict)
return result
class MultipleColumnXRangeAnnotation(MultipleColumnRangeAnnotation):
"""A horizontal range annotation for MultipleColumnChart.
This is a convenience class that automatically sets type="x" and display="range",
and validates that both x0 and x1 are provided.
"""
def __init__(self, **data):
data.setdefault("type", "x")
data.setdefault("display", "range")
super().__init__(**data)
@model_validator(mode="after")
def validate_x_positions_required(self) -> "MultipleColumnXRangeAnnotation":
if self.x0 is None or self.x1 is None:
raise ValueError(
"MultipleColumnXRangeAnnotation requires both x0 and x1 to be set"
)
return self
class MultipleColumnYRangeAnnotation(MultipleColumnRangeAnnotation):
"""A vertical range annotation for MultipleColumnChart.
This is a convenience class that automatically sets type="y" and display="range",
and validates that both y0 and y1 are provided.
"""
def __init__(self, **data):
data.setdefault("type", "y")
data.setdefault("display", "range")
super().__init__(**data)
@model_validator(mode="after")
def validate_y_positions_required(self) -> "MultipleColumnYRangeAnnotation":
if self.y0 is None or self.y1 is None:
raise ValueError(
"MultipleColumnYRangeAnnotation requires both y0 and y1 to be set"
)
return self
class MultipleColumnXLineAnnotation(MultipleColumnRangeAnnotation):
"""A vertical line annotation for MultipleColumnChart.
This is a convenience class that automatically sets type="x" and display="line",
and validates that x0 is provided.
"""
def __init__(self, **data):
data.setdefault("type", "x")
data.setdefault("display", "line")
super().__init__(**data)
@model_validator(mode="after")
def validate_x_position_required(self) -> "MultipleColumnXLineAnnotation":
if self.x0 is None:
raise ValueError("MultipleColumnXLineAnnotation requires x0 to be set")
return self
class MultipleColumnYLineAnnotation(MultipleColumnRangeAnnotation):
"""A horizontal line annotation for MultipleColumnChart.
This is a convenience class that automatically sets type="y" and display="line",
and validates that y0 is provided.
"""
def __init__(self, **data):
data.setdefault("type", "y")
data.setdefault("display", "line")
super().__init__(**data)
@model_validator(mode="after")
def validate_y_position_required(self) -> "MultipleColumnYLineAnnotation":
if self.y0 is None:
raise ValueError("MultipleColumnYLineAnnotation requires y0 to be set")
return self
[docs]
class MultipleColumnChart(
GridDisplayMixin,
GridFormatMixin,
CustomRangeMixin,
CustomTicksMixin,
AnnotationsMixin,
BaseChart,
):
"""A base class for the Datawrapper API's multiple column chart.
Note: This chart uses MultipleColumnTextAnnotation and MultipleColumnRangeAnnotation
for annotations, which extend the base annotation classes with plot-specific fields.
The parent AnnotationsMixin fields accept these subclasses automatically.
"""
model_config = ConfigDict(
populate_by_name=True,
strict=True,
validate_assignment=True,
validate_default=True,
use_enum_values=True,
json_schema_extra={
"examples": [
{
"chart-type": "multiple-columns",
"title": "Regional Sales Comparison",
"data": pd.DataFrame(
{
"Year": [2020, 2021, 2022, 2023],
"North": [100, 110, 120, 130],
"South": [90, 95, 100, 105],
"East": [80, 85, 90, 95],
}
),
"grid_column": 3,
"grid_row_height": 140,
}
]
},
)
#: The type of datawrapper chart to create
chart_type: Literal["multiple-columns"] = Field(
default="multiple-columns",
alias="chart-type",
description="The type of datawrapper chart to create",
)
#: Panels configuration
panels: list[dict[str, Any]] = Field(
default_factory=list,
description="Panel configurations for the chart",
)
#
# Layout
#
#: Fixed vs auto layout. If minimumWidth is selected it's auto layout
grid_layout: Literal["fixedCount", "minimumWidth"] = Field(
default="fixedCount",
alias="grid-layout",
description="Fixed vs auto layout",
)
#: How the panels are laid out on desktop
grid_column: int = Field(
default=2,
alias="grid-column",
description="How the panels are laid out on desktop",
)
#: How the panels are laid out on mobile
grid_column_mobile: int = Field(
default=2,
alias="grid-column-mobile",
description="How the panels are laid out on mobile",
)
#: How the panels are laid out - only changed if layout is not fixedCount
grid_column_width: int = Field(
default=200,
alias="grid-column-width",
description="Minimum width for auto layout",
)
#: Height of rows
grid_row_height: int = Field(
default=140,
alias="grid-row-height",
description="Height of rows",
)
#: Sort of the panels
sort: bool = Field(
default=False,
description="Whether to sort the panels",
)
#: Whether to sort the panels in reverse order
sort_reverse: bool = Field(
default=False,
alias="sort-reverse",
description="Whether to sort the panels in reverse order",
)
#: How to sort the panels
sort_by: Literal["start", "end", "range", "diff", "change", "title"] = Field(
default="end",
alias="sort-by",
description="How to sort the panels",
)
#
# Horizontal axis
#
#: The labeling of the x axis
x_grid_labels: Literal["on", "off"] = Field(
default="on",
alias="x-grid-labels",
description="The labeling of the x axis",
)
#: x_grid for panels
x_grid_all: GridDisplay | str = Field(
default="off",
alias="x-grid-all",
description="x_grid for panels",
)
#
# Vertical axis
#
#: The labeling of the y grid labels
y_grid_labels: GridLabelPosition | str = Field(
default="outside",
alias="y-grid-labels",
description="The labeling of the y grid labels",
)
#: Which side to put the y-axis labels on
y_grid_label_align: GridLabelAlign | str = Field(
default="left",
alias="y-grid-label-align",
description="Which side to put the y-axis labels on",
)
#
# Appearance
#
#: The base color for the chart (palette index or hex string)
base_color: str | int = Field(
default=0,
alias="base-color",
description="The base color for the chart (palette index or hex string)",
)
#: The negative color to use, if you want one
negative_color: str | None = Field(
default=None,
alias="negative-color",
description="The negative color to use, if you want one",
)
#: A mapping of layer names to colors
color_category: dict[str, str] = Field(
default_factory=dict,
alias="color-category",
description="A mapping of layer names to colors",
)
#: The padding between bars as a percentage of the bar width
bar_padding: int = Field(
default=30,
alias="bar-padding",
description="The padding between bars as a percentage of the bar width",
)
#: How to set the plot height
plot_height_mode: PlotHeightMode | str = Field(
default="fixed",
alias="plot-height-mode",
description="How to set the plot height",
)
#: The fixed height of the plot
plot_height_fixed: int = Field(
default=300,
alias="plot-height-fixed",
description="The fixed height of the plot",
)
#: The ratio of the plot height
plot_height_ratio: float = Field(
default=0.5,
alias="plot-height-ratio",
description="The ratio of the plot height",
)
[docs]
@field_validator("plot_height_mode")
@classmethod
def validate_plot_height_mode(cls, v: PlotHeightMode | str) -> PlotHeightMode | str:
"""Validate that plot_height_mode is a valid PlotHeightMode value."""
if isinstance(v, str):
valid_values = [e.value for e in PlotHeightMode]
if v not in valid_values:
raise ValueError(f"Invalid value: {v}. Must be one of {valid_values}")
return v
[docs]
@field_validator("text_annotations", mode="before")
@classmethod
def convert_text_annotations(
cls, v: Sequence[MultipleColumnTextAnnotation | dict[Any, Any]]
) -> list[MultipleColumnTextAnnotation]:
"""Convert dict annotations to MultipleColumnTextAnnotation instances.
This ensures that when annotations are passed as dicts, they are converted
to the proper annotation class so that serialize_model() includes the plot field.
"""
if not v:
return []
result = []
for item in v:
if isinstance(item, dict):
# Convert dict to MultipleColumnTextAnnotation instance
result.append(MultipleColumnTextAnnotation(**item))
else:
# Already an instance, keep as is
result.append(item)
return result
[docs]
@field_validator("range_annotations", mode="before")
@classmethod
def convert_range_annotations(
cls, v: Sequence[MultipleColumnRangeAnnotation | dict[Any, Any]]
) -> list[MultipleColumnRangeAnnotation]:
"""Convert dict annotations to MultipleColumnRangeAnnotation instances.
This ensures that when annotations are passed as dicts, they are converted
to the proper annotation class so that serialize_model() includes the plot field.
"""
if not v:
return []
result = []
for item in v:
if isinstance(item, dict):
# Convert dict to MultipleColumnRangeAnnotation instance
result.append(MultipleColumnRangeAnnotation(**item))
else:
# Already an instance, keep as is
result.append(item)
return result
#
# Tooltips
#
#: Whether or not to show tooltips on hover
show_tooltips: bool = Field(
default=True,
alias="show-tooltips",
description="Whether or not to show tooltips on hover",
)
#: Whether to show tooltips synchronously across all panels
sync_multiple_tooltips: bool = Field(
default=False,
alias="syncMultipleTooltips",
description="Whether to show tooltips synchronously across all panels",
)
#: The format for the y-axis values in tooltips (use DateFormat or NumberFormat enum or custom format strings)
tooltip_number_format: DateFormat | NumberFormat | str = Field(
default="",
alias="tooltip-number-format",
description="The format for the y-axis values in tooltips. Use DateFormat for temporal data, NumberFormat for numeric data, or provide custom format strings.",
)
#
# Labels
#
#: Whether or not to column labels the same as the column
label_colors: bool = Field(
default=False,
alias="label-colors",
description="Whether or not to column labels the same as the column",
)
#: Whether or not to show the color key above the chart
show_color_key: bool = Field(
default=False,
alias="show-color-key",
description="Whether or not to show the color key above the chart",
)
#: Whether or not to show value labels
show_value_labels: ValueLabelDisplay | str = Field(
default="off",
alias="show-value-labels",
description="Whether or not to show value labels",
)
#: How to format the value labels (use DateFormat or NumberFormat enum or custom format strings)
value_labels_format: DateFormat | NumberFormat | str = Field(
default="",
alias="value-labels-format",
description="How to format the value labels. Use DateFormat for temporal data, NumberFormat for numeric data, or provide custom format strings.",
)
#: Where to place the value labels
value_labels_placement: ValueLabelPlacement | str = Field(
default="outside",
alias="value-labels-placement",
description="Where to place the value labels",
)
#: The amount of margin to leave for the right hand side for labels
label_margin: int = Field(
default=0,
alias="label-margin",
description="The amount of margin to leave for the right hand side for labels. Zero is automatically calculated.",
)
#: Show label for all panels
x_grid_label_all: bool = Field(
default=False,
alias="x-grid-label-all",
description="Show label for all panels",
)
[docs]
def serialize_model(self) -> dict:
"""Serialize the model to a dictionary."""
# Call the parent class's serialize_model method
model = super().serialize_model()
# Add chart specific properties to visualize section
visualize_data = {
# Layout
"gridLayout": self.grid_layout,
"gridColumnCount": self.grid_column,
"gridColumnCountMobile": self.grid_column_mobile,
"gridColumnMinWidth": self.grid_column_width,
"gridRowHeightFixed": self.grid_row_height,
"sort": {
"enabled": self.sort,
"reverse": self.sort_reverse,
"by": self.sort_by,
},
# Horizontal and vertical axis (from mixins)
**self._serialize_grid_config(),
**self._serialize_grid_format(),
**self._serialize_custom_range(),
**self._serialize_custom_ticks(),
# Horizontal axis (chart-specific)
"x-grid-labels": self.x_grid_labels,
"x-grid": self.x_grid_all,
"grid-lines-x": {
"type": "" if self.x_grid == "off" else self.x_grid,
"enabled": self.x_grid != "off",
},
# Vertical axis (chart-specific)
"grid-lines": self.y_grid,
"yAxisLabels": {
"enabled": self.y_grid_labels != "off",
"alignment": self.y_grid_label_align,
"placement": "" if self.y_grid_labels == "off" else self.y_grid_labels,
},
# Appearance
"base-color": self.base_color,
"negativeColor": NegativeColor.serialize(self.negative_color),
"bar-padding": self.bar_padding,
"color-category": ColorCategory.serialize(self.color_category),
"color-by-column": bool(self.color_category),
**PlotHeight.serialize(
self.plot_height_mode,
self.plot_height_fixed,
self.plot_height_ratio,
),
"panels": {panel["column"]: panel for panel in self.panels},
# Tooltips
"show-tooltips": self.show_tooltips,
"syncMultipleTooltips": self.sync_multiple_tooltips,
"tooltip-number-format": self.tooltip_number_format,
# Labels
"show-color-key": self.show_color_key,
"label-colors": self.label_colors,
"label-margin": self.label_margin,
**ValueLabels.serialize(
self.show_value_labels,
self.value_labels_format,
placement=self.value_labels_placement,
chart_type="multiple-column",
),
"xGridLabelAllColumns": self.x_grid_label_all,
# Annotations
**self._serialize_annotations(
text_annotation_class=MultipleColumnTextAnnotation,
range_annotation_class=MultipleColumnRangeAnnotation,
),
}
model["metadata"]["visualize"].update(visualize_data)
# Return the serialized data
return model
[docs]
@classmethod
def deserialize_model(cls, api_response: dict[str, Any]) -> dict[str, Any]:
"""Parse Datawrapper API response including multiple column chart specific fields.
Args:
api_response: The JSON response from the chart metadata endpoint
chart_data: The CSV data from the chart data endpoint
Returns:
Dictionary that can be used to initialize the MultipleColumnChart model
"""
# Call parent to get base fields
init_data = super().deserialize_model(api_response)
# Extract multiple-column-specific sections
metadata = api_response.get("metadata", {})
visualize = metadata.get("visualize", {})
# Layout
if "gridLayout" in visualize:
init_data["grid_layout"] = visualize["gridLayout"]
if "gridColumnCount" in visualize:
init_data["grid_column"] = visualize["gridColumnCount"]
if "gridColumnCountMobile" in visualize:
init_data["grid_column_mobile"] = visualize["gridColumnCountMobile"]
if "gridColumnMinWidth" in visualize:
init_data["grid_column_width"] = visualize["gridColumnMinWidth"]
if "gridRowHeightFixed" in visualize:
init_data["grid_row_height"] = visualize["gridRowHeightFixed"]
# Parse sort object
sort_obj = visualize.get("sort", {})
if isinstance(sort_obj, dict):
init_data["sort"] = sort_obj.get("enabled", False)
init_data["sort_reverse"] = sort_obj.get("reverse", False)
init_data["sort_by"] = sort_obj.get("by", "end")
else:
init_data["sort"] = False
init_data["sort_reverse"] = False
init_data["sort_by"] = "end"
# Horizontal and vertical axis (from mixins)
init_data.update(cls._deserialize_grid_config(visualize))
init_data.update(cls._deserialize_grid_format(visualize))
init_data.update(cls._deserialize_custom_range(visualize))
init_data.update(cls._deserialize_custom_ticks(visualize))
# Horizontal axis (chart-specific)
if "x-grid-labels" in visualize:
init_data["x_grid_labels"] = visualize["x-grid-labels"]
if "x-grid" in visualize:
init_data["x_grid_all"] = visualize["x-grid"]
# Parse grid-lines-x
grid_lines_x = visualize.get("grid-lines-x", {})
if isinstance(grid_lines_x, dict):
if grid_lines_x.get("enabled", False):
init_data["x_grid"] = grid_lines_x.get("type", "ticks")
else:
init_data["x_grid"] = "off"
else:
init_data["x_grid"] = "off"
# Vertical axis (chart-specific)
# Parse grid-lines (can be bool or string "show")
if "grid-lines" in visualize:
grid_lines_val = visualize["grid-lines"]
if isinstance(grid_lines_val, str):
init_data["y_grid"] = grid_lines_val == "show"
else:
init_data["y_grid"] = bool(grid_lines_val)
# Parse yAxisLabels - check both yAxisLabels object and y-grid-labels field
y_axis_labels = visualize.get("yAxisLabels", {})
if isinstance(y_axis_labels, dict) and y_axis_labels:
# If yAxisLabels object exists, use it
if y_axis_labels.get("enabled", True):
init_data["y_grid_labels"] = y_axis_labels.get("placement", "outside")
else:
init_data["y_grid_labels"] = "off"
init_data["y_grid_label_align"] = y_axis_labels.get("alignment", "left")
else:
# Fall back to y-grid-labels field
if "y-grid-labels" in visualize:
init_data["y_grid_labels"] = visualize["y-grid-labels"]
if "y-grid-label-align" in visualize:
init_data["y_grid_label_align"] = visualize["y-grid-label-align"]
# Appearance
if "base-color" in visualize:
init_data["base_color"] = visualize["base-color"]
if "bar-padding" in visualize:
init_data["bar_padding"] = visualize["bar-padding"]
# Parse color-category using utility
color_data = ColorCategory.deserialize(visualize.get("color-category"))
init_data["color_category"] = color_data["color_category"]
# Parse negativeColor
if "negativeColor" in visualize:
init_data["negative_color"] = NegativeColor.deserialize(
visualize["negativeColor"]
)
# Plot height
init_data.update(PlotHeight.deserialize(visualize))
# Parse panels (dict to list)
panels_obj = visualize.get("panels", {})
if isinstance(panels_obj, dict):
init_data["panels"] = [
{"column": col, **config} for col, config in panels_obj.items()
]
else:
init_data["panels"] = []
# Tooltips
if "show-tooltips" in visualize:
init_data["show_tooltips"] = visualize["show-tooltips"]
if "syncMultipleTooltips" in visualize:
init_data["sync_multiple_tooltips"] = visualize["syncMultipleTooltips"]
if "tooltip-number-format" in visualize:
init_data["tooltip_number_format"] = visualize["tooltip-number-format"]
# Labels
if "label-colors" in visualize:
init_data["label_colors"] = visualize["label-colors"]
if "show-color-key" in visualize:
init_data["show_color_key"] = visualize["show-color-key"]
if "label-margin" in visualize:
init_data["label_margin"] = visualize["label-margin"]
if "xGridLabelAllColumns" in visualize:
init_data["x_grid_label_all"] = visualize["xGridLabelAllColumns"]
# Parse valueLabels using utility
init_data.update(
ValueLabels.deserialize(visualize, chart_type="multiple-column")
)
# Annotations
init_data.update(
cls._deserialize_annotations(
visualize,
text_annotation_class=MultipleColumnTextAnnotation,
range_annotation_class=MultipleColumnRangeAnnotation,
)
)
return init_data