Source code for datawrapper.charts.arrow

from typing import Any, Literal

import pandas as pd
from pydantic import ConfigDict, Field, field_validator

from .base import BaseChart
from .enums import DateFormat, NumberFormat, ReplaceFlagsType
from .serializers import ColorCategory, CustomRange, ReplaceFlags


[docs] class ArrowChart(BaseChart): """A base class for the Datawrapper API's arrow chart.""" model_config = ConfigDict( populate_by_name=True, strict=True, validate_assignment=True, validate_default=True, use_enum_values=True, json_schema_extra={ "examples": [ { "chart-type": "d3-arrow-plot", "title": "Population Change by Region", "source_name": "Census Bureau", "data": pd.DataFrame( { "Region": ["North", "South", "East", "West"], "2020": [100, 150, 120, 90], "2023": [110, 160, 115, 95], } ), "start_column": "2020", "end_column": "2023", "thick_arrows": True, } ] }, ) #: The type of datawrapper chart to create chart_type: Literal["d3-arrow-plot"] = Field( default="d3-arrow-plot", alias="chart-type", description="The type of datawrapper chart to create", ) # # Customize arrows # #: The base color for the arrows base_color: str | int = Field( default=0, alias="base-color", description="The base color for the arrows", ) #: A mapping of layer names to colors color_category: dict[str, str] = Field( default_factory=dict, alias="color-category", description="A mapping of layer names to colors", ) #: Thicken the arrows thick_arrows: bool = Field( default=True, alias="thick-arrows", description="Thicken the arrows", ) #: Show the y-axis grid lines y_grid: str = Field( default="on", alias="y-grid", description="Show the y-axis grid lines", ) #: Whether to replace country codes with flags (use ReplaceFlagsType enum or raw string) replace_flags: ReplaceFlagsType | str = Field( default="off", alias="replace-flags", description="Whether to replace country codes with flags. Use ReplaceFlagsType enum for type safety or provide raw strings.", ) # # Sorting & ordering # #: Whether to sort the ranges sort_ranges: bool = Field( default=False, alias="sort-ranges", description="Whether to sort the ranges", ) #: How to sort the ranges sort_by: Literal["end", "start", "difference", "change"] = Field( default="end", alias="sort-by", description="How to sort the ranges", ) #: Reverse the order of the ranges reverse_order: bool = Field( default=False, alias="reverse-order", description="Reverse the order of the ranges", ) # # Labels & formatting # #: The number format for value labels (use DateFormat or NumberFormat enum or custom format strings) value_label_format: DateFormat | NumberFormat | str = Field( default="", alias="value-label-format", description="The number format for value labels. Use DateFormat for temporal data, NumberFormat for numeric data, or provide custom format strings.", ) #: The field you want to use for the value labels range_value_labels: str = Field( default="", alias="range-value-labels", description="The field you want to use for the value labels", ) # # Axes # #: The custom range for the x axis custom_range: list[Any] | tuple[Any, Any] = Field( default_factory=lambda: ["", ""], alias="custom-range", description="The custom range for the x axis", ) #: The type of range on the x-axis range_extent: Literal["nice", "custom", "data"] = Field( default="nice", alias="range-extent", description="The type of range on the x-axis", ) #: The column that arrows should start at start_column: str | None = Field( default=None, description="The column that arrows should start at", ) #: The column that arrows should end at end_column: str | None = Field( default=None, description="The column that arrows should end at", ) #: The column to color by color_column: str | None = Field( default=None, description="The column to color by", ) #: The column to label by label_column: str | None = Field( default=None, description="The column to label by", ) #: The column to group arrows by groups_column: str | None = Field( default=None, description="The column to group arrows by", ) # # Features # #: Label on the first arrow that shows column names arrow_key: bool = Field( default=False, alias="arrow-key", description="Label on the first arrow that shows column names", )
[docs] @field_validator("replace_flags") @classmethod def validate_replace_flags( cls, v: ReplaceFlagsType | str ) -> ReplaceFlagsType | str: """Validate that replace_flags is a valid ReplaceFlagsType value.""" if isinstance(v, str): valid_values = [e.value for e in ReplaceFlagsType] if v not in valid_values: raise ValueError( f"Invalid replace_flags: {v}. Must be one of {valid_values}" ) return v
[docs] def serialize_model(self) -> dict: """Serialize the model to a dictionary.""" # Call the parent class's serialize_model method model = super().serialize_model() # Add chart specific properties to visualize section model["metadata"]["visualize"].update( { "y-grid": self.y_grid, "reverse-order": self.reverse_order, "thick-arrows": self.thick_arrows, "base-color": self.base_color, "color-category": ColorCategory.serialize(self.color_category), "range-value-labels": self.range_value_labels, "sort-range": { "by": self.sort_by, "enabled": self.sort_ranges, }, "custom-range": CustomRange.serialize(self.custom_range), "range-extent": self.range_extent, "value-label-format": self.value_label_format, "color-by-column": bool(self.color_category), "group-by-column": self.groups_column is not None, "replace-flags": ReplaceFlags.serialize(self.replace_flags), "show-arrow-key": self.arrow_key, } ) # Add axes section (separate from visualize) - only include non-None fields axes_dict = {} if self.start_column is not None: axes_dict["start"] = self.start_column if self.end_column is not None: axes_dict["end"] = self.end_column if self.color_column is not None: axes_dict["colors"] = self.color_column if self.label_column is not None: axes_dict["labels"] = self.label_column if self.groups_column is not None: axes_dict["groups"] = self.groups_column # Only add axes section if there are fields to include if axes_dict: model["metadata"]["axes"] = axes_dict # Return the serialized data return model
[docs] @classmethod def deserialize_model(cls, api_response: dict[str, Any]) -> dict[str, Any]: """Parse Datawrapper API response including arrow chart specific fields. Args: api_response: The JSON response from the chart metadata endpoint Returns: Dictionary that can be used to initialize the ArrowChart model """ # Call parent to get base fields init_data = super().deserialize_model(api_response) # Extract arrow-specific sections metadata = api_response.get("metadata", {}) visualize = metadata.get("visualize", {}) axes = metadata.get("axes", {}) # Customize arrows if "y-grid" in visualize: init_data["y_grid"] = visualize["y-grid"] if "reverse-order" in visualize: init_data["reverse_order"] = visualize["reverse-order"] if "thick-arrows" in visualize: init_data["thick_arrows"] = visualize["thick-arrows"] # Base color if "base-color" in visualize: init_data["base_color"] = visualize["base-color"] # Parse color-category using utility color_data = ColorCategory.deserialize(visualize.get("color-category")) init_data["color_category"] = color_data["color_category"] # Labels & formatting if "range-value-labels" in visualize: init_data["range_value_labels"] = visualize["range-value-labels"] if "value-label-format" in visualize: init_data["value_label_format"] = visualize["value-label-format"] # Sorting & ordering sort_range_obj = visualize.get("sort-range", {}) if isinstance(sort_range_obj, dict): init_data["sort_by"] = sort_range_obj.get("by", "end") init_data["sort_ranges"] = sort_range_obj.get("enabled", False) else: init_data["sort_by"] = "end" init_data["sort_ranges"] = False # Parse replace-flags using utility if "replace-flags" in visualize: init_data["replace_flags"] = ReplaceFlags.deserialize( visualize["replace-flags"] ) # Axes init_data["custom_range"] = CustomRange.deserialize( visualize.get("custom-range") ) if "range-extent" in visualize: init_data["range_extent"] = visualize["range-extent"] # Parse axes section if "start" in axes: init_data["start_column"] = axes["start"] if "end" in axes: init_data["end_column"] = axes["end"] if "colors" in axes: init_data["color_column"] = axes["colors"] if "labels" in axes: init_data["label_column"] = axes["labels"] if "groups" in axes: init_data["groups_column"] = axes["groups"] # Features if "show-arrow-key" in visualize: init_data["arrow_key"] = visualize["show-arrow-key"] return init_data