from typing import Any, Literal
import pandas as pd
from pydantic import ConfigDict, Field, field_validator
from .base import BaseChart
from .enums import (
DateFormat,
NumberFormat,
PlotHeightMode,
RegressionMethod,
ScatterAxisPosition,
ScatterGridLines,
ScatterShape,
ScatterSize,
)
from .models import AnnotationsMixin
from .serializers import ColorCategory, PlotHeight
[docs]
class ScatterPlot(AnnotationsMixin, BaseChart):
"""A base class for the Datawrapper API's scatter plot chart."""
model_config = ConfigDict(
populate_by_name=True,
strict=True,
validate_assignment=True,
validate_default=True,
use_enum_values=True,
json_schema_extra={
"examples": [
{
"chart-type": "d3-scatter-plot",
"title": "GDP vs Life Expectancy",
"data": pd.DataFrame(
{
"Country": ["USA", "China", "India"],
"GDP": [60000, 15000, 7000],
"Life Expectancy": [79, 76, 69],
"Population": [330, 1400, 1380],
}
),
"x_column": "GDP",
"y_column": "Life Expectancy",
"size_column": "Population",
}
]
},
)
#: The type of datawrapper chart to create
chart_type: Literal["d3-scatter-plot"] = Field(
default="d3-scatter-plot",
alias="chart-type",
description="The type of datawrapper chart to create",
)
#
# Horizontal axis
#
#: The column to use for the x-axis
x_column: str | None = Field(
default=None,
alias="x-column",
description="The column to use for the x-axis",
)
#: The range for the x-axis
x_range: tuple[Any, Any] | list[Any] = Field(
default=("", ""),
alias="x-range",
description="The range for the x-axis",
)
#: Custom ticks for the x-axis
x_ticks: list[Any] = Field(
default_factory=list,
alias="x-ticks",
description="Custom ticks for the x-axis",
)
#: Set the x-axis on a Logarithmic scale
x_log: bool = Field(
default=False,
alias="x-log",
description="Set the x-axis on a Logarithmic scale",
)
#: Format of the x-axis ticks (use DateFormat or NumberFormat enum or custom format strings)
x_format: DateFormat | NumberFormat | str = Field(
default="",
alias="x-format",
description="Format of the x-axis ticks. Use DateFormat for temporal data, NumberFormat for numeric data, or provide custom format strings.",
)
#: The position of the x-axis ticks and labels
x_position: ScatterAxisPosition | str = Field(
default="bottom",
alias="x-position",
description="The position of the x-axis ticks and labels",
)
#: How to display x-axis grid lines
x_grid_lines: ScatterGridLines | str = Field(
default="on",
alias="x-grid-lines",
description="How to display x-axis grid lines",
)
#
# Vertical axis
#
#: The column to use for the y-axis
y_column: str | None = Field(
default=None,
alias="y-column",
description="The column to use for the y-axis",
)
#: The range for the y-axis
y_range: tuple[Any, Any] | list[Any] = Field(
default=("", ""),
alias="y-range",
description="The range for the y-axis",
)
#: Custom ticks for the y-axis
y_ticks: list[Any] = Field(
default_factory=list,
alias="y-ticks",
description="Custom ticks for the y-axis",
)
#: Set the y-axis on a Logarithmic scale
y_log: bool = Field(
default=False,
alias="y-log",
description="Set the y-axis on a Logarithmic scale",
)
#: Format of the y-axis ticks (use DateFormat or NumberFormat enum or custom format strings)
y_format: DateFormat | NumberFormat | str = Field(
default="",
alias="y-format",
description="Format of the y-axis ticks. Use DateFormat for temporal data, NumberFormat for numeric data, or provide custom format strings.",
)
#: The position of the y-axis ticks and labels
y_position: ScatterAxisPosition | str = Field(
default="bottom",
alias="y-position",
description="The position of the y-axis ticks and labels",
)
#: How to display y-axis grid lines
y_grid_lines: ScatterGridLines | str = Field(
default="on",
alias="y-grid-lines",
description="How to display y-axis grid lines",
)
#
# Color
#
#: The default color (palette index or hex string)
base_color: str | int = Field(
default=0,
alias="base-color",
description="The default color (palette index or hex string)",
)
#: The opacity of the points
opacity: float = Field(
default=1.0,
description="The opacity of the points",
)
#: Whether to show an outline stroke on points
outlines: bool = Field(
default=False,
description="Whether to show an outline stroke on points",
)
#: The color of the outline stroke
color_outline: str = Field(
default="#000000",
alias="color-outline",
description="The color of the outline stroke",
)
#: Whether to show the color key
show_color_key: bool = Field(
default=False,
alias="show-color-key",
description="Whether to show the color key",
)
#: The column with the color for the points
color_column: str = Field(
default="",
alias="color-column",
description="The column with the color for the points",
)
#: A mapping of layer names to colors
color_category: dict[str, str] = Field(
default_factory=dict,
alias="color-category",
description="A mapping of layer names to colors",
)
#: Dictionary mapping category names to their display labels in the color legend
category_labels: dict[str, str] = Field(
default_factory=dict,
alias="category-labels",
description="Dictionary mapping category names to their display labels in the color legend",
)
#: List defining the order in which categories appear in the chart and legend
category_order: list[str] = Field(
default_factory=list,
alias="category-order",
description="List defining the order in which categories appear in the chart and legend",
)
#: A list of columns to exclude from the color key
exclude_from_color_key: list[str] = Field(
default_factory=list,
alias="exclude-from-color-key",
description="A list of columns to exclude from the color key",
)
#
# Size
#
#: How the size is set
size: ScatterSize | str = Field(
default="fixed",
description="How the size is set",
)
#: The fixed size, if it's set that way
fixed_size: int | float = Field(
default=5,
alias="fixed-size",
description="The fixed size, if it's set that way",
)
#: The dynamic column to size with
size_column: str | None = Field(
default=None,
alias="size-column",
description="The dynamic column to size with",
)
#: The maximum size of a dynamic setting
max_size: int | float = Field(
default=25,
alias="max-size",
description="The maximum size of a dynamic setting",
)
#: Whether to reduce the size on mobile phones
responsive_symbol_size: bool = Field(
default=False,
alias="responsive-symbol-size",
description="Whether to reduce the size on mobile phones",
)
#: Whether to show the size legend
show_size_legend: bool = Field(
default=False,
alias="show-size-legend",
description="Whether to show the size legend",
)
#: Where to show the size legend
size_legend_position: Literal[
"above",
"below",
"inside-left-top",
"inside-center-top",
"inside-right-top",
"inside-left-bottom",
"inside-center-bottom",
"inside-right-bottom",
] = Field(
default="above",
alias="size-legend-position",
description="Where to show the size legend",
)
#: The percentage offset of the size legend on the x-axis
legend_offset_x: int = Field(
default=0,
alias="legend-offset-x",
description="The percentage offset of the size legend on the x-axis",
)
#: The percentage offset of the size legend on the y-axis
legend_offset_y: int = Field(
default=0,
alias="legend-offset-y",
description="The percentage offset of the size legend on the y-axis",
)
#: How to format the values of the size legend
size_legend_values_format: Literal["auto", "custom"] = Field(
default="auto",
alias="size-legend-values-format",
description="How to format the values of the size legend",
)
#: The list of values to include in the size legend
size_legend_values: list[int | float] = Field(
default_factory=list,
alias="size-legend-values",
description="The list of values to include in the size legend",
)
#: Where to put the value labels on the size legend
size_legend_label_position: Literal["below", "right"] = Field(
default="below",
alias="size-legend-label-position",
description="Where to put the value labels on the size legend",
)
#: How to format the size legend label values (use DateFormat or NumberFormat enum or custom format strings)
size_legend_label_format: DateFormat | NumberFormat | str = Field(
default="",
alias="size-legend-label-format",
description="How to format the size legend label values. Use DateFormat for temporal data, NumberFormat for numeric data, or provide custom format strings.",
)
#: Whether to show a size legend title
size_legend_title_enabled: bool = Field(
default=False,
alias="size-legend-title-enabled",
description="Whether to show a size legend title",
)
#: What to put in the size legend title
size_legend_title: str = Field(
default="",
alias="size-legend-title",
description="What to put in the size legend title",
)
#: Where to put the size legend title
size_legend_title_position: Literal["left", "right", "above", "below"] = Field(
default="left",
alias="size-legend-title-position",
description="Where to put the size legend title",
)
#: The maximum width of the size legend title in pixels
size_legend_title_width: int | float = Field(
default=200,
alias="size-legend-title-width",
description="The maximum width of the size legend title in pixels",
)
#
# Shape
#
#: How to set the shape
shape: ScatterShape | str = Field(
default="fixed",
description="How to set the shape",
)
#: Options for the shape
fixed_shape: ScatterShape | str = Field(
default="symbolCircle",
alias="fixed-shape",
description="Options for the shape",
)
#: The columns to get the variable shapes for
shape_column: str | None = Field(
default=None,
alias="shape-column",
description="The columns to get the variable shapes for",
)
#
# Trend line
#
#: Whether or not to show a regression line
regression: bool = Field(
default=False,
description="Whether or not to show a regression line",
)
#: The regression method to use
regression_method: RegressionMethod | str = Field(
default="linear",
alias="regression-method",
description="The regression method to use",
)
#
# Appearance
#
#: How to set the plot height
plot_height_mode: PlotHeightMode | str = Field(
default="fixed",
alias="plot-height-mode",
description="How to set the plot height",
)
#: The fixed height of the plot
plot_height_fixed: int | float = Field(
default=300,
alias="plot-height-fixed",
description="The fixed height of the plot",
)
#: The ratio of the plot height
plot_height_ratio: float = Field(
default=0.5,
alias="plot-height-ratio",
description="The ratio of the plot height",
)
[docs]
@field_validator("plot_height_mode")
@classmethod
def validate_plot_height_mode(cls, v: PlotHeightMode | str) -> PlotHeightMode | str:
"""Validate that plot_height_mode is a valid PlotHeightMode value."""
if isinstance(v, str):
valid_values = [e.value for e in PlotHeightMode]
if v not in valid_values:
raise ValueError(f"Invalid value: {v}. Must be one of {valid_values}")
return v
#
# Annotations
#
#: Add custom lines on the chart
custom_lines: str = Field(
default="",
alias="custom-lines",
description="Add custom lines on the chart",
)
#
# Labeling
#
#: The column to use for the labels
label_column: str | None = Field(
default=None,
alias="label-column",
description="The column to use for the labels",
)
#: Whether to automatically label symbols
auto_labels: bool = Field(
default=True,
alias="auto-labels",
description="Whether to automatically label symbols",
)
#: Values to add labels for
add_labels: list[Any] = Field(
default_factory=list,
alias="add-labels",
description="Values to add labels for",
)
#: Whether to highlight labeled symbols
highlight_labeled: bool = Field(
default=True,
alias="highlight-labeled",
description="Whether to highlight labeled symbols",
)
#
# Tooltips
#
#: Whether to show tooltips
tooltip_enabled: bool = Field(
default=True,
alias="tooltip-enabled",
description="Whether to show tooltips",
)
#: Tooltip title format
tooltip_title: str = Field(
default="",
alias="tooltip-title",
description="Tooltip title format",
)
#: Tooltip body format
tooltip_body: str = Field(
default="",
alias="tooltip-body",
description="Tooltip body format",
)
#: Whether the tooltip is sticky on click
tooltip_sticky: bool = Field(
default=False,
alias="tooltip-sticky",
description="Whether the tooltip is sticky on click",
)
[docs]
def serialize_model(self) -> dict:
"""Serialize the model to a dictionary."""
# Call the parent class's serialize_model method
model = super().serialize_model()
# Set the axes setting
axes = {}
if self.x_column:
axes["x"] = self.x_column
if self.y_column:
axes["y"] = self.y_column
if self.size_column:
axes["size"] = self.size_column
if self.shape_column:
axes["shape"] = self.shape_column
if self.label_column:
axes["labels"] = self.label_column
if self.color_column:
axes["color"] = self.color_column
# Add axes to metadata
model["metadata"]["axes"] = axes
# Add chart specific properties to visualize section
model["metadata"]["visualize"].update(
{
# Horizontal axis
"x-axis": {
"log": self.x_log,
"range": self.x_range,
"ticks": self.x_ticks,
},
"x-format": self.x_format,
"x-pos": self.x_position,
"x-grid-lines": self.x_grid_lines,
# Vertical axis
"y-axis": {
"log": self.y_log,
"range": self.y_range,
"ticks": self.y_ticks,
},
"y-format": self.y_format,
"y-pos": self.y_position,
"y-grid-lines": self.y_grid_lines,
# Colors
"base-color": self.base_color,
"opacity": self.opacity,
"outlines": self.outlines,
"color-outline": self.color_outline,
"show-color-key": self.show_color_key,
"color-category": ColorCategory.serialize(
self.color_category,
self.category_labels,
self.category_order,
self.exclude_from_color_key,
),
"color-by-column": bool(self.color_category),
# Size
"size": self.size,
"fixed-size": self.fixed_size,
"max-size": self.max_size,
"responsive-symbol-size": self.responsive_symbol_size,
"show-size-legend": self.show_size_legend,
"size-legend-position": self.size_legend_position,
"legend-offset-x": self.legend_offset_x,
"legend-offset-y": self.legend_offset_y,
"size-legend-values-setting": self.size_legend_values_format,
"size-legend-values": self.size_legend_values,
"size-legend-label-position": self.size_legend_label_position,
"size-legend-label-format": self.size_legend_label_format,
"size-legend-title-enabled": self.size_legend_title_enabled,
"size-legend-title": self.size_legend_title,
"size-legend-title-position": self.size_legend_title_position,
"size-legend-title-width": self.size_legend_title_width,
# Shapes
"shape": self.shape,
"fixed-shape": self.fixed_shape,
# Trend line
"regression": self.regression,
"regression-method": self.regression_method,
# Appearance
**PlotHeight.serialize(
self.plot_height_mode,
self.plot_height_fixed,
self.plot_height_ratio,
),
# Annotations
**self._serialize_annotations(),
"custom-lines": self.custom_lines,
# Labeling
"auto-labels": self.auto_labels,
"add-labels": self.add_labels,
"highlight-labeled": self.highlight_labeled,
# Tooltips
"tooltip": {
"body": self.tooltip_body,
"title": self.tooltip_title,
"sticky": self.tooltip_sticky,
"enabled": self.tooltip_enabled,
"migrated": True,
},
}
)
# Return the serialized data
return model
[docs]
@classmethod
def deserialize_model(cls, api_response: dict[str, Any]) -> dict[str, Any]:
"""Parse Datawrapper API response including scatter plot specific fields.
Args:
api_response: The JSON response from the chart metadata endpoint
chart_data: The CSV data from the chart data endpoint
Returns:
Dictionary that can be used to initialize the ScatterPlot model
"""
# Call parent to get base fields
init_data = super().deserialize_model(api_response)
# Extract scatter-specific sections
metadata = api_response.get("metadata", {})
visualize = metadata.get("visualize", {})
axes = metadata.get("axes", {})
# Parse axes columns
init_data["x_column"] = axes.get("x")
init_data["y_column"] = axes.get("y")
init_data["size_column"] = axes.get("size")
init_data["shape_column"] = axes.get("shape")
init_data["label_column"] = axes.get("labels")
if "color" in axes:
init_data["color_column"] = axes["color"]
# Parse x-axis
x_axis = visualize.get("x-axis", {})
if isinstance(x_axis, dict):
init_data["x_log"] = x_axis.get("log", False)
init_data["x_range"] = x_axis.get("range", ["", ""])
init_data["x_ticks"] = x_axis.get("ticks", [])
else:
init_data["x_log"] = False
init_data["x_range"] = ["", ""]
init_data["x_ticks"] = []
if "x-format" in visualize:
init_data["x_format"] = visualize["x-format"]
if "x-pos" in visualize:
init_data["x_position"] = visualize["x-pos"]
if "x-grid-lines" in visualize:
init_data["x_grid_lines"] = visualize["x-grid-lines"]
# Parse y-axis
y_axis = visualize.get("y-axis", {})
if isinstance(y_axis, dict):
init_data["y_log"] = y_axis.get("log", False)
init_data["y_range"] = y_axis.get("range", ["", ""])
init_data["y_ticks"] = y_axis.get("ticks", [])
else:
init_data["y_log"] = False
init_data["y_range"] = ["", ""]
init_data["y_ticks"] = []
if "y-format" in visualize:
init_data["y_format"] = visualize["y-format"]
if "y-pos" in visualize:
init_data["y_position"] = visualize["y-pos"]
if "y-grid-lines" in visualize:
init_data["y_grid_lines"] = visualize["y-grid-lines"]
# Colors
if "base-color" in visualize:
init_data["base_color"] = visualize["base-color"]
if "opacity" in visualize:
init_data["opacity"] = visualize["opacity"]
if "outlines" in visualize:
init_data["outlines"] = visualize["outlines"]
if "color-outline" in visualize:
init_data["color_outline"] = visualize["color-outline"]
if "show-color-key" in visualize:
init_data["show_color_key"] = visualize["show-color-key"]
# Parse color-category using utility
init_data.update(ColorCategory.deserialize(visualize.get("color-category")))
# Size
if "size" in visualize:
init_data["size"] = visualize["size"]
if "fixed-size" in visualize:
init_data["fixed_size"] = visualize["fixed-size"]
if "max-size" in visualize:
init_data["max_size"] = visualize["max-size"]
if "responsive-symbol-size" in visualize:
init_data["responsive_symbol_size"] = visualize["responsive-symbol-size"]
if "show-size-legend" in visualize:
init_data["show_size_legend"] = visualize["show-size-legend"]
if "size-legend-position" in visualize:
init_data["size_legend_position"] = visualize["size-legend-position"]
if "legend-offset-x" in visualize:
init_data["legend_offset_x"] = visualize["legend-offset-x"]
if "legend-offset-y" in visualize:
init_data["legend_offset_y"] = visualize["legend-offset-y"]
if "size-legend-values-setting" in visualize:
init_data["size_legend_values_format"] = visualize[
"size-legend-values-setting"
]
if "size-legend-values" in visualize:
init_data["size_legend_values"] = visualize["size-legend-values"]
if "size-legend-label-position" in visualize:
init_data["size_legend_label_position"] = visualize[
"size-legend-label-position"
]
if "size-legend-label-format" in visualize:
init_data["size_legend_label_format"] = visualize[
"size-legend-label-format"
]
if "size-legend-title-enabled" in visualize:
init_data["size_legend_title_enabled"] = visualize[
"size-legend-title-enabled"
]
if "size-legend-title" in visualize:
init_data["size_legend_title"] = visualize["size-legend-title"]
if "size-legend-title-position" in visualize:
init_data["size_legend_title_position"] = visualize[
"size-legend-title-position"
]
if "size-legend-title-width" in visualize:
init_data["size_legend_title_width"] = visualize["size-legend-title-width"]
# Shape
if "shape" in visualize:
init_data["shape"] = visualize["shape"]
if "fixed-shape" in visualize:
init_data["fixed_shape"] = visualize["fixed-shape"]
# Trend line
if "regression" in visualize:
init_data["regression"] = visualize["regression"]
if "regression-method" in visualize:
init_data["regression_method"] = visualize["regression-method"]
# Appearance
init_data.update(PlotHeight.deserialize(visualize))
# Annotations
init_data.update(cls._deserialize_annotations(visualize))
if "custom-lines" in visualize:
init_data["custom_lines"] = visualize["custom-lines"]
# Labeling
if "auto-labels" in visualize:
init_data["auto_labels"] = visualize["auto-labels"]
if "add-labels" in visualize:
init_data["add_labels"] = visualize["add-labels"]
if "highlight-labeled" in visualize:
init_data["highlight_labeled"] = visualize["highlight-labeled"]
# Tooltips
tooltip = visualize.get("tooltip", {})
if isinstance(tooltip, dict):
init_data["tooltip_enabled"] = tooltip.get("enabled", True)
init_data["tooltip_title"] = tooltip.get("title", "")
init_data["tooltip_body"] = tooltip.get("body", "")
init_data["tooltip_sticky"] = tooltip.get("sticky", False)
else:
init_data["tooltip_enabled"] = True
init_data["tooltip_title"] = ""
init_data["tooltip_body"] = ""
init_data["tooltip_sticky"] = False
return init_data