Commit d076164b authored by Archit Tamarapu's avatar Archit Tamarapu
Browse files

update plotting code

parent b0581ef2
Loading
Loading
Loading
Loading
+131 −11
Original line number Diff line number Diff line
@@ -69,19 +69,138 @@ def plot_loudness_by_bandwidth(df, in_fmt, out_fmt, out_dir):
        ax = axes[idx]
        bw_data = filtered_df[filtered_df["bandwidth"] == bw]

        if input_loudness is not None:
            # highlight ±1 LKFS green
            ax.axhspan(
                input_loudness - 1,
                input_loudness + 1,
                color="xkcd:palegreen",
                alpha=0.2,
                zorder=0,
            )

            # highlight ±1 to ±3 LKFS yellow
            ax.axhspan(
                input_loudness + 1,
                input_loudness + 3,
                color="xkcd:yellow",
                alpha=0.2,
                zorder=0,
            )
            ax.axhspan(
                input_loudness - 3,
                input_loudness - 1,
                color="xkcd:yellow",
                alpha=0.2,
                zorder=0,
            )

            # highlight ±3 to ±6 LKFS orange
            ax.axhspan(
                input_loudness + 3,
                input_loudness + 6,
                color="xkcd:orange",
                alpha=0.2,
                zorder=0,
            )
            ax.axhspan(
                input_loudness - 6,
                input_loudness - 3,
                color="xkcd:orange",
                alpha=0.2,
                zorder=0,
            )

            # red beyond ±6 LKFS
            ax.axhspan(
                input_loudness + 6, y_max, color="xkcd:crimson", alpha=0.2, zorder=0
            )
            ax.axhspan(
                y_min, input_loudness - 6, color="xkcd:crimson", alpha=0.2, zorder=0
            )

        legend_added = {True: False, False: False}

        # plot for each dtx value
        for dtx_val in [True, False]:
            subset = bw_data[bw_data["dtx"] == dtx_val]

            if not subset.empty:
                # equal spacing for bitrates on x axis
                x_positions = [bitrate_to_idx[br] for br in subset["bitrate"]]
                y_values = subset["output_loudness"]
                for _, row in subset.iterrows():
                    br = row["bitrate"]
                    y_val = row["output_loudness"]
                    x_pos = bitrate_to_idx[br]

                    # Clamp out of range values top top/bottom with ^ and v markers and a value textbox
                    if y_val > y_max:
                        ax.scatter(
                            x_pos,
                            y_max - 0.3,
                            marker="^",
                            s=150,
                            color="red",
                            edgecolors="darkred",
                            linewidths=1.5,
                            alpha=0.9,
                            zorder=5,
                        )
                        ax.text(
                            x_pos,
                            y_max - 0.8,
                            f"{y_val:.1f}",
                            ha="center",
                            va="top",
                            fontsize=7,
                            fontweight="bold",
                            bbox=dict(
                                boxstyle="round,pad=0.3",
                                facecolor="white",
                                edgecolor="red",
                                alpha=0.9,
                            ),
                        )
                    elif y_val < y_min:
                        ax.scatter(
                            x_pos,
                            y_min + 0.3,
                            marker="v",
                            s=150,
                            color="red",
                            edgecolors="darkred",
                            linewidths=1.5,
                            alpha=0.9,
                            zorder=5,
                        )
                        ax.text(
                            x_pos,
                            y_min + 0.8,
                            f"{y_val:.1f}",
                            ha="center",
                            va="bottom",
                            fontsize=7,
                            fontweight="bold",
                            bbox=dict(
                                boxstyle="round,pad=0.3",
                                facecolor="white",
                                edgecolor="red",
                                alpha=0.9,
                            ),
                        )
                    else:
                        marker = "x" if dtx_val else "o"
                label = f"dtx={dtx_val}" if show_legend else None
                        label = None
                        if show_legend and not legend_added[dtx_val]:
                            label = f"dtx={dtx_val}"
                            legend_added[dtx_val] = True

                        ax.scatter(
                    x_positions, y_values, label=label, marker=marker, s=80, alpha=0.7
                            x_pos,
                            y_val,
                            label=label,
                            marker=marker,
                            s=80,
                            alpha=0.7,
                            zorder=3,
                        )

        # plot input loudness as horizontal reference line
@@ -94,6 +213,7 @@ def plot_loudness_by_bandwidth(df, in_fmt, out_fmt, out_dir):
                linewidth=2,
                alpha=0.7,
                label=label_input,
                zorder=2,
            )

        ax.set_xticks(range(len(bitrates)))
@@ -110,8 +230,8 @@ def plot_loudness_by_bandwidth(df, in_fmt, out_fmt, out_dir):
        ax.set_xlabel("Bitrate (kbps)", fontsize=11)
        ax.set_ylabel("Output Loudness (LKFS)", fontsize=11)
        ax.set_title(f"{bw.upper()}", fontsize=12)
        ax.grid(True, which="minor", alpha=0.3)
        ax.grid(True, which="major", alpha=0.5)
        ax.grid(True, which="minor", alpha=0.3, zorder=1)
        ax.grid(True, which="major", alpha=0.5, zorder=1)

        # only show legend if there are multiple DTX values
        if show_legend:
@@ -121,7 +241,7 @@ def plot_loudness_by_bandwidth(df, in_fmt, out_fmt, out_dir):
    fig.suptitle(title, fontsize=14)

    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.savefig(out_file)
    plt.savefig(out_file, dpi=150)
    plt.close()