import {
  clearApiProviders,
  createAssistantMessageEventStream,
  getApiProvider,
  registerBuiltInApiProviders,
  unregisterApiProviders,
} from "@mariozechner/pi-ai";
import { afterEach, describe, expect, it, vi } from "vitest";
import { ensureCustomApiRegistered, getCustomApiRegistrySourceId } from "./custom-api-registry.js";

describe("ensureCustomApiRegistered", () => {
  afterEach(() => {
    unregisterApiProviders(getCustomApiRegistrySourceId("test-custom-api"));
    clearApiProviders();
    registerBuiltInApiProviders();
  });

  it("registers a custom api provider once", () => {
    const streamFn = vi.fn(() => createAssistantMessageEventStream());

    expect(ensureCustomApiRegistered("test-custom-api", streamFn)).toBe(true);
    expect(ensureCustomApiRegistered("test-custom-api", streamFn)).toBe(false);

    const provider = getApiProvider("test-custom-api");
    expect(provider).toBeDefined();
  });

  it("delegates both stream entrypoints to the provided stream function", () => {
    const stream = createAssistantMessageEventStream();
    const streamFn = vi.fn(() => stream);
    ensureCustomApiRegistered("test-custom-api", streamFn);

    const provider = getApiProvider("test-custom-api");
    expect(provider).toBeDefined();

    const model = { api: "test-custom-api", provider: "custom", id: "m" };
    const context = { messages: [] };
    const options = { maxTokens: 32 };

    expect(provider?.stream(model as never, context as never, options as never)).toBe(stream);
    expect(provider?.streamSimple(model as never, context as never, options as never)).toBe(stream);
    expect(streamFn).toHaveBeenCalledTimes(2);
  });
});
