Skip to content

Add a Custom Provider

Need to query a model not natively supported by ChainForge? Or maybe you have an LLM chain setup somewhere, and you'd like to compare responses against other models like GPT-4.

Adding a custom provider is as simple as writing a Python function, decorating it a bit, and dropping it into the ChainForge settings window. Here's adding support for Cohere:

To extend the provider list in ChainForge:

  1. Write a completion function in Python that conforms to CustomProviderProtocol.
  2. Register it in ChainForge by decorating your function with @provider and giving it a name. See example below.

  3. Drop the script into the global settings screen (the Custom Providers tab).

Once added, custom provider scripts are cache'd by ChainForge and will be auto-run every time you load the application.

Danger

Only add Python scripts you trust. ChainForge will execute the Python code inside, so be careful.

Simple Example: Reverse the input prompt

Here's a naive provider that just returns the prompt reversed:

from chainforge.providers import provider

@provider(name="Mirror", emoji="🪞")
def mirror_the_prompt(prompt: str, **kwargs) -> str:
    return prompt[::-1]

Try saving this code as a Python script (.py), then dropping it into ChainForge settings!

Advanced Example: Cohere API

Here's a custom model provider to support Cohere AI text completions through their Python API. It adds different model names for Cohere's models, and extends the settings window with two parameters, temperature and max_tokens.

(Note: You must have the cohere package installed and an API key.)

from chainforge.providers import provider
import cohere

# Init the Cohere client (replace with your API key):
co = cohere.Client('<YOUR_COHERE_API_KEY>')

# JSON schemas to pass react-jsonschema-form, one for this provider's settings and one to describe the settings UI.
COHERE_SETTINGS_SCHEMA = {
  "settings": {
    "temperature": {
      "type": "number",
      "title": "temperature",
      "description": "Controls the 'creativity' or randomness of the response.",
      "default": 0.75,
      "minimum": 0,
      "maximum": 5.0,
      "multipleOf": 0.01,
    },
    "max_tokens": {
      "type": "integer",
      "title": "max_tokens",
      "description": "Maximum number of tokens to generate in the response.",
      "default": 100,
      "minimum": 1,
      "maximum": 1024,
    },
  },
  "ui": {
    "temperature": {
      "ui:help": "Defaults to 1.0.",
      "ui:widget": "range"
    },
    "max_tokens": {
      "ui:help": "Defaults to 100.",
      "ui:widget": "range"
    },
  }
}

# Our custom model provider for Cohere's text generation API.
@provider(name="Cohere",
          emoji="🖇", 
          models=['command', 'command-nightly', 'command-light', 'command-light-nightly'],
          rate_limit="sequential", # enter "sequential" for blocking; an integer N > 0 means N is the max mumber of requests per minute. 
          settings_schema=COHERE_SETTINGS_SCHEMA)
def CohereCompletion(prompt: str, model: str, temperature: float = 0.75, **kwargs) -> str:
    print(f"Calling Cohere model {model} with prompt '{prompt}'...")
    response = co.generate(model=model, prompt=prompt, temperature=temperature, **kwargs)
    return response.generations[0].text

This example uses the models keyword argument in the @provider decorator to define different model names for this provider that we can then change in its settings window. It also adds custom settings with temperature and max_tokens parameters. The settings schema are in react-jsonschema-form format. You can get a sense for what settings JSON you might write by looking at the schemas already defined in the ChainForge source code.

cohere-provider-settings

Defining a function conforming to CustomProviderProtocol

CustomProviderProtocol is defined as:

class CustomProviderProtocol(Protocol):
  def __call__(self, 
               prompt: str,
               model: Optional[str], 
               chat_history: Optional[ChatHistory],
               **kwargs: Any) -> str:
      ...

Parameters:

  • prompt: Text to prompt the model. (If it's a chat model, this is the new message to send.)
  • model: Optional. The name of the particular model to use, from the CF settings window. Useful when you have multiple models for a single provider.
  • chat_history: Optional. Providers may be passed a past chat context as a list of chat messages in OpenAI format. Chat history does not include the new message to send off (which is passed instead as the prompt parameter). This is only relevant when using Chat Turn nodes.
  • kwargs: Any other parameters to pass the provider call, like temperature. Parameter names are the keynames in your provider's settings_schema, passed to the @provider decorator. Only relevant if you are defining a custom settings JSON to edit provider/model settings in ChainForge.

Full description of @provider

@provider is a decorator for registering custom response provider methods or classes (Callables) that conform to CustomProviderProtocol. Use it by decorating your CustomProviderProtocol-conforming function:

@provider(name='My Provider', emoji='✨', ...)
def my_completion_func(prompt: str, **kwargs):
  ...

You must provide name and emoji kwargs, but there are more optional arguments:

  • name: The name of your custom provider. Required. (Must be unique; cannot be blank.)
  • emoji: The emoji to use as the default icon for your provider in the CF interface. Required.
  • models: A list of models that your provider supports, that you want to be able to choose between in Settings window. If you're just calling a single model, you can omit this.
  • rate_limit: If an integer, the maximum number of simulatenous requests to send per minute. To force requests to be sequential (wait until each request returns before sending another), enter "sequential". Default is sequential.
  • settings_schema: an optional JSON Schema specifying the name of your provider in the ChainForge UI, the available settings, and the UI for those settings. The settings and UI specs are in react-jsonschema-form format.

    Specifically, your settings_schema dict should have keys:

    {
      "settings": <JSON dict of the schema properties for your settings form, in react-jsonschema-form format (https://rjsf-team.github.io/react-jsonschema-form/docs/)>,
      "ui": <JSON dict of the UI Schema for your settings form, in react-jsonschema-form (see UISchema example here: https://rjsf-team.github.io/react-jsonschema-form/) 
    }
    

    You may look to adapt an existing schema from ModelSettingsSchemas.js in the ChainForge source code, BUT with the following things to keep in mind:

    • the value of "settings" should just be the value of "properties" in the full schema
    • don't include a 'shortname' property; this will be added by default and set to the value of name
    • don't include a 'model' property; this will be populated by the list you passed to models (if any)
    • the keynames of all properties of the schema should be valid as variable names for Python keyword args; i.e., no spaces

    Finally, if you want temperature to appear in the ChainForge UI, you must name your settings schema property temperature, and give it minimum and maximum values.

    Warning

    Only textarea, range, and enum, and text input UI widgets are properly supported from react-jsonschema-form; you can try other widget types, but the CSS may not display property. This is because Mantine's CSS clashes with react-jsonschema-form's CSS of Bootstrap. If you'd like more widgets, you'll have to tweak the CSS and make a pull request to the ChainForge source code.