From 840874e1b579bcd5c57779f997926f313eb300b4 Mon Sep 17 00:00:00 2001 From: notowen333 <51685858+notowen333@users.noreply.github.com> Date: Thu, 28 May 2026 10:43:06 -0400 Subject: [PATCH 1/4] feat: add per-invocation limits (turns / outputTokens / totalTokens) (#1106) --- strands-ts/src/agent/__tests__/agent.test.ts | 274 +++++++++++++++++++ strands-ts/src/agent/agent.ts | 73 +++++ strands-ts/src/types/agent.ts | 54 ++++ strands-ts/src/types/messages.ts | 8 +- 4 files changed, 408 insertions(+), 1 deletion(-) diff --git a/strands-ts/src/agent/__tests__/agent.test.ts b/strands-ts/src/agent/__tests__/agent.test.ts index abed0afad9..056df3de1d 100644 --- a/strands-ts/src/agent/__tests__/agent.test.ts +++ b/strands-ts/src/agent/__tests__/agent.test.ts @@ -1926,4 +1926,278 @@ describe('normalizeToolUseNames', () => { setterSpy.mockRestore() }) }) + + describe('limits', () => { + const toolUseTurn = ( + toolUseId: string, + usage: { inputTokens: number; outputTokens: number; totalTokens: number } + ): Parameters => [ + { type: 'toolUseBlock', name: 'loop', toolUseId, input: {} }, + { usage }, + ] + + const passthroughTool = (): ReturnType => + createMockTool( + 'loop', + (context) => + new ToolResultBlock({ + toolUseId: context.toolUse.toolUseId, + status: 'success' as const, + content: [new TextBlock('ok')], + }) + ) + + describe('when limits.turns is reached', () => { + it('runs the cycle to completion and bails at top of next iteration', async () => { + const model = new MockMessageModel() + .addTurn(...toolUseTurn('tool-1', { inputTokens: 10, outputTokens: 5, totalTokens: 15 })) + .addTurn(...toolUseTurn('tool-2', { inputTokens: 20, outputTokens: 5, totalTokens: 25 })) + + const agent = new Agent({ model, tools: [passthroughTool()] }) + + const result = await agent.invoke('go', { limits: { turns: 1 } }) + + // Bail after tools — lastMessage is the user toolResult, so we don't + // use expectAgentResult (which assumes role 'assistant'). + expect(result).toEqual( + expect.objectContaining({ + type: 'agentResult', + stopReason: 'limitTurns', + lastMessage: expect.objectContaining({ + role: 'user', + content: expect.arrayContaining([expect.any(ToolResultBlock)]), + }), + metrics: expectLoopMetrics({ + cycleCount: 1, + toolNames: ['loop'], + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, + }), + }) + ) + }) + }) + + describe('when limits is generous', () => { + it('does not trip and the model ends naturally', async () => { + const model = new MockMessageModel().addTurn( + { type: 'textBlock', text: 'done' }, + { usage: { inputTokens: 5, outputTokens: 5, totalTokens: 10 } } + ) + const agent = new Agent({ model }) + + const result = await agent.invoke('go', { limits: { turns: 5, outputTokens: 1000, totalTokens: 1000 } }) + + expect(result).toEqual( + expectAgentResult({ + stopReason: 'endTurn', + messageText: 'done', + cycleCount: 1, + usage: { inputTokens: 5, outputTokens: 5, totalTokens: 10 }, + }) + ) + }) + }) + + describe('when limits.outputTokens is reached', () => { + it('returns limitOutputTokens once cumulative outputTokens hits the cap', async () => { + const model = new MockMessageModel() + .addTurn(...toolUseTurn('tool-1', { inputTokens: 10, outputTokens: 60, totalTokens: 70 })) + .addTurn(...toolUseTurn('tool-2', { inputTokens: 10, outputTokens: 60, totalTokens: 70 })) + + const agent = new Agent({ model, tools: [passthroughTool()] }) + + const result = await agent.invoke('go', { limits: { outputTokens: 100 } }) + + expect(result).toEqual( + expect.objectContaining({ + type: 'agentResult', + stopReason: 'limitOutputTokens', + lastMessage: expect.objectContaining({ + role: 'user', + content: expect.arrayContaining([expect.any(ToolResultBlock)]), + }), + metrics: expectLoopMetrics({ + cycleCount: 2, + toolNames: ['loop'], + usage: { inputTokens: 20, outputTokens: 120, totalTokens: 140 }, + }), + }) + ) + }) + + it('uses at-most (>=) semantics: stops when count exactly equals the cap', async () => { + const model = new MockMessageModel() + .addTurn(...toolUseTurn('tool-1', { inputTokens: 10, outputTokens: 100, totalTokens: 110 })) + .addTurn(...toolUseTurn('tool-2', { inputTokens: 10, outputTokens: 100, totalTokens: 110 })) + + const agent = new Agent({ model, tools: [passthroughTool()] }) + + const result = await agent.invoke('go', { limits: { outputTokens: 100 } }) + + expect(result).toEqual( + expect.objectContaining({ + type: 'agentResult', + stopReason: 'limitOutputTokens', + metrics: expectLoopMetrics({ + cycleCount: 1, + toolNames: ['loop'], + usage: { inputTokens: 10, outputTokens: 100, totalTokens: 110 }, + }), + }) + ) + }) + }) + + describe('when limits.totalTokens is reached', () => { + it('returns limitTotalTokens once cumulative totalTokens hits the cap', async () => { + const model = new MockMessageModel() + .addTurn(...toolUseTurn('tool-1', { inputTokens: 200, outputTokens: 100, totalTokens: 300 })) + .addTurn(...toolUseTurn('tool-2', { inputTokens: 200, outputTokens: 100, totalTokens: 300 })) + + const agent = new Agent({ model, tools: [passthroughTool()] }) + + const result = await agent.invoke('go', { limits: { totalTokens: 500 } }) + + expect(result).toEqual( + expect.objectContaining({ + type: 'agentResult', + stopReason: 'limitTotalTokens', + lastMessage: expect.objectContaining({ + role: 'user', + content: expect.arrayContaining([expect.any(ToolResultBlock)]), + }), + metrics: expectLoopMetrics({ + cycleCount: 2, + toolNames: ['loop'], + usage: { inputTokens: 400, outputTokens: 200, totalTokens: 600 }, + }), + }) + ) + }) + }) + + describe('when the model ends naturally on the same turn the limit would trip', () => { + it('returns endTurn — the model answer wins', async () => { + const model = new MockMessageModel().addTurn( + { type: 'textBlock', text: 'final answer' }, + { usage: { inputTokens: 300, outputTokens: 300, totalTokens: 600 } } + ) + const agent = new Agent({ model }) + + const result = await agent.invoke('go', { limits: { totalTokens: 500 } }) + + expect(result).toEqual( + expectAgentResult({ + stopReason: 'endTurn', + messageText: 'final answer', + cycleCount: 1, + usage: { inputTokens: 300, outputTokens: 300, totalTokens: 600 }, + }) + ) + }) + }) + + describe('when multiple limits trip simultaneously', () => { + const heavyUsage = { inputTokens: 100, outputTokens: 100, totalTokens: 200 } + + const buildAgent = (): Agent => { + const model = new MockMessageModel() + .addTurn(...toolUseTurn('tool-1', heavyUsage)) + .addTurn(...toolUseTurn('tool-2', heavyUsage)) + return new Agent({ model, tools: [passthroughTool()] }) + } + + it('prefers turns when all three trip', async () => { + const result = await buildAgent().invoke('go', { + limits: { turns: 1, totalTokens: 1, outputTokens: 1 }, + }) + + expect(result).toEqual( + expect.objectContaining({ + type: 'agentResult', + stopReason: 'limitTurns', + metrics: expectLoopMetrics({ cycleCount: 1, toolNames: ['loop'] }), + }) + ) + }) + + it('prefers totalTokens over outputTokens', async () => { + const result = await buildAgent().invoke('go', { limits: { totalTokens: 1, outputTokens: 1 } }) + + expect(result).toEqual( + expect.objectContaining({ + type: 'agentResult', + stopReason: 'limitTotalTokens', + metrics: expectLoopMetrics({ cycleCount: 1, toolNames: ['loop'] }), + }) + ) + }) + + it('falls back to outputTokens when no higher-priority cap is set', async () => { + const result = await buildAgent().invoke('go', { limits: { outputTokens: 1 } }) + + expect(result).toEqual( + expect.objectContaining({ + type: 'agentResult', + stopReason: 'limitOutputTokens', + metrics: expectLoopMetrics({ cycleCount: 1, toolNames: ['loop'] }), + }) + ) + }) + }) + + describe('when the same agent is reused across invocations', () => { + it('scopes the limit to the current invocation, not lifetime', async () => { + // Each turn uses 50 output tokens. With limits.outputTokens: 75, a single + // invocation tolerates one turn but trips on the second. If the cap + // were lifetime-scoped, the second `invoke()` would trip on its first + // turn (75 cumulative across both calls). + const model = new MockMessageModel() + .addTurn( + { type: 'textBlock', text: 'first' }, + { usage: { inputTokens: 10, outputTokens: 50, totalTokens: 60 } } + ) + .addTurn( + { type: 'textBlock', text: 'second' }, + { usage: { inputTokens: 10, outputTokens: 50, totalTokens: 60 } } + ) + + const agent = new Agent({ model }) + + const r1 = await agent.invoke('go', { limits: { outputTokens: 75 } }) + expect(r1.stopReason).toBe('endTurn') + expect(r1.metrics?.latestAgentInvocation?.cycles.length).toBe(1) + + const r2 = await agent.invoke('go again', { limits: { outputTokens: 75 } }) + expect(r2.stopReason).toBe('endTurn') + expect(r2.metrics?.latestAgentInvocation?.cycles.length).toBe(1) + }) + }) + + describe('when a limit is invalid', () => { + it.each([ + ['negative', { limits: { turns: -1 } }], + ['zero', { limits: { turns: 0 } }], + ['NaN', { limits: { outputTokens: NaN } }], + ['Infinity', { limits: { totalTokens: Infinity } }], + ])('rejects %s with TypeError', async (_label, options) => { + const agent = new Agent({ model: new MockMessageModel().addTurn({ type: 'textBlock', text: 'never reached' }) }) + await expect(agent.invoke('go', options)).rejects.toThrow(TypeError) + }) + }) + + describe('when invoked via stream()', () => { + it('returns limitTurns as the generator return value', async () => { + const model = new MockMessageModel() + .addTurn(...toolUseTurn('tool-1', { inputTokens: 10, outputTokens: 5, totalTokens: 15 })) + .addTurn(...toolUseTurn('tool-2', { inputTokens: 10, outputTokens: 5, totalTokens: 15 })) + + const agent = new Agent({ model, tools: [passthroughTool()] }) + + const { result } = await collectGenerator(agent.stream('go', { limits: { turns: 1 } })) + + expect(result).toEqual(expect.objectContaining({ type: 'agentResult', stopReason: 'limitTurns' })) + }) + }) + }) }) diff --git a/strands-ts/src/agent/agent.ts b/strands-ts/src/agent/agent.ts index 41c674ed7a..7fb3d64841 100644 --- a/strands-ts/src/agent/agent.ts +++ b/strands-ts/src/agent/agent.ts @@ -15,6 +15,7 @@ import { type ContentBlockData, Message, type MessageData, + type StopReason, type SystemPrompt, type SystemPromptData, TextBlock, @@ -493,6 +494,64 @@ export class Agent implements LocalAgent, InvokableAgent { } } + /** + * Validates the per-invocation budget caps in {@link InvokeOptions.limits}. + * Called once at the top of `_stream` so bad inputs fail fast with a clear + * error instead of silently no-op'ing (`NaN`, `Infinity`) or tripping + * pathologically (zero swallows the user input; negative trips immediately). + * + * Each cap, when set, must be a positive finite number. Fractional values + * are accepted — harmless, and useful for token budgets derived from + * arithmetic. + */ + private _validateLimits(options: InvokeOptions | undefined): void { + if (!options?.limits) return + const assertPositive = (name: string, value: number | undefined): void => { + if (value !== undefined && (!Number.isFinite(value) || value <= 0)) { + throw new TypeError(`${name} must be a positive finite number, got ${value}`) + } + } + assertPositive('limits.turns', options.limits.turns) + assertPositive('limits.outputTokens', options.limits.outputTokens) + assertPositive('limits.totalTokens', options.limits.totalTokens) + } + + /** + * Evaluates the per-invocation budget caps in {@link InvokeOptions.limits} + * against the current invocation's metrics. Called at the top of each + * agent-loop iteration, after `_throwIfCancelled` and before `startCycle`. + * + * Reads from {@link AgentMetrics.latestAgentInvocation} (scoped to the + * current invocation) — not `cycleCount` / `accumulatedUsage`, which are + * lifetime accumulators that would cause caps to fire prematurely on the + * second `invoke()` call against a reused agent. + * + * Priority on simultaneous trip: turns → totalTokens → outputTokens. + * + * Returns the {@link StopReason} the loop should terminate with, or + * `undefined` if every configured cap is still within budget. + */ + private _checkLimits(options: InvokeOptions | undefined): StopReason | undefined { + const limits = options?.limits + if (!limits) return undefined + const invocation = this._meter.metrics.latestAgentInvocation + if (!invocation) return undefined + + const cycleCount = invocation.cycles.length + const { outputTokens, totalTokens } = invocation.usage + + if (limits.turns !== undefined && cycleCount >= limits.turns) { + return 'limitTurns' + } + if (limits.totalTokens !== undefined && totalTokens >= limits.totalTokens) { + return 'limitTotalTokens' + } + if (limits.outputTokens !== undefined && outputTokens >= limits.outputTokens) { + return 'limitOutputTokens' + } + return undefined + } + /** * The tools this agent can use. */ @@ -862,6 +921,8 @@ export class Agent implements LocalAgent, InvokableAgent { let currentArgs: InvokeArgs | undefined = args let result: AgentResult | undefined + this._validateLimits(options) + // Resolve structured output schema from per-invocation options or constructor config const structuredOutputSchema = options?.structuredOutputSchema ?? this._structuredOutputSchema const structuredOutputTool = structuredOutputSchema ? new StructuredOutputTool(structuredOutputSchema) : undefined @@ -929,6 +990,18 @@ export class Agent implements LocalAgent, InvokableAgent { while (true) { this._throwIfCancelled() + const limitStopReason = this._checkLimits(options) + if (limitStopReason) { + result = new AgentResult({ + stopReason: limitStopReason, + lastMessage: this.messages.at(-1)!, + traces: this._tracer.localTraces, + metrics: this._meter.metrics, + invocationState, + }) + return result + } + // Start metrics cycle tracking const { cycleId, startTime: cycleStartTime } = this._meter.startCycle() diff --git a/strands-ts/src/types/agent.ts b/strands-ts/src/types/agent.ts index a842feab70..a352e4c8ae 100644 --- a/strands-ts/src/types/agent.ts +++ b/strands-ts/src/types/agent.ts @@ -118,6 +118,60 @@ export interface InvokeOptions { * ``` */ cancelSignal?: AbortSignal + + /** + * Per-invocation budget caps. Each cap, when set, bounds the agent loop + * for this `invoke()` / `stream()` call only — counters are not cumulative + * across reuses of the same agent. + * + * Caps are checked at the top of each loop iteration. Tools requested by + * the previous turn always run to completion before a cap fires, so + * `agent.messages` remains in a reinvokable state. + * + * Each cap, when set, must be a positive finite number. Omit any field + * (or `limits` itself) for no limit on that dimension. + * + * Priority on simultaneous trip (highest first): `turns`, `totalTokens`, + * `outputTokens`. The corresponding `stopReason` is `'limitTurns'`, + * `'limitTotalTokens'`, or `'limitOutputTokens'`. + */ + limits?: { + /** + * Maximum number of agent loop iterations (turns). A turn is one model + * call plus any tool execution that follows. Counted against + * `metrics.latestAgentInvocation.cycles.length`. + */ + turns?: number + + /** + * Maximum cumulative model-generated tokens, summed across every model + * call in the agent loop + * (`metrics.latestAgentInvocation.usage.outputTokens`). + * + * Distinct from per-call provider-level `maxTokens` settings (e.g. + * `GoogleModelConfig.params.maxOutputTokens`), which bound a single + * model call's output. This cap bounds the loop's cumulative output + * across however many calls it makes. + * + * Soft cap: a single oversized model response can overshoot the budget. + * The agent stops at the first turn boundary on or after the budget is + * reached; it does not bound any individual model call. + */ + outputTokens?: number + + /** + * Maximum cumulative input + output tokens + * (`metrics.latestAgentInvocation.usage.totalTokens`). Each model + * call's input includes prior turns, so this counter compounds across + * the run — it approximates the total token spend you would be billed + * for. + * + * Soft cap: a single oversized model response can overshoot the budget. + * The agent stops at the first turn boundary on or after the budget is + * reached; it does not bound any individual model call. + */ + totalTokens?: number + } } /** diff --git a/strands-ts/src/types/messages.ts b/strands-ts/src/types/messages.ts index a5bd8f24f4..8fe3a912a3 100644 --- a/strands-ts/src/types/messages.ts +++ b/strands-ts/src/types/messages.ts @@ -657,7 +657,10 @@ export class JsonBlock implements JsonBlockData, JSONSerializable * - `endTurn` - Natural end of the model's turn * - `guardrailIntervened` - A guardrail policy stopped generation * - `interrupt` - Agent execution was interrupted for human input - * - `maxTokens` - Maximum token limit was reached + * - `maxTokens` - The model provider's per-call token cap was reached + * - `limitOutputTokens` - Agent loop stopped because `InvokeOptions.limits.outputTokens` was reached + * - `limitTotalTokens` - Agent loop stopped because `InvokeOptions.limits.totalTokens` was reached + * - `limitTurns` - Agent loop stopped because `InvokeOptions.limits.turns` was reached * - `pauseTurn` - Model paused a long-running turn; the response should be sent back to continue * - `refusal` - A streaming classifier intervened to handle a potential policy violation * - `stopSequence` - A stop sequence was encountered @@ -671,6 +674,9 @@ export type StopReason = | 'guardrailIntervened' | 'interrupt' | 'maxTokens' + | 'limitOutputTokens' + | 'limitTotalTokens' + | 'limitTurns' | 'pauseTurn' | 'refusal' | 'stopSequence' From 36c16f1390d6d1c10add26e51481fd72fbdd44ba Mon Sep 17 00:00:00 2001 From: Chay Nabors Date: Thu, 28 May 2026 12:07:49 -0400 Subject: [PATCH 2/4] feat(wit): align type names with TS SDK (#1107) Co-authored-by: Strands Agent <217235299+strands-agent@users.noreply.github.com> --- .github/workflows/typescript-pr-and-push.yml | 9 + .github/workflows/typescript-ts-check.yml | 9 + .github/workflows/wasm-py-check.yml | 52 + package-lock.json | 37 +- package.json | 2 +- pyproject.toml | 4 +- strandly/package.json | 10 +- strandly/src/cli.ts | 23 +- strandly/tsconfig.json | 5 +- strands-py-wasm/pyproject.toml | 8 +- strands-py-wasm/src/strands/__init__.py | 877 +-- strands-py-wasm/src/strands/_generated.py | 5438 ----------------- .../src/strands/_generated/__init__.py | 234 + .../_generated/strands_agent/__init__.py | 1 + .../strands/_generated/strands_agent/api.py | 221 + .../_generated/strands_agent/conversation.py | 57 + .../strands_agent/edge_handler_registry.py | 44 + .../strands_agent/elicitation_handler.py | 68 + .../_generated/strands_agent/host_log.py | 36 + .../strands/_generated/strands_agent/mcp.py | 114 + .../_generated/strands_agent/messages.py | 424 ++ .../strands_agent/model_provider.py | 55 + .../_generated/strands_agent/models.py | 162 + .../_generated/strands_agent/multi_agent.py | 269 + .../strands/_generated/strands_agent/retry.py | 97 + .../_generated/strands_agent/sessions.py | 253 + .../strands_agent/snapshot_storage.py | 62 + .../strands_agent/snapshot_trigger_handler.py | 45 + .../_generated/strands_agent/streaming.py | 443 ++ .../_generated/strands_agent/tool_provider.py | 13 + .../strands/_generated/strands_agent/tools.py | 130 + .../_generated/strands_agent/vended.py | 116 + .../_generated/wasi_clocks/__init__.py | 1 + .../_generated/wasi_clocks/monotonic_clock.py | 18 + .../_generated/wasi_clocks/wall_clock.py | 18 + .../strands/_generated/wasi_io/__init__.py | 1 + .../src/strands/_generated/wasi_io/error.py | 41 + .../src/strands/_generated/wasi_io/poll.py | 28 + .../src/strands/_generated/wasi_io/streams.py | 102 + strands-py-wasm/src/strands/_runtime.py | 39 +- strands-py-wasm/src/strands/types.py | 67 + strands-wasm/entry.ts | 3 - wit/agent.wit | 8 +- wit/conversation.wit | 15 +- wit/mcp.wit | 12 +- wit/models.wit | 20 +- wit/multiagent.wit | 10 +- wit/retry.wit | 12 +- wit/sessions.wit | 22 +- wit/vended.wit | 24 +- 50 files changed, 3636 insertions(+), 6123 deletions(-) create mode 100644 .github/workflows/wasm-py-check.yml delete mode 100644 strands-py-wasm/src/strands/_generated.py create mode 100644 strands-py-wasm/src/strands/_generated/__init__.py create mode 100644 strands-py-wasm/src/strands/_generated/strands_agent/__init__.py create mode 100644 strands-py-wasm/src/strands/_generated/strands_agent/api.py create mode 100644 strands-py-wasm/src/strands/_generated/strands_agent/conversation.py create mode 100644 strands-py-wasm/src/strands/_generated/strands_agent/edge_handler_registry.py create mode 100644 strands-py-wasm/src/strands/_generated/strands_agent/elicitation_handler.py create mode 100644 strands-py-wasm/src/strands/_generated/strands_agent/host_log.py create mode 100644 strands-py-wasm/src/strands/_generated/strands_agent/mcp.py create mode 100644 strands-py-wasm/src/strands/_generated/strands_agent/messages.py create mode 100644 strands-py-wasm/src/strands/_generated/strands_agent/model_provider.py create mode 100644 strands-py-wasm/src/strands/_generated/strands_agent/models.py create mode 100644 strands-py-wasm/src/strands/_generated/strands_agent/multi_agent.py create mode 100644 strands-py-wasm/src/strands/_generated/strands_agent/retry.py create mode 100644 strands-py-wasm/src/strands/_generated/strands_agent/sessions.py create mode 100644 strands-py-wasm/src/strands/_generated/strands_agent/snapshot_storage.py create mode 100644 strands-py-wasm/src/strands/_generated/strands_agent/snapshot_trigger_handler.py create mode 100644 strands-py-wasm/src/strands/_generated/strands_agent/streaming.py create mode 100644 strands-py-wasm/src/strands/_generated/strands_agent/tool_provider.py create mode 100644 strands-py-wasm/src/strands/_generated/strands_agent/tools.py create mode 100644 strands-py-wasm/src/strands/_generated/strands_agent/vended.py create mode 100644 strands-py-wasm/src/strands/_generated/wasi_clocks/__init__.py create mode 100644 strands-py-wasm/src/strands/_generated/wasi_clocks/monotonic_clock.py create mode 100644 strands-py-wasm/src/strands/_generated/wasi_clocks/wall_clock.py create mode 100644 strands-py-wasm/src/strands/_generated/wasi_io/__init__.py create mode 100644 strands-py-wasm/src/strands/_generated/wasi_io/error.py create mode 100644 strands-py-wasm/src/strands/_generated/wasi_io/poll.py create mode 100644 strands-py-wasm/src/strands/_generated/wasi_io/streams.py create mode 100644 strands-py-wasm/src/strands/types.py diff --git a/.github/workflows/typescript-pr-and-push.yml b/.github/workflows/typescript-pr-and-push.yml index 5d539f95e0..b23b507dc2 100644 --- a/.github/workflows/typescript-pr-and-push.yml +++ b/.github/workflows/typescript-pr-and-push.yml @@ -13,6 +13,7 @@ on: - 'package.json' - 'package-lock.json' - '.github/workflows/typescript-*' + - '.github/workflows/wasm-*' merge_group: types: [checks_requested] push: @@ -26,6 +27,7 @@ on: - 'package.json' - 'package-lock.json' - '.github/workflows/typescript-*' + - '.github/workflows/wasm-*' workflow_dispatch: concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} @@ -46,6 +48,13 @@ jobs: with: ref: ${{ github.event.pull_request.head.sha }} + call-py-check: + uses: ./.github/workflows/wasm-py-check.yml + permissions: + contents: read + with: + ref: ${{ github.event.pull_request.head.sha }} + call-ts-test: uses: ./.github/workflows/typescript-ts-test.yml permissions: diff --git a/.github/workflows/typescript-ts-check.yml b/.github/workflows/typescript-ts-check.yml index 27accf3da2..1564fde45e 100644 --- a/.github/workflows/typescript-ts-check.yml +++ b/.github/workflows/typescript-ts-check.yml @@ -30,6 +30,15 @@ jobs: - name: Install dependencies run: npm ci + - name: Verify package-lock.json is in sync with package.json + run: | + npm install --package-lock-only --ignore-scripts + if ! git diff --quiet -- package-lock.json; then + echo "::error::package-lock.json drifted from package.json -- run 'npm install' and commit the lockfile" + git diff --stat -- package-lock.json + exit 1 + fi + - name: Build run: npm run build diff --git a/.github/workflows/wasm-py-check.yml b/.github/workflows/wasm-py-check.yml new file mode 100644 index 0000000000..fed29bce4d --- /dev/null +++ b/.github/workflows/wasm-py-check.yml @@ -0,0 +1,52 @@ +name: Python Check + +on: + workflow_call: + inputs: + ref: + required: true + type: string + +jobs: + py-check: + name: strandly check --py + permissions: + contents: read + runs-on: macos-latest + + steps: + - name: Checkout code + uses: actions/checkout@v6 + with: + ref: ${{ inputs.ref }} + persist-credentials: false + + - name: Setup Node.js + uses: actions/setup-node@v6 + with: + node-version: 22 + package-manager-cache: false + + - name: Setup Python + uses: actions/setup-python@v6 + with: + python-version: "3.12" + + - name: Install wasm-tools + run: brew install wasm-tools + + - name: Bootstrap strandly + venv + run: | + npm ci + echo "$GITHUB_WORKSPACE/node_modules/.bin" >> "$GITHUB_PATH" + python -m venv .venv + .venv/bin/pip install -e . + + - name: Build strands-wasm + run: strandly build --wasm + + - name: Install strands-py-wasm + run: .venv/bin/pip install -e strands-py-wasm + + - name: strandly check --py + run: strandly check --py diff --git a/package-lock.json b/package-lock.json index 6ce38e70bf..45d44e6344 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1628,6 +1628,7 @@ "cpu": [ "arm64" ], + "dev": true, "license": "MIT", "optional": true, "os": [ @@ -4916,6 +4917,7 @@ }, "node_modules/esbuild": { "version": "0.27.7", + "dev": true, "hasInstallScript": true, "license": "MIT", "bin": { @@ -4960,6 +4962,7 @@ "cpu": [ "ppc64" ], + "dev": true, "license": "MIT", "optional": true, "os": [ @@ -4976,6 +4979,7 @@ "cpu": [ "arm" ], + "dev": true, "license": "MIT", "optional": true, "os": [ @@ -4992,6 +4996,7 @@ "cpu": [ "arm64" ], + "dev": true, "license": "MIT", "optional": true, "os": [ @@ -5008,6 +5013,7 @@ "cpu": [ "x64" ], + "dev": true, "license": "MIT", "optional": true, "os": [ @@ -5024,6 +5030,7 @@ "cpu": [ "x64" ], + "dev": true, "license": "MIT", "optional": true, "os": [ @@ -5040,6 +5047,7 @@ "cpu": [ "arm64" ], + "dev": true, "license": "MIT", "optional": true, "os": [ @@ -5056,6 +5064,7 @@ "cpu": [ "x64" ], + "dev": true, "license": "MIT", "optional": true, "os": [ @@ -5072,6 +5081,7 @@ "cpu": [ "arm" ], + "dev": true, "license": "MIT", "optional": true, "os": [ @@ -5088,6 +5098,7 @@ "cpu": [ "arm64" ], + "dev": true, "license": "MIT", "optional": true, "os": [ @@ -5104,6 +5115,7 @@ "cpu": [ "ia32" ], + "dev": true, "license": "MIT", "optional": true, "os": [ @@ -5120,6 +5132,7 @@ "cpu": [ "loong64" ], + "dev": true, "license": "MIT", "optional": true, "os": [ @@ -5136,6 +5149,7 @@ "cpu": [ "mips64el" ], + "dev": true, "license": "MIT", "optional": true, "os": [ @@ -5152,6 +5166,7 @@ "cpu": [ "ppc64" ], + "dev": true, "license": "MIT", "optional": true, "os": [ @@ -5168,6 +5183,7 @@ "cpu": [ "riscv64" ], + "dev": true, "license": "MIT", "optional": true, "os": [ @@ -5184,6 +5200,7 @@ "cpu": [ "s390x" ], + "dev": true, "license": "MIT", "optional": true, "os": [ @@ -5200,6 +5217,7 @@ "cpu": [ "x64" ], + "dev": true, "license": "MIT", "optional": true, "os": [ @@ -5216,6 +5234,7 @@ "cpu": [ "arm64" ], + "dev": true, "license": "MIT", "optional": true, "os": [ @@ -5232,6 +5251,7 @@ "cpu": [ "x64" ], + "dev": true, "license": "MIT", "optional": true, "os": [ @@ -5248,6 +5268,7 @@ "cpu": [ "arm64" ], + "dev": true, "license": "MIT", "optional": true, "os": [ @@ -5264,6 +5285,7 @@ "cpu": [ "x64" ], + "dev": true, "license": "MIT", "optional": true, "os": [ @@ -5280,6 +5302,7 @@ "cpu": [ "arm64" ], + "dev": true, "license": "MIT", "optional": true, "os": [ @@ -5296,6 +5319,7 @@ "cpu": [ "x64" ], + "dev": true, "license": "MIT", "optional": true, "os": [ @@ -5312,6 +5336,7 @@ "cpu": [ "arm64" ], + "dev": true, "license": "MIT", "optional": true, "os": [ @@ -5328,6 +5353,7 @@ "cpu": [ "ia32" ], + "dev": true, "license": "MIT", "optional": true, "os": [ @@ -5344,6 +5370,7 @@ "cpu": [ "x64" ], + "dev": true, "license": "MIT", "optional": true, "os": [ @@ -6137,6 +6164,7 @@ }, "node_modules/get-tsconfig": { "version": "4.14.0", + "dev": true, "license": "MIT", "dependencies": { "resolve-pkg-maps": "^1.0.0" @@ -7436,6 +7464,7 @@ }, "node_modules/resolve-pkg-maps": { "version": "1.0.0", + "dev": true, "license": "MIT", "funding": { "url": "https://github.com/privatenumber/resolve-pkg-maps?sponsor=1" @@ -8058,6 +8087,7 @@ }, "node_modules/tsx": { "version": "4.21.0", + "dev": true, "license": "MIT", "dependencies": { "esbuild": "~0.27.0", @@ -8075,11 +8105,13 @@ }, "node_modules/tsx/node_modules/fsevents": { "version": "2.3.3", + "dev": true, "license": "MIT", "optional": true, "os": [ "darwin" ], + "peer": true, "engines": { "node": "^8.16.0 || ^10.6.0 || >=11.0.0" } @@ -8596,11 +8628,10 @@ "name": "@strands-agents/strandly", "version": "0.0.1", "dependencies": { - "commander": "^14", - "tsx": "^4.21.0" + "commander": "^14" }, "bin": { - "strandly": "src/cli.ts" + "strandly": "dist/cli.js" }, "devDependencies": { "@types/node": "^22", diff --git a/package.json b/package.json index c50c43d928..1fa4498851 100644 --- a/package.json +++ b/package.json @@ -13,7 +13,7 @@ }, "scripts": { "dev": "strandly", - "prepare": "husky && npm run build", + "prepare": "husky", "build": "npm run build -w strands-ts", "test": "npm run test -w strands-ts", "test:coverage": "npm run test:coverage -w strands-ts", diff --git a/pyproject.toml b/pyproject.toml index 81c1aef125..c8ba98c74e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,8 +35,8 @@ cache_dir = ".pytest_cache" [tool.ruff] line-length = 120 include = ["strands-py-wasm/src/**/*.py", "strandly/scripts/**/*.py"] -# _generated.py is machine-written; neither lint nor format enforce rules on it. -extend-exclude = ["strands-py-wasm/src/strands/_generated.py"] +# _generated/ is machine-written; neither lint nor format enforce rules on it. +extend-exclude = ["strands-py-wasm/src/strands/_generated/"] [tool.ruff.lint] select = [ diff --git a/strandly/package.json b/strandly/package.json index 27baf2160b..25fc7fcb45 100644 --- a/strandly/package.json +++ b/strandly/package.json @@ -4,14 +4,18 @@ "private": true, "type": "module", "bin": { - "strandly": "./src/cli.ts" + "strandly": "./dist/cli.js" }, + "files": [ + "dist" + ], "scripts": { + "build": "tsc", + "prepare": "tsc", "type-check": "tsc --noEmit" }, "dependencies": { - "commander": "^14", - "tsx": "^4.21.0" + "commander": "^14" }, "devDependencies": { "@types/node": "^22", diff --git a/strandly/src/cli.ts b/strandly/src/cli.ts index 843886e7de..181b19726d 100755 --- a/strandly/src/cli.ts +++ b/strandly/src/cli.ts @@ -1,4 +1,4 @@ -#!/usr/bin/env tsx +#!/usr/bin/env node import { execSync } from 'node:child_process' import { existsSync, readdirSync, readFileSync, writeFileSync } from 'node:fs' @@ -210,9 +210,14 @@ function test(opts?: { py?: boolean; ts?: boolean; file?: string }): void { function check(opts?: { ts?: boolean; wasm?: boolean; py?: boolean }): void { const all = !opts?.ts && !opts?.wasm && !opts?.py - if (all || opts?.py) py('ruff check src/strands') + if (all || opts?.py) { + py('ruff check src/strands') + py('ruff format --check src/strands') + py('pyright src/strands') + } if (all || opts?.ts) run('npm run type-check -w strands-ts') if (all || opts?.wasm) run('npm run type-check -w strands-wasm') + if (all || opts?.py || opts?.wasm) generate({ check: true }) } function fmt(opts?: { check?: boolean }): void { @@ -242,7 +247,10 @@ function generate(opts?: { check?: boolean }): void { } // Generate Python types from WIT via wasmtime-py's component bindgen. - run(`${VENV}/bin/python -m wasmtime.component.bindgen wit -o strands-py-wasm/src/strands/_generated.py`) + // Output is a package: one module per WIT interface plus an __init__.py. + // The hand-written ``strands.types`` module re-exports the curated public + // subset from this private ``_generated`` package. + run(`${VENV}/bin/python -m wasmtime.component.bindgen wit -o strands-py-wasm/src/strands/_generated`) // Ensure TS + WASM are built first. if (!existsSync(join(ROOT, 'strands-wasm/dist/strands-agent.wasm'))) { @@ -251,12 +259,13 @@ function generate(opts?: { check?: boolean }): void { if (opts?.check) { try { - execSync('git diff --quiet -- strands-wasm/generated/ strands-ts/generated/ strands-py-wasm/src/strands/_generated.py', { - cwd: ROOT, - }) + execSync( + 'git diff --quiet -- strands-wasm/generated/ strands-ts/generated/ strands-py-wasm/src/strands/_generated/', + { cwd: ROOT } + ) } catch { console.error("error: generated files are out of date -- run 'strandly generate' and commit") - run('git diff --stat -- strands-wasm/generated/ strands-ts/generated/ strands-py-wasm/src/strands/_generated.py') + run('git diff --stat -- strands-wasm/generated/ strands-ts/generated/ strands-py-wasm/src/strands/_generated/') process.exit(1) } } diff --git a/strandly/tsconfig.json b/strandly/tsconfig.json index 6d68507881..345f8ec66a 100644 --- a/strandly/tsconfig.json +++ b/strandly/tsconfig.json @@ -5,8 +5,9 @@ "moduleResolution": "NodeNext", "types": ["node"], "strict": true, - "noEmit": true, - "skipLibCheck": true + "skipLibCheck": true, + "outDir": "dist", + "rootDir": "src" }, "include": ["src"] } diff --git a/strands-py-wasm/pyproject.toml b/strands-py-wasm/pyproject.toml index f2a025cd2b..ad895ac033 100644 --- a/strands-py-wasm/pyproject.toml +++ b/strands-py-wasm/pyproject.toml @@ -28,8 +28,8 @@ classifiers = [ "Topic :: Software Development :: Libraries :: Python Modules", ] dependencies = [ - # Imports as `wasmtime`. Pinned to pgrayy fork until upstream PRs land. - "pgrayy-wasmtime>=46.0.6,<47.0.0", + # Imports as `wasmtime`. Pinned to pgrayy/wasm-deps git URL until upstream PRs land. + "pgrayy-wasmtime @ git+https://github.com/pgrayy/wasm-deps.git@983c56c037321a50733355ea17f3a192899194dc#subdirectory=wasmtime-py", ] @@ -43,6 +43,10 @@ Homepage = "https://github.com/strands-agents/sdk-typescript" Documentation = "https://strandsagents.com" +[tool.hatch.metadata] +# pgrayy-wasmtime is pinned via a git URL until upstream PRs land. +allow-direct-references = true + [tool.hatch.build.targets.wheel] packages = ["src/strands"] diff --git a/strands-py-wasm/src/strands/__init__.py b/strands-py-wasm/src/strands/__init__.py index 424311a1a8..64fd46d8bc 100644 --- a/strands-py-wasm/src/strands/__init__.py +++ b/strands-py-wasm/src/strands/__init__.py @@ -1,14 +1,15 @@ """Strands Agents SDK Python surface. -Wire types live in :mod:`strands._generated`, auto-generated by ``bindgen``. -Those classes are accepted directly by wasmtime-py at the FFI boundary: -kebab-case record attributes, ``Variant(tag, payload)`` for tagged variants, -raw payloads for untagged ones. - -This module wraps the wire types with ergonomic Python helpers: builders -that take snake_case kwargs, factories that fill variant arms, and SDK-level -event classes that lift wire-shape ``StreamEvent`` values into a typed Python -class hierarchy so users can dispatch with ``isinstance``. +Wire types live in :mod:`strands.types` (a re-export of the machine-generated +:mod:`strands._generated` package). Records are dataclasses; variant arms are +nested ``VariantCase`` subclasses (e.g. ``StreamEvent.TextDelta``) that pass +``isinstance`` and ``match`` natively. + +This module overrides a handful of generated config dataclasses to fold in +ergonomic transforms (seconds → nanoseconds, ``**extras`` → JSON-encoded +``additional_config``, dict → list-of-pairs). Variant-arm wrapping is handled +at the boundary that consumes the value — Agent for its own slots, +SessionManager for storage, McpClientConfig for transport, etc. """ from __future__ import annotations @@ -21,8 +22,31 @@ class hierarchy so users can dispatch with ``isinstance``. from dataclasses import asdict, is_dataclass from typing import Any, Protocol, TypeVar, get_type_hints, runtime_checkable -from strands import _generated as _t -from strands._generated import * # noqa: F401,F403 re-export every generated type +from wasmtime.component import VariantCase as _WitVariantCase + +from strands import types + +# Hot-path types users import directly from ``strands``. Wire types not +# given an SDK shim below (FileStorage, S3Storage, ...) are passthroughs; +# re-exported here so ``from strands import FileStorage`` works without +# making users reach into ``strands.types``. +from strands.types import ( # noqa: F401 + AgentSkills, + ContentBlock, + ContextOffloader, + CustomStorage, + FileStorage, + Interrupt, + Metrics, + Role, + S3Storage, + SlidingWindowConversationManager, + StopReason, + StreamEvent, + SummarizingConversationManager, + ToolError, + Usage, +) class StrandsError(Exception): @@ -40,7 +64,7 @@ class ContextWindowOverflowError(_ModelErrorBase): class MaxTokensError(_ModelErrorBase): """Model stopped generating because it hit the max-tokens budget.""" - def __init__(self, message: str, partial_message: _t.Message | None = None) -> None: + def __init__(self, message: str, partial_message: types.Message | None = None) -> None: super().__init__(message) self.partial_message = partial_message @@ -73,228 +97,240 @@ class SessionError(StrandsError): """Session storage read/write failed.""" -# Inputs Message / PromptInput accept. Plain strings auto-wrap as text. -_ContentInput = ( - str - | _t.TextBlock - | _t.JsonBlock - | _t.ToolUseBlock - | _t.ToolResultBlock - | _t.ReasoningBlock - | _t.CachePointBlock - | _t.ImageBlock - | _t.VideoBlock - | _t.DocumentBlock - | _t.CitationsBlock - | _t.InterruptResponseBlock - | Any # already-built ContentBlock variants are passed through -) +def _extras_to_json(extras: dict[str, Any] | None) -> str | None: + return json.dumps(extras) if extras else None + + +_CONTENT_ARM_BY_TYPE: dict[type, type] = { + types.TextBlock: types.ContentBlock.Text, + types.JsonBlock: types.ContentBlock.Json, + types.ToolUseBlock: types.ContentBlock.ToolUse, + types.ToolResultBlock: types.ContentBlock.ToolResult, + types.ReasoningBlock: types.ContentBlock.Reasoning, + types.CachePointBlock: types.ContentBlock.CachePoint, + types.ImageBlock: types.ContentBlock.Image, + types.VideoBlock: types.ContentBlock.Video, + types.DocumentBlock: types.ContentBlock.Document, + types.CitationsBlock: types.ContentBlock.Citations, + types.InterruptResponseBlock: types.ContentBlock.InterruptResponse, +} def _as_content_block(item: Any) -> Any: """Wrap any accepted content shape as a ``ContentBlock`` variant arm.""" if isinstance(item, str): - return _t.ContentBlock_Text(_t.TextBlock(text=item)) - if isinstance(item, _t.TextBlock): - return _t.ContentBlock_Text(item) - if isinstance(item, _t.JsonBlock): - return _t.ContentBlock_Json(item) - if isinstance(item, _t.ToolUseBlock): - return _t.ContentBlock_ToolUse(item) - if isinstance(item, _t.ToolResultBlock): - return _t.ContentBlock_ToolResult(item) - if isinstance(item, _t.ReasoningBlock): - return _t.ContentBlock_Reasoning(item) - if isinstance(item, _t.CachePointBlock): - return _t.ContentBlock_CachePoint(item) - if isinstance(item, _t.ImageBlock): - return _t.ContentBlock_Image(item) - if isinstance(item, _t.VideoBlock): - return _t.ContentBlock_Video(item) - if isinstance(item, _t.DocumentBlock): - return _t.ContentBlock_Document(item) - if isinstance(item, _t.CitationsBlock): - return _t.ContentBlock_Citations(item) - if isinstance(item, _t.InterruptResponseBlock): - return _t.ContentBlock_InterruptResponse(item) - return item # already a ContentBlock variant - - -class ImageBlock(_t.ImageBlock): - def __init__( - self, - *, - format: str, - bytes: bytes | None = None, - url: str | None = None, - s3: _t.S3Location | None = None, - ) -> None: - provided = [x for x in (bytes, url, s3) if x is not None] - if len(provided) != 1: - raise ValueError("ImageBlock requires exactly one of bytes, url, or s3") - if bytes is not None: - source = _t.ImageSource_Bytes(bytes) - elif url is not None: - source = _t.ImageSource_Url(url) - else: - assert s3 is not None - source = _t.ImageSource_S3(s3) - super().__init__(format=format, source=source) + return types.ContentBlock.Text(types.TextBlock(text=item)) + for block_type, arm in _CONTENT_ARM_BY_TYPE.items(): + if isinstance(item, block_type): + return arm(item) + return item # already a ContentBlock variant arm + + +_MODEL_ARM_BY_TYPE: dict[type, type] = { + types.BedrockModel: types.ModelConfig.Bedrock, + types.AnthropicModel: types.ModelConfig.Anthropic, + types.OpenaiModel: types.ModelConfig.Openai, + types.GoogleModel: types.ModelConfig.Gemini, + types.CustomModel: types.ModelConfig.Custom, +} + +_CM_ARM_BY_TYPE: dict[type, type] = { + types.SlidingWindowConversationManager: types.ConversationManagerConfig.SlidingWindow, + types.SummarizingConversationManager: types.ConversationManagerConfig.Summarizing, +} + +_VENDED_TOOL_ARM_BY_TYPE: dict[type, type] = { + types.BashTool: types.VendedTool.Bash, + types.FileEditorTool: types.VendedTool.FileEditor, + types.HttpRequestTool: types.VendedTool.HttpRequest, + types.NotebookTool: types.VendedTool.Notebook, +} + +_VENDED_PLUGIN_ARM_BY_TYPE: dict[type, type] = { + types.AgentSkills: types.VendedPlugin.Skills, + types.ContextOffloader: types.VendedPlugin.ContextOffloader, +} + + +def _wrap(value: Any, arm_table: dict[type, type]) -> Any: + """Wrap ``value`` in the variant arm whose payload type matches its MRO. + + Walks the MRO so SDK ergonomic subclasses (``BedrockModel`` extends + ``types.BedrockModel``) hit the same arm as the raw bindgen type. Returns + the value unchanged if it's already an arm (so passing a fully-constructed + ``ModelConfig.Bedrock(...)`` is idempotent) or doesn't match anything. + """ + if value is None or isinstance(value, _WitVariantCase): + return value + for cls in type(value).__mro__: + arm = arm_table.get(cls) + if arm is not None: + return arm(value) + return value -class VideoBlock(_t.VideoBlock): +class Message(types.Message): def __init__( self, *, - format: str, - bytes: bytes | None = None, - s3: _t.S3Location | None = None, + role: types.Role, + content: Iterable[Any], + metadata: types.MessageMetadata | None = None, ) -> None: - if (bytes is None) == (s3 is None): - raise ValueError("VideoBlock requires exactly one of bytes or s3") - source = _t.VideoSource_Bytes(bytes) if bytes is not None else _t.VideoSource_S3(s3) - super().__init__(format=format, source=source) + super().__init__( + role=role, + content=[_as_content_block(c) for c in content], + metadata=metadata, + ) + + @classmethod + def user(cls, *content: Any, metadata: types.MessageMetadata | None = None) -> Message: + return cls(role=types.Role.USER, content=content, metadata=metadata) + @classmethod + def assistant(cls, *content: Any, metadata: types.MessageMetadata | None = None) -> Message: + return cls(role=types.Role.ASSISTANT, content=content, metadata=metadata) -class DocumentBlock(_t.DocumentBlock): + +class BedrockModel(types.BedrockModel): def __init__( self, + model_id: str = "us.anthropic.claude-opus-4-7-v1:0", *, - name: str, - format: str, - bytes: bytes | None = None, - text: str | None = None, - content: list[_t.TextBlock] | None = None, - s3: _t.S3Location | None = None, - citations: bool = False, - context: str | None = None, + region: str | None = None, + access_key_id: str | None = None, + secret_access_key: str | None = None, + session_token: str | None = None, + **extras: Any, ) -> None: - provided = [x for x in (bytes, text, content, s3) if x is not None] - if len(provided) != 1: - raise ValueError("DocumentBlock requires exactly one of bytes, text, content, or s3") - if bytes is not None: - source = _t.DocumentSource_Bytes(bytes) - elif text is not None: - source = _t.DocumentSource_Text(text) - elif content is not None: - source = _t.DocumentSource_Content(content) - else: - assert s3 is not None - source = _t.DocumentSource_S3(s3) + # The wasm bundle links the AWS SDK browser build, which has no credential + # chain. Resolve via botocore so users get the same behavior they'd get + # from any other Python AWS app (env vars, ~/.aws, SSO, IMDS, etc.). + if access_key_id is None and secret_access_key is None: + try: + import botocore.session + + creds = botocore.session.Session().get_credentials() + if creds is not None: + frozen = creds.get_frozen_credentials() + access_key_id = frozen.access_key + secret_access_key = frozen.secret_key + session_token = frozen.token + except ImportError: + pass + super().__init__( - name=name, - format=format, - source=source, - citations=_t.DocumentCitationsConfig(enabled=citations) if citations else None, - context=context, + model_id=model_id, + region=region, + access_key_id=access_key_id, + secret_access_key=secret_access_key, + session_token=session_token, + additional_config=_extras_to_json(extras), ) -class InterruptResponseBlock(_t.InterruptResponseBlock): - def __init__(self, *, interrupt_id: str, response: Any) -> None: - payload = response if isinstance(response, str) else json.dumps(response) - super().__init__(interrupt_id=interrupt_id, response=payload) +class AnthropicModel(types.AnthropicModel): + def __init__(self, model_id: str | None = None, *, api_key: str | None = None, **extras: Any) -> None: + super().__init__(model_id=model_id, api_key=api_key, additional_config=_extras_to_json(extras)) -class Message(_t.Message): +class OpenaiModel(types.OpenaiModel): + def __init__(self, model_id: str | None = None, *, api_key: str | None = None, **extras: Any) -> None: + super().__init__(model_id=model_id, api_key=api_key, additional_config=_extras_to_json(extras)) + + +class GoogleModel(types.GoogleModel): + def __init__(self, model_id: str | None = None, *, api_key: str | None = None, **extras: Any) -> None: + super().__init__(model_id=model_id, api_key=api_key, additional_config=_extras_to_json(extras)) + + +class CustomModel(types.CustomModel): def __init__( self, + provider_id: str, *, - role: _t.Role, - content: Iterable[Any], - metadata: _t.MessageMetadata | None = None, + model_id: str | None = None, + stateful: bool = False, + **extras: Any, ) -> None: super().__init__( - role=role, - content=[_as_content_block(c) for c in content], - metadata=metadata, + provider_id=provider_id, + model_id=model_id, + additional_config=_extras_to_json(extras), + stateful=stateful, ) - @classmethod - def user(cls, *content: Any, metadata: _t.MessageMetadata | None = None) -> Message: - return cls(role=_t.Role.USER, content=content, metadata=metadata) - @classmethod - def assistant(cls, *content: Any, metadata: _t.MessageMetadata | None = None) -> Message: - return cls(role=_t.Role.ASSISTANT, content=content, metadata=metadata) +def _json_default(obj: Any) -> Any: + if is_dataclass(obj) and not isinstance(obj, type): + return asdict(obj) + if hasattr(obj, "__dict__"): + return {k: v for k, v in obj.__dict__.items() if not k.startswith("_")} + raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable") -def _extras_to_json(extras: dict[str, Any] | None) -> str | None: - return json.dumps(extras) if extras else None +def agent_node(*, id: str, agent_config: Any, description: str | None, timeout: int | None) -> types.AgentNode: + """Build an ``AgentNode`` with a JSON-encoded nested agent config. + ``agent_config`` may be a string (passed through), a dataclass, or any + object with a ``__dict__``; serialization happens here so call sites + don't import :mod:`json`. + """ + encoded = agent_config if isinstance(agent_config, str) else json.dumps(agent_config, default=_json_default) + return types.AgentNode(id=id, description=description, timeout=timeout, agent_config=encoded) -def BedrockModel( - model_id: str = "us.anthropic.claude-opus-4-7-v1:0", - *, - region: str | None = None, - access_key_id: str | None = None, - secret_access_key: str | None = None, - session_token: str | None = None, - **extras: Any, -) -> Any: - """Build a Bedrock ``model-config`` value.""" - # The wasm bundle links the AWS SDK browser build, which has no credential - # chain. Resolve via botocore so users get the same behavior they'd get - # from any other Python AWS app (env vars, ~/.aws, SSO, IMDS, etc.). - if access_key_id is None and secret_access_key is None: - try: - import botocore.session - - creds = botocore.session.Session().get_credentials() - if creds is not None: - frozen = creds.get_frozen_credentials() - access_key_id = frozen.access_key - secret_access_key = frozen.secret_key - session_token = frozen.token - except ImportError: - pass - return _t.ModelConfig_Bedrock( - _t.BedrockConfig( - model_id=model_id, - region=region, - access_key_id=access_key_id, - secret_access_key=secret_access_key, - session_token=session_token, - additional_config=_extras_to_json(extras), - ) - ) +def multi_agent_node(*, id: str, orchestrator: Any, description: str | None) -> types.MultiAgentNode: + """Build a ``MultiAgentNode`` with a JSON-encoded nested orchestrator.""" + encoded = orchestrator if isinstance(orchestrator, str) else json.dumps(orchestrator, default=_json_default) + return types.MultiAgentNode(id=id, description=description, orchestrator=encoded) -def AnthropicModel(model_id: str | None = None, *, api_key: str | None = None, **extras: Any) -> Any: - return _t.ModelConfig_Anthropic( - _t.AnthropicConfig(model_id=model_id, api_key=api_key, additional_config=_extras_to_json(extras)) - ) +class StdioTransport(types.StdioTransport): + """``StdioTransport`` with ``dict[str, str]`` env shorthand.""" + + def __init__( + self, + *, + command: str, + args: list[str], + env: dict[str, str], + cwd: str | None, + ) -> None: + super().__init__( + command=command, + args=args, + env=[types.EnvVar(key=k, value=v) for k, v in env.items()], + cwd=cwd, + ) -def OpenAIModel(model_id: str | None = None, *, api_key: str | None = None, **extras: Any) -> Any: - return _t.ModelConfig_Openai( - _t.OpenaiConfig(model_id=model_id, api_key=api_key, additional_config=_extras_to_json(extras)) - ) +class HttpTransport(types.HttpTransport): + """``HttpTransport`` with ``dict[str, str]`` headers shorthand.""" + def __init__(self, *, url: str, headers: dict[str, str]) -> None: + super().__init__( + url=url, + headers=[types.HttpHeader(name=k, value=v) for k, v in headers.items()], + ) -def GoogleModel(model_id: str | None = None, *, api_key: str | None = None, **extras: Any) -> Any: - return _t.ModelConfig_Gemini( - _t.GeminiConfig(model_id=model_id, api_key=api_key, additional_config=_extras_to_json(extras)) - ) +class SseTransport(types.SseTransport): + """``SseTransport`` with ``dict[str, str]`` headers shorthand.""" -def CustomModel( - provider_id: str, - *, - model_id: str | None = None, - stateful: bool = False, - **extras: Any, -) -> Any: - """Host-implemented provider. Pair with a ``model-provider`` callback.""" - return _t.ModelConfig_Custom( - _t.CustomModelConfig( - provider_id=provider_id, - model_id=model_id, - additional_config=_extras_to_json(extras), - stateful=stateful, + def __init__(self, *, url: str, headers: dict[str, str]) -> None: + super().__init__( + url=url, + headers=[types.HttpHeader(name=k, value=v) for k, v in headers.items()], ) - ) + + +class InterruptResponse(types.RespondArgs): + """Reply to a paused agent via ``response-stream.respond``.""" + + def __init__(self, *, interrupt_id: str, response: Any) -> None: + payload = response if isinstance(response, str) else json.dumps(response) + super().__init__(interrupt_id=interrupt_id, response=payload) class PydanticTool: @@ -316,8 +352,8 @@ def __init__( self.input_schema = input_model.model_json_schema() self.func = func - def to_spec(self) -> _t.ToolSpec: - return _t.ToolSpec( + def to_spec(self) -> types.ToolSpec: + return types.ToolSpec( name=self.name, description=self.description, input_schema=json.dumps(self.input_schema), @@ -345,8 +381,8 @@ def __init__( self.input_schema = input_schema self.func = func - def to_spec(self) -> _t.ToolSpec: - return _t.ToolSpec( + def to_spec(self) -> types.ToolSpec: + return types.ToolSpec( name=self.name, description=self.description, input_schema=json.dumps(self.input_schema), @@ -359,18 +395,18 @@ def invoke(self, raw_input: str) -> list[Any]: def _coerce_tool_result(result: Any) -> list[Any]: if isinstance(result, str): - return [_t.ToolResultContent_Text(_t.TextBlock(text=result))] - if isinstance(result, _t.TextBlock): - return [_t.ToolResultContent_Text(result)] - if isinstance(result, _t.JsonBlock): - return [_t.ToolResultContent_Json(result)] + return [types.ToolResultContent.Text(types.TextBlock(text=result))] + if isinstance(result, types.TextBlock): + return [types.ToolResultContent.Text(result)] + if isinstance(result, types.JsonBlock): + return [types.ToolResultContent.Json(result)] if isinstance(result, dict): - return [_t.ToolResultContent_Json(_t.JsonBlock(json=json.dumps(result)))] + return [types.ToolResultContent.Json(types.JsonBlock(json=json.dumps(result)))] if is_dataclass(result) and not isinstance(result, type): - return [_t.ToolResultContent_Json(_t.JsonBlock(json=json.dumps(asdict(result))))] + return [types.ToolResultContent.Json(types.JsonBlock(json=json.dumps(asdict(result))))] if isinstance(result, list): return result - return [_t.ToolResultContent_Text(_t.TextBlock(text=str(result)))] + return [types.ToolResultContent.Text(types.TextBlock(text=str(result)))] def _py_type_to_schema(py_type: Any) -> dict[str, Any]: @@ -442,338 +478,6 @@ def wrap(f: Callable[..., Any]) -> Tool: return wrap(func) if func is not None else wrap -def NullConversationManager() -> Any: - """No management. History grows without bound.""" - return _t.ConversationManagerConfig_None() - - -def SlidingWindowConversationManager( - *, window_size: int = 40, should_truncate_results: bool = True -) -> Any: - return _t.ConversationManagerConfig_SlidingWindow( - _t.SlidingWindowConfig(window_size=window_size, should_truncate_results=should_truncate_results) - ) - - -def SummarizingConversationManager( - *, - summary_ratio: float = 0.3, - preserve_recent_messages: int = 10, - summarization_system_prompt: str | None = None, - summarization_model: Any = None, -) -> Any: - return _t.ConversationManagerConfig_Summarizing( - _t.SummarizingConfig( - summary_ratio=summary_ratio, - preserve_recent_messages=preserve_recent_messages, - summarization_system_prompt=summarization_system_prompt, - summarization_model=summarization_model, - ) - ) - - -def _seconds_to_ns(seconds: float) -> int: - return int(seconds * 1_000_000_000) - - -def _optional_ns(seconds: float | None) -> int | None: - return None if seconds is None else _seconds_to_ns(seconds) - - -def ConstantBackoff(*, delay: float = 1.0) -> Any: - return _t.BackoffStrategy_Constant(_t.ConstantBackoffConfig(delay=_seconds_to_ns(delay))) - - -def LinearBackoff( - *, - base: float = 1.0, - max: float = 30.0, - jitter: _t.JitterKind = _t.JitterKind.FULL, -) -> Any: - return _t.BackoffStrategy_Linear( - _t.LinearBackoffConfig(base=_seconds_to_ns(base), max=_seconds_to_ns(max), jitter=jitter) - ) - - -def ExponentialBackoff( - *, - base: float = 1.0, - max: float = 30.0, - factor: float = 2.0, - jitter: _t.JitterKind = _t.JitterKind.FULL, -) -> Any: - return _t.BackoffStrategy_Exponential( - _t.ExponentialBackoffConfig( - base=_seconds_to_ns(base), - max=_seconds_to_ns(max), - factor=factor, - jitter=jitter, - ) - ) - - -class ModelRetryStrategy(_t.ModelRetryStrategy): - def __init__( - self, - *, - max_attempts: int = 6, - backoff: Any = None, - total_budget: float | None = None, - ) -> None: - super().__init__( - max_attempts=max_attempts, - backoff=backoff if backoff is not None else ExponentialBackoff(), - total_budget=_optional_ns(total_budget), - ) - - -def FileStorage(base_dir: str) -> Any: - return _t.StorageConfig_File(_t.FileStorageConfig(base_dir=base_dir)) - - -def S3Storage(*, bucket: str, region: str | None = None, prefix: str | None = None) -> Any: - return _t.StorageConfig_S3(_t.S3StorageConfig(bucket=bucket, region=region, prefix=prefix)) - - -def CustomStorage(backend_id: str) -> Any: - """Host-implemented backend. Pair with a ``snapshot-storage`` handler.""" - return _t.StorageConfig_Custom(_t.CustomStorageConfig(backend_id=backend_id)) - - -class SessionManager(_t.SessionConfig): - """Attach session persistence to an agent.""" - - def __init__( - self, - *, - session_id: str, - storage: Any, - save_latest: Any = None, - ) -> None: - super().__init__(session_id=session_id, storage=storage, save_latest=save_latest) - - -def _coerce_nested_config(value: Any) -> str: - if isinstance(value, str): - return value - return json.dumps(value, default=_json_default) - - -def _json_default(obj: Any) -> Any: - if is_dataclass(obj) and not isinstance(obj, type): - return asdict(obj) - if hasattr(obj, "__dict__"): - return {k: v for k, v in obj.__dict__.items() if not k.startswith("_")} - raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable") - - -def AgentNode( - *, - id: str, - agent_config: Any, - description: str | None = None, - timeout: float | None = None, -) -> Any: - return _t.NodeConfig_Agent( - _t.AgentNodeConfig( - id=id, - description=description, - timeout=_optional_ns(timeout), - agent_config=_coerce_nested_config(agent_config), - ) - ) - - -def MultiAgentNode(*, id: str, orchestrator: Any, description: str | None = None) -> Any: - return _t.NodeConfig_MultiAgent( - _t.MultiAgentNodeConfig( - id=id, - description=description, - orchestrator=_coerce_nested_config(orchestrator), - ) - ) - - -class Graph(_t.GraphConfig): - def __init__( - self, - *, - id: str, - nodes: list[Any], - edges: list[Any] | None = None, - sources: list[str] | None = None, - max_concurrency: int | None = None, - max_steps: int | None = None, - timeout: float | None = None, - node_timeout: float | None = None, - ) -> None: - super().__init__( - id=id, - nodes=nodes, - edges=edges or [], - sources=sources or [], - max_concurrency=max_concurrency, - max_steps=max_steps, - timeout=_optional_ns(timeout), - node_timeout=_optional_ns(node_timeout), - ) - - -class Swarm(_t.SwarmConfig): - def __init__( - self, - *, - id: str, - nodes: list[Any], - start_node_id: str, - max_steps: int | None = None, - timeout: float | None = None, - node_timeout: float | None = None, - ) -> None: - super().__init__( - id=id, - nodes=nodes, - start_node_id=start_node_id, - max_steps=max_steps, - timeout=_optional_ns(timeout), - node_timeout=_optional_ns(node_timeout), - ) - - -def BashTool(*, default_timeout: int | None = None) -> Any: - return _t.VendedTool_Bash(_t.BashToolConfig(default_timeout_s=default_timeout)) - - -def FileEditorTool(*, workspace_root: str | None = None) -> Any: - return _t.VendedTool_FileEditor(_t.FileEditorToolConfig(workspace_root=workspace_root)) - - -def HttpRequestTool(*, allowed_hosts: list[str] | None = None, max_response_bytes: int = 0) -> Any: - return _t.VendedTool_HttpRequest( - _t.HttpRequestToolConfig( - allowed_hosts=allowed_hosts or [], - max_response_bytes=max_response_bytes, - ) - ) - - -def NotebookTool(*, workspace_root: str | None = None) -> Any: - return _t.VendedTool_Notebook(_t.NotebookToolConfig(workspace_root=workspace_root)) - - -def SkillsPlugin( - *, - skills: list[str], - strict: bool = False, - max_resource_files: int | None = None, - state_key: str | None = None, -) -> Any: - return _t.VendedPlugin_Skills( - _t.SkillsPluginConfig( - skills=[_t.SkillSource(path=p) for p in skills], - strict=strict, - max_resource_files=max_resource_files, - state_key=state_key, - ) - ) - - -def ContextOffloaderPlugin( - *, - max_result_tokens: int | None = None, - preview_tokens: int | None = None, - include_retrieval_tool: bool = True, -) -> Any: - return _t.VendedPlugin_ContextOffloader( - _t.ContextOffloaderPluginConfig( - max_result_tokens=max_result_tokens, - preview_tokens=preview_tokens, - include_retrieval_tool=include_retrieval_tool, - ) - ) - - -def StdioMcpTransport( - *, - command: str, - args: list[str] | None = None, - env: dict[str, str] | None = None, - cwd: str | None = None, -) -> Any: - """Launch an MCP server as a subprocess and talk to it over stdio.""" - return _t.McpTransport_Stdio( - _t.StdioTransportConfig( - command=command, - args=args or [], - env=[_t.EnvVar(key=k, value=v) for k, v in (env or {}).items()], - cwd=cwd, - ) - ) - - -def StreamableHttpMcpTransport(*, url: str, headers: dict[str, str] | None = None) -> Any: - """Talk to a hosted MCP server over streamable HTTP.""" - return _t.McpTransport_StreamableHttp( - _t.HttpTransportConfig( - url=url, - headers=[_t.HttpHeader(name=k, value=v) for k, v in (headers or {}).items()], - ) - ) - - -def SseMcpTransport(*, url: str, headers: dict[str, str] | None = None) -> Any: - """Legacy SSE transport.""" - return _t.McpTransport_Sse( - _t.SseTransportConfig( - url=url, - headers=[_t.HttpHeader(name=k, value=v) for k, v in (headers or {}).items()], - ) - ) - - -class McpClient(_t.McpClientConfig): - """Declare an MCP client the host should open and route tools from.""" - - def __init__( - self, - *, - client_id: str, - transport: Any, - application_name: str | None = None, - application_version: str | None = None, - tasks_ttl: float | None = None, - tasks_poll_timeout: float | None = None, - elicitation_enabled: bool = False, - fail_open: bool = False, - disable_instrumentation: bool = False, - ) -> None: - tasks = None - if tasks_ttl is not None or tasks_poll_timeout is not None: - tasks = _t.TasksConfig( - ttl=_seconds_to_ns(tasks_ttl if tasks_ttl is not None else 60.0), - poll_timeout=_seconds_to_ns(tasks_poll_timeout if tasks_poll_timeout is not None else 300.0), - ) - super().__init__( - client_id=client_id, - application_name=application_name, - application_version=application_version, - transport=transport, - tasks_config=tasks, - elicitation_enabled=elicitation_enabled, - fail_open=fail_open, - disable_instrumentation=disable_instrumentation, - ) - - -class InterruptResponse(_t.RespondArgs): - """Reply to a paused agent via ``response-stream.respond``.""" - - def __init__(self, *, interrupt_id: str, response: Any) -> None: - payload = response if isinstance(response, str) else json.dumps(response) - super().__init__(interrupt_id=interrupt_id, response=payload) - - - _ToolInput = Tool | PydanticTool | Callable[..., Any] _ToolChoiceInput = Any # tagged ToolChoice variant or "name" shorthand or None @@ -787,11 +491,14 @@ def _coerce_tool(item: _ToolInput) -> Tool | PydanticTool: def _coerce_prompt(value: Any) -> Any: - """Coerce a string, content blocks, or pre-built value to a ``prompt-input``.""" + """Coerce a string or iterable of content blocks to a ``prompt-input``.""" if isinstance(value, str): return value - if isinstance(value, list): - return [_as_content_block(c) for c in value] + if isinstance(value, types.Message): + raise TypeError( + "stream_async/invoke take content blocks as input, not Messages. " + "Pass conversation history via Agent(messages=[...]) instead." + ) if hasattr(value, "__iter__") and not isinstance(value, (bytes, str)): return [_as_content_block(c) for c in value] return value @@ -801,7 +508,7 @@ def _coerce_tool_choice(value: _ToolChoiceInput) -> Any: if value is None: return None if isinstance(value, str): - return _t.ToolChoice_Named(value) + return types.ToolChoice.Named(value) return value @@ -811,24 +518,24 @@ class Agent: def __init__( self, *, - model: Any = None, - messages: list[Any] | None = None, - system_prompt: Any = None, + model: types.ModelInput | None = None, + messages: list[types.Message] | None = None, + system_prompt: str | list[Any] | None = None, tools: list[_ToolInput] | None = None, - agent_tools: list[Any] | None = None, - vended_tools: list[Any] | None = None, - vended_plugins: list[Any] | None = None, - mcp_clients: list[Any] | None = None, + agent_tools: list[types.AgentAsToolConfig] | None = None, + vended_tools: list[types.VendedToolInput] | None = None, + vended_plugins: list[types.VendedPluginInput] | None = None, + mcp_clients: list[types.McpClientConfig] | None = None, name: str | None = None, id: str | None = None, description: str | None = None, - tool_executor: Any = None, + tool_executor: types.ToolExecutorStrategy | None = None, display_output: bool | None = None, - trace_attributes: list[Any] | None = None, - trace_context: Any = None, - session: Any = None, - conversation_manager: Any = None, - retry: Any = None, + trace_attributes: list[types.TraceAttribute] | None = None, + trace_context: types.TraceContext | None = None, + session: types.SessionManager | None = None, + conversation_manager: types.ConversationManagerInput | None = None, + retry: types.RetryConfig | None = None, structured_output_schema: str | None = None, app_state: dict[str, Any] | None = None, model_state: dict[str, Any] | None = None, @@ -836,17 +543,22 @@ def __init__( self._tools: list[Tool | PydanticTool] = [_coerce_tool(t) for t in (tools or [])] identity = None if name is not None or id is not None or description is not None: - identity = _t.AgentIdentity(name=name, id=id, description=description) + identity = types.AgentIdentity(name=name, id=id, description=description) + + wrapped_vended_tools = [_wrap(v, _VENDED_TOOL_ARM_BY_TYPE) for v in vended_tools] if vended_tools else None + wrapped_vended_plugins = ( + [_wrap(p, _VENDED_PLUGIN_ARM_BY_TYPE) for p in vended_plugins] if vended_plugins else None + ) - self._config = _t.AgentConfig( - model=model, + self._config = types.AgentConfig( + model=_wrap(model, _MODEL_ARM_BY_TYPE), model_params=None, messages=messages, system_prompt=(_coerce_prompt(system_prompt) if system_prompt is not None else None), tools=[t.to_spec() for t in self._tools] or None, agent_tools=agent_tools, - vended_tools=vended_tools, - vended_plugins=vended_plugins, + vended_tools=wrapped_vended_tools, + vended_plugins=wrapped_vended_plugins, mcp_clients=mcp_clients, identity=identity, tool_executor=tool_executor, @@ -854,7 +566,7 @@ def __init__( trace_attributes=trace_attributes, trace_context=trace_context, session=session, - conversation_manager=conversation_manager, + conversation_manager=_wrap(conversation_manager, _CM_ARM_BY_TYPE), retry=retry, structured_output_schema=structured_output_schema, app_state=json.dumps(app_state) if app_state else None, @@ -863,7 +575,7 @@ def __init__( self._runtime: Any = None @property - def config(self) -> _t.AgentConfig: + def config(self) -> types.AgentConfig: return self._config def _ensure_runtime(self) -> Any: @@ -890,9 +602,9 @@ def _build_invoke_args( tools: list[_ToolInput] | None, tool_choice: _ToolChoiceInput, structured_output_schema: str | None, - ) -> _t.InvokeArgs: + ) -> types.InvokeArgs: extra_tools = [_coerce_tool(t).to_spec() for t in (tools or [])] or None - return _t.InvokeArgs( + return types.InvokeArgs( input=_coerce_prompt(prompt), tools=extra_tools, tool_choice=_coerce_tool_choice(tool_choice), @@ -906,7 +618,7 @@ async def stream_async( tools: list[_ToolInput] | None = None, tool_choice: _ToolChoiceInput = None, structured_output_schema: str | None = None, - ) -> AsyncIterator[_t.StreamEvent]: + ) -> AsyncIterator[types.StreamEvent]: """Yield :class:`StreamEvent` arms as the agent runs.""" runtime = await self._ensure_runtime_async() args = self._build_invoke_args(prompt, tools, tool_choice, structured_output_schema) @@ -952,8 +664,7 @@ def invoke( pass else: raise RuntimeError( - "Agent.invoke() cannot run inside an existing event loop. " - "Use 'await agent.invoke_async(...)' instead." + "Agent.invoke() cannot run inside an existing event loop. Use 'await agent.invoke_async(...)' instead." ) return asyncio.run( self.invoke_async( @@ -972,12 +683,12 @@ def cancel(self) -> None: async def respond(self, interrupt_id: str, response: Any) -> None: runtime = await self._ensure_runtime_async() payload = response if isinstance(response, str) else json.dumps(response) - await runtime.respond(_t.RespondArgs(interrupt_id=interrupt_id, response=payload)) + await runtime.respond(types.RespondArgs(interrupt_id=interrupt_id, response=payload)) - async def get_messages(self) -> list[_t.Message]: + async def get_messages(self) -> list[types.Message]: return await (await self._ensure_runtime_async()).get_messages() - async def set_messages(self, messages: list[_t.Message]) -> None: + async def set_messages(self, messages: list[types.Message]) -> None: await (await self._ensure_runtime_async()).set_messages(messages) async def get_app_state(self) -> dict[str, Any]: @@ -997,29 +708,29 @@ class _AgentResultAccumulator: """Folds the stream of events into the fields of an :class:`AgentResult`.""" def __init__(self) -> None: - self._stop: _t.StopEvent | None = None - self._last_message: _t.Message | None = None - self._interrupts: list[_t.Interrupt] = [] + self._stop: types.StopEvent | None = None + self._last_message: types.Message | None = None + self._interrupts: list[types.Interrupt] = [] - def consume(self, event: _t.StreamEvent) -> None: - if isinstance(event, _t.StreamEvent_MessageAdded): + def consume(self, event: types.StreamEvent) -> None: + if isinstance(event, types.StreamEvent.MessageAdded): self._last_message = event.value.message - elif isinstance(event, _t.StreamEvent_ModelMessage): + elif isinstance(event, types.StreamEvent.ModelMessage): self._last_message = event.value.message - elif isinstance(event, _t.StreamEvent_Stop): + elif isinstance(event, types.StreamEvent.Stop): self._stop = event.value - elif isinstance(event, _t.StreamEvent_AgentResult): + elif isinstance(event, types.StreamEvent.AgentResult): self._stop = event.value.stop - elif isinstance(event, _t.StreamEvent_Interrupt): + elif isinstance(event, types.StreamEvent.Interrupt): self._interrupts.append(event.value) def finalize(self, agent: Agent) -> AgentResult: stop = self._stop last = self._last_message if last is None: - last = _t.Message(role=_t.Role.ASSISTANT, content=[], metadata=None) + last = types.Message(role=types.Role.ASSISTANT, content=[], metadata=None) return AgentResult( - stop_reason=stop.reason if stop is not None else _t.StopReason.END_TURN, + stop_reason=stop.reason if stop is not None else types.StopReason.END_TURN, last_message=last, usage=stop.usage if stop is not None else None, metrics=None, @@ -1043,7 +754,7 @@ class HookRegistry: """Register callbacks keyed by ``StreamEvent`` arm class. Each arm of the wire ``stream-event`` variant is a distinct Python class - (``StreamEvent_TextDelta``, ``StreamEvent_Stop``, ...). Subscribers match + (``StreamEvent.TextDelta``, ``StreamEvent.Stop``, ...). Subscribers match by exact class. Callbacks for arms whose name begins with ``After`` dispatch in reverse @@ -1098,14 +809,14 @@ class AgentResult: def __init__( self, *, - stop_reason: _t.StopReason, - last_message: _t.Message, + stop_reason: types.StopReason, + last_message: types.Message, invocation_state: dict[str, Any] | None = None, - traces: list[_t.AgentTrace] | None = None, - metrics: _t.AgentMetrics | None = None, - usage: _t.Usage | None = None, + traces: list[types.AgentTrace] | None = None, + metrics: types.AgentMetrics | None = None, + usage: types.Usage | None = None, structured_output: Any = None, - interrupts: list[_t.Interrupt] | None = None, + interrupts: list[types.Interrupt] | None = None, ) -> None: self.stop_reason = stop_reason self.last_message = last_message diff --git a/strands-py-wasm/src/strands/_generated.py b/strands-py-wasm/src/strands/_generated.py deleted file mode 100644 index 624367c77b..0000000000 --- a/strands-py-wasm/src/strands/_generated.py +++ /dev/null @@ -1,5438 +0,0 @@ -"""Auto-generated by bindgen. Do not edit. - -Wire-shape Python types for the WIT world. Constructed values are accepted -directly by wasmtime-py without further marshaling — kebab-case record -attributes, ``Variant(tag, payload)`` for tagged variants, raw payloads for -untagged ones. -""" - -from __future__ import annotations - -from typing import Any, Optional, Union - -from wasmtime.component import Variant as _WitVariant -from wasmtime.component import VariantCase as _WitVariantCase - - -def ok(value: Any = None) -> _WitVariant: - """Wrap ``value`` as the ``ok`` arm of a ``result``.""" - return _WitVariant("ok", value) - - -def err(value: Any = None) -> _WitVariant: - """Wrap ``value`` as the ``err`` arm of a ``result``.""" - return _WitVariant("err", value) - - -class Error: - """A resource which represents some error information. - -The only method provided by this resource is `to-debug-string`, -which provides some human-readable information about the error. - -In the `wasi:io` package, this resource is returned through the -`wasi:io/streams/stream-error` type. - -To provide more specific error information, other interfaces may -offer functions to "downcast" this error into more specific types. For example, -errors returned from streams derived from filesystem types can be described using -the filesystem's own error-code type. This is done using the function -`wasi:filesystem/types/filesystem-error-code`, which takes a `borrow` -parameter and returns an `option`. - -The set of functions which can "downcast" an `error` into a more -concrete type is open.""" - # Wraps a wasmtime-py ResourceAny / ResourceHost handle. - # The runtime sets ._handle to the underlying resource and - # ._invoke to a callable that dispatches a method by WIT name. - - def __init__(self, handle: Any, invoke: Any = None) -> None: - self._handle = handle - self._invoke = invoke - - def to_debug_string(self) -> str: - return self._invoke('[method]error.to-debug-string', (self._handle,)) - - -class Pollable: - """`pollable` represents a single I/O event which may be ready, or not.""" - # Wraps a wasmtime-py ResourceAny / ResourceHost handle. - # The runtime sets ._handle to the underlying resource and - # ._invoke to a callable that dispatches a method by WIT name. - - def __init__(self, handle: Any, invoke: Any = None) -> None: - self._handle = handle - self._invoke = invoke - - def ready(self) -> bool: - return self._invoke('[method]pollable.ready', (self._handle,)) - - def block(self) -> None: - return self._invoke('[method]pollable.block', (self._handle,)) - - -StreamError = Any | None -"""An error for input-stream and output-stream operations.""" - -class InputStream: - """An input bytestream. - -`input-stream`s are *non-blocking* to the extent practical on underlying -platforms. I/O operations always return promptly; if fewer bytes are -promptly available than requested, they return the number of bytes promptly -available, which could even be zero. To wait for data to be available, -use the `subscribe` function to obtain a `pollable` which can be polled -for using `wasi:io/poll`.""" - # Wraps a wasmtime-py ResourceAny / ResourceHost handle. - # The runtime sets ._handle to the underlying resource and - # ._invoke to a callable that dispatches a method by WIT name. - - def __init__(self, handle: Any, invoke: Any = None) -> None: - self._handle = handle - self._invoke = invoke - - def read(self, len: int) -> Any: - return self._invoke('[method]input-stream.read', (self._handle, len,)) - - def blocking_read(self, len: int) -> Any: - return self._invoke('[method]input-stream.blocking-read', (self._handle, len,)) - - def skip(self, len: int) -> Any: - return self._invoke('[method]input-stream.skip', (self._handle, len,)) - - def blocking_skip(self, len: int) -> Any: - return self._invoke('[method]input-stream.blocking-skip', (self._handle, len,)) - - def subscribe(self) -> Any: - return self._invoke('[method]input-stream.subscribe', (self._handle,)) - - -class OutputStream: - """An output bytestream. - -`output-stream`s are *non-blocking* to the extent practical on -underlying platforms. Except where specified otherwise, I/O operations also -always return promptly, after the number of bytes that can be written -promptly, which could even be zero. To wait for the stream to be ready to -accept data, the `subscribe` function to obtain a `pollable` which can be -polled for using `wasi:io/poll`. - -Dropping an `output-stream` while there's still an active write in -progress may result in the data being lost. Before dropping the stream, -be sure to fully flush your writes.""" - # Wraps a wasmtime-py ResourceAny / ResourceHost handle. - # The runtime sets ._handle to the underlying resource and - # ._invoke to a callable that dispatches a method by WIT name. - - def __init__(self, handle: Any, invoke: Any = None) -> None: - self._handle = handle - self._invoke = invoke - - def check_write(self) -> Any: - return self._invoke('[method]output-stream.check-write', (self._handle,)) - - def write(self, contents: bytes) -> Any: - return self._invoke('[method]output-stream.write', (self._handle, contents,)) - - def blocking_write_and_flush(self, contents: bytes) -> Any: - return self._invoke('[method]output-stream.blocking-write-and-flush', (self._handle, contents,)) - - def flush(self) -> Any: - return self._invoke('[method]output-stream.flush', (self._handle,)) - - def blocking_flush(self) -> Any: - return self._invoke('[method]output-stream.blocking-flush', (self._handle,)) - - def subscribe(self) -> Any: - return self._invoke('[method]output-stream.subscribe', (self._handle,)) - - def write_zeroes(self, len: int) -> Any: - return self._invoke('[method]output-stream.write-zeroes', (self._handle, len,)) - - def blocking_write_zeroes_and_flush(self, len: int) -> Any: - return self._invoke('[method]output-stream.blocking-write-zeroes-and-flush', (self._handle, len,)) - - def splice(self, src: Any, len: int) -> Any: - return self._invoke('[method]output-stream.splice', (self._handle, src, len,)) - - def blocking_splice(self, src: Any, len: int) -> Any: - return self._invoke('[method]output-stream.blocking-splice', (self._handle, src, len,)) - - -class Datetime: - """A time and date in seconds plus nanoseconds.""" - def __init__( - self, - *, - seconds: int, - nanoseconds: int, - ) -> None: - setattr(self, 'seconds', seconds) - setattr(self, 'nanoseconds', nanoseconds) - - def __repr__(self) -> str: - return f'Datetime(seconds={getattr(self, 'seconds')!r}, nanoseconds={getattr(self, 'nanoseconds')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, Datetime): - return NotImplemented - return getattr(self, 'seconds') == getattr(other, 'seconds') and getattr(self, 'nanoseconds') == getattr(other, 'nanoseconds') - - def __hash__(self) -> int: - return id(self) - - -class LogLevel(str): - """Severity level of a log entry.""" - __slots__ = () - - TRACE: 'LogLevel' - DEBUG: 'LogLevel' - INFO: 'LogLevel' - WARN: 'LogLevel' - ERROR: 'LogLevel' - -LogLevel.TRACE = LogLevel('trace') # type: ignore[attr-defined] -LogLevel.DEBUG = LogLevel('debug') # type: ignore[attr-defined] -LogLevel.INFO = LogLevel('info') # type: ignore[attr-defined] -LogLevel.WARN = LogLevel('warn') # type: ignore[attr-defined] -LogLevel.ERROR = LogLevel('error') # type: ignore[attr-defined] - - -class LogEntry: - """A single structured log entry.""" - def __init__( - self, - *, - level: LogLevel, - message: str, - context: Optional[str], - ) -> None: - setattr(self, 'level', level) - setattr(self, 'message', message) - setattr(self, 'context', context) - - def __repr__(self) -> str: - return f'LogEntry(level={getattr(self, 'level')!r}, message={getattr(self, 'message')!r}, context={getattr(self, 'context')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, LogEntry): - return NotImplemented - return getattr(self, 'level') == getattr(other, 'level') and getattr(self, 'message') == getattr(other, 'message') and getattr(self, 'context') == getattr(other, 'context') - - def __hash__(self) -> int: - return id(self) - - -class ElicitRequest: - """Request for user input.""" - def __init__( - self, - *, - client_id: str, - message: str, - request: str, - ) -> None: - setattr(self, 'client-id', client_id) - setattr(self, 'message', message) - setattr(self, 'request', request) - - @property - def client_id(self) -> str: - return getattr(self, 'client-id') - - def __repr__(self) -> str: - return f'ElicitRequest(client_id={getattr(self, 'client-id')!r}, message={getattr(self, 'message')!r}, request={getattr(self, 'request')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, ElicitRequest): - return NotImplemented - return getattr(self, 'client-id') == getattr(other, 'client-id') and getattr(self, 'message') == getattr(other, 'message') and getattr(self, 'request') == getattr(other, 'request') - - def __hash__(self) -> int: - return id(self) - - -class ElicitAction(str): - """Outcome of an elicitation request.""" - __slots__ = () - - ACCEPT: 'ElicitAction' - DECLINE: 'ElicitAction' - CANCEL: 'ElicitAction' - -ElicitAction.ACCEPT = ElicitAction('accept') # type: ignore[attr-defined] -ElicitAction.DECLINE = ElicitAction('decline') # type: ignore[attr-defined] -ElicitAction.CANCEL = ElicitAction('cancel') # type: ignore[attr-defined] - - -class ElicitResponse: - """Response to an elicitation request.""" - def __init__( - self, - *, - action: ElicitAction, - content: Optional[str], - ) -> None: - setattr(self, 'action', action) - setattr(self, 'content', content) - - def __repr__(self) -> str: - return f'ElicitResponse(action={getattr(self, 'action')!r}, content={getattr(self, 'content')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, ElicitResponse): - return NotImplemented - return getattr(self, 'action') == getattr(other, 'action') and getattr(self, 'content') == getattr(other, 'content') - - def __hash__(self) -> int: - return id(self) - - -class ElicitationError: - """Why an elicitation call failed.""" - pass - -class ElicitationError_UnknownClient(ElicitationError, _WitVariantCase): - """No handler registered for the given `client-id`.""" - tag = 'unknown-client' - -class ElicitationError_HandlerFailed(ElicitationError, _WitVariantCase): - """Handler raised an exception.""" - tag = 'handler-failed' - -class ElicitationError_TimedOut(ElicitationError, _WitVariantCase): - """Request timed out waiting for a human response.""" - tag = 'timed-out' - -_ElicitationError_CASES: dict[str, type] = { - 'unknown-client': ElicitationError_UnknownClient, - 'handler-failed': ElicitationError_HandlerFailed, - 'timed-out': ElicitationError_TimedOut, -} - -def _ElicitationError_lift(raw: _WitVariant) -> ElicitationError: - cls = _ElicitationError_CASES.get(raw.tag) - if cls is None: - raise ValueError(f'unknown ElicitationError arm: {raw.tag!r}') - return cls(raw.payload) -ElicitationError.lift = staticmethod(_ElicitationError_lift) # type: ignore[attr-defined] - -class TextBlock: - """Plain text.""" - def __init__( - self, - *, - text: str, - ) -> None: - setattr(self, 'text', text) - - def __repr__(self) -> str: - return f'TextBlock(text={getattr(self, 'text')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, TextBlock): - return NotImplemented - return getattr(self, 'text') == getattr(other, 'text') - - def __hash__(self) -> int: - return id(self) - - -class S3Location: - """Object stored in Amazon S3.""" - def __init__( - self, - *, - uri: str, - bucket_owner: Optional[str], - ) -> None: - setattr(self, 'uri', uri) - setattr(self, 'bucket-owner', bucket_owner) - - @property - def bucket_owner(self) -> Optional[str]: - return getattr(self, 'bucket-owner') - - def __repr__(self) -> str: - return f'S3Location(uri={getattr(self, 'uri')!r}, bucket_owner={getattr(self, 'bucket-owner')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, S3Location): - return NotImplemented - return getattr(self, 'uri') == getattr(other, 'uri') and getattr(self, 'bucket-owner') == getattr(other, 'bucket-owner') - - def __hash__(self) -> int: - return id(self) - - -ImageSource = bytes | str | S3Location -"""Source of image bytes.""" - -class ImageBlock: - """Image attached to a message.""" - def __init__( - self, - *, - format: str, - source: ImageSource, - ) -> None: - setattr(self, 'format', format) - setattr(self, 'source', source) - - def __repr__(self) -> str: - return f'ImageBlock(format={getattr(self, 'format')!r}, source={getattr(self, 'source')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, ImageBlock): - return NotImplemented - return getattr(self, 'format') == getattr(other, 'format') and getattr(self, 'source') == getattr(other, 'source') - - def __hash__(self) -> int: - return id(self) - - -VideoSource = bytes | S3Location -"""Source of video bytes.""" - -class VideoBlock: - """Video attached to a message.""" - def __init__( - self, - *, - format: str, - source: VideoSource, - ) -> None: - setattr(self, 'format', format) - setattr(self, 'source', source) - - def __repr__(self) -> str: - return f'VideoBlock(format={getattr(self, 'format')!r}, source={getattr(self, 'source')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, VideoBlock): - return NotImplemented - return getattr(self, 'format') == getattr(other, 'format') and getattr(self, 'source') == getattr(other, 'source') - - def __hash__(self) -> int: - return id(self) - - -DocumentSource = bytes | str | list[TextBlock] | S3Location -"""Source of document bytes.""" - -class DocumentCitationsConfig: - """Citation configuration attached to a document.""" - def __init__( - self, - *, - enabled: bool, - ) -> None: - setattr(self, 'enabled', enabled) - - def __repr__(self) -> str: - return f'DocumentCitationsConfig(enabled={getattr(self, 'enabled')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, DocumentCitationsConfig): - return NotImplemented - return getattr(self, 'enabled') == getattr(other, 'enabled') - - def __hash__(self) -> int: - return id(self) - - -class DocumentBlock: - """Document attached to a message.""" - def __init__( - self, - *, - name: str, - format: str, - source: DocumentSource, - citations: Optional[DocumentCitationsConfig], - context: Optional[str], - ) -> None: - setattr(self, 'name', name) - setattr(self, 'format', format) - setattr(self, 'source', source) - setattr(self, 'citations', citations) - setattr(self, 'context', context) - - def __repr__(self) -> str: - return f'DocumentBlock(name={getattr(self, 'name')!r}, format={getattr(self, 'format')!r}, source={getattr(self, 'source')!r}, citations={getattr(self, 'citations')!r}, context={getattr(self, 'context')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, DocumentBlock): - return NotImplemented - return getattr(self, 'name') == getattr(other, 'name') and getattr(self, 'format') == getattr(other, 'format') and getattr(self, 'source') == getattr(other, 'source') and getattr(self, 'citations') == getattr(other, 'citations') and getattr(self, 'context') == getattr(other, 'context') - - def __hash__(self) -> int: - return id(self) - - -class ReasoningBlock: - """Model's thought process. Either plain reasoning (with an optional -signature) or an opaque redacted blob.""" - def __init__( - self, - *, - text: Optional[str], - signature: Optional[str], - redacted_content: Optional[bytes], - ) -> None: - setattr(self, 'text', text) - setattr(self, 'signature', signature) - setattr(self, 'redacted-content', redacted_content) - - @property - def redacted_content(self) -> Optional[bytes]: - return getattr(self, 'redacted-content') - - def __repr__(self) -> str: - return f'ReasoningBlock(text={getattr(self, 'text')!r}, signature={getattr(self, 'signature')!r}, redacted_content={getattr(self, 'redacted-content')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, ReasoningBlock): - return NotImplemented - return getattr(self, 'text') == getattr(other, 'text') and getattr(self, 'signature') == getattr(other, 'signature') and getattr(self, 'redacted-content') == getattr(other, 'redacted-content') - - def __hash__(self) -> int: - return id(self) - - -class CacheKind(str): - """Prompt-caching kind. More arms will be added as providers surface -additional cache tiers (e.g. Anthropic's `ephemeral`).""" - __slots__ = () - - DEFAULT_CACHE: 'CacheKind' - -CacheKind.DEFAULT_CACHE = CacheKind('default-cache') # type: ignore[attr-defined] - - -class CachePointBlock: - """Marks a caching boundary in the prompt.""" - def __init__( - self, - *, - kind: CacheKind, - ) -> None: - setattr(self, 'kind', kind) - - def __repr__(self) -> str: - return f'CachePointBlock(kind={getattr(self, 'kind')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, CachePointBlock): - return NotImplemented - return getattr(self, 'kind') == getattr(other, 'kind') - - def __hash__(self) -> int: - return id(self) - - -class GuardQualifier(str): - """How a piece of guard content should be evaluated.""" - __slots__ = () - - GROUNDING_SOURCE: 'GuardQualifier' - QUERY: 'GuardQualifier' - GUARD_CONTENT: 'GuardQualifier' - -GuardQualifier.GROUNDING_SOURCE = GuardQualifier('grounding-source') # type: ignore[attr-defined] -GuardQualifier.QUERY = GuardQualifier('query') # type: ignore[attr-defined] -GuardQualifier.GUARD_CONTENT = GuardQualifier('guard-content') # type: ignore[attr-defined] - - -class GuardContentText: - """Text submitted to a guardrail for evaluation.""" - def __init__( - self, - *, - qualifiers: list[GuardQualifier], - text: str, - ) -> None: - setattr(self, 'qualifiers', qualifiers) - setattr(self, 'text', text) - - def __repr__(self) -> str: - return f'GuardContentText(qualifiers={getattr(self, 'qualifiers')!r}, text={getattr(self, 'text')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, GuardContentText): - return NotImplemented - return getattr(self, 'qualifiers') == getattr(other, 'qualifiers') and getattr(self, 'text') == getattr(other, 'text') - - def __hash__(self) -> int: - return id(self) - - -class GuardContentImage: - """Image submitted to a guardrail for evaluation.""" - def __init__( - self, - *, - format: str, - bytes: bytes, - ) -> None: - setattr(self, 'format', format) - setattr(self, 'bytes', bytes) - - def __repr__(self) -> str: - return f'GuardContentImage(format={getattr(self, 'format')!r}, bytes={getattr(self, 'bytes')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, GuardContentImage): - return NotImplemented - return getattr(self, 'format') == getattr(other, 'format') and getattr(self, 'bytes') == getattr(other, 'bytes') - - def __hash__(self) -> int: - return id(self) - - -class GuardContentBlock: - """Content submitted to a guardrail for evaluation.""" - pass - -class GuardContentBlock_Text(GuardContentBlock, _WitVariantCase): - """Text guard content.""" - tag = 'text' - -class GuardContentBlock_Image(GuardContentBlock, _WitVariantCase): - """Image guard content.""" - tag = 'image' - -_GuardContentBlock_CASES: dict[str, type] = { - 'text': GuardContentBlock_Text, - 'image': GuardContentBlock_Image, -} - -def _GuardContentBlock_lift(raw: _WitVariant) -> GuardContentBlock: - cls = _GuardContentBlock_CASES.get(raw.tag) - if cls is None: - raise ValueError(f'unknown GuardContentBlock arm: {raw.tag!r}') - return cls(raw.payload) -GuardContentBlock.lift = staticmethod(_GuardContentBlock_lift) # type: ignore[attr-defined] - -class DocumentRange: - """Range within a source document (characters, pages, or chunks).""" - def __init__( - self, - *, - document_index: int, - start: int, - end: int, - ) -> None: - setattr(self, 'document-index', document_index) - setattr(self, 'start', start) - setattr(self, 'end', end) - - @property - def document_index(self) -> int: - return getattr(self, 'document-index') - - def __repr__(self) -> str: - return f'DocumentRange(document_index={getattr(self, 'document-index')!r}, start={getattr(self, 'start')!r}, end={getattr(self, 'end')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, DocumentRange): - return NotImplemented - return getattr(self, 'document-index') == getattr(other, 'document-index') and getattr(self, 'start') == getattr(other, 'start') and getattr(self, 'end') == getattr(other, 'end') - - def __hash__(self) -> int: - return id(self) - - -class SearchResultRange: - """Range within a search result.""" - def __init__( - self, - *, - search_result_index: int, - start: int, - end: int, - ) -> None: - setattr(self, 'search-result-index', search_result_index) - setattr(self, 'start', start) - setattr(self, 'end', end) - - @property - def search_result_index(self) -> int: - return getattr(self, 'search-result-index') - - def __repr__(self) -> str: - return f'SearchResultRange(search_result_index={getattr(self, 'search-result-index')!r}, start={getattr(self, 'start')!r}, end={getattr(self, 'end')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, SearchResultRange): - return NotImplemented - return getattr(self, 'search-result-index') == getattr(other, 'search-result-index') and getattr(self, 'start') == getattr(other, 'start') and getattr(self, 'end') == getattr(other, 'end') - - def __hash__(self) -> int: - return id(self) - - -class WebLocation: - """Web citation target.""" - def __init__( - self, - *, - url: str, - domain: Optional[str], - ) -> None: - setattr(self, 'url', url) - setattr(self, 'domain', domain) - - def __repr__(self) -> str: - return f'WebLocation(url={getattr(self, 'url')!r}, domain={getattr(self, 'domain')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, WebLocation): - return NotImplemented - return getattr(self, 'url') == getattr(other, 'url') and getattr(self, 'domain') == getattr(other, 'domain') - - def __hash__(self) -> int: - return id(self) - - -class CitationLocation: - """Anchor a citation points to.""" - pass - -class CitationLocation_DocumentChar(CitationLocation, _WitVariantCase): - """Character range within a document.""" - tag = 'document-char' - -class CitationLocation_DocumentPage(CitationLocation, _WitVariantCase): - """Page range within a document.""" - tag = 'document-page' - -class CitationLocation_DocumentChunk(CitationLocation, _WitVariantCase): - """Chunk range within a document.""" - tag = 'document-chunk' - -class CitationLocation_SearchResult(CitationLocation, _WitVariantCase): - """Range within a search result.""" - tag = 'search-result' - -class CitationLocation_Web(CitationLocation, _WitVariantCase): - """Web page.""" - tag = 'web' - -_CitationLocation_CASES: dict[str, type] = { - 'document-char': CitationLocation_DocumentChar, - 'document-page': CitationLocation_DocumentPage, - 'document-chunk': CitationLocation_DocumentChunk, - 'search-result': CitationLocation_SearchResult, - 'web': CitationLocation_Web, -} - -def _CitationLocation_lift(raw: _WitVariant) -> CitationLocation: - cls = _CitationLocation_CASES.get(raw.tag) - if cls is None: - raise ValueError(f'unknown CitationLocation arm: {raw.tag!r}') - return cls(raw.payload) -CitationLocation.lift = staticmethod(_CitationLocation_lift) # type: ignore[attr-defined] - -class CitationText: - """Text fragment from a source or a generated answer.""" - def __init__( - self, - *, - text: str, - ) -> None: - setattr(self, 'text', text) - - def __repr__(self) -> str: - return f'CitationText(text={getattr(self, 'text')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, CitationText): - return NotImplemented - return getattr(self, 'text') == getattr(other, 'text') - - def __hash__(self) -> int: - return id(self) - - -class Citation: - """Link from generated content back to a source location.""" - def __init__( - self, - *, - location: CitationLocation, - source: str, - source_content: list[CitationText], - title: str, - ) -> None: - setattr(self, 'location', location) - setattr(self, 'source', source) - setattr(self, 'source-content', source_content) - setattr(self, 'title', title) - - @property - def source_content(self) -> list[CitationText]: - return getattr(self, 'source-content') - - def __repr__(self) -> str: - return f'Citation(location={getattr(self, 'location')!r}, source={getattr(self, 'source')!r}, source_content={getattr(self, 'source-content')!r}, title={getattr(self, 'title')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, Citation): - return NotImplemented - return getattr(self, 'location') == getattr(other, 'location') and getattr(self, 'source') == getattr(other, 'source') and getattr(self, 'source-content') == getattr(other, 'source-content') and getattr(self, 'title') == getattr(other, 'title') - - def __hash__(self) -> int: - return id(self) - - -class CitationsBlock: - """Citations emitted by the model when citations are enabled.""" - def __init__( - self, - *, - citations: list[Citation], - content: list[CitationText], - ) -> None: - setattr(self, 'citations', citations) - setattr(self, 'content', content) - - def __repr__(self) -> str: - return f'CitationsBlock(citations={getattr(self, 'citations')!r}, content={getattr(self, 'content')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, CitationsBlock): - return NotImplemented - return getattr(self, 'citations') == getattr(other, 'citations') and getattr(self, 'content') == getattr(other, 'content') - - def __hash__(self) -> int: - return id(self) - - -class ToolUseBlock: - """Model's request to call a tool.""" - def __init__( - self, - *, - name: str, - tool_use_id: str, - input: str, - reasoning_signature: Optional[str], - ) -> None: - setattr(self, 'name', name) - setattr(self, 'tool-use-id', tool_use_id) - setattr(self, 'input', input) - setattr(self, 'reasoning-signature', reasoning_signature) - - @property - def tool_use_id(self) -> str: - return getattr(self, 'tool-use-id') - - @property - def reasoning_signature(self) -> Optional[str]: - return getattr(self, 'reasoning-signature') - - def __repr__(self) -> str: - return f'ToolUseBlock(name={getattr(self, 'name')!r}, tool_use_id={getattr(self, 'tool-use-id')!r}, input={getattr(self, 'input')!r}, reasoning_signature={getattr(self, 'reasoning-signature')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, ToolUseBlock): - return NotImplemented - return getattr(self, 'name') == getattr(other, 'name') and getattr(self, 'tool-use-id') == getattr(other, 'tool-use-id') and getattr(self, 'input') == getattr(other, 'input') and getattr(self, 'reasoning-signature') == getattr(other, 'reasoning-signature') - - def __hash__(self) -> int: - return id(self) - - -class ToolResultStatus(str): - """Whether a tool invocation succeeded. Richer classification lives on `tools.tool-error`.""" - __slots__ = () - - SUCCESS: 'ToolResultStatus' - ERROR: 'ToolResultStatus' - -ToolResultStatus.SUCCESS = ToolResultStatus('success') # type: ignore[attr-defined] -ToolResultStatus.ERROR = ToolResultStatus('error') # type: ignore[attr-defined] - - -class JsonBlock: - """Structured JSON payload. Used for tool results and agent-as-tool -outputs that carry schema-validated data, not prose.""" - def __init__( - self, - *, - json: str, - ) -> None: - setattr(self, 'json', json) - - def __repr__(self) -> str: - return f'JsonBlock(json={getattr(self, 'json')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, JsonBlock): - return NotImplemented - return getattr(self, 'json') == getattr(other, 'json') - - def __hash__(self) -> int: - return id(self) - - -class ToolResultContent: - """Block valid inside `tool-result-block.content`. Narrower than `content-block`.""" - pass - -class ToolResultContent_Text(ToolResultContent, _WitVariantCase): - """Text output.""" - tag = 'text' - -class ToolResultContent_Json(ToolResultContent, _WitVariantCase): - """Structured JSON output.""" - tag = 'json' - -class ToolResultContent_Image(ToolResultContent, _WitVariantCase): - """Image output.""" - tag = 'image' - -class ToolResultContent_Video(ToolResultContent, _WitVariantCase): - """Video output.""" - tag = 'video' - -class ToolResultContent_Document(ToolResultContent, _WitVariantCase): - """Document output.""" - tag = 'document' - -_ToolResultContent_CASES: dict[str, type] = { - 'text': ToolResultContent_Text, - 'json': ToolResultContent_Json, - 'image': ToolResultContent_Image, - 'video': ToolResultContent_Video, - 'document': ToolResultContent_Document, -} - -def _ToolResultContent_lift(raw: _WitVariant) -> ToolResultContent: - cls = _ToolResultContent_CASES.get(raw.tag) - if cls is None: - raise ValueError(f'unknown ToolResultContent arm: {raw.tag!r}') - return cls(raw.payload) -ToolResultContent.lift = staticmethod(_ToolResultContent_lift) # type: ignore[attr-defined] - -class ToolResultBlock: - """Outcome of a tool execution.""" - def __init__( - self, - *, - tool_use_id: str, - status: ToolResultStatus, - content: list[ToolResultContent], - ) -> None: - setattr(self, 'tool-use-id', tool_use_id) - setattr(self, 'status', status) - setattr(self, 'content', content) - - @property - def tool_use_id(self) -> str: - return getattr(self, 'tool-use-id') - - def __repr__(self) -> str: - return f'ToolResultBlock(tool_use_id={getattr(self, 'tool-use-id')!r}, status={getattr(self, 'status')!r}, content={getattr(self, 'content')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, ToolResultBlock): - return NotImplemented - return getattr(self, 'tool-use-id') == getattr(other, 'tool-use-id') and getattr(self, 'status') == getattr(other, 'status') and getattr(self, 'content') == getattr(other, 'content') - - def __hash__(self) -> int: - return id(self) - - -class InterruptResponseBlock: - """User response to a previously-raised interrupt. Supplied on the -next invocation to resume the paused agent.""" - def __init__( - self, - *, - interrupt_id: str, - response: str, - ) -> None: - setattr(self, 'interrupt-id', interrupt_id) - setattr(self, 'response', response) - - @property - def interrupt_id(self) -> str: - return getattr(self, 'interrupt-id') - - def __repr__(self) -> str: - return f'InterruptResponseBlock(interrupt_id={getattr(self, 'interrupt-id')!r}, response={getattr(self, 'response')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, InterruptResponseBlock): - return NotImplemented - return getattr(self, 'interrupt-id') == getattr(other, 'interrupt-id') and getattr(self, 'response') == getattr(other, 'response') - - def __hash__(self) -> int: - return id(self) - - -class ContentBlock: - """Any block that can appear inside a message.""" - pass - -class ContentBlock_Text(ContentBlock, _WitVariantCase): - """Plain text.""" - tag = 'text' - -class ContentBlock_Json(ContentBlock, _WitVariantCase): - """Structured JSON payload.""" - tag = 'json' - -class ContentBlock_ToolUse(ContentBlock, _WitVariantCase): - """Model requested a tool call.""" - tag = 'tool-use' - -class ContentBlock_ToolResult(ContentBlock, _WitVariantCase): - """Tool call completed.""" - tag = 'tool-result' - -class ContentBlock_Reasoning(ContentBlock, _WitVariantCase): - """Model reasoning.""" - tag = 'reasoning' - -class ContentBlock_CachePoint(ContentBlock, _WitVariantCase): - """Caching boundary marker.""" - tag = 'cache-point' - -class ContentBlock_GuardContent(ContentBlock, _WitVariantCase): - """Content submitted for guardrail evaluation.""" - tag = 'guard-content' - -class ContentBlock_Image(ContentBlock, _WitVariantCase): - """Image.""" - tag = 'image' - -class ContentBlock_Video(ContentBlock, _WitVariantCase): - """Video.""" - tag = 'video' - -class ContentBlock_Document(ContentBlock, _WitVariantCase): - """Document.""" - tag = 'document' - -class ContentBlock_Citations(ContentBlock, _WitVariantCase): - """Citations emitted by the model.""" - tag = 'citations' - -class ContentBlock_InterruptResponse(ContentBlock, _WitVariantCase): - """Response to a prior interrupt, supplied when resuming.""" - tag = 'interrupt-response' - -_ContentBlock_CASES: dict[str, type] = { - 'text': ContentBlock_Text, - 'json': ContentBlock_Json, - 'tool-use': ContentBlock_ToolUse, - 'tool-result': ContentBlock_ToolResult, - 'reasoning': ContentBlock_Reasoning, - 'cache-point': ContentBlock_CachePoint, - 'guard-content': ContentBlock_GuardContent, - 'image': ContentBlock_Image, - 'video': ContentBlock_Video, - 'document': ContentBlock_Document, - 'citations': ContentBlock_Citations, - 'interrupt-response': ContentBlock_InterruptResponse, -} - -def _ContentBlock_lift(raw: _WitVariant) -> ContentBlock: - cls = _ContentBlock_CASES.get(raw.tag) - if cls is None: - raise ValueError(f'unknown ContentBlock arm: {raw.tag!r}') - return cls(raw.payload) -ContentBlock.lift = staticmethod(_ContentBlock_lift) # type: ignore[attr-defined] - -class Role(str): - """Who a message is from.""" - __slots__ = () - - USER: 'Role' - ASSISTANT: 'Role' - -Role.USER = Role('user') # type: ignore[attr-defined] -Role.ASSISTANT = Role('assistant') # type: ignore[attr-defined] - - -class Usage: - """Token consumption for a model invocation.""" - def __init__( - self, - *, - input_tokens: int, - output_tokens: int, - total_tokens: int, - cache_read_input_tokens: Optional[int], - cache_write_input_tokens: Optional[int], - ) -> None: - setattr(self, 'input-tokens', input_tokens) - setattr(self, 'output-tokens', output_tokens) - setattr(self, 'total-tokens', total_tokens) - setattr(self, 'cache-read-input-tokens', cache_read_input_tokens) - setattr(self, 'cache-write-input-tokens', cache_write_input_tokens) - - @property - def input_tokens(self) -> int: - return getattr(self, 'input-tokens') - - @property - def output_tokens(self) -> int: - return getattr(self, 'output-tokens') - - @property - def total_tokens(self) -> int: - return getattr(self, 'total-tokens') - - @property - def cache_read_input_tokens(self) -> Optional[int]: - return getattr(self, 'cache-read-input-tokens') - - @property - def cache_write_input_tokens(self) -> Optional[int]: - return getattr(self, 'cache-write-input-tokens') - - def __repr__(self) -> str: - return f'Usage(input_tokens={getattr(self, 'input-tokens')!r}, output_tokens={getattr(self, 'output-tokens')!r}, total_tokens={getattr(self, 'total-tokens')!r}, cache_read_input_tokens={getattr(self, 'cache-read-input-tokens')!r}, cache_write_input_tokens={getattr(self, 'cache-write-input-tokens')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, Usage): - return NotImplemented - return getattr(self, 'input-tokens') == getattr(other, 'input-tokens') and getattr(self, 'output-tokens') == getattr(other, 'output-tokens') and getattr(self, 'total-tokens') == getattr(other, 'total-tokens') and getattr(self, 'cache-read-input-tokens') == getattr(other, 'cache-read-input-tokens') and getattr(self, 'cache-write-input-tokens') == getattr(other, 'cache-write-input-tokens') - - def __hash__(self) -> int: - return id(self) - - -class Metrics: - """Performance metrics for a model invocation.""" - def __init__( - self, - *, - latency_ms: float, - ) -> None: - setattr(self, 'latency-ms', latency_ms) - - @property - def latency_ms(self) -> float: - return getattr(self, 'latency-ms') - - def __repr__(self) -> str: - return f'Metrics(latency_ms={getattr(self, 'latency-ms')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, Metrics): - return NotImplemented - return getattr(self, 'latency-ms') == getattr(other, 'latency-ms') - - def __hash__(self) -> int: - return id(self) - - -class MessageMetadata: - """Metadata attached to a message. Not sent to model providers; persisted -alongside the message for bookkeeping.""" - def __init__( - self, - *, - usage: Optional[Usage], - metrics: Optional[Metrics], - custom: Optional[str], - ) -> None: - setattr(self, 'usage', usage) - setattr(self, 'metrics', metrics) - setattr(self, 'custom', custom) - - def __repr__(self) -> str: - return f'MessageMetadata(usage={getattr(self, 'usage')!r}, metrics={getattr(self, 'metrics')!r}, custom={getattr(self, 'custom')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, MessageMetadata): - return NotImplemented - return getattr(self, 'usage') == getattr(other, 'usage') and getattr(self, 'metrics') == getattr(other, 'metrics') and getattr(self, 'custom') == getattr(other, 'custom') - - def __hash__(self) -> int: - return id(self) - - -class Message: - """A complete message in a conversation.""" - def __init__( - self, - *, - role: Role, - content: list[ContentBlock], - metadata: Optional[MessageMetadata], - ) -> None: - setattr(self, 'role', role) - setattr(self, 'content', content) - setattr(self, 'metadata', metadata) - - def __repr__(self) -> str: - return f'Message(role={getattr(self, 'role')!r}, content={getattr(self, 'content')!r}, metadata={getattr(self, 'metadata')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, Message): - return NotImplemented - return getattr(self, 'role') == getattr(other, 'role') and getattr(self, 'content') == getattr(other, 'content') and getattr(self, 'metadata') == getattr(other, 'metadata') - - def __hash__(self) -> int: - return id(self) - - -PromptInput = str | list[ContentBlock] -"""A prompt-style input: either prose or structured content. Used for -both system prompts and user input.""" - -class AnthropicConfig: - """Anthropic API model configuration.""" - def __init__( - self, - *, - model_id: Optional[str], - api_key: Optional[str], - additional_config: Optional[str], - ) -> None: - setattr(self, 'model-id', model_id) - setattr(self, 'api-key', api_key) - setattr(self, 'additional-config', additional_config) - - @property - def model_id(self) -> Optional[str]: - return getattr(self, 'model-id') - - @property - def api_key(self) -> Optional[str]: - return getattr(self, 'api-key') - - @property - def additional_config(self) -> Optional[str]: - return getattr(self, 'additional-config') - - def __repr__(self) -> str: - return f'AnthropicConfig(model_id={getattr(self, 'model-id')!r}, api_key={getattr(self, 'api-key')!r}, additional_config={getattr(self, 'additional-config')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, AnthropicConfig): - return NotImplemented - return getattr(self, 'model-id') == getattr(other, 'model-id') and getattr(self, 'api-key') == getattr(other, 'api-key') and getattr(self, 'additional-config') == getattr(other, 'additional-config') - - def __hash__(self) -> int: - return id(self) - - -class BedrockConfig: - """AWS Bedrock model configuration.""" - def __init__( - self, - *, - model_id: str, - region: Optional[str], - access_key_id: Optional[str], - secret_access_key: Optional[str], - session_token: Optional[str], - additional_config: Optional[str], - ) -> None: - setattr(self, 'model-id', model_id) - setattr(self, 'region', region) - setattr(self, 'access-key-id', access_key_id) - setattr(self, 'secret-access-key', secret_access_key) - setattr(self, 'session-token', session_token) - setattr(self, 'additional-config', additional_config) - - @property - def model_id(self) -> str: - return getattr(self, 'model-id') - - @property - def access_key_id(self) -> Optional[str]: - return getattr(self, 'access-key-id') - - @property - def secret_access_key(self) -> Optional[str]: - return getattr(self, 'secret-access-key') - - @property - def session_token(self) -> Optional[str]: - return getattr(self, 'session-token') - - @property - def additional_config(self) -> Optional[str]: - return getattr(self, 'additional-config') - - def __repr__(self) -> str: - return f'BedrockConfig(model_id={getattr(self, 'model-id')!r}, region={getattr(self, 'region')!r}, access_key_id={getattr(self, 'access-key-id')!r}, secret_access_key={getattr(self, 'secret-access-key')!r}, session_token={getattr(self, 'session-token')!r}, additional_config={getattr(self, 'additional-config')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, BedrockConfig): - return NotImplemented - return getattr(self, 'model-id') == getattr(other, 'model-id') and getattr(self, 'region') == getattr(other, 'region') and getattr(self, 'access-key-id') == getattr(other, 'access-key-id') and getattr(self, 'secret-access-key') == getattr(other, 'secret-access-key') and getattr(self, 'session-token') == getattr(other, 'session-token') and getattr(self, 'additional-config') == getattr(other, 'additional-config') - - def __hash__(self) -> int: - return id(self) - - -class OpenaiConfig: - """OpenAI API model configuration.""" - def __init__( - self, - *, - model_id: Optional[str], - api_key: Optional[str], - additional_config: Optional[str], - ) -> None: - setattr(self, 'model-id', model_id) - setattr(self, 'api-key', api_key) - setattr(self, 'additional-config', additional_config) - - @property - def model_id(self) -> Optional[str]: - return getattr(self, 'model-id') - - @property - def api_key(self) -> Optional[str]: - return getattr(self, 'api-key') - - @property - def additional_config(self) -> Optional[str]: - return getattr(self, 'additional-config') - - def __repr__(self) -> str: - return f'OpenaiConfig(model_id={getattr(self, 'model-id')!r}, api_key={getattr(self, 'api-key')!r}, additional_config={getattr(self, 'additional-config')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, OpenaiConfig): - return NotImplemented - return getattr(self, 'model-id') == getattr(other, 'model-id') and getattr(self, 'api-key') == getattr(other, 'api-key') and getattr(self, 'additional-config') == getattr(other, 'additional-config') - - def __hash__(self) -> int: - return id(self) - - -class GeminiConfig: - """Google Gemini API model configuration.""" - def __init__( - self, - *, - model_id: Optional[str], - api_key: Optional[str], - additional_config: Optional[str], - ) -> None: - setattr(self, 'model-id', model_id) - setattr(self, 'api-key', api_key) - setattr(self, 'additional-config', additional_config) - - @property - def model_id(self) -> Optional[str]: - return getattr(self, 'model-id') - - @property - def api_key(self) -> Optional[str]: - return getattr(self, 'api-key') - - @property - def additional_config(self) -> Optional[str]: - return getattr(self, 'additional-config') - - def __repr__(self) -> str: - return f'GeminiConfig(model_id={getattr(self, 'model-id')!r}, api_key={getattr(self, 'api-key')!r}, additional_config={getattr(self, 'additional-config')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, GeminiConfig): - return NotImplemented - return getattr(self, 'model-id') == getattr(other, 'model-id') and getattr(self, 'api-key') == getattr(other, 'api-key') and getattr(self, 'additional-config') == getattr(other, 'additional-config') - - def __hash__(self) -> int: - return id(self) - - -class CustomModelConfig: - """Custom model provider supplied by your application.""" - def __init__( - self, - *, - provider_id: str, - model_id: Optional[str], - additional_config: Optional[str], - stateful: bool, - ) -> None: - setattr(self, 'provider-id', provider_id) - setattr(self, 'model-id', model_id) - setattr(self, 'additional-config', additional_config) - setattr(self, 'stateful', stateful) - - @property - def provider_id(self) -> str: - return getattr(self, 'provider-id') - - @property - def model_id(self) -> Optional[str]: - return getattr(self, 'model-id') - - @property - def additional_config(self) -> Optional[str]: - return getattr(self, 'additional-config') - - def __repr__(self) -> str: - return f'CustomModelConfig(provider_id={getattr(self, 'provider-id')!r}, model_id={getattr(self, 'model-id')!r}, additional_config={getattr(self, 'additional-config')!r}, stateful={getattr(self, 'stateful')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, CustomModelConfig): - return NotImplemented - return getattr(self, 'provider-id') == getattr(other, 'provider-id') and getattr(self, 'model-id') == getattr(other, 'model-id') and getattr(self, 'additional-config') == getattr(other, 'additional-config') and getattr(self, 'stateful') == getattr(other, 'stateful') - - def __hash__(self) -> int: - return id(self) - - -class ModelConfig: - """Which model provider the agent should use.""" - pass - -class ModelConfig_Anthropic(ModelConfig, _WitVariantCase): - """Anthropic API.""" - tag = 'anthropic' - -class ModelConfig_Bedrock(ModelConfig, _WitVariantCase): - """AWS Bedrock.""" - tag = 'bedrock' - -class ModelConfig_Openai(ModelConfig, _WitVariantCase): - """OpenAI API.""" - tag = 'openai' - -class ModelConfig_Gemini(ModelConfig, _WitVariantCase): - """Google Gemini API.""" - tag = 'gemini' - -class ModelConfig_Custom(ModelConfig, _WitVariantCase): - """Custom provider supplied by your application. Implement the -`model-provider` interface to serve it.""" - tag = 'custom' - -_ModelConfig_CASES: dict[str, type] = { - 'anthropic': ModelConfig_Anthropic, - 'bedrock': ModelConfig_Bedrock, - 'openai': ModelConfig_Openai, - 'gemini': ModelConfig_Gemini, - 'custom': ModelConfig_Custom, -} - -def _ModelConfig_lift(raw: _WitVariant) -> ModelConfig: - cls = _ModelConfig_CASES.get(raw.tag) - if cls is None: - raise ValueError(f'unknown ModelConfig arm: {raw.tag!r}') - return cls(raw.payload) -ModelConfig.lift = staticmethod(_ModelConfig_lift) # type: ignore[attr-defined] - -class ModelParams: - """Sampling parameters applied to every call on the chosen provider.""" - def __init__( - self, - *, - max_tokens: Optional[int], - temperature: Optional[float], - top_p: Optional[float], - ) -> None: - setattr(self, 'max-tokens', max_tokens) - setattr(self, 'temperature', temperature) - setattr(self, 'top-p', top_p) - - @property - def max_tokens(self) -> Optional[int]: - return getattr(self, 'max-tokens') - - @property - def top_p(self) -> Optional[float]: - return getattr(self, 'top-p') - - def __repr__(self) -> str: - return f'ModelParams(max_tokens={getattr(self, 'max-tokens')!r}, temperature={getattr(self, 'temperature')!r}, top_p={getattr(self, 'top-p')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, ModelParams): - return NotImplemented - return getattr(self, 'max-tokens') == getattr(other, 'max-tokens') and getattr(self, 'temperature') == getattr(other, 'temperature') and getattr(self, 'top-p') == getattr(other, 'top-p') - - def __hash__(self) -> int: - return id(self) - - -class ModelError: - """Why a model call failed. Retry logic keys off of which arm fires, so -implementations should pick the narrowest one that fits.""" - pass - -class ModelError_UnknownProvider(ModelError, _WitVariantCase): - """No provider registered for the given `provider-id`.""" - tag = 'unknown-provider' - -class ModelError_InvalidRequest(ModelError, _WitVariantCase): - """Provider refused the request due to malformed input.""" - tag = 'invalid-request' - -class ModelError_Unauthorized(ModelError, _WitVariantCase): - """Caller lacks permission (missing or expired credentials).""" - tag = 'unauthorized' - -class ModelError_Throttled(ModelError, _WitVariantCase): - """Provider returned a rate-limit error. Retry after a backoff.""" - tag = 'throttled' - -class ModelError_ServerError(ModelError, _WitVariantCase): - """Provider returned a server-side error. Retry may succeed.""" - tag = 'server-error' - -class ModelError_ContextWindowExceeded(ModelError, _WitVariantCase): - """Request exceeded the model's context window.""" - tag = 'context-window-exceeded' - -class ModelError_ContentFiltered(ModelError, _WitVariantCase): - """Content was rejected by provider safety policy.""" - tag = 'content-filtered' - -class ModelError_Transient(ModelError, _WitVariantCase): - """Transient network or transport failure. Retry may succeed.""" - tag = 'transient' - -class ModelError_Internal(ModelError, _WitVariantCase): - """Catch-all for internal failures.""" - tag = 'internal' - -_ModelError_CASES: dict[str, type] = { - 'unknown-provider': ModelError_UnknownProvider, - 'invalid-request': ModelError_InvalidRequest, - 'unauthorized': ModelError_Unauthorized, - 'throttled': ModelError_Throttled, - 'server-error': ModelError_ServerError, - 'context-window-exceeded': ModelError_ContextWindowExceeded, - 'content-filtered': ModelError_ContentFiltered, - 'transient': ModelError_Transient, - 'internal': ModelError_Internal, -} - -def _ModelError_lift(raw: _WitVariant) -> ModelError: - cls = _ModelError_CASES.get(raw.tag) - if cls is None: - raise ValueError(f'unknown ModelError arm: {raw.tag!r}') - return cls(raw.payload) -ModelError.lift = staticmethod(_ModelError_lift) # type: ignore[attr-defined] - -class SlidingWindowConfig: - """Sliding-window strategy: trim oldest messages once the conversation -exceeds `window-size`.""" - def __init__( - self, - *, - window_size: int, - should_truncate_results: bool, - ) -> None: - setattr(self, 'window-size', window_size) - setattr(self, 'should-truncate-results', should_truncate_results) - - @property - def window_size(self) -> int: - return getattr(self, 'window-size') - - @property - def should_truncate_results(self) -> bool: - return getattr(self, 'should-truncate-results') - - def __repr__(self) -> str: - return f'SlidingWindowConfig(window_size={getattr(self, 'window-size')!r}, should_truncate_results={getattr(self, 'should-truncate-results')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, SlidingWindowConfig): - return NotImplemented - return getattr(self, 'window-size') == getattr(other, 'window-size') and getattr(self, 'should-truncate-results') == getattr(other, 'should-truncate-results') - - def __hash__(self) -> int: - return id(self) - - -class SummarizingConfig: - """Summarizing strategy: once the conversation grows, summarize older -messages into a single summary message and keep the rest verbatim.""" - def __init__( - self, - *, - summary_ratio: float, - preserve_recent_messages: int, - summarization_system_prompt: Optional[str], - summarization_model: Optional[ModelConfig], - ) -> None: - setattr(self, 'summary-ratio', summary_ratio) - setattr(self, 'preserve-recent-messages', preserve_recent_messages) - setattr(self, 'summarization-system-prompt', summarization_system_prompt) - setattr(self, 'summarization-model', summarization_model) - - @property - def summary_ratio(self) -> float: - return getattr(self, 'summary-ratio') - - @property - def preserve_recent_messages(self) -> int: - return getattr(self, 'preserve-recent-messages') - - @property - def summarization_system_prompt(self) -> Optional[str]: - return getattr(self, 'summarization-system-prompt') - - @property - def summarization_model(self) -> Optional[ModelConfig]: - return getattr(self, 'summarization-model') - - def __repr__(self) -> str: - return f'SummarizingConfig(summary_ratio={getattr(self, 'summary-ratio')!r}, preserve_recent_messages={getattr(self, 'preserve-recent-messages')!r}, summarization_system_prompt={getattr(self, 'summarization-system-prompt')!r}, summarization_model={getattr(self, 'summarization-model')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, SummarizingConfig): - return NotImplemented - return getattr(self, 'summary-ratio') == getattr(other, 'summary-ratio') and getattr(self, 'preserve-recent-messages') == getattr(other, 'preserve-recent-messages') and getattr(self, 'summarization-system-prompt') == getattr(other, 'summarization-system-prompt') and getattr(self, 'summarization-model') == getattr(other, 'summarization-model') - - def __hash__(self) -> int: - return id(self) - - -class ConversationManagerConfig: - """Which conversation manager the agent uses.""" - pass - -class ConversationManagerConfig_None(ConversationManagerConfig, _WitVariantCase): - """No conversation management. History grows without bound and -context-overflow errors propagate to the caller.""" - tag = 'none' - -class ConversationManagerConfig_SlidingWindow(ConversationManagerConfig, _WitVariantCase): - """Sliding-window trimming.""" - tag = 'sliding-window' - -class ConversationManagerConfig_Summarizing(ConversationManagerConfig, _WitVariantCase): - """Summarization of older messages.""" - tag = 'summarizing' - -_ConversationManagerConfig_CASES: dict[str, type] = { - 'none': ConversationManagerConfig_None, - 'sliding-window': ConversationManagerConfig_SlidingWindow, - 'summarizing': ConversationManagerConfig_Summarizing, -} - -def _ConversationManagerConfig_lift(raw: _WitVariant) -> ConversationManagerConfig: - cls = _ConversationManagerConfig_CASES.get(raw.tag) - if cls is None: - raise ValueError(f'unknown ConversationManagerConfig arm: {raw.tag!r}') - return cls(raw.payload) -ConversationManagerConfig.lift = staticmethod(_ConversationManagerConfig_lift) # type: ignore[attr-defined] - -class JitterKind(str): - """How much random variation to apply to computed delays.""" - __slots__ = () - - NONE: 'JitterKind' - FULL: 'JitterKind' - EQUAL: 'JitterKind' - DECORRELATED: 'JitterKind' - -JitterKind.NONE = JitterKind('none') # type: ignore[attr-defined] -JitterKind.FULL = JitterKind('full') # type: ignore[attr-defined] -JitterKind.EQUAL = JitterKind('equal') # type: ignore[attr-defined] -JitterKind.DECORRELATED = JitterKind('decorrelated') # type: ignore[attr-defined] - - -class ConstantBackoffConfig: - """Fixed delay between attempts.""" - def __init__( - self, - *, - delay: int, - ) -> None: - setattr(self, 'delay', delay) - - def __repr__(self) -> str: - return f'ConstantBackoffConfig(delay={getattr(self, 'delay')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, ConstantBackoffConfig): - return NotImplemented - return getattr(self, 'delay') == getattr(other, 'delay') - - def __hash__(self) -> int: - return id(self) - - -class LinearBackoffConfig: - """Delay grows linearly with attempt number.""" - def __init__( - self, - *, - base: int, - max: int, - jitter: JitterKind, - ) -> None: - setattr(self, 'base', base) - setattr(self, 'max', max) - setattr(self, 'jitter', jitter) - - def __repr__(self) -> str: - return f'LinearBackoffConfig(base={getattr(self, 'base')!r}, max={getattr(self, 'max')!r}, jitter={getattr(self, 'jitter')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, LinearBackoffConfig): - return NotImplemented - return getattr(self, 'base') == getattr(other, 'base') and getattr(self, 'max') == getattr(other, 'max') and getattr(self, 'jitter') == getattr(other, 'jitter') - - def __hash__(self) -> int: - return id(self) - - -class ExponentialBackoffConfig: - """Delay grows exponentially with attempt number.""" - def __init__( - self, - *, - base: int, - max: int, - factor: float, - jitter: JitterKind, - ) -> None: - setattr(self, 'base', base) - setattr(self, 'max', max) - setattr(self, 'factor', factor) - setattr(self, 'jitter', jitter) - - def __repr__(self) -> str: - return f'ExponentialBackoffConfig(base={getattr(self, 'base')!r}, max={getattr(self, 'max')!r}, factor={getattr(self, 'factor')!r}, jitter={getattr(self, 'jitter')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, ExponentialBackoffConfig): - return NotImplemented - return getattr(self, 'base') == getattr(other, 'base') and getattr(self, 'max') == getattr(other, 'max') and getattr(self, 'factor') == getattr(other, 'factor') and getattr(self, 'jitter') == getattr(other, 'jitter') - - def __hash__(self) -> int: - return id(self) - - -class BackoffStrategy: - """Backoff curve applied between attempts.""" - pass - -class BackoffStrategy_Constant(BackoffStrategy, _WitVariantCase): - """Fixed delay.""" - tag = 'constant' - -class BackoffStrategy_Linear(BackoffStrategy, _WitVariantCase): - """Linear growth.""" - tag = 'linear' - -class BackoffStrategy_Exponential(BackoffStrategy, _WitVariantCase): - """Exponential growth.""" - tag = 'exponential' - -_BackoffStrategy_CASES: dict[str, type] = { - 'constant': BackoffStrategy_Constant, - 'linear': BackoffStrategy_Linear, - 'exponential': BackoffStrategy_Exponential, -} - -def _BackoffStrategy_lift(raw: _WitVariant) -> BackoffStrategy: - cls = _BackoffStrategy_CASES.get(raw.tag) - if cls is None: - raise ValueError(f'unknown BackoffStrategy arm: {raw.tag!r}') - return cls(raw.payload) -BackoffStrategy.lift = staticmethod(_BackoffStrategy_lift) # type: ignore[attr-defined] - -class ModelRetryStrategy: - """A single retry strategy. Default is exponential backoff, full jitter, 6 attempts.""" - def __init__( - self, - *, - max_attempts: int, - backoff: BackoffStrategy, - total_budget: Optional[int], - ) -> None: - setattr(self, 'max-attempts', max_attempts) - setattr(self, 'backoff', backoff) - setattr(self, 'total-budget', total_budget) - - @property - def max_attempts(self) -> int: - return getattr(self, 'max-attempts') - - @property - def total_budget(self) -> Optional[int]: - return getattr(self, 'total-budget') - - def __repr__(self) -> str: - return f'ModelRetryStrategy(max_attempts={getattr(self, 'max-attempts')!r}, backoff={getattr(self, 'backoff')!r}, total_budget={getattr(self, 'total-budget')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, ModelRetryStrategy): - return NotImplemented - return getattr(self, 'max-attempts') == getattr(other, 'max-attempts') and getattr(self, 'backoff') == getattr(other, 'backoff') and getattr(self, 'total-budget') == getattr(other, 'total-budget') - - def __hash__(self) -> int: - return id(self) - - -class RetryConfig: - """Retry configuration attached to an agent. -Every strategy observes every failure; the first to request a delay wins. -Empty list disables retries; omitting `agent-config.retry` applies a default -single exponential strategy. Two strategies with the same `backoff` arm -surface as `agent-error::invalid-input`.""" - def __init__( - self, - *, - strategies: list[ModelRetryStrategy], - ) -> None: - setattr(self, 'strategies', strategies) - - def __repr__(self) -> str: - return f'RetryConfig(strategies={getattr(self, 'strategies')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, RetryConfig): - return NotImplemented - return getattr(self, 'strategies') == getattr(other, 'strategies') - - def __hash__(self) -> int: - return id(self) - - -class FileStorageConfig: - """Local filesystem snapshot storage.""" - def __init__( - self, - *, - base_dir: str, - ) -> None: - setattr(self, 'base-dir', base_dir) - - @property - def base_dir(self) -> str: - return getattr(self, 'base-dir') - - def __repr__(self) -> str: - return f'FileStorageConfig(base_dir={getattr(self, 'base-dir')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, FileStorageConfig): - return NotImplemented - return getattr(self, 'base-dir') == getattr(other, 'base-dir') - - def __hash__(self) -> int: - return id(self) - - -class S3StorageConfig: - """S3 snapshot storage.""" - def __init__( - self, - *, - bucket: str, - region: Optional[str], - prefix: Optional[str], - ) -> None: - setattr(self, 'bucket', bucket) - setattr(self, 'region', region) - setattr(self, 'prefix', prefix) - - def __repr__(self) -> str: - return f'S3StorageConfig(bucket={getattr(self, 'bucket')!r}, region={getattr(self, 'region')!r}, prefix={getattr(self, 'prefix')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, S3StorageConfig): - return NotImplemented - return getattr(self, 'bucket') == getattr(other, 'bucket') and getattr(self, 'region') == getattr(other, 'region') and getattr(self, 'prefix') == getattr(other, 'prefix') - - def __hash__(self) -> int: - return id(self) - - -class CustomStorageConfig: - """Reference to an application-implemented storage backend.""" - def __init__( - self, - *, - backend_id: str, - ) -> None: - setattr(self, 'backend-id', backend_id) - - @property - def backend_id(self) -> str: - return getattr(self, 'backend-id') - - def __repr__(self) -> str: - return f'CustomStorageConfig(backend_id={getattr(self, 'backend-id')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, CustomStorageConfig): - return NotImplemented - return getattr(self, 'backend-id') == getattr(other, 'backend-id') - - def __hash__(self) -> int: - return id(self) - - -class StorageConfig: - """Where to persist session snapshots.""" - pass - -class StorageConfig_File(StorageConfig, _WitVariantCase): - """Local filesystem.""" - tag = 'file' - -class StorageConfig_S3(StorageConfig, _WitVariantCase): - """Amazon S3.""" - tag = 's3' - -class StorageConfig_Custom(StorageConfig, _WitVariantCase): - """Application-implemented backend.""" - tag = 'custom' - -_StorageConfig_CASES: dict[str, type] = { - 'file': StorageConfig_File, - 's3': StorageConfig_S3, - 'custom': StorageConfig_Custom, -} - -def _StorageConfig_lift(raw: _WitVariant) -> StorageConfig: - cls = _StorageConfig_CASES.get(raw.tag) - if cls is None: - raise ValueError(f'unknown StorageConfig arm: {raw.tag!r}') - return cls(raw.payload) -StorageConfig.lift = staticmethod(_StorageConfig_lift) # type: ignore[attr-defined] - -class SaveLatestPolicy: - """When to update the "latest" snapshot pointer. The `trigger` arm -carries the id of an application-supplied callback that decides -per-invocation.""" - pass - -class SaveLatestPolicy_Message(SaveLatestPolicy, _WitVariantCase): - """After every message added to the conversation.""" - tag = 'message' - -class SaveLatestPolicy_Invocation(SaveLatestPolicy, _WitVariantCase): - """Once per invocation, after it completes.""" - tag = 'invocation' - -class SaveLatestPolicy_Trigger(SaveLatestPolicy, _WitVariantCase): - """Each invocation consults the named `snapshot-trigger-handler`. -The id identifies which handler to invoke.""" - tag = 'trigger' - -_SaveLatestPolicy_CASES: dict[str, type] = { - 'message': SaveLatestPolicy_Message, - 'invocation': SaveLatestPolicy_Invocation, - 'trigger': SaveLatestPolicy_Trigger, -} - -def _SaveLatestPolicy_lift(raw: _WitVariant) -> SaveLatestPolicy: - cls = _SaveLatestPolicy_CASES.get(raw.tag) - if cls is None: - raise ValueError(f'unknown SaveLatestPolicy arm: {raw.tag!r}') - return cls(raw.payload) -SaveLatestPolicy.lift = staticmethod(_SaveLatestPolicy_lift) # type: ignore[attr-defined] - -class SessionConfig: - """Session persistence configuration attached to an agent.""" - def __init__( - self, - *, - session_id: str, - storage: StorageConfig, - save_latest: Optional[SaveLatestPolicy], - ) -> None: - setattr(self, 'session-id', session_id) - setattr(self, 'storage', storage) - setattr(self, 'save-latest', save_latest) - - @property - def session_id(self) -> str: - return getattr(self, 'session-id') - - @property - def save_latest(self) -> Optional[SaveLatestPolicy]: - return getattr(self, 'save-latest') - - def __repr__(self) -> str: - return f'SessionConfig(session_id={getattr(self, 'session-id')!r}, storage={getattr(self, 'storage')!r}, save_latest={getattr(self, 'save-latest')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, SessionConfig): - return NotImplemented - return getattr(self, 'session-id') == getattr(other, 'session-id') and getattr(self, 'storage') == getattr(other, 'storage') and getattr(self, 'save-latest') == getattr(other, 'save-latest') - - def __hash__(self) -> int: - return id(self) - - -class SnapshotScope(str): - """Which kind of state a snapshot describes.""" - __slots__ = () - - AGENT: 'SnapshotScope' - MULTI_AGENT: 'SnapshotScope' - -SnapshotScope.AGENT = SnapshotScope('agent') # type: ignore[attr-defined] -SnapshotScope.MULTI_AGENT = SnapshotScope('multi-agent') # type: ignore[attr-defined] - - -class SnapshotLocation: - """Locator for a snapshot within the storage hierarchy.""" - def __init__( - self, - *, - session_id: str, - scope: SnapshotScope, - scope_id: str, - ) -> None: - setattr(self, 'session-id', session_id) - setattr(self, 'scope', scope) - setattr(self, 'scope-id', scope_id) - - @property - def session_id(self) -> str: - return getattr(self, 'session-id') - - @property - def scope_id(self) -> str: - return getattr(self, 'scope-id') - - def __repr__(self) -> str: - return f'SnapshotLocation(session_id={getattr(self, 'session-id')!r}, scope={getattr(self, 'scope')!r}, scope_id={getattr(self, 'scope-id')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, SnapshotLocation): - return NotImplemented - return getattr(self, 'session-id') == getattr(other, 'session-id') and getattr(self, 'scope') == getattr(other, 'scope') and getattr(self, 'scope-id') == getattr(other, 'scope-id') - - def __hash__(self) -> int: - return id(self) - - -class SlidingWindowState: - """Sliding-window conversation manager state at snapshot time.""" - def __init__( - self, - *, - removed_message_count: int, - ) -> None: - setattr(self, 'removed-message-count', removed_message_count) - - @property - def removed_message_count(self) -> int: - return getattr(self, 'removed-message-count') - - def __repr__(self) -> str: - return f'SlidingWindowState(removed_message_count={getattr(self, 'removed-message-count')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, SlidingWindowState): - return NotImplemented - return getattr(self, 'removed-message-count') == getattr(other, 'removed-message-count') - - def __hash__(self) -> int: - return id(self) - - -class SummarizingState: - """Summarizing conversation manager state at snapshot time.""" - def __init__( - self, - *, - summary_message: Optional[Message], - removed_message_count: int, - ) -> None: - setattr(self, 'summary-message', summary_message) - setattr(self, 'removed-message-count', removed_message_count) - - @property - def summary_message(self) -> Optional[Message]: - return getattr(self, 'summary-message') - - @property - def removed_message_count(self) -> int: - return getattr(self, 'removed-message-count') - - def __repr__(self) -> str: - return f'SummarizingState(summary_message={getattr(self, 'summary-message')!r}, removed_message_count={getattr(self, 'removed-message-count')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, SummarizingState): - return NotImplemented - return getattr(self, 'summary-message') == getattr(other, 'summary-message') and getattr(self, 'removed-message-count') == getattr(other, 'removed-message-count') - - def __hash__(self) -> int: - return id(self) - - -class ConversationManagerState: - """Conversation manager snapshot state. Which arm is populated depends -on the conversation manager the agent was built with.""" - pass - -class ConversationManagerState_None(ConversationManagerState, _WitVariantCase): - """No conversation manager or null manager; nothing to persist.""" - tag = 'none' - -class ConversationManagerState_SlidingWindow(ConversationManagerState, _WitVariantCase): - """Sliding-window manager state.""" - tag = 'sliding-window' - -class ConversationManagerState_Summarizing(ConversationManagerState, _WitVariantCase): - """Summarizing manager state.""" - tag = 'summarizing' - -_ConversationManagerState_CASES: dict[str, type] = { - 'none': ConversationManagerState_None, - 'sliding-window': ConversationManagerState_SlidingWindow, - 'summarizing': ConversationManagerState_Summarizing, -} - -def _ConversationManagerState_lift(raw: _WitVariant) -> ConversationManagerState: - cls = _ConversationManagerState_CASES.get(raw.tag) - if cls is None: - raise ValueError(f'unknown ConversationManagerState arm: {raw.tag!r}') - return cls(raw.payload) -ConversationManagerState.lift = staticmethod(_ConversationManagerState_lift) # type: ignore[attr-defined] - -class RetryStrategyState: - """Retry-strategy state at snapshot time.""" - def __init__( - self, - *, - attempts_used: int, - elapsed_ms: int, - ) -> None: - setattr(self, 'attempts-used', attempts_used) - setattr(self, 'elapsed-ms', elapsed_ms) - - @property - def attempts_used(self) -> int: - return getattr(self, 'attempts-used') - - @property - def elapsed_ms(self) -> int: - return getattr(self, 'elapsed-ms') - - def __repr__(self) -> str: - return f'RetryStrategyState(attempts_used={getattr(self, 'attempts-used')!r}, elapsed_ms={getattr(self, 'elapsed-ms')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, RetryStrategyState): - return NotImplemented - return getattr(self, 'attempts-used') == getattr(other, 'attempts-used') and getattr(self, 'elapsed-ms') == getattr(other, 'elapsed-ms') - - def __hash__(self) -> int: - return id(self) - - -class PluginStateEntry: - """Named plugin state. `data` is an opaque JSON object owned by the plugin.""" - def __init__( - self, - *, - plugin_name: str, - data: str, - ) -> None: - setattr(self, 'plugin-name', plugin_name) - setattr(self, 'data', data) - - @property - def plugin_name(self) -> str: - return getattr(self, 'plugin-name') - - def __repr__(self) -> str: - return f'PluginStateEntry(plugin_name={getattr(self, 'plugin-name')!r}, data={getattr(self, 'data')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, PluginStateEntry): - return NotImplemented - return getattr(self, 'plugin-name') == getattr(other, 'plugin-name') and getattr(self, 'data') == getattr(other, 'data') - - def __hash__(self) -> int: - return id(self) - - -class SnapshotData: - """Framework-owned snapshot state. All fields are optional because an -agent may not exercise every subsystem in a given run.""" - def __init__( - self, - *, - messages: list[Message], - conversation_manager: Optional[ConversationManagerState], - retry_strategy: Optional[RetryStrategyState], - model_state: Optional[str], - plugins: list[PluginStateEntry], - ) -> None: - setattr(self, 'messages', messages) - setattr(self, 'conversation-manager', conversation_manager) - setattr(self, 'retry-strategy', retry_strategy) - setattr(self, 'model-state', model_state) - setattr(self, 'plugins', plugins) - - @property - def conversation_manager(self) -> Optional[ConversationManagerState]: - return getattr(self, 'conversation-manager') - - @property - def retry_strategy(self) -> Optional[RetryStrategyState]: - return getattr(self, 'retry-strategy') - - @property - def model_state(self) -> Optional[str]: - return getattr(self, 'model-state') - - def __repr__(self) -> str: - return f'SnapshotData(messages={getattr(self, 'messages')!r}, conversation_manager={getattr(self, 'conversation-manager')!r}, retry_strategy={getattr(self, 'retry-strategy')!r}, model_state={getattr(self, 'model-state')!r}, plugins={getattr(self, 'plugins')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, SnapshotData): - return NotImplemented - return getattr(self, 'messages') == getattr(other, 'messages') and getattr(self, 'conversation-manager') == getattr(other, 'conversation-manager') and getattr(self, 'retry-strategy') == getattr(other, 'retry-strategy') and getattr(self, 'model-state') == getattr(other, 'model-state') and getattr(self, 'plugins') == getattr(other, 'plugins') - - def __hash__(self) -> int: - return id(self) - - -class Snapshot: - """Point-in-time capture of agent or orchestrator state.""" - def __init__( - self, - *, - scope: SnapshotScope, - schema_version: str, - created_at: Datetime, - data: SnapshotData, - app_data: str, - ) -> None: - setattr(self, 'scope', scope) - setattr(self, 'schema-version', schema_version) - setattr(self, 'created-at', created_at) - setattr(self, 'data', data) - setattr(self, 'app-data', app_data) - - @property - def schema_version(self) -> str: - return getattr(self, 'schema-version') - - @property - def created_at(self) -> Datetime: - return getattr(self, 'created-at') - - @property - def app_data(self) -> str: - return getattr(self, 'app-data') - - def __repr__(self) -> str: - return f'Snapshot(scope={getattr(self, 'scope')!r}, schema_version={getattr(self, 'schema-version')!r}, created_at={getattr(self, 'created-at')!r}, data={getattr(self, 'data')!r}, app_data={getattr(self, 'app-data')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, Snapshot): - return NotImplemented - return getattr(self, 'scope') == getattr(other, 'scope') and getattr(self, 'schema-version') == getattr(other, 'schema-version') and getattr(self, 'created-at') == getattr(other, 'created-at') and getattr(self, 'data') == getattr(other, 'data') and getattr(self, 'app-data') == getattr(other, 'app-data') - - def __hash__(self) -> int: - return id(self) - - -class SnapshotManifest: - """Metadata describing the snapshot manifest file.""" - def __init__( - self, - *, - schema_version: str, - updated_at: Datetime, - ) -> None: - setattr(self, 'schema-version', schema_version) - setattr(self, 'updated-at', updated_at) - - @property - def schema_version(self) -> str: - return getattr(self, 'schema-version') - - @property - def updated_at(self) -> Datetime: - return getattr(self, 'updated-at') - - def __repr__(self) -> str: - return f'SnapshotManifest(schema_version={getattr(self, 'schema-version')!r}, updated_at={getattr(self, 'updated-at')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, SnapshotManifest): - return NotImplemented - return getattr(self, 'schema-version') == getattr(other, 'schema-version') and getattr(self, 'updated-at') == getattr(other, 'updated-at') - - def __hash__(self) -> int: - return id(self) - - -class StorageError: - """Why a snapshot operation failed.""" - pass - -class StorageError_NotFound(StorageError, _WitVariantCase): - """No snapshot or manifest at the requested location.""" - tag = 'not-found' - -class StorageError_AccessDenied(StorageError, _WitVariantCase): - """Caller lacks permission to read or write the storage.""" - tag = 'access-denied' - -class StorageError_OutOfSpace(StorageError, _WitVariantCase): - """Backing storage is full or over quota.""" - tag = 'out-of-space' - -class StorageError_Corrupt(StorageError, _WitVariantCase): - """Snapshot is malformed or cannot be deserialized.""" - tag = 'corrupt' - -class StorageError_Conflict(StorageError, _WitVariantCase): - """Concurrent writers collided; retrying may succeed.""" - tag = 'conflict' - -class StorageError_Transient(StorageError, _WitVariantCase): - """Transient I/O failure; retrying may succeed.""" - tag = 'transient' - -class StorageError_Permanent(StorageError, _WitVariantCase): - """Permanent backend failure.""" - tag = 'permanent' - -class StorageError_UnknownBackend(StorageError, _WitVariantCase): - """No custom backend registered for the given backend-id.""" - tag = 'unknown-backend' - -_StorageError_CASES: dict[str, type] = { - 'not-found': StorageError_NotFound, - 'access-denied': StorageError_AccessDenied, - 'out-of-space': StorageError_OutOfSpace, - 'corrupt': StorageError_Corrupt, - 'conflict': StorageError_Conflict, - 'transient': StorageError_Transient, - 'permanent': StorageError_Permanent, - 'unknown-backend': StorageError_UnknownBackend, -} - -def _StorageError_lift(raw: _WitVariant) -> StorageError: - cls = _StorageError_CASES.get(raw.tag) - if cls is None: - raise ValueError(f'unknown StorageError arm: {raw.tag!r}') - return cls(raw.payload) -StorageError.lift = staticmethod(_StorageError_lift) # type: ignore[attr-defined] - -class SaveSnapshotArgs: - """Arguments for `save-snapshot`.""" - def __init__( - self, - *, - backend_id: str, - location: SnapshotLocation, - snapshot_id: str, - is_latest: bool, - snapshot: Snapshot, - ) -> None: - setattr(self, 'backend-id', backend_id) - setattr(self, 'location', location) - setattr(self, 'snapshot-id', snapshot_id) - setattr(self, 'is-latest', is_latest) - setattr(self, 'snapshot', snapshot) - - @property - def backend_id(self) -> str: - return getattr(self, 'backend-id') - - @property - def snapshot_id(self) -> str: - return getattr(self, 'snapshot-id') - - @property - def is_latest(self) -> bool: - return getattr(self, 'is-latest') - - def __repr__(self) -> str: - return f'SaveSnapshotArgs(backend_id={getattr(self, 'backend-id')!r}, location={getattr(self, 'location')!r}, snapshot_id={getattr(self, 'snapshot-id')!r}, is_latest={getattr(self, 'is-latest')!r}, snapshot={getattr(self, 'snapshot')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, SaveSnapshotArgs): - return NotImplemented - return getattr(self, 'backend-id') == getattr(other, 'backend-id') and getattr(self, 'location') == getattr(other, 'location') and getattr(self, 'snapshot-id') == getattr(other, 'snapshot-id') and getattr(self, 'is-latest') == getattr(other, 'is-latest') and getattr(self, 'snapshot') == getattr(other, 'snapshot') - - def __hash__(self) -> int: - return id(self) - - -class LoadSnapshotArgs: - """Arguments for `load-snapshot`.""" - def __init__( - self, - *, - backend_id: str, - location: SnapshotLocation, - snapshot_id: Optional[str], - ) -> None: - setattr(self, 'backend-id', backend_id) - setattr(self, 'location', location) - setattr(self, 'snapshot-id', snapshot_id) - - @property - def backend_id(self) -> str: - return getattr(self, 'backend-id') - - @property - def snapshot_id(self) -> Optional[str]: - return getattr(self, 'snapshot-id') - - def __repr__(self) -> str: - return f'LoadSnapshotArgs(backend_id={getattr(self, 'backend-id')!r}, location={getattr(self, 'location')!r}, snapshot_id={getattr(self, 'snapshot-id')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, LoadSnapshotArgs): - return NotImplemented - return getattr(self, 'backend-id') == getattr(other, 'backend-id') and getattr(self, 'location') == getattr(other, 'location') and getattr(self, 'snapshot-id') == getattr(other, 'snapshot-id') - - def __hash__(self) -> int: - return id(self) - - -class ListSnapshotIdsArgs: - """Arguments for `list-snapshot-ids`.""" - def __init__( - self, - *, - backend_id: str, - location: SnapshotLocation, - limit: Optional[int], - start_after: Optional[str], - ) -> None: - setattr(self, 'backend-id', backend_id) - setattr(self, 'location', location) - setattr(self, 'limit', limit) - setattr(self, 'start-after', start_after) - - @property - def backend_id(self) -> str: - return getattr(self, 'backend-id') - - @property - def start_after(self) -> Optional[str]: - return getattr(self, 'start-after') - - def __repr__(self) -> str: - return f'ListSnapshotIdsArgs(backend_id={getattr(self, 'backend-id')!r}, location={getattr(self, 'location')!r}, limit={getattr(self, 'limit')!r}, start_after={getattr(self, 'start-after')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, ListSnapshotIdsArgs): - return NotImplemented - return getattr(self, 'backend-id') == getattr(other, 'backend-id') and getattr(self, 'location') == getattr(other, 'location') and getattr(self, 'limit') == getattr(other, 'limit') and getattr(self, 'start-after') == getattr(other, 'start-after') - - def __hash__(self) -> int: - return id(self) - - -class DeleteSessionArgs: - """Arguments for `delete-session`.""" - def __init__( - self, - *, - backend_id: str, - session_id: str, - ) -> None: - setattr(self, 'backend-id', backend_id) - setattr(self, 'session-id', session_id) - - @property - def backend_id(self) -> str: - return getattr(self, 'backend-id') - - @property - def session_id(self) -> str: - return getattr(self, 'session-id') - - def __repr__(self) -> str: - return f'DeleteSessionArgs(backend_id={getattr(self, 'backend-id')!r}, session_id={getattr(self, 'session-id')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, DeleteSessionArgs): - return NotImplemented - return getattr(self, 'backend-id') == getattr(other, 'backend-id') and getattr(self, 'session-id') == getattr(other, 'session-id') - - def __hash__(self) -> int: - return id(self) - - -class ManifestArgs: - """Arguments for `load-manifest` / `save-manifest`.""" - def __init__( - self, - *, - backend_id: str, - location: SnapshotLocation, - ) -> None: - setattr(self, 'backend-id', backend_id) - setattr(self, 'location', location) - - @property - def backend_id(self) -> str: - return getattr(self, 'backend-id') - - def __repr__(self) -> str: - return f'ManifestArgs(backend_id={getattr(self, 'backend-id')!r}, location={getattr(self, 'location')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, ManifestArgs): - return NotImplemented - return getattr(self, 'backend-id') == getattr(other, 'backend-id') and getattr(self, 'location') == getattr(other, 'location') - - def __hash__(self) -> int: - return id(self) - - -class SaveManifestArgs: - """Arguments for `save-manifest`.""" - def __init__( - self, - *, - backend_id: str, - location: SnapshotLocation, - manifest: SnapshotManifest, - ) -> None: - setattr(self, 'backend-id', backend_id) - setattr(self, 'location', location) - setattr(self, 'manifest', manifest) - - @property - def backend_id(self) -> str: - return getattr(self, 'backend-id') - - def __repr__(self) -> str: - return f'SaveManifestArgs(backend_id={getattr(self, 'backend-id')!r}, location={getattr(self, 'location')!r}, manifest={getattr(self, 'manifest')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, SaveManifestArgs): - return NotImplemented - return getattr(self, 'backend-id') == getattr(other, 'backend-id') and getattr(self, 'location') == getattr(other, 'location') and getattr(self, 'manifest') == getattr(other, 'manifest') - - def __hash__(self) -> int: - return id(self) - - -class TriggerParams: - """Context passed to the trigger on each call.""" - def __init__( - self, - *, - trigger_id: str, - message_count: int, - last_message: Optional[Message], - ) -> None: - setattr(self, 'trigger-id', trigger_id) - setattr(self, 'message-count', message_count) - setattr(self, 'last-message', last_message) - - @property - def trigger_id(self) -> str: - return getattr(self, 'trigger-id') - - @property - def message_count(self) -> int: - return getattr(self, 'message-count') - - @property - def last_message(self) -> Optional[Message]: - return getattr(self, 'last-message') - - def __repr__(self) -> str: - return f'TriggerParams(trigger_id={getattr(self, 'trigger-id')!r}, message_count={getattr(self, 'message-count')!r}, last_message={getattr(self, 'last-message')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, TriggerParams): - return NotImplemented - return getattr(self, 'trigger-id') == getattr(other, 'trigger-id') and getattr(self, 'message-count') == getattr(other, 'message-count') and getattr(self, 'last-message') == getattr(other, 'last-message') - - def __hash__(self) -> int: - return id(self) - - -class TriggerError: - """Why a trigger evaluation failed.""" - pass - -class TriggerError_Unknown(TriggerError, _WitVariantCase): - """No trigger registered for the given id.""" - tag = 'unknown' - -class TriggerError_Failed(TriggerError, _WitVariantCase): - """Trigger raised an exception.""" - tag = 'failed' - -_TriggerError_CASES: dict[str, type] = { - 'unknown': TriggerError_Unknown, - 'failed': TriggerError_Failed, -} - -def _TriggerError_lift(raw: _WitVariant) -> TriggerError: - cls = _TriggerError_CASES.get(raw.tag) - if cls is None: - raise ValueError(f'unknown TriggerError arm: {raw.tag!r}') - return cls(raw.payload) -TriggerError.lift = staticmethod(_TriggerError_lift) # type: ignore[attr-defined] - -class ToolSpec: - """Declaration of a tool the model can call.""" - def __init__( - self, - *, - name: str, - description: str, - input_schema: str, - ) -> None: - setattr(self, 'name', name) - setattr(self, 'description', description) - setattr(self, 'input-schema', input_schema) - - @property - def input_schema(self) -> str: - return getattr(self, 'input-schema') - - def __repr__(self) -> str: - return f'ToolSpec(name={getattr(self, 'name')!r}, description={getattr(self, 'description')!r}, input_schema={getattr(self, 'input-schema')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, ToolSpec): - return NotImplemented - return getattr(self, 'name') == getattr(other, 'name') and getattr(self, 'description') == getattr(other, 'description') and getattr(self, 'input-schema') == getattr(other, 'input-schema') - - def __hash__(self) -> int: - return id(self) - - -class AgentAsToolConfig: - """Wrap a configured agent as a tool callable by the parent agent. The -child agent is instantiated at registration time.""" - def __init__( - self, - *, - name: Optional[str], - description: Optional[str], - preserve_context: bool, - agent_config: str, - ) -> None: - setattr(self, 'name', name) - setattr(self, 'description', description) - setattr(self, 'preserve-context', preserve_context) - setattr(self, 'agent-config', agent_config) - - @property - def preserve_context(self) -> bool: - return getattr(self, 'preserve-context') - - @property - def agent_config(self) -> str: - return getattr(self, 'agent-config') - - def __repr__(self) -> str: - return f'AgentAsToolConfig(name={getattr(self, 'name')!r}, description={getattr(self, 'description')!r}, preserve_context={getattr(self, 'preserve-context')!r}, agent_config={getattr(self, 'agent-config')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, AgentAsToolConfig): - return NotImplemented - return getattr(self, 'name') == getattr(other, 'name') and getattr(self, 'description') == getattr(other, 'description') and getattr(self, 'preserve-context') == getattr(other, 'preserve-context') and getattr(self, 'agent-config') == getattr(other, 'agent-config') - - def __hash__(self) -> int: - return id(self) - - -class CallToolArgs: - """Arguments for a single tool call.""" - def __init__( - self, - *, - name: str, - input: str, - tool_use_id: str, - ) -> None: - setattr(self, 'name', name) - setattr(self, 'input', input) - setattr(self, 'tool-use-id', tool_use_id) - - @property - def tool_use_id(self) -> str: - return getattr(self, 'tool-use-id') - - def __repr__(self) -> str: - return f'CallToolArgs(name={getattr(self, 'name')!r}, input={getattr(self, 'input')!r}, tool_use_id={getattr(self, 'tool-use-id')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, CallToolArgs): - return NotImplemented - return getattr(self, 'name') == getattr(other, 'name') and getattr(self, 'input') == getattr(other, 'input') and getattr(self, 'tool-use-id') == getattr(other, 'tool-use-id') - - def __hash__(self) -> int: - return id(self) - - -class ToolChoice: - """Policy controlling whether and how the model calls tools on the next -generation step.""" - pass - -class ToolChoice_Auto(ToolChoice, _WitVariantCase): - """Model decides whether to call a tool.""" - tag = 'auto' - -class ToolChoice_Any(ToolChoice, _WitVariantCase): - """Model must call at least one tool.""" - tag = 'any' - -class ToolChoice_Named(ToolChoice, _WitVariantCase): - """Model must call the tool with this name.""" - tag = 'named' - -_ToolChoice_CASES: dict[str, type] = { - 'auto': ToolChoice_Auto, - 'any': ToolChoice_Any, - 'named': ToolChoice_Named, -} - -def _ToolChoice_lift(raw: _WitVariant) -> ToolChoice: - cls = _ToolChoice_CASES.get(raw.tag) - if cls is None: - raise ValueError(f'unknown ToolChoice arm: {raw.tag!r}') - return cls(raw.payload) -ToolChoice.lift = staticmethod(_ToolChoice_lift) # type: ignore[attr-defined] - -class ToolEventStream: - """Pull-based stream of tool events. Sync-WIT placeholder for -`stream`.""" - # Wraps a wasmtime-py ResourceAny / ResourceHost handle. - # The runtime sets ._handle to the underlying resource and - # ._invoke to a callable that dispatches a method by WIT name. - - def __init__(self, handle: Any, invoke: Any = None) -> None: - self._handle = handle - self._invoke = invoke - - def read(self) -> Optional[ToolStreamEvent]: - return self._invoke('[method]tool-event-stream.read', (self._handle,)) - - -class ToolError: - """Why a tool call failed.""" - pass - -class ToolError_Unknown(ToolError, _WitVariantCase): - """No tool registered under the given name.""" - tag = 'unknown' - -class ToolError_InvalidInput(ToolError, _WitVariantCase): - """Tool input didn't match the declared input schema.""" - tag = 'invalid-input' - -class ToolError_ExecutionFailed(ToolError, _WitVariantCase): - """Tool ran but returned an error result.""" - tag = 'execution-failed' - -class ToolError_TimedOut(ToolError, _WitVariantCase): - """Tool exceeded its time budget.""" - tag = 'timed-out' - -class ToolError_Cancelled(ToolError, _WitVariantCase): - """Tool was cancelled before completion.""" - tag = 'cancelled' - -class ToolError_Internal(ToolError, _WitVariantCase): - """Catch-all for internal failures.""" - tag = 'internal' - -_ToolError_CASES: dict[str, type] = { - 'unknown': ToolError_Unknown, - 'invalid-input': ToolError_InvalidInput, - 'execution-failed': ToolError_ExecutionFailed, - 'timed-out': ToolError_TimedOut, - 'cancelled': ToolError_Cancelled, - 'internal': ToolError_Internal, -} - -def _ToolError_lift(raw: _WitVariant) -> ToolError: - cls = _ToolError_CASES.get(raw.tag) - if cls is None: - raise ValueError(f'unknown ToolError arm: {raw.tag!r}') - return cls(raw.payload) -ToolError.lift = staticmethod(_ToolError_lift) # type: ignore[attr-defined] - -ToolStreamEvent = str | list[ToolResultContent] | ToolError -"""Incremental event emitted by a streaming tool while running.""" - -class McpConnectionState(str): - """Connection state of an MCP client.""" - __slots__ = () - - DISCONNECTED: 'McpConnectionState' - CONNECTED: 'McpConnectionState' - FAILED: 'McpConnectionState' - -McpConnectionState.DISCONNECTED = McpConnectionState('disconnected') # type: ignore[attr-defined] -McpConnectionState.CONNECTED = McpConnectionState('connected') # type: ignore[attr-defined] -McpConnectionState.FAILED = McpConnectionState('failed') # type: ignore[attr-defined] - - -class EnvVar: - """Single environment variable entry.""" - def __init__( - self, - *, - key: str, - value: str, - ) -> None: - setattr(self, 'key', key) - setattr(self, 'value', value) - - def __repr__(self) -> str: - return f'EnvVar(key={getattr(self, 'key')!r}, value={getattr(self, 'value')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, EnvVar): - return NotImplemented - return getattr(self, 'key') == getattr(other, 'key') and getattr(self, 'value') == getattr(other, 'value') - - def __hash__(self) -> int: - return id(self) - - -class StdioTransportConfig: - """STDIO transport configuration.""" - def __init__( - self, - *, - command: str, - args: list[str], - env: list[EnvVar], - cwd: Optional[str], - ) -> None: - setattr(self, 'command', command) - setattr(self, 'args', args) - setattr(self, 'env', env) - setattr(self, 'cwd', cwd) - - def __repr__(self) -> str: - return f'StdioTransportConfig(command={getattr(self, 'command')!r}, args={getattr(self, 'args')!r}, env={getattr(self, 'env')!r}, cwd={getattr(self, 'cwd')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, StdioTransportConfig): - return NotImplemented - return getattr(self, 'command') == getattr(other, 'command') and getattr(self, 'args') == getattr(other, 'args') and getattr(self, 'env') == getattr(other, 'env') and getattr(self, 'cwd') == getattr(other, 'cwd') - - def __hash__(self) -> int: - return id(self) - - -class HttpHeader: - """Single HTTP header entry.""" - def __init__( - self, - *, - name: str, - value: str, - ) -> None: - setattr(self, 'name', name) - setattr(self, 'value', value) - - def __repr__(self) -> str: - return f'HttpHeader(name={getattr(self, 'name')!r}, value={getattr(self, 'value')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, HttpHeader): - return NotImplemented - return getattr(self, 'name') == getattr(other, 'name') and getattr(self, 'value') == getattr(other, 'value') - - def __hash__(self) -> int: - return id(self) - - -class HttpTransportConfig: - """HTTP transport configuration.""" - def __init__( - self, - *, - url: str, - headers: list[HttpHeader], - ) -> None: - setattr(self, 'url', url) - setattr(self, 'headers', headers) - - def __repr__(self) -> str: - return f'HttpTransportConfig(url={getattr(self, 'url')!r}, headers={getattr(self, 'headers')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, HttpTransportConfig): - return NotImplemented - return getattr(self, 'url') == getattr(other, 'url') and getattr(self, 'headers') == getattr(other, 'headers') - - def __hash__(self) -> int: - return id(self) - - -class SseTransportConfig: - """SSE transport configuration.""" - def __init__( - self, - *, - url: str, - headers: list[HttpHeader], - ) -> None: - setattr(self, 'url', url) - setattr(self, 'headers', headers) - - def __repr__(self) -> str: - return f'SseTransportConfig(url={getattr(self, 'url')!r}, headers={getattr(self, 'headers')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, SseTransportConfig): - return NotImplemented - return getattr(self, 'url') == getattr(other, 'url') and getattr(self, 'headers') == getattr(other, 'headers') - - def __hash__(self) -> int: - return id(self) - - -class McpTransport: - """How the client talks to the MCP server.""" - pass - -class McpTransport_Stdio(McpTransport, _WitVariantCase): - """STDIO transport. Spawn a local process and talk via pipes.""" - tag = 'stdio' - -class McpTransport_StreamableHttp(McpTransport, _WitVariantCase): - """Streamable HTTP transport, per the current MCP specification.""" - tag = 'streamable-http' - -class McpTransport_Sse(McpTransport, _WitVariantCase): - """Legacy Server-Sent Events transport. Retained for older servers.""" - tag = 'sse' - -_McpTransport_CASES: dict[str, type] = { - 'stdio': McpTransport_Stdio, - 'streamable-http': McpTransport_StreamableHttp, - 'sse': McpTransport_Sse, -} - -def _McpTransport_lift(raw: _WitVariant) -> McpTransport: - cls = _McpTransport_CASES.get(raw.tag) - if cls is None: - raise ValueError(f'unknown McpTransport arm: {raw.tag!r}') - return cls(raw.payload) -McpTransport.lift = staticmethod(_McpTransport_lift) # type: ignore[attr-defined] - -class TasksConfig: - """Task-augmented tool execution. Enables long-running tools with -progress tracking. Experimental in the MCP specification.""" - def __init__( - self, - *, - ttl: int, - poll_timeout: int, - ) -> None: - setattr(self, 'ttl', ttl) - setattr(self, 'poll-timeout', poll_timeout) - - @property - def poll_timeout(self) -> int: - return getattr(self, 'poll-timeout') - - def __repr__(self) -> str: - return f'TasksConfig(ttl={getattr(self, 'ttl')!r}, poll_timeout={getattr(self, 'poll-timeout')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, TasksConfig): - return NotImplemented - return getattr(self, 'ttl') == getattr(other, 'ttl') and getattr(self, 'poll-timeout') == getattr(other, 'poll-timeout') - - def __hash__(self) -> int: - return id(self) - - -class McpClientConfig: - """MCP client configuration.""" - def __init__( - self, - *, - client_id: str, - application_name: Optional[str], - application_version: Optional[str], - transport: McpTransport, - tasks_config: Optional[TasksConfig], - elicitation_enabled: bool, - fail_open: bool, - disable_instrumentation: bool, - ) -> None: - setattr(self, 'client-id', client_id) - setattr(self, 'application-name', application_name) - setattr(self, 'application-version', application_version) - setattr(self, 'transport', transport) - setattr(self, 'tasks-config', tasks_config) - setattr(self, 'elicitation-enabled', elicitation_enabled) - setattr(self, 'fail-open', fail_open) - setattr(self, 'disable-instrumentation', disable_instrumentation) - - @property - def client_id(self) -> str: - return getattr(self, 'client-id') - - @property - def application_name(self) -> Optional[str]: - return getattr(self, 'application-name') - - @property - def application_version(self) -> Optional[str]: - return getattr(self, 'application-version') - - @property - def tasks_config(self) -> Optional[TasksConfig]: - return getattr(self, 'tasks-config') - - @property - def elicitation_enabled(self) -> bool: - return getattr(self, 'elicitation-enabled') - - @property - def fail_open(self) -> bool: - return getattr(self, 'fail-open') - - @property - def disable_instrumentation(self) -> bool: - return getattr(self, 'disable-instrumentation') - - def __repr__(self) -> str: - return f'McpClientConfig(client_id={getattr(self, 'client-id')!r}, application_name={getattr(self, 'application-name')!r}, application_version={getattr(self, 'application-version')!r}, transport={getattr(self, 'transport')!r}, tasks_config={getattr(self, 'tasks-config')!r}, elicitation_enabled={getattr(self, 'elicitation-enabled')!r}, fail_open={getattr(self, 'fail-open')!r}, disable_instrumentation={getattr(self, 'disable-instrumentation')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, McpClientConfig): - return NotImplemented - return getattr(self, 'client-id') == getattr(other, 'client-id') and getattr(self, 'application-name') == getattr(other, 'application-name') and getattr(self, 'application-version') == getattr(other, 'application-version') and getattr(self, 'transport') == getattr(other, 'transport') and getattr(self, 'tasks-config') == getattr(other, 'tasks-config') and getattr(self, 'elicitation-enabled') == getattr(other, 'elicitation-enabled') and getattr(self, 'fail-open') == getattr(other, 'fail-open') and getattr(self, 'disable-instrumentation') == getattr(other, 'disable-instrumentation') - - def __hash__(self) -> int: - return id(self) - - -class Interrupt: - """Human-in-the-loop interrupt raised by a tool or hook.""" - def __init__( - self, - *, - id: str, - name: str, - reason: Optional[str], - ) -> None: - setattr(self, 'id', id) - setattr(self, 'name', name) - setattr(self, 'reason', reason) - - def __repr__(self) -> str: - return f'Interrupt(id={getattr(self, 'id')!r}, name={getattr(self, 'name')!r}, reason={getattr(self, 'reason')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, Interrupt): - return NotImplemented - return getattr(self, 'id') == getattr(other, 'id') and getattr(self, 'name') == getattr(other, 'name') and getattr(self, 'reason') == getattr(other, 'reason') - - def __hash__(self) -> int: - return id(self) - - -class StopReason(str): - """Why the model stopped generating.""" - __slots__ = () - - END_TURN: 'StopReason' - TOOL_USE: 'StopReason' - MAX_TOKENS: 'StopReason' - ERROR: 'StopReason' - CONTENT_FILTERED: 'StopReason' - GUARDRAIL_INTERVENED: 'StopReason' - STOP_SEQUENCE: 'StopReason' - MODEL_CONTEXT_WINDOW_EXCEEDED: 'StopReason' - CANCELLED: 'StopReason' - -StopReason.END_TURN = StopReason('end-turn') # type: ignore[attr-defined] -StopReason.TOOL_USE = StopReason('tool-use') # type: ignore[attr-defined] -StopReason.MAX_TOKENS = StopReason('max-tokens') # type: ignore[attr-defined] -StopReason.ERROR = StopReason('error') # type: ignore[attr-defined] -StopReason.CONTENT_FILTERED = StopReason('content-filtered') # type: ignore[attr-defined] -StopReason.GUARDRAIL_INTERVENED = StopReason('guardrail-intervened') # type: ignore[attr-defined] -StopReason.STOP_SEQUENCE = StopReason('stop-sequence') # type: ignore[attr-defined] -StopReason.MODEL_CONTEXT_WINDOW_EXCEEDED = StopReason('model-context-window-exceeded') # type: ignore[attr-defined] -StopReason.CANCELLED = StopReason('cancelled') # type: ignore[attr-defined] - - -class MetadataEvent: - """Usage and metrics accumulated so far.""" - def __init__( - self, - *, - usage: Optional[Usage], - metrics: Optional[Metrics], - ) -> None: - setattr(self, 'usage', usage) - setattr(self, 'metrics', metrics) - - def __repr__(self) -> str: - return f'MetadataEvent(usage={getattr(self, 'usage')!r}, metrics={getattr(self, 'metrics')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, MetadataEvent): - return NotImplemented - return getattr(self, 'usage') == getattr(other, 'usage') and getattr(self, 'metrics') == getattr(other, 'metrics') - - def __hash__(self) -> int: - return id(self) - - -class TraceMetadataEntry: - """Single key-value pair attached to a trace. Values are string-typed -to keep traces compact; structured payloads belong on `message`.""" - def __init__( - self, - *, - key: str, - value: str, - ) -> None: - setattr(self, 'key', key) - setattr(self, 'value', value) - - def __repr__(self) -> str: - return f'TraceMetadataEntry(key={getattr(self, 'key')!r}, value={getattr(self, 'value')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, TraceMetadataEntry): - return NotImplemented - return getattr(self, 'key') == getattr(other, 'key') and getattr(self, 'value') == getattr(other, 'value') - - def __hash__(self) -> int: - return id(self) - - -class AgentTrace: - """In-memory trace node. Returned flat; reconstruct the tree via `parent-id`.""" - def __init__( - self, - *, - id: str, - name: str, - parent_id: Optional[str], - start_time_ms: int, - end_time_ms: Optional[int], - duration_ms: int, - metadata: list[TraceMetadataEntry], - message: Optional[Message], - ) -> None: - setattr(self, 'id', id) - setattr(self, 'name', name) - setattr(self, 'parent-id', parent_id) - setattr(self, 'start-time-ms', start_time_ms) - setattr(self, 'end-time-ms', end_time_ms) - setattr(self, 'duration-ms', duration_ms) - setattr(self, 'metadata', metadata) - setattr(self, 'message', message) - - @property - def parent_id(self) -> Optional[str]: - return getattr(self, 'parent-id') - - @property - def start_time_ms(self) -> int: - return getattr(self, 'start-time-ms') - - @property - def end_time_ms(self) -> Optional[int]: - return getattr(self, 'end-time-ms') - - @property - def duration_ms(self) -> int: - return getattr(self, 'duration-ms') - - def __repr__(self) -> str: - return f'AgentTrace(id={getattr(self, 'id')!r}, name={getattr(self, 'name')!r}, parent_id={getattr(self, 'parent-id')!r}, start_time_ms={getattr(self, 'start-time-ms')!r}, end_time_ms={getattr(self, 'end-time-ms')!r}, duration_ms={getattr(self, 'duration-ms')!r}, metadata={getattr(self, 'metadata')!r}, message={getattr(self, 'message')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, AgentTrace): - return NotImplemented - return getattr(self, 'id') == getattr(other, 'id') and getattr(self, 'name') == getattr(other, 'name') and getattr(self, 'parent-id') == getattr(other, 'parent-id') and getattr(self, 'start-time-ms') == getattr(other, 'start-time-ms') and getattr(self, 'end-time-ms') == getattr(other, 'end-time-ms') and getattr(self, 'duration-ms') == getattr(other, 'duration-ms') and getattr(self, 'metadata') == getattr(other, 'metadata') and getattr(self, 'message') == getattr(other, 'message') - - def __hash__(self) -> int: - return id(self) - - -class ToolMetrics: - """Per-tool execution metrics keyed by tool name in `agent-metrics`.""" - def __init__( - self, - *, - tool_name: str, - call_count: int, - success_count: int, - error_count: int, - total_time_ms: int, - ) -> None: - setattr(self, 'tool-name', tool_name) - setattr(self, 'call-count', call_count) - setattr(self, 'success-count', success_count) - setattr(self, 'error-count', error_count) - setattr(self, 'total-time-ms', total_time_ms) - - @property - def tool_name(self) -> str: - return getattr(self, 'tool-name') - - @property - def call_count(self) -> int: - return getattr(self, 'call-count') - - @property - def success_count(self) -> int: - return getattr(self, 'success-count') - - @property - def error_count(self) -> int: - return getattr(self, 'error-count') - - @property - def total_time_ms(self) -> int: - return getattr(self, 'total-time-ms') - - def __repr__(self) -> str: - return f'ToolMetrics(tool_name={getattr(self, 'tool-name')!r}, call_count={getattr(self, 'call-count')!r}, success_count={getattr(self, 'success-count')!r}, error_count={getattr(self, 'error-count')!r}, total_time_ms={getattr(self, 'total-time-ms')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, ToolMetrics): - return NotImplemented - return getattr(self, 'tool-name') == getattr(other, 'tool-name') and getattr(self, 'call-count') == getattr(other, 'call-count') and getattr(self, 'success-count') == getattr(other, 'success-count') and getattr(self, 'error-count') == getattr(other, 'error-count') and getattr(self, 'total-time-ms') == getattr(other, 'total-time-ms') - - def __hash__(self) -> int: - return id(self) - - -class InvocationMetrics: - """Per-invocation metrics. Cycles are flattened into `agent-metrics.cycles` -and linked back via `invocation-id`.""" - def __init__( - self, - *, - invocation_id: str, - usage: Usage, - ) -> None: - setattr(self, 'invocation-id', invocation_id) - setattr(self, 'usage', usage) - - @property - def invocation_id(self) -> str: - return getattr(self, 'invocation-id') - - def __repr__(self) -> str: - return f'InvocationMetrics(invocation_id={getattr(self, 'invocation-id')!r}, usage={getattr(self, 'usage')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, InvocationMetrics): - return NotImplemented - return getattr(self, 'invocation-id') == getattr(other, 'invocation-id') and getattr(self, 'usage') == getattr(other, 'usage') - - def __hash__(self) -> int: - return id(self) - - -class AgentLoopMetrics: - """Per-cycle usage tracking.""" - def __init__( - self, - *, - cycle_id: str, - invocation_id: str, - duration_ms: int, - usage: Usage, - ) -> None: - setattr(self, 'cycle-id', cycle_id) - setattr(self, 'invocation-id', invocation_id) - setattr(self, 'duration-ms', duration_ms) - setattr(self, 'usage', usage) - - @property - def cycle_id(self) -> str: - return getattr(self, 'cycle-id') - - @property - def invocation_id(self) -> str: - return getattr(self, 'invocation-id') - - @property - def duration_ms(self) -> int: - return getattr(self, 'duration-ms') - - def __repr__(self) -> str: - return f'AgentLoopMetrics(cycle_id={getattr(self, 'cycle-id')!r}, invocation_id={getattr(self, 'invocation-id')!r}, duration_ms={getattr(self, 'duration-ms')!r}, usage={getattr(self, 'usage')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, AgentLoopMetrics): - return NotImplemented - return getattr(self, 'cycle-id') == getattr(other, 'cycle-id') and getattr(self, 'invocation-id') == getattr(other, 'invocation-id') and getattr(self, 'duration-ms') == getattr(other, 'duration-ms') and getattr(self, 'usage') == getattr(other, 'usage') - - def __hash__(self) -> int: - return id(self) - - -class AgentMetrics: - """Snapshot of agent metrics. Returned by `agent.get-metrics`.""" - def __init__( - self, - *, - cycle_count: int, - accumulated_usage: Usage, - accumulated_metrics: Metrics, - invocations: list[InvocationMetrics], - cycles: list[AgentLoopMetrics], - tool_metrics: list[ToolMetrics], - latest_context_size: Optional[int], - projected_context_size: Optional[int], - ) -> None: - setattr(self, 'cycle-count', cycle_count) - setattr(self, 'accumulated-usage', accumulated_usage) - setattr(self, 'accumulated-metrics', accumulated_metrics) - setattr(self, 'invocations', invocations) - setattr(self, 'cycles', cycles) - setattr(self, 'tool-metrics', tool_metrics) - setattr(self, 'latest-context-size', latest_context_size) - setattr(self, 'projected-context-size', projected_context_size) - - @property - def cycle_count(self) -> int: - return getattr(self, 'cycle-count') - - @property - def accumulated_usage(self) -> Usage: - return getattr(self, 'accumulated-usage') - - @property - def accumulated_metrics(self) -> Metrics: - return getattr(self, 'accumulated-metrics') - - @property - def tool_metrics(self) -> list[ToolMetrics]: - return getattr(self, 'tool-metrics') - - @property - def latest_context_size(self) -> Optional[int]: - return getattr(self, 'latest-context-size') - - @property - def projected_context_size(self) -> Optional[int]: - return getattr(self, 'projected-context-size') - - def __repr__(self) -> str: - return f'AgentMetrics(cycle_count={getattr(self, 'cycle-count')!r}, accumulated_usage={getattr(self, 'accumulated-usage')!r}, accumulated_metrics={getattr(self, 'accumulated-metrics')!r}, invocations={getattr(self, 'invocations')!r}, cycles={getattr(self, 'cycles')!r}, tool_metrics={getattr(self, 'tool-metrics')!r}, latest_context_size={getattr(self, 'latest-context-size')!r}, projected_context_size={getattr(self, 'projected-context-size')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, AgentMetrics): - return NotImplemented - return getattr(self, 'cycle-count') == getattr(other, 'cycle-count') and getattr(self, 'accumulated-usage') == getattr(other, 'accumulated-usage') and getattr(self, 'accumulated-metrics') == getattr(other, 'accumulated-metrics') and getattr(self, 'invocations') == getattr(other, 'invocations') and getattr(self, 'cycles') == getattr(other, 'cycles') and getattr(self, 'tool-metrics') == getattr(other, 'tool-metrics') and getattr(self, 'latest-context-size') == getattr(other, 'latest-context-size') and getattr(self, 'projected-context-size') == getattr(other, 'projected-context-size') - - def __hash__(self) -> int: - return id(self) - - -class ToolUseData: - """Mutable tool-use descriptor. `before-tool-call` hooks may rewrite fields.""" - def __init__( - self, - *, - name: str, - tool_use_id: str, - input: str, - ) -> None: - setattr(self, 'name', name) - setattr(self, 'tool-use-id', tool_use_id) - setattr(self, 'input', input) - - @property - def tool_use_id(self) -> str: - return getattr(self, 'tool-use-id') - - def __repr__(self) -> str: - return f'ToolUseData(name={getattr(self, 'name')!r}, tool_use_id={getattr(self, 'tool-use-id')!r}, input={getattr(self, 'input')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, ToolUseData): - return NotImplemented - return getattr(self, 'name') == getattr(other, 'name') and getattr(self, 'tool-use-id') == getattr(other, 'tool-use-id') and getattr(self, 'input') == getattr(other, 'input') - - def __hash__(self) -> int: - return id(self) - - -class HookRedaction: - """Redaction information when guardrails block content.""" - def __init__( - self, - *, - user_message: str, - ) -> None: - setattr(self, 'user-message', user_message) - - @property - def user_message(self) -> str: - return getattr(self, 'user-message') - - def __repr__(self) -> str: - return f'HookRedaction(user_message={getattr(self, 'user-message')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, HookRedaction): - return NotImplemented - return getattr(self, 'user-message') == getattr(other, 'user-message') - - def __hash__(self) -> int: - return id(self) - - -class ModelStopData: - """Model response surfaced on `after-model-call`.""" - def __init__( - self, - *, - message: Message, - stop_reason: StopReason, - redaction: Optional[HookRedaction], - ) -> None: - setattr(self, 'message', message) - setattr(self, 'stop-reason', stop_reason) - setattr(self, 'redaction', redaction) - - @property - def stop_reason(self) -> StopReason: - return getattr(self, 'stop-reason') - - def __repr__(self) -> str: - return f'ModelStopData(message={getattr(self, 'message')!r}, stop_reason={getattr(self, 'stop-reason')!r}, redaction={getattr(self, 'redaction')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, ModelStopData): - return NotImplemented - return getattr(self, 'message') == getattr(other, 'message') and getattr(self, 'stop-reason') == getattr(other, 'stop-reason') and getattr(self, 'redaction') == getattr(other, 'redaction') - - def __hash__(self) -> int: - return id(self) - - -class BeforeInvocationData: - """Payload for `before-invocation`.""" - def __init__( - self, - *, - invocation_state: str, - ) -> None: - setattr(self, 'invocation-state', invocation_state) - - @property - def invocation_state(self) -> str: - return getattr(self, 'invocation-state') - - def __repr__(self) -> str: - return f'BeforeInvocationData(invocation_state={getattr(self, 'invocation-state')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, BeforeInvocationData): - return NotImplemented - return getattr(self, 'invocation-state') == getattr(other, 'invocation-state') - - def __hash__(self) -> int: - return id(self) - - -class AfterInvocationData: - """Payload for `after-invocation`.""" - def __init__( - self, - *, - invocation_state: str, - ) -> None: - setattr(self, 'invocation-state', invocation_state) - - @property - def invocation_state(self) -> str: - return getattr(self, 'invocation-state') - - def __repr__(self) -> str: - return f'AfterInvocationData(invocation_state={getattr(self, 'invocation-state')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, AfterInvocationData): - return NotImplemented - return getattr(self, 'invocation-state') == getattr(other, 'invocation-state') - - def __hash__(self) -> int: - return id(self) - - -class MessageAddedData: - """Payload for `message-added`.""" - def __init__( - self, - *, - message: Message, - ) -> None: - setattr(self, 'message', message) - - def __repr__(self) -> str: - return f'MessageAddedData(message={getattr(self, 'message')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, MessageAddedData): - return NotImplemented - return getattr(self, 'message') == getattr(other, 'message') - - def __hash__(self) -> int: - return id(self) - - -class BeforeModelCallData: - """Payload for `before-model-call`.""" - def __init__( - self, - *, - projected_input_tokens: Optional[int], - ) -> None: - setattr(self, 'projected-input-tokens', projected_input_tokens) - - @property - def projected_input_tokens(self) -> Optional[int]: - return getattr(self, 'projected-input-tokens') - - def __repr__(self) -> str: - return f'BeforeModelCallData(projected_input_tokens={getattr(self, 'projected-input-tokens')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, BeforeModelCallData): - return NotImplemented - return getattr(self, 'projected-input-tokens') == getattr(other, 'projected-input-tokens') - - def __hash__(self) -> int: - return id(self) - - -class AfterModelCallData: - """Payload for `after-model-call`.""" - def __init__( - self, - *, - attempt_count: int, - stop_data: Optional[ModelStopData], - error: Optional[ModelError], - ) -> None: - setattr(self, 'attempt-count', attempt_count) - setattr(self, 'stop-data', stop_data) - setattr(self, 'error', error) - - @property - def attempt_count(self) -> int: - return getattr(self, 'attempt-count') - - @property - def stop_data(self) -> Optional[ModelStopData]: - return getattr(self, 'stop-data') - - def __repr__(self) -> str: - return f'AfterModelCallData(attempt_count={getattr(self, 'attempt-count')!r}, stop_data={getattr(self, 'stop-data')!r}, error={getattr(self, 'error')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, AfterModelCallData): - return NotImplemented - return getattr(self, 'attempt-count') == getattr(other, 'attempt-count') and getattr(self, 'stop-data') == getattr(other, 'stop-data') and getattr(self, 'error') == getattr(other, 'error') - - def __hash__(self) -> int: - return id(self) - - -class BeforeToolCallData: - """Payload for `before-tool-call`.""" - def __init__( - self, - *, - tool_use: ToolUseData, - ) -> None: - setattr(self, 'tool-use', tool_use) - - @property - def tool_use(self) -> ToolUseData: - return getattr(self, 'tool-use') - - def __repr__(self) -> str: - return f'BeforeToolCallData(tool_use={getattr(self, 'tool-use')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, BeforeToolCallData): - return NotImplemented - return getattr(self, 'tool-use') == getattr(other, 'tool-use') - - def __hash__(self) -> int: - return id(self) - - -class AfterToolCallData: - """Payload for `after-tool-call`.""" - def __init__( - self, - *, - tool_use: ToolUseData, - tool_result: ToolResultBlock, - error: Optional[ToolError], - ) -> None: - setattr(self, 'tool-use', tool_use) - setattr(self, 'tool-result', tool_result) - setattr(self, 'error', error) - - @property - def tool_use(self) -> ToolUseData: - return getattr(self, 'tool-use') - - @property - def tool_result(self) -> ToolResultBlock: - return getattr(self, 'tool-result') - - def __repr__(self) -> str: - return f'AfterToolCallData(tool_use={getattr(self, 'tool-use')!r}, tool_result={getattr(self, 'tool-result')!r}, error={getattr(self, 'error')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, AfterToolCallData): - return NotImplemented - return getattr(self, 'tool-use') == getattr(other, 'tool-use') and getattr(self, 'tool-result') == getattr(other, 'tool-result') and getattr(self, 'error') == getattr(other, 'error') - - def __hash__(self) -> int: - return id(self) - - -class ToolsBatchData: - """Payload for `before-tools` / `after-tools`.""" - def __init__( - self, - *, - message: Message, - ) -> None: - setattr(self, 'message', message) - - def __repr__(self) -> str: - return f'ToolsBatchData(message={getattr(self, 'message')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, ToolsBatchData): - return NotImplemented - return getattr(self, 'message') == getattr(other, 'message') - - def __hash__(self) -> int: - return id(self) - - -class ContentBlockData: - """Payload for `content-block`.""" - def __init__( - self, - *, - content_block: ContentBlock, - ) -> None: - setattr(self, 'content-block', content_block) - - @property - def content_block(self) -> ContentBlock: - return getattr(self, 'content-block') - - def __repr__(self) -> str: - return f'ContentBlockData(content_block={getattr(self, 'content-block')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, ContentBlockData): - return NotImplemented - return getattr(self, 'content-block') == getattr(other, 'content-block') - - def __hash__(self) -> int: - return id(self) - - -class ModelMessageData: - """Payload for `model-message`.""" - def __init__( - self, - *, - message: Message, - stop_reason: StopReason, - ) -> None: - setattr(self, 'message', message) - setattr(self, 'stop-reason', stop_reason) - - @property - def stop_reason(self) -> StopReason: - return getattr(self, 'stop-reason') - - def __repr__(self) -> str: - return f'ModelMessageData(message={getattr(self, 'message')!r}, stop_reason={getattr(self, 'stop-reason')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, ModelMessageData): - return NotImplemented - return getattr(self, 'message') == getattr(other, 'message') and getattr(self, 'stop-reason') == getattr(other, 'stop-reason') - - def __hash__(self) -> int: - return id(self) - - -class ToolResultData: - """Payload for `tool-result-hook`.""" - def __init__( - self, - *, - tool_result: ToolResultBlock, - ) -> None: - setattr(self, 'tool-result', tool_result) - - @property - def tool_result(self) -> ToolResultBlock: - return getattr(self, 'tool-result') - - def __repr__(self) -> str: - return f'ToolResultData(tool_result={getattr(self, 'tool-result')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, ToolResultData): - return NotImplemented - return getattr(self, 'tool-result') == getattr(other, 'tool-result') - - def __hash__(self) -> int: - return id(self) - - -class ToolStreamUpdateData: - """Payload for `tool-stream-update`.""" - def __init__( - self, - *, - data: str, - ) -> None: - setattr(self, 'data', data) - - def __repr__(self) -> str: - return f'ToolStreamUpdateData(data={getattr(self, 'data')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, ToolStreamUpdateData): - return NotImplemented - return getattr(self, 'data') == getattr(other, 'data') - - def __hash__(self) -> int: - return id(self) - - -class ModelStreamUpdateData: - """Payload for `model-stream-update`.""" - def __init__( - self, - *, - event: str, - ) -> None: - setattr(self, 'event', event) - - def __repr__(self) -> str: - return f'ModelStreamUpdateData(event={getattr(self, 'event')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, ModelStreamUpdateData): - return NotImplemented - return getattr(self, 'event') == getattr(other, 'event') - - def __hash__(self) -> int: - return id(self) - - -class InputRedaction: - """Input redaction emitted when a guardrail blocks input. Original is in history.""" - def __init__( - self, - *, - replace_content: str, - ) -> None: - setattr(self, 'replace-content', replace_content) - - @property - def replace_content(self) -> str: - return getattr(self, 'replace-content') - - def __repr__(self) -> str: - return f'InputRedaction(replace_content={getattr(self, 'replace-content')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, InputRedaction): - return NotImplemented - return getattr(self, 'replace-content') == getattr(other, 'replace-content') - - def __hash__(self) -> int: - return id(self) - - -class OutputRedaction: - """Output redaction emitted when a guardrail blocks output.""" - def __init__( - self, - *, - redacted_content: Optional[str], - replace_content: str, - ) -> None: - setattr(self, 'redacted-content', redacted_content) - setattr(self, 'replace-content', replace_content) - - @property - def redacted_content(self) -> Optional[str]: - return getattr(self, 'redacted-content') - - @property - def replace_content(self) -> str: - return getattr(self, 'replace-content') - - def __repr__(self) -> str: - return f'OutputRedaction(redacted_content={getattr(self, 'redacted-content')!r}, replace_content={getattr(self, 'replace-content')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, OutputRedaction): - return NotImplemented - return getattr(self, 'redacted-content') == getattr(other, 'redacted-content') and getattr(self, 'replace-content') == getattr(other, 'replace-content') - - def __hash__(self) -> int: - return id(self) - - -class RedactionEvent: - """Redaction event. Input and output fields are independent; at least one is set.""" - def __init__( - self, - *, - input_redaction: Optional[InputRedaction], - output_redaction: Optional[OutputRedaction], - ) -> None: - setattr(self, 'input-redaction', input_redaction) - setattr(self, 'output-redaction', output_redaction) - - @property - def input_redaction(self) -> Optional[InputRedaction]: - return getattr(self, 'input-redaction') - - @property - def output_redaction(self) -> Optional[OutputRedaction]: - return getattr(self, 'output-redaction') - - def __repr__(self) -> str: - return f'RedactionEvent(input_redaction={getattr(self, 'input-redaction')!r}, output_redaction={getattr(self, 'output-redaction')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, RedactionEvent): - return NotImplemented - return getattr(self, 'input-redaction') == getattr(other, 'input-redaction') and getattr(self, 'output-redaction') == getattr(other, 'output-redaction') - - def __hash__(self) -> int: - return id(self) - - -class StopEvent: - """Terminal event for a stream.""" - def __init__( - self, - *, - reason: StopReason, - usage: Optional[Usage], - metrics: Optional[Metrics], - structured_output: Optional[str], - ) -> None: - setattr(self, 'reason', reason) - setattr(self, 'usage', usage) - setattr(self, 'metrics', metrics) - setattr(self, 'structured-output', structured_output) - - @property - def structured_output(self) -> Optional[str]: - return getattr(self, 'structured-output') - - def __repr__(self) -> str: - return f'StopEvent(reason={getattr(self, 'reason')!r}, usage={getattr(self, 'usage')!r}, metrics={getattr(self, 'metrics')!r}, structured_output={getattr(self, 'structured-output')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, StopEvent): - return NotImplemented - return getattr(self, 'reason') == getattr(other, 'reason') and getattr(self, 'usage') == getattr(other, 'usage') and getattr(self, 'metrics') == getattr(other, 'metrics') and getattr(self, 'structured-output') == getattr(other, 'structured-output') - - def __hash__(self) -> int: - return id(self) - - -class AgentResultData: - """Payload for `agent-result`.""" - def __init__( - self, - *, - stop: StopEvent, - ) -> None: - setattr(self, 'stop', stop) - - def __repr__(self) -> str: - return f'AgentResultData(stop={getattr(self, 'stop')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, AgentResultData): - return NotImplemented - return getattr(self, 'stop') == getattr(other, 'stop') - - def __hash__(self) -> int: - return id(self) - - -class StreamError: - """Why the agent loop surfaced an error mid-stream.""" - pass - -class StreamError_Model(StreamError, _WitVariantCase): - """A model call failed.""" - tag = 'model' - -class StreamError_Tool(StreamError, _WitVariantCase): - """A tool call failed.""" - tag = 'tool' - -class StreamError_ContextWindowExceeded(StreamError, _WitVariantCase): - """Input exceeded the model's context window and no conversation -manager could recover.""" - tag = 'context-window-exceeded' - -class StreamError_MaxTokensReached(StreamError, _WitVariantCase): - """Exceeded the model's max-tokens budget mid-response.""" - tag = 'max-tokens-reached' - -class StreamError_StructuredOutputUnavailable(StreamError, _WitVariantCase): - """Structured output was requested but the model never called the -tool, even after being forced.""" - tag = 'structured-output-unavailable' - -class StreamError_Internal(StreamError, _WitVariantCase): - """Catch-all for internal failures.""" - tag = 'internal' - -_StreamError_CASES: dict[str, type] = { - 'model': StreamError_Model, - 'tool': StreamError_Tool, - 'context-window-exceeded': StreamError_ContextWindowExceeded, - 'max-tokens-reached': StreamError_MaxTokensReached, - 'structured-output-unavailable': StreamError_StructuredOutputUnavailable, - 'internal': StreamError_Internal, -} - -def _StreamError_lift(raw: _WitVariant) -> StreamError: - cls = _StreamError_CASES.get(raw.tag) - if cls is None: - raise ValueError(f'unknown StreamError arm: {raw.tag!r}') - return cls(raw.payload) -StreamError.lift = staticmethod(_StreamError_lift) # type: ignore[attr-defined] - -class StreamEvent: - """Events yielded during agent streaming. -Hot-path arms: `text-delta`, `tool-use`, `tool-result`. Other content -blocks flow through `content`. Lifecycle arms (`before-invocation` -through `agent-result`) mirror a hook system and can be filtered by tag.""" - pass - -class StreamEvent_TextDelta(StreamEvent, _WitVariantCase): - """Incremental text from the model.""" - tag = 'text-delta' - -class StreamEvent_ToolUse(StreamEvent, _WitVariantCase): - """Model requested a tool call.""" - tag = 'tool-use' - -class StreamEvent_ToolResult(StreamEvent, _WitVariantCase): - """Tool call completed.""" - tag = 'tool-result' - -class StreamEvent_Content(StreamEvent, _WitVariantCase): - """Non-hot-path content block (image, reasoning, citations, etc).""" - tag = 'content' - -class StreamEvent_Metadata(StreamEvent, _WitVariantCase): - """Cumulative usage and metrics snapshot.""" - tag = 'metadata' - -class StreamEvent_Stop(StreamEvent, _WitVariantCase): - """Terminal event for the stream.""" - tag = 'stop' - -class StreamEvent_Redaction(StreamEvent, _WitVariantCase): - """Guardrail redaction fired.""" - tag = 'redaction' - -class StreamEvent_Error(StreamEvent, _WitVariantCase): - """Recoverable error surfaced mid-stream.""" - tag = 'error' - -class StreamEvent_Interrupt(StreamEvent, _WitVariantCase): - """Human-in-the-loop pause; resume via `response-stream.respond`.""" - tag = 'interrupt' - -class StreamEvent_Initialized(StreamEvent, _WitVariantCase): - """Agent finished construction.""" - tag = 'initialized' - -class StreamEvent_BeforeInvocation(StreamEvent, _WitVariantCase): - """About to process a user invocation.""" - tag = 'before-invocation' - -class StreamEvent_AfterInvocation(StreamEvent, _WitVariantCase): - """Finished processing a user invocation.""" - tag = 'after-invocation' - -class StreamEvent_MessageAdded(StreamEvent, _WitVariantCase): - """A message was appended to the conversation.""" - tag = 'message-added' - -class StreamEvent_BeforeModelCall(StreamEvent, _WitVariantCase): - """About to call the model.""" - tag = 'before-model-call' - -class StreamEvent_AfterModelCall(StreamEvent, _WitVariantCase): - """Model call returned.""" - tag = 'after-model-call' - -class StreamEvent_BeforeTools(StreamEvent, _WitVariantCase): - """About to run a batch of tool calls from one assistant turn.""" - tag = 'before-tools' - -class StreamEvent_AfterTools(StreamEvent, _WitVariantCase): - """Tool batch finished.""" - tag = 'after-tools' - -class StreamEvent_BeforeToolCall(StreamEvent, _WitVariantCase): - """About to call a single tool.""" - tag = 'before-tool-call' - -class StreamEvent_AfterToolCall(StreamEvent, _WitVariantCase): - """Tool call returned.""" - tag = 'after-tool-call' - -class StreamEvent_ContentBlock(StreamEvent, _WitVariantCase): - """A content block was assembled during streaming.""" - tag = 'content-block' - -class StreamEvent_ModelMessage(StreamEvent, _WitVariantCase): - """Model finished producing a full message.""" - tag = 'model-message' - -class StreamEvent_ToolResultHook(StreamEvent, _WitVariantCase): - """Tool finished execution (completion event, not streaming update).""" - tag = 'tool-result-hook' - -class StreamEvent_ToolUpdate(StreamEvent, _WitVariantCase): - """Streaming update from a tool.""" - tag = 'tool-update' - -class StreamEvent_ModelUpdate(StreamEvent, _WitVariantCase): - """Streaming update from the model.""" - tag = 'model-update' - -class StreamEvent_AgentResult(StreamEvent, _WitVariantCase): - """Final event for an invocation, carrying the terminal result.""" - tag = 'agent-result' - -_StreamEvent_CASES: dict[str, type] = { - 'text-delta': StreamEvent_TextDelta, - 'tool-use': StreamEvent_ToolUse, - 'tool-result': StreamEvent_ToolResult, - 'content': StreamEvent_Content, - 'metadata': StreamEvent_Metadata, - 'stop': StreamEvent_Stop, - 'redaction': StreamEvent_Redaction, - 'error': StreamEvent_Error, - 'interrupt': StreamEvent_Interrupt, - 'initialized': StreamEvent_Initialized, - 'before-invocation': StreamEvent_BeforeInvocation, - 'after-invocation': StreamEvent_AfterInvocation, - 'message-added': StreamEvent_MessageAdded, - 'before-model-call': StreamEvent_BeforeModelCall, - 'after-model-call': StreamEvent_AfterModelCall, - 'before-tools': StreamEvent_BeforeTools, - 'after-tools': StreamEvent_AfterTools, - 'before-tool-call': StreamEvent_BeforeToolCall, - 'after-tool-call': StreamEvent_AfterToolCall, - 'content-block': StreamEvent_ContentBlock, - 'model-message': StreamEvent_ModelMessage, - 'tool-result-hook': StreamEvent_ToolResultHook, - 'tool-update': StreamEvent_ToolUpdate, - 'model-update': StreamEvent_ModelUpdate, - 'agent-result': StreamEvent_AgentResult, -} - -def _StreamEvent_lift(raw: _WitVariant) -> StreamEvent: - cls = _StreamEvent_CASES.get(raw.tag) - if cls is None: - raise ValueError(f'unknown StreamEvent arm: {raw.tag!r}') - return cls(raw.payload) -StreamEvent.lift = staticmethod(_StreamEvent_lift) # type: ignore[attr-defined] - -class ModelEventStream: - """Pull-based stream of model events from a custom provider; host produces, guest reads.""" - # Wraps a wasmtime-py ResourceAny / ResourceHost handle. - # The runtime sets ._handle to the underlying resource and - # ._invoke to a callable that dispatches a method by WIT name. - - def __init__(self, handle: Any, invoke: Any = None) -> None: - self._handle = handle - self._invoke = invoke - - def read(self) -> Optional[StreamEvent]: - return self._invoke('[method]model-event-stream.read', (self._handle,)) - - -class ModelStreamOptions: - """Options passed alongside the messages on each streaming call.""" - def __init__( - self, - *, - system_prompt: Optional[PromptInput], - tools: Optional[list[ToolSpec]], - tool_choice: Optional[ToolChoice], - ) -> None: - setattr(self, 'system-prompt', system_prompt) - setattr(self, 'tools', tools) - setattr(self, 'tool-choice', tool_choice) - - @property - def system_prompt(self) -> Optional[PromptInput]: - return getattr(self, 'system-prompt') - - @property - def tool_choice(self) -> Optional[ToolChoice]: - return getattr(self, 'tool-choice') - - def __repr__(self) -> str: - return f'ModelStreamOptions(system_prompt={getattr(self, 'system-prompt')!r}, tools={getattr(self, 'tools')!r}, tool_choice={getattr(self, 'tool-choice')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, ModelStreamOptions): - return NotImplemented - return getattr(self, 'system-prompt') == getattr(other, 'system-prompt') and getattr(self, 'tools') == getattr(other, 'tools') and getattr(self, 'tool-choice') == getattr(other, 'tool-choice') - - def __hash__(self) -> int: - return id(self) - - -class StartStreamArgs: - """Arguments for `start-stream`.""" - def __init__( - self, - *, - provider_id: str, - messages: list[Message], - options: ModelStreamOptions, - ) -> None: - setattr(self, 'provider-id', provider_id) - setattr(self, 'messages', messages) - setattr(self, 'options', options) - - @property - def provider_id(self) -> str: - return getattr(self, 'provider-id') - - def __repr__(self) -> str: - return f'StartStreamArgs(provider_id={getattr(self, 'provider-id')!r}, messages={getattr(self, 'messages')!r}, options={getattr(self, 'options')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, StartStreamArgs): - return NotImplemented - return getattr(self, 'provider-id') == getattr(other, 'provider-id') and getattr(self, 'messages') == getattr(other, 'messages') and getattr(self, 'options') == getattr(other, 'options') - - def __hash__(self) -> int: - return id(self) - - -class CountTokensArgs: - """Arguments for `count-tokens`.""" - def __init__( - self, - *, - provider_id: str, - messages: list[Message], - system_prompt: Optional[PromptInput], - tools: Optional[list[ToolSpec]], - ) -> None: - setattr(self, 'provider-id', provider_id) - setattr(self, 'messages', messages) - setattr(self, 'system-prompt', system_prompt) - setattr(self, 'tools', tools) - - @property - def provider_id(self) -> str: - return getattr(self, 'provider-id') - - @property - def system_prompt(self) -> Optional[PromptInput]: - return getattr(self, 'system-prompt') - - def __repr__(self) -> str: - return f'CountTokensArgs(provider_id={getattr(self, 'provider-id')!r}, messages={getattr(self, 'messages')!r}, system_prompt={getattr(self, 'system-prompt')!r}, tools={getattr(self, 'tools')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, CountTokensArgs): - return NotImplemented - return getattr(self, 'provider-id') == getattr(other, 'provider-id') and getattr(self, 'messages') == getattr(other, 'messages') and getattr(self, 'system-prompt') == getattr(other, 'system-prompt') and getattr(self, 'tools') == getattr(other, 'tools') - - def __hash__(self) -> int: - return id(self) - - -class OrchestrationStatus(str): - """Lifecycle status of a node or overall run.""" - __slots__ = () - - PENDING: 'OrchestrationStatus' - EXECUTING: 'OrchestrationStatus' - COMPLETED: 'OrchestrationStatus' - FAILED: 'OrchestrationStatus' - CANCELLED: 'OrchestrationStatus' - -OrchestrationStatus.PENDING = OrchestrationStatus('pending') # type: ignore[attr-defined] -OrchestrationStatus.EXECUTING = OrchestrationStatus('executing') # type: ignore[attr-defined] -OrchestrationStatus.COMPLETED = OrchestrationStatus('completed') # type: ignore[attr-defined] -OrchestrationStatus.FAILED = OrchestrationStatus('failed') # type: ignore[attr-defined] -OrchestrationStatus.CANCELLED = OrchestrationStatus('cancelled') # type: ignore[attr-defined] - - -class TerminalStatus(str): - """Terminal status of a node or run.""" - __slots__ = () - - COMPLETED: 'TerminalStatus' - FAILED: 'TerminalStatus' - CANCELLED: 'TerminalStatus' - -TerminalStatus.COMPLETED = TerminalStatus('completed') # type: ignore[attr-defined] -TerminalStatus.FAILED = TerminalStatus('failed') # type: ignore[attr-defined] -TerminalStatus.CANCELLED = TerminalStatus('cancelled') # type: ignore[attr-defined] - - -class NodeKind(str): - """What a node is.""" - __slots__ = () - - AGENT: 'NodeKind' - MULTI_AGENT: 'NodeKind' - -NodeKind.AGENT = NodeKind('agent') # type: ignore[attr-defined] -NodeKind.MULTI_AGENT = NodeKind('multi-agent') # type: ignore[attr-defined] - - -class AgentNodeConfig: - """Definition of an agent-backed node.""" - def __init__( - self, - *, - id: str, - description: Optional[str], - timeout: Optional[int], - agent_config: str, - ) -> None: - setattr(self, 'id', id) - setattr(self, 'description', description) - setattr(self, 'timeout', timeout) - setattr(self, 'agent-config', agent_config) - - @property - def agent_config(self) -> str: - return getattr(self, 'agent-config') - - def __repr__(self) -> str: - return f'AgentNodeConfig(id={getattr(self, 'id')!r}, description={getattr(self, 'description')!r}, timeout={getattr(self, 'timeout')!r}, agent_config={getattr(self, 'agent-config')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, AgentNodeConfig): - return NotImplemented - return getattr(self, 'id') == getattr(other, 'id') and getattr(self, 'description') == getattr(other, 'description') and getattr(self, 'timeout') == getattr(other, 'timeout') and getattr(self, 'agent-config') == getattr(other, 'agent-config') - - def __hash__(self) -> int: - return id(self) - - -class MultiAgentNodeConfig: - """Definition of a node that wraps another orchestrator.""" - def __init__( - self, - *, - id: str, - description: Optional[str], - orchestrator: str, - ) -> None: - setattr(self, 'id', id) - setattr(self, 'description', description) - setattr(self, 'orchestrator', orchestrator) - - def __repr__(self) -> str: - return f'MultiAgentNodeConfig(id={getattr(self, 'id')!r}, description={getattr(self, 'description')!r}, orchestrator={getattr(self, 'orchestrator')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, MultiAgentNodeConfig): - return NotImplemented - return getattr(self, 'id') == getattr(other, 'id') and getattr(self, 'description') == getattr(other, 'description') and getattr(self, 'orchestrator') == getattr(other, 'orchestrator') - - def __hash__(self) -> int: - return id(self) - - -class NodeConfig: - """Any node a graph or swarm can execute.""" - pass - -class NodeConfig_Agent(NodeConfig, _WitVariantCase): - """Wraps a single agent.""" - tag = 'agent' - -class NodeConfig_MultiAgent(NodeConfig, _WitVariantCase): - """Wraps a nested orchestrator.""" - tag = 'multi-agent' - -_NodeConfig_CASES: dict[str, type] = { - 'agent': NodeConfig_Agent, - 'multi-agent': NodeConfig_MultiAgent, -} - -def _NodeConfig_lift(raw: _WitVariant) -> NodeConfig: - cls = _NodeConfig_CASES.get(raw.tag) - if cls is None: - raise ValueError(f'unknown NodeConfig arm: {raw.tag!r}') - return cls(raw.payload) -NodeConfig.lift = staticmethod(_NodeConfig_lift) # type: ignore[attr-defined] - -class EdgeHandler: - """Condition attached to a graph edge.""" - def __init__( - self, - *, - handler_id: str, - ) -> None: - setattr(self, 'handler-id', handler_id) - - @property - def handler_id(self) -> str: - return getattr(self, 'handler-id') - - def __repr__(self) -> str: - return f'EdgeHandler(handler_id={getattr(self, 'handler-id')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, EdgeHandler): - return NotImplemented - return getattr(self, 'handler-id') == getattr(other, 'handler-id') - - def __hash__(self) -> int: - return id(self) - - -class EdgeConfig: - """Edge connecting two graph nodes.""" - def __init__( - self, - *, - source: str, - target: str, - handler: Optional[EdgeHandler], - ) -> None: - setattr(self, 'source', source) - setattr(self, 'target', target) - setattr(self, 'handler', handler) - - def __repr__(self) -> str: - return f'EdgeConfig(source={getattr(self, 'source')!r}, target={getattr(self, 'target')!r}, handler={getattr(self, 'handler')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, EdgeConfig): - return NotImplemented - return getattr(self, 'source') == getattr(other, 'source') and getattr(self, 'target') == getattr(other, 'target') and getattr(self, 'handler') == getattr(other, 'handler') - - def __hash__(self) -> int: - return id(self) - - -class GraphConfig: - """Runtime configuration for a Graph.""" - def __init__( - self, - *, - id: str, - nodes: list[NodeConfig], - edges: list[EdgeConfig], - sources: list[str], - max_concurrency: Optional[int], - max_steps: Optional[int], - timeout: Optional[int], - node_timeout: Optional[int], - ) -> None: - setattr(self, 'id', id) - setattr(self, 'nodes', nodes) - setattr(self, 'edges', edges) - setattr(self, 'sources', sources) - setattr(self, 'max-concurrency', max_concurrency) - setattr(self, 'max-steps', max_steps) - setattr(self, 'timeout', timeout) - setattr(self, 'node-timeout', node_timeout) - - @property - def max_concurrency(self) -> Optional[int]: - return getattr(self, 'max-concurrency') - - @property - def max_steps(self) -> Optional[int]: - return getattr(self, 'max-steps') - - @property - def node_timeout(self) -> Optional[int]: - return getattr(self, 'node-timeout') - - def __repr__(self) -> str: - return f'GraphConfig(id={getattr(self, 'id')!r}, nodes={getattr(self, 'nodes')!r}, edges={getattr(self, 'edges')!r}, sources={getattr(self, 'sources')!r}, max_concurrency={getattr(self, 'max-concurrency')!r}, max_steps={getattr(self, 'max-steps')!r}, timeout={getattr(self, 'timeout')!r}, node_timeout={getattr(self, 'node-timeout')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, GraphConfig): - return NotImplemented - return getattr(self, 'id') == getattr(other, 'id') and getattr(self, 'nodes') == getattr(other, 'nodes') and getattr(self, 'edges') == getattr(other, 'edges') and getattr(self, 'sources') == getattr(other, 'sources') and getattr(self, 'max-concurrency') == getattr(other, 'max-concurrency') and getattr(self, 'max-steps') == getattr(other, 'max-steps') and getattr(self, 'timeout') == getattr(other, 'timeout') and getattr(self, 'node-timeout') == getattr(other, 'node-timeout') - - def __hash__(self) -> int: - return id(self) - - -class SwarmConfig: - """Runtime configuration for a Swarm.""" - def __init__( - self, - *, - id: str, - nodes: list[AgentNodeConfig], - start_node_id: str, - max_steps: Optional[int], - timeout: Optional[int], - node_timeout: Optional[int], - ) -> None: - setattr(self, 'id', id) - setattr(self, 'nodes', nodes) - setattr(self, 'start-node-id', start_node_id) - setattr(self, 'max-steps', max_steps) - setattr(self, 'timeout', timeout) - setattr(self, 'node-timeout', node_timeout) - - @property - def start_node_id(self) -> str: - return getattr(self, 'start-node-id') - - @property - def max_steps(self) -> Optional[int]: - return getattr(self, 'max-steps') - - @property - def node_timeout(self) -> Optional[int]: - return getattr(self, 'node-timeout') - - def __repr__(self) -> str: - return f'SwarmConfig(id={getattr(self, 'id')!r}, nodes={getattr(self, 'nodes')!r}, start_node_id={getattr(self, 'start-node-id')!r}, max_steps={getattr(self, 'max-steps')!r}, timeout={getattr(self, 'timeout')!r}, node_timeout={getattr(self, 'node-timeout')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, SwarmConfig): - return NotImplemented - return getattr(self, 'id') == getattr(other, 'id') and getattr(self, 'nodes') == getattr(other, 'nodes') and getattr(self, 'start-node-id') == getattr(other, 'start-node-id') and getattr(self, 'max-steps') == getattr(other, 'max-steps') and getattr(self, 'timeout') == getattr(other, 'timeout') and getattr(self, 'node-timeout') == getattr(other, 'node-timeout') - - def __hash__(self) -> int: - return id(self) - - -class NodeError: - """Why a node or run ended in `failed` status.""" - pass - -class NodeError_Execution(NodeError, _WitVariantCase): - """An underlying agent or nested orchestrator failed.""" - tag = 'execution' - -class NodeError_Timeout(NodeError, _WitVariantCase): - """Wall-clock ceiling was exceeded.""" - tag = 'timeout' - -class NodeError_LimitExceeded(NodeError, _WitVariantCase): - """A declared runtime limit (max-steps, max-concurrency) was hit.""" - tag = 'limit-exceeded' - -class NodeError_EdgeHandler(NodeError, _WitVariantCase): - """Edge handler rejected the traversal with an error.""" - tag = 'edge-handler' - -class NodeError_InvalidConfig(NodeError, _WitVariantCase): - """Invalid configuration detected at run time.""" - tag = 'invalid-config' - -class NodeError_Internal(NodeError, _WitVariantCase): - """Catch-all for internal failures.""" - tag = 'internal' - -_NodeError_CASES: dict[str, type] = { - 'execution': NodeError_Execution, - 'timeout': NodeError_Timeout, - 'limit-exceeded': NodeError_LimitExceeded, - 'edge-handler': NodeError_EdgeHandler, - 'invalid-config': NodeError_InvalidConfig, - 'internal': NodeError_Internal, -} - -def _NodeError_lift(raw: _WitVariant) -> NodeError: - cls = _NodeError_CASES.get(raw.tag) - if cls is None: - raise ValueError(f'unknown NodeError arm: {raw.tag!r}') - return cls(raw.payload) -NodeError.lift = staticmethod(_NodeError_lift) # type: ignore[attr-defined] - -class NodeResult: - """Result of a single node execution.""" - def __init__( - self, - *, - node_id: str, - status: TerminalStatus, - duration: int, - content: list[ContentBlock], - error: Optional[NodeError], - structured_output: Optional[str], - usage: Optional[Usage], - metrics: Optional[Metrics], - ) -> None: - setattr(self, 'node-id', node_id) - setattr(self, 'status', status) - setattr(self, 'duration', duration) - setattr(self, 'content', content) - setattr(self, 'error', error) - setattr(self, 'structured-output', structured_output) - setattr(self, 'usage', usage) - setattr(self, 'metrics', metrics) - - @property - def node_id(self) -> str: - return getattr(self, 'node-id') - - @property - def structured_output(self) -> Optional[str]: - return getattr(self, 'structured-output') - - def __repr__(self) -> str: - return f'NodeResult(node_id={getattr(self, 'node-id')!r}, status={getattr(self, 'status')!r}, duration={getattr(self, 'duration')!r}, content={getattr(self, 'content')!r}, error={getattr(self, 'error')!r}, structured_output={getattr(self, 'structured-output')!r}, usage={getattr(self, 'usage')!r}, metrics={getattr(self, 'metrics')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, NodeResult): - return NotImplemented - return getattr(self, 'node-id') == getattr(other, 'node-id') and getattr(self, 'status') == getattr(other, 'status') and getattr(self, 'duration') == getattr(other, 'duration') and getattr(self, 'content') == getattr(other, 'content') and getattr(self, 'error') == getattr(other, 'error') and getattr(self, 'structured-output') == getattr(other, 'structured-output') and getattr(self, 'usage') == getattr(other, 'usage') and getattr(self, 'metrics') == getattr(other, 'metrics') - - def __hash__(self) -> int: - return id(self) - - -class MultiAgentResult: - """Final result of a graph or swarm run.""" - def __init__( - self, - *, - status: TerminalStatus, - nodes: list[NodeResult], - duration: int, - usage: Optional[Usage], - metrics: Optional[Metrics], - ) -> None: - setattr(self, 'status', status) - setattr(self, 'nodes', nodes) - setattr(self, 'duration', duration) - setattr(self, 'usage', usage) - setattr(self, 'metrics', metrics) - - def __repr__(self) -> str: - return f'MultiAgentResult(status={getattr(self, 'status')!r}, nodes={getattr(self, 'nodes')!r}, duration={getattr(self, 'duration')!r}, usage={getattr(self, 'usage')!r}, metrics={getattr(self, 'metrics')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, MultiAgentResult): - return NotImplemented - return getattr(self, 'status') == getattr(other, 'status') and getattr(self, 'nodes') == getattr(other, 'nodes') and getattr(self, 'duration') == getattr(other, 'duration') and getattr(self, 'usage') == getattr(other, 'usage') and getattr(self, 'metrics') == getattr(other, 'metrics') - - def __hash__(self) -> int: - return id(self) - - -class MultiAgentInvokeArgs: - """Arguments for invoking a graph or swarm.""" - def __init__( - self, - *, - input: PromptInput, - invocation_state: Optional[str], - ) -> None: - setattr(self, 'input', input) - setattr(self, 'invocation-state', invocation_state) - - @property - def invocation_state(self) -> Optional[str]: - return getattr(self, 'invocation-state') - - def __repr__(self) -> str: - return f'MultiAgentInvokeArgs(input={getattr(self, 'input')!r}, invocation_state={getattr(self, 'invocation-state')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, MultiAgentInvokeArgs): - return NotImplemented - return getattr(self, 'input') == getattr(other, 'input') and getattr(self, 'invocation-state') == getattr(other, 'invocation-state') - - def __hash__(self) -> int: - return id(self) - - -class NodeStartData: - """Payload for `node-start`.""" - def __init__( - self, - *, - node_id: str, - kind: NodeKind, - ) -> None: - setattr(self, 'node-id', node_id) - setattr(self, 'kind', kind) - - @property - def node_id(self) -> str: - return getattr(self, 'node-id') - - def __repr__(self) -> str: - return f'NodeStartData(node_id={getattr(self, 'node-id')!r}, kind={getattr(self, 'kind')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, NodeStartData): - return NotImplemented - return getattr(self, 'node-id') == getattr(other, 'node-id') and getattr(self, 'kind') == getattr(other, 'kind') - - def __hash__(self) -> int: - return id(self) - - -class NodeEventData: - """Payload for `node-event`. Carries a nested stream event from a -running node.""" - def __init__( - self, - *, - node_id: str, - event: StreamEvent, - ) -> None: - setattr(self, 'node-id', node_id) - setattr(self, 'event', event) - - @property - def node_id(self) -> str: - return getattr(self, 'node-id') - - def __repr__(self) -> str: - return f'NodeEventData(node_id={getattr(self, 'node-id')!r}, event={getattr(self, 'event')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, NodeEventData): - return NotImplemented - return getattr(self, 'node-id') == getattr(other, 'node-id') and getattr(self, 'event') == getattr(other, 'event') - - def __hash__(self) -> int: - return id(self) - - -class HandoffEvent: - """Payload for a handoff edge firing.""" - def __init__( - self, - *, - from_node_ids: list[str], - to_node_ids: list[str], - ) -> None: - setattr(self, 'from-node-ids', from_node_ids) - setattr(self, 'to-node-ids', to_node_ids) - - @property - def from_node_ids(self) -> list[str]: - return getattr(self, 'from-node-ids') - - @property - def to_node_ids(self) -> list[str]: - return getattr(self, 'to-node-ids') - - def __repr__(self) -> str: - return f'HandoffEvent(from_node_ids={getattr(self, 'from-node-ids')!r}, to_node_ids={getattr(self, 'to-node-ids')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, HandoffEvent): - return NotImplemented - return getattr(self, 'from-node-ids') == getattr(other, 'from-node-ids') and getattr(self, 'to-node-ids') == getattr(other, 'to-node-ids') - - def __hash__(self) -> int: - return id(self) - - -class MultiAgentStreamEvent: - """Events emitted while streaming a multi-agent run.""" - pass - -class MultiAgentStreamEvent_NodeStart(MultiAgentStreamEvent, _WitVariantCase): - """A node began executing.""" - tag = 'node-start' - -class MultiAgentStreamEvent_Nested(MultiAgentStreamEvent, _WitVariantCase): - """A nested stream event from a running node.""" - tag = 'nested' - -class MultiAgentStreamEvent_NodeStop(MultiAgentStreamEvent, _WitVariantCase): - """A node finished executing.""" - tag = 'node-stop' - -class MultiAgentStreamEvent_Handoff(MultiAgentStreamEvent, _WitVariantCase): - """A handoff happened between nodes.""" - tag = 'handoff' - -class MultiAgentStreamEvent_RunComplete(MultiAgentStreamEvent, _WitVariantCase): - """Terminal result for the run.""" - tag = 'run-complete' - -_MultiAgentStreamEvent_CASES: dict[str, type] = { - 'node-start': MultiAgentStreamEvent_NodeStart, - 'nested': MultiAgentStreamEvent_Nested, - 'node-stop': MultiAgentStreamEvent_NodeStop, - 'handoff': MultiAgentStreamEvent_Handoff, - 'run-complete': MultiAgentStreamEvent_RunComplete, -} - -def _MultiAgentStreamEvent_lift(raw: _WitVariant) -> MultiAgentStreamEvent: - cls = _MultiAgentStreamEvent_CASES.get(raw.tag) - if cls is None: - raise ValueError(f'unknown MultiAgentStreamEvent arm: {raw.tag!r}') - return cls(raw.payload) -MultiAgentStreamEvent.lift = staticmethod(_MultiAgentStreamEvent_lift) # type: ignore[attr-defined] - -class EdgeHandlerError: - """Why an edge evaluation failed.""" - pass - -class EdgeHandlerError_Unknown(EdgeHandlerError, _WitVariantCase): - """No handler registered for the given id.""" - tag = 'unknown' - -class EdgeHandlerError_Failed(EdgeHandlerError, _WitVariantCase): - """Handler raised an exception.""" - tag = 'failed' - -_EdgeHandlerError_CASES: dict[str, type] = { - 'unknown': EdgeHandlerError_Unknown, - 'failed': EdgeHandlerError_Failed, -} - -def _EdgeHandlerError_lift(raw: _WitVariant) -> EdgeHandlerError: - cls = _EdgeHandlerError_CASES.get(raw.tag) - if cls is None: - raise ValueError(f'unknown EdgeHandlerError arm: {raw.tag!r}') - return cls(raw.payload) -EdgeHandlerError.lift = staticmethod(_EdgeHandlerError_lift) # type: ignore[attr-defined] - -class HandlerState: - """State snapshot passed to `evaluate` so the handler can branch on -prior node results.""" - def __init__( - self, - *, - results: list[NodeResult], - execution_count: int, - ) -> None: - setattr(self, 'results', results) - setattr(self, 'execution-count', execution_count) - - @property - def execution_count(self) -> int: - return getattr(self, 'execution-count') - - def __repr__(self) -> str: - return f'HandlerState(results={getattr(self, 'results')!r}, execution_count={getattr(self, 'execution-count')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, HandlerState): - return NotImplemented - return getattr(self, 'results') == getattr(other, 'results') and getattr(self, 'execution-count') == getattr(other, 'execution-count') - - def __hash__(self) -> int: - return id(self) - - -class BashToolConfig: - """Bash tool configuration.""" - def __init__( - self, - *, - default_timeout_s: Optional[int], - ) -> None: - setattr(self, 'default-timeout-s', default_timeout_s) - - @property - def default_timeout_s(self) -> Optional[int]: - return getattr(self, 'default-timeout-s') - - def __repr__(self) -> str: - return f'BashToolConfig(default_timeout_s={getattr(self, 'default-timeout-s')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, BashToolConfig): - return NotImplemented - return getattr(self, 'default-timeout-s') == getattr(other, 'default-timeout-s') - - def __hash__(self) -> int: - return id(self) - - -class FileEditorToolConfig: - """File editor tool configuration.""" - def __init__( - self, - *, - workspace_root: Optional[str], - ) -> None: - setattr(self, 'workspace-root', workspace_root) - - @property - def workspace_root(self) -> Optional[str]: - return getattr(self, 'workspace-root') - - def __repr__(self) -> str: - return f'FileEditorToolConfig(workspace_root={getattr(self, 'workspace-root')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, FileEditorToolConfig): - return NotImplemented - return getattr(self, 'workspace-root') == getattr(other, 'workspace-root') - - def __hash__(self) -> int: - return id(self) - - -class HttpRequestToolConfig: - """HTTP request tool configuration.""" - def __init__( - self, - *, - allowed_hosts: list[str], - max_response_bytes: int, - ) -> None: - setattr(self, 'allowed-hosts', allowed_hosts) - setattr(self, 'max-response-bytes', max_response_bytes) - - @property - def allowed_hosts(self) -> list[str]: - return getattr(self, 'allowed-hosts') - - @property - def max_response_bytes(self) -> int: - return getattr(self, 'max-response-bytes') - - def __repr__(self) -> str: - return f'HttpRequestToolConfig(allowed_hosts={getattr(self, 'allowed-hosts')!r}, max_response_bytes={getattr(self, 'max-response-bytes')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, HttpRequestToolConfig): - return NotImplemented - return getattr(self, 'allowed-hosts') == getattr(other, 'allowed-hosts') and getattr(self, 'max-response-bytes') == getattr(other, 'max-response-bytes') - - def __hash__(self) -> int: - return id(self) - - -class NotebookToolConfig: - """Notebook tool configuration.""" - def __init__( - self, - *, - workspace_root: Optional[str], - ) -> None: - setattr(self, 'workspace-root', workspace_root) - - @property - def workspace_root(self) -> Optional[str]: - return getattr(self, 'workspace-root') - - def __repr__(self) -> str: - return f'NotebookToolConfig(workspace_root={getattr(self, 'workspace-root')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, NotebookToolConfig): - return NotImplemented - return getattr(self, 'workspace-root') == getattr(other, 'workspace-root') - - def __hash__(self) -> int: - return id(self) - - -class VendedTool: - """Built-in tools.""" - pass - -class VendedTool_Bash(VendedTool, _WitVariantCase): - """Run shell commands in a persistent bash session.""" - tag = 'bash' - -class VendedTool_FileEditor(VendedTool, _WitVariantCase): - """Create, view, and edit files on disk.""" - tag = 'file-editor' - -class VendedTool_HttpRequest(VendedTool, _WitVariantCase): - """Make HTTP requests.""" - tag = 'http-request' - -class VendedTool_Notebook(VendedTool, _WitVariantCase): - """Read and execute Jupyter notebook cells.""" - tag = 'notebook' - -_VendedTool_CASES: dict[str, type] = { - 'bash': VendedTool_Bash, - 'file-editor': VendedTool_FileEditor, - 'http-request': VendedTool_HttpRequest, - 'notebook': VendedTool_Notebook, -} - -def _VendedTool_lift(raw: _WitVariant) -> VendedTool: - cls = _VendedTool_CASES.get(raw.tag) - if cls is None: - raise ValueError(f'unknown VendedTool arm: {raw.tag!r}') - return cls(raw.payload) -VendedTool.lift = staticmethod(_VendedTool_lift) # type: ignore[attr-defined] - -class SkillSource: - """Location of a skill definition on disk.""" - def __init__( - self, - *, - path: str, - ) -> None: - setattr(self, 'path', path) - - def __repr__(self) -> str: - return f'SkillSource(path={getattr(self, 'path')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, SkillSource): - return NotImplemented - return getattr(self, 'path') == getattr(other, 'path') - - def __hash__(self) -> int: - return id(self) - - -class SkillsPluginConfig: - """Skills plugin configuration.""" - def __init__( - self, - *, - skills: list[SkillSource], - strict: bool, - max_resource_files: Optional[int], - state_key: Optional[str], - ) -> None: - setattr(self, 'skills', skills) - setattr(self, 'strict', strict) - setattr(self, 'max-resource-files', max_resource_files) - setattr(self, 'state-key', state_key) - - @property - def max_resource_files(self) -> Optional[int]: - return getattr(self, 'max-resource-files') - - @property - def state_key(self) -> Optional[str]: - return getattr(self, 'state-key') - - def __repr__(self) -> str: - return f'SkillsPluginConfig(skills={getattr(self, 'skills')!r}, strict={getattr(self, 'strict')!r}, max_resource_files={getattr(self, 'max-resource-files')!r}, state_key={getattr(self, 'state-key')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, SkillsPluginConfig): - return NotImplemented - return getattr(self, 'skills') == getattr(other, 'skills') and getattr(self, 'strict') == getattr(other, 'strict') and getattr(self, 'max-resource-files') == getattr(other, 'max-resource-files') and getattr(self, 'state-key') == getattr(other, 'state-key') - - def __hash__(self) -> int: - return id(self) - - -class ContextOffloaderPluginConfig: - """Context offloader plugin configuration.""" - def __init__( - self, - *, - max_result_tokens: Optional[int], - preview_tokens: Optional[int], - include_retrieval_tool: bool, - ) -> None: - setattr(self, 'max-result-tokens', max_result_tokens) - setattr(self, 'preview-tokens', preview_tokens) - setattr(self, 'include-retrieval-tool', include_retrieval_tool) - - @property - def max_result_tokens(self) -> Optional[int]: - return getattr(self, 'max-result-tokens') - - @property - def preview_tokens(self) -> Optional[int]: - return getattr(self, 'preview-tokens') - - @property - def include_retrieval_tool(self) -> bool: - return getattr(self, 'include-retrieval-tool') - - def __repr__(self) -> str: - return f'ContextOffloaderPluginConfig(max_result_tokens={getattr(self, 'max-result-tokens')!r}, preview_tokens={getattr(self, 'preview-tokens')!r}, include_retrieval_tool={getattr(self, 'include-retrieval-tool')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, ContextOffloaderPluginConfig): - return NotImplemented - return getattr(self, 'max-result-tokens') == getattr(other, 'max-result-tokens') and getattr(self, 'preview-tokens') == getattr(other, 'preview-tokens') and getattr(self, 'include-retrieval-tool') == getattr(other, 'include-retrieval-tool') - - def __hash__(self) -> int: - return id(self) - - -class VendedPlugin: - """Built-in plugins.""" - pass - -class VendedPlugin_Skills(VendedPlugin, _WitVariantCase): - """Load and activate Anthropic-style skills from disk.""" - tag = 'skills' - -class VendedPlugin_ContextOffloader(VendedPlugin, _WitVariantCase): - """Offload large tool results to external storage.""" - tag = 'context-offloader' - -_VendedPlugin_CASES: dict[str, type] = { - 'skills': VendedPlugin_Skills, - 'context-offloader': VendedPlugin_ContextOffloader, -} - -def _VendedPlugin_lift(raw: _WitVariant) -> VendedPlugin: - cls = _VendedPlugin_CASES.get(raw.tag) - if cls is None: - raise ValueError(f'unknown VendedPlugin arm: {raw.tag!r}') - return cls(raw.payload) -VendedPlugin.lift = staticmethod(_VendedPlugin_lift) # type: ignore[attr-defined] - -class ConcurrentOptions: - """Concurrent-execution options.""" - def __init__( - self, - *, - max_concurrency: Optional[int], - ) -> None: - setattr(self, 'max-concurrency', max_concurrency) - - @property - def max_concurrency(self) -> Optional[int]: - return getattr(self, 'max-concurrency') - - def __repr__(self) -> str: - return f'ConcurrentOptions(max_concurrency={getattr(self, 'max-concurrency')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, ConcurrentOptions): - return NotImplemented - return getattr(self, 'max-concurrency') == getattr(other, 'max-concurrency') - - def __hash__(self) -> int: - return id(self) - - -ToolExecutorStrategy = None | ConcurrentOptions -"""Strategy for executing tool calls emitted in a single assistant turn.""" - -AttributeValue = str | int | float | bool -"""Scalar attribute value attached to a trace.""" - -class TraceAttribute: - """Key-value pair attached to every OpenTelemetry span the agent emits. -Distinct from `streaming.trace-metadata-entry`, which is string-only.""" - def __init__( - self, - *, - key: str, - value: AttributeValue, - ) -> None: - setattr(self, 'key', key) - setattr(self, 'value', value) - - def __repr__(self) -> str: - return f'TraceAttribute(key={getattr(self, 'key')!r}, value={getattr(self, 'value')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, TraceAttribute): - return NotImplemented - return getattr(self, 'key') == getattr(other, 'key') and getattr(self, 'value') == getattr(other, 'value') - - def __hash__(self) -> int: - return id(self) - - -class TraceContext: - """W3C Trace Context headers linking the agent's spans to a caller's trace.""" - def __init__( - self, - *, - traceparent: str, - tracestate: Optional[str], - ) -> None: - setattr(self, 'traceparent', traceparent) - setattr(self, 'tracestate', tracestate) - - def __repr__(self) -> str: - return f'TraceContext(traceparent={getattr(self, 'traceparent')!r}, tracestate={getattr(self, 'tracestate')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, TraceContext): - return NotImplemented - return getattr(self, 'traceparent') == getattr(other, 'traceparent') and getattr(self, 'tracestate') == getattr(other, 'tracestate') - - def __hash__(self) -> int: - return id(self) - - -class AgentIdentity: - """Display-level identity of the agent; all fields default to sensible values.""" - def __init__( - self, - *, - name: Optional[str], - id: Optional[str], - description: Optional[str], - ) -> None: - setattr(self, 'name', name) - setattr(self, 'id', id) - setattr(self, 'description', description) - - def __repr__(self) -> str: - return f'AgentIdentity(name={getattr(self, 'name')!r}, id={getattr(self, 'id')!r}, description={getattr(self, 'description')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, AgentIdentity): - return NotImplemented - return getattr(self, 'name') == getattr(other, 'name') and getattr(self, 'id') == getattr(other, 'id') and getattr(self, 'description') == getattr(other, 'description') - - def __hash__(self) -> int: - return id(self) - - -class AgentConfig: - """Configuration passed to the `agent` constructor. -Invalid config surfaces on the first `generate` as `invalid-input`.""" - def __init__( - self, - *, - model: Optional[ModelConfig], - model_params: Optional[ModelParams], - messages: Optional[list[Message]], - system_prompt: Optional[PromptInput], - tools: Optional[list[ToolSpec]], - agent_tools: Optional[list[AgentAsToolConfig]], - vended_tools: Optional[list[VendedTool]], - vended_plugins: Optional[list[VendedPlugin]], - mcp_clients: Optional[list[McpClientConfig]], - identity: Optional[AgentIdentity], - tool_executor: Optional[ToolExecutorStrategy], - display_output: Optional[bool], - trace_attributes: Optional[list[TraceAttribute]], - trace_context: Optional[TraceContext], - session: Optional[SessionConfig], - conversation_manager: Optional[ConversationManagerConfig], - retry: Optional[RetryConfig], - structured_output_schema: Optional[str], - app_state: Optional[str], - model_state: Optional[str], - ) -> None: - setattr(self, 'model', model) - setattr(self, 'model-params', model_params) - setattr(self, 'messages', messages) - setattr(self, 'system-prompt', system_prompt) - setattr(self, 'tools', tools) - setattr(self, 'agent-tools', agent_tools) - setattr(self, 'vended-tools', vended_tools) - setattr(self, 'vended-plugins', vended_plugins) - setattr(self, 'mcp-clients', mcp_clients) - setattr(self, 'identity', identity) - setattr(self, 'tool-executor', (_WitVariant('none') if tool_executor is None else _WitVariant('some', tool_executor))) - setattr(self, 'display-output', display_output) - setattr(self, 'trace-attributes', trace_attributes) - setattr(self, 'trace-context', trace_context) - setattr(self, 'session', session) - setattr(self, 'conversation-manager', conversation_manager) - setattr(self, 'retry', retry) - setattr(self, 'structured-output-schema', structured_output_schema) - setattr(self, 'app-state', app_state) - setattr(self, 'model-state', model_state) - - @property - def model_params(self) -> Optional[ModelParams]: - return getattr(self, 'model-params') - - @property - def system_prompt(self) -> Optional[PromptInput]: - return getattr(self, 'system-prompt') - - @property - def agent_tools(self) -> Optional[list[AgentAsToolConfig]]: - return getattr(self, 'agent-tools') - - @property - def vended_tools(self) -> Optional[list[VendedTool]]: - return getattr(self, 'vended-tools') - - @property - def vended_plugins(self) -> Optional[list[VendedPlugin]]: - return getattr(self, 'vended-plugins') - - @property - def mcp_clients(self) -> Optional[list[McpClientConfig]]: - return getattr(self, 'mcp-clients') - - @property - def tool_executor(self) -> Optional[ToolExecutorStrategy]: - return (None if getattr(self, 'tool-executor').tag == 'none' else getattr(self, 'tool-executor').payload) - - @property - def display_output(self) -> Optional[bool]: - return getattr(self, 'display-output') - - @property - def trace_attributes(self) -> Optional[list[TraceAttribute]]: - return getattr(self, 'trace-attributes') - - @property - def trace_context(self) -> Optional[TraceContext]: - return getattr(self, 'trace-context') - - @property - def conversation_manager(self) -> Optional[ConversationManagerConfig]: - return getattr(self, 'conversation-manager') - - @property - def structured_output_schema(self) -> Optional[str]: - return getattr(self, 'structured-output-schema') - - @property - def app_state(self) -> Optional[str]: - return getattr(self, 'app-state') - - @property - def model_state(self) -> Optional[str]: - return getattr(self, 'model-state') - - def __repr__(self) -> str: - return f'AgentConfig(model={getattr(self, 'model')!r}, model_params={getattr(self, 'model-params')!r}, messages={getattr(self, 'messages')!r}, system_prompt={getattr(self, 'system-prompt')!r}, tools={getattr(self, 'tools')!r}, agent_tools={getattr(self, 'agent-tools')!r}, vended_tools={getattr(self, 'vended-tools')!r}, vended_plugins={getattr(self, 'vended-plugins')!r}, mcp_clients={getattr(self, 'mcp-clients')!r}, identity={getattr(self, 'identity')!r}, tool_executor={getattr(self, 'tool-executor')!r}, display_output={getattr(self, 'display-output')!r}, trace_attributes={getattr(self, 'trace-attributes')!r}, trace_context={getattr(self, 'trace-context')!r}, session={getattr(self, 'session')!r}, conversation_manager={getattr(self, 'conversation-manager')!r}, retry={getattr(self, 'retry')!r}, structured_output_schema={getattr(self, 'structured-output-schema')!r}, app_state={getattr(self, 'app-state')!r}, model_state={getattr(self, 'model-state')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, AgentConfig): - return NotImplemented - return getattr(self, 'model') == getattr(other, 'model') and getattr(self, 'model-params') == getattr(other, 'model-params') and getattr(self, 'messages') == getattr(other, 'messages') and getattr(self, 'system-prompt') == getattr(other, 'system-prompt') and getattr(self, 'tools') == getattr(other, 'tools') and getattr(self, 'agent-tools') == getattr(other, 'agent-tools') and getattr(self, 'vended-tools') == getattr(other, 'vended-tools') and getattr(self, 'vended-plugins') == getattr(other, 'vended-plugins') and getattr(self, 'mcp-clients') == getattr(other, 'mcp-clients') and getattr(self, 'identity') == getattr(other, 'identity') and getattr(self, 'tool-executor') == getattr(other, 'tool-executor') and getattr(self, 'display-output') == getattr(other, 'display-output') and getattr(self, 'trace-attributes') == getattr(other, 'trace-attributes') and getattr(self, 'trace-context') == getattr(other, 'trace-context') and getattr(self, 'session') == getattr(other, 'session') and getattr(self, 'conversation-manager') == getattr(other, 'conversation-manager') and getattr(self, 'retry') == getattr(other, 'retry') and getattr(self, 'structured-output-schema') == getattr(other, 'structured-output-schema') and getattr(self, 'app-state') == getattr(other, 'app-state') and getattr(self, 'model-state') == getattr(other, 'model-state') - - def __hash__(self) -> int: - return id(self) - - -class InvokeArgs: - """Arguments for `agent.generate`.""" - def __init__( - self, - *, - input: PromptInput, - tools: Optional[list[ToolSpec]], - tool_choice: Optional[ToolChoice], - structured_output_schema: Optional[str], - ) -> None: - setattr(self, 'input', input) - setattr(self, 'tools', tools) - setattr(self, 'tool-choice', tool_choice) - setattr(self, 'structured-output-schema', structured_output_schema) - - @property - def tool_choice(self) -> Optional[ToolChoice]: - return getattr(self, 'tool-choice') - - @property - def structured_output_schema(self) -> Optional[str]: - return getattr(self, 'structured-output-schema') - - def __repr__(self) -> str: - return f'InvokeArgs(input={getattr(self, 'input')!r}, tools={getattr(self, 'tools')!r}, tool_choice={getattr(self, 'tool-choice')!r}, structured_output_schema={getattr(self, 'structured-output-schema')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, InvokeArgs): - return NotImplemented - return getattr(self, 'input') == getattr(other, 'input') and getattr(self, 'tools') == getattr(other, 'tools') and getattr(self, 'tool-choice') == getattr(other, 'tool-choice') and getattr(self, 'structured-output-schema') == getattr(other, 'structured-output-schema') - - def __hash__(self) -> int: - return id(self) - - -class RespondArgs: - """Payload supplied when resuming from a human-in-the-loop interrupt.""" - def __init__( - self, - *, - interrupt_id: str, - response: str, - ) -> None: - setattr(self, 'interrupt-id', interrupt_id) - setattr(self, 'response', response) - - @property - def interrupt_id(self) -> str: - return getattr(self, 'interrupt-id') - - def __repr__(self) -> str: - return f'RespondArgs(interrupt_id={getattr(self, 'interrupt-id')!r}, response={getattr(self, 'response')!r})' - - def __eq__(self, other: object) -> bool: - if not isinstance(other, RespondArgs): - return NotImplemented - return getattr(self, 'interrupt-id') == getattr(other, 'interrupt-id') and getattr(self, 'response') == getattr(other, 'response') - - def __hash__(self) -> int: - return id(self) - - -class AgentError: - """Why an agent-resource call failed.""" - pass - -class AgentError_NoSessionConfigured(AgentError, _WitVariantCase): - """The agent was constructed without a session config.""" - tag = 'no-session-configured' - -class AgentError_Storage(AgentError, _WitVariantCase): - """The storage backend rejected the operation.""" - tag = 'storage' - -class AgentError_InvalidInput(AgentError, _WitVariantCase): - """Supplied payload did not match the expected shape.""" - tag = 'invalid-input' - -class AgentError_UnknownInterrupt(AgentError, _WitVariantCase): - """Supplied `interrupt-id` does not match any live interrupt.""" - tag = 'unknown-interrupt' - -class AgentError_Internal(AgentError, _WitVariantCase): - """Catch-all for internal failures.""" - tag = 'internal' - -_AgentError_CASES: dict[str, type] = { - 'no-session-configured': AgentError_NoSessionConfigured, - 'storage': AgentError_Storage, - 'invalid-input': AgentError_InvalidInput, - 'unknown-interrupt': AgentError_UnknownInterrupt, - 'internal': AgentError_Internal, -} - -def _AgentError_lift(raw: _WitVariant) -> AgentError: - cls = _AgentError_CASES.get(raw.tag) - if cls is None: - raise ValueError(f'unknown AgentError arm: {raw.tag!r}') - return cls(raw.payload) -AgentError.lift = staticmethod(_AgentError_lift) # type: ignore[attr-defined] - -class Agent: - """An agent instance. Persistent across `generate` calls.""" - # Wraps a wasmtime-py ResourceAny / ResourceHost handle. - # The runtime sets ._handle to the underlying resource and - # ._invoke to a callable that dispatches a method by WIT name. - - def __init__(self, handle: Any, invoke: Any = None) -> None: - self._handle = handle - self._invoke = invoke - - @staticmethod - def new(config: AgentConfig, *, invoke: Any) -> 'Agent': - return Agent(invoke('[constructor]agent', (config,)), invoke) - - def generate(self, args: InvokeArgs) -> Any: - return self._invoke('[method]agent.generate', (self._handle, args,)) - - def get_messages(self) -> list[Message]: - return self._invoke('[method]agent.get-messages', (self._handle,)) - - def set_messages(self, messages: list[Message]) -> Any: - return self._invoke('[method]agent.set-messages', (self._handle, messages,)) - - def get_app_state(self) -> str: - return self._invoke('[method]agent.get-app-state', (self._handle,)) - - def set_app_state(self, json: str) -> Any: - return self._invoke('[method]agent.set-app-state', (self._handle, json,)) - - def get_model_state(self) -> str: - return self._invoke('[method]agent.get-model-state', (self._handle,)) - - def set_model_state(self, json: str) -> Any: - return self._invoke('[method]agent.set-model-state', (self._handle, json,)) - - def get_traces(self) -> list[AgentTrace]: - return self._invoke('[method]agent.get-traces', (self._handle,)) - - def get_metrics(self) -> AgentMetrics: - return self._invoke('[method]agent.get-metrics', (self._handle,)) - - def save_session(self) -> Any: - return self._invoke('[method]agent.save-session', (self._handle,)) - - def list_snapshots(self) -> Any: - return self._invoke('[method]agent.list-snapshots', (self._handle,)) - - def delete_session(self) -> Any: - return self._invoke('[method]agent.delete-session', (self._handle,)) - - -class EventStream: - """Pull-based stream of agent events; sync-WIT placeholder for `stream`.""" - # Wraps a wasmtime-py ResourceAny / ResourceHost handle. - # The runtime sets ._handle to the underlying resource and - # ._invoke to a callable that dispatches a method by WIT name. - - def __init__(self, handle: Any, invoke: Any = None) -> None: - self._handle = handle - self._invoke = invoke - - def read(self) -> Optional[StreamEvent]: - return self._invoke('[method]event-stream.read', (self._handle,)) - - -class ResponseStream: - """Handle to an in-flight `generate` invocation.""" - # Wraps a wasmtime-py ResourceAny / ResourceHost handle. - # The runtime sets ._handle to the underlying resource and - # ._invoke to a callable that dispatches a method by WIT name. - - def __init__(self, handle: Any, invoke: Any = None) -> None: - self._handle = handle - self._invoke = invoke - - def events(self) -> Any: - return self._invoke('[method]response-stream.events', (self._handle,)) - - def respond(self, args: RespondArgs) -> Any: - return self._invoke('[method]response-stream.respond', (self._handle, args,)) - - def cancel(self) -> None: - return self._invoke('[method]response-stream.cancel', (self._handle,)) - diff --git a/strands-py-wasm/src/strands/_generated/__init__.py b/strands-py-wasm/src/strands/_generated/__init__.py new file mode 100644 index 0000000000..75043cd245 --- /dev/null +++ b/strands-py-wasm/src/strands/_generated/__init__.py @@ -0,0 +1,234 @@ +"""Auto-generated by bindgen. Do not edit.""" +# ruff: noqa: E501, F401, I001 +# fmt: off + +from __future__ import annotations + +from typing import Any + +from wasmtime.component import Variant as _WitVariant + +from .strands_agent.api import Agent, AgentConfig, AgentError, AgentIdentity, AttributeValue, ConcurrentOptions, EventStream, InvokeArgs, RespondArgs, ResponseStream, ToolExecutorStrategy, TraceAttribute, TraceContext +from .strands_agent.conversation import ConversationManagerConfig, SlidingWindowConversationManager, SummarizingConversationManager +from .strands_agent.edge_handler_registry import EdgeHandlerError, HandlerState +from .strands_agent.elicitation_handler import ElicitAction, ElicitRequest, ElicitResponse, ElicitationError +from .strands_agent.host_log import LogEntry, LogLevel +from .strands_agent.mcp import EnvVar, HttpHeader, HttpTransport, McpClientConfig, McpConnectionState, McpTransport, SseTransport, StdioTransport, TasksConfig +from .strands_agent.messages import CacheKind, CachePointBlock, Citation, CitationLocation, CitationText, CitationsBlock, ContentBlock, DocumentBlock, DocumentCitationsConfig, DocumentRange, DocumentSource, GuardContentBlock, GuardContentImage, GuardContentText, GuardQualifier, ImageBlock, ImageSource, InterruptResponseBlock, JsonBlock, Message, MessageMetadata, Metrics, PromptInput, ReasoningBlock, Role, S3Location, SearchResultRange, TextBlock, ToolResultBlock, ToolResultContent, ToolResultStatus, ToolUseBlock, Usage, VideoBlock, VideoSource, WebLocation +from .strands_agent.model_provider import CountTokensArgs, ModelEventStream, ModelStreamOptions, StartStreamArgs +from .strands_agent.models import AnthropicModel, BedrockModel, CustomModel, GoogleModel, ModelConfig, ModelError, ModelParams, OpenaiModel +from .strands_agent.multi_agent import AgentNode, EdgeConfig, EdgeHandler, GraphConfig, HandoffEvent, MultiAgentInvokeArgs, MultiAgentNode, MultiAgentResult, MultiAgentStreamEvent, NodeConfig, NodeError, NodeEventData, NodeKind, NodeResult, NodeStartData, OrchestrationStatus, SwarmConfig, TerminalStatus +from .strands_agent.retry import BackoffStrategy, ConstantBackoff, ExponentialBackoff, JitterKind, LinearBackoff, ModelRetryStrategy, RetryConfig +from .strands_agent.sessions import ConversationManagerState, CustomStorage, FileStorage, PluginStateEntry, RetryStrategyState, S3Storage, SaveLatestPolicy, SessionManager, SlidingWindowState, Snapshot, SnapshotData, SnapshotLocation, SnapshotManifest, SnapshotScope, StorageConfig, StorageError, SummarizingState +from .strands_agent.snapshot_storage import DeleteSessionArgs, ListSnapshotIdsArgs, LoadSnapshotArgs, ManifestArgs, SaveManifestArgs, SaveSnapshotArgs +from .strands_agent.snapshot_trigger_handler import TriggerError, TriggerParams +from .strands_agent.streaming import AfterInvocationData, AfterModelCallData, AfterToolCallData, AgentLoopMetrics, AgentMetrics, AgentResultData, AgentTrace, BeforeInvocationData, BeforeModelCallData, BeforeToolCallData, ContentBlockData, HookRedaction, InputRedaction, Interrupt, InvocationMetrics, MessageAddedData, MetadataEvent, ModelMessageData, ModelStopData, ModelStreamUpdateData, OutputRedaction, RedactionEvent, StopEvent, StopReason, StreamError, StreamEvent, ToolMetrics, ToolResultData, ToolStreamUpdateData, ToolUseData, ToolsBatchData, TraceMetadataEntry +from .strands_agent.tools import AgentAsToolConfig, CallToolArgs, ToolChoice, ToolError, ToolEventStream, ToolSpec, ToolStreamEvent +from .strands_agent.vended import AgentSkills, BashTool, ContextOffloader, FileEditorTool, HttpRequestTool, NotebookTool, SkillSource, VendedPlugin, VendedTool +from .wasi_clocks.monotonic_clock import Duration, Instant +from .wasi_clocks.wall_clock import Datetime +from .wasi_io.error import Error +from .wasi_io.poll import Pollable +from .wasi_io.streams import InputStream, OutputStream + + +def ok(value: Any = None) -> _WitVariant: + """Wrap ``value`` as the ``ok`` arm of a ``result``.""" + return _WitVariant("ok", value) + + +def err(value: Any = None) -> _WitVariant: + """Wrap ``value`` as the ``err`` arm of a ``result``.""" + return _WitVariant("err", value) + + +__all__ = [ + 'AfterInvocationData', + 'AfterModelCallData', + 'AfterToolCallData', + 'Agent', + 'AgentAsToolConfig', + 'AgentConfig', + 'AgentError', + 'AgentIdentity', + 'AgentLoopMetrics', + 'AgentMetrics', + 'AgentNode', + 'AgentResultData', + 'AgentSkills', + 'AgentTrace', + 'AnthropicModel', + 'AttributeValue', + 'BackoffStrategy', + 'BashTool', + 'BedrockModel', + 'BeforeInvocationData', + 'BeforeModelCallData', + 'BeforeToolCallData', + 'CacheKind', + 'CachePointBlock', + 'CallToolArgs', + 'Citation', + 'CitationLocation', + 'CitationText', + 'CitationsBlock', + 'ConcurrentOptions', + 'ConstantBackoff', + 'ContentBlock', + 'ContentBlockData', + 'ContextOffloader', + 'ConversationManagerConfig', + 'ConversationManagerState', + 'CountTokensArgs', + 'CustomModel', + 'CustomStorage', + 'Datetime', + 'DeleteSessionArgs', + 'DocumentBlock', + 'DocumentCitationsConfig', + 'DocumentRange', + 'DocumentSource', + 'Duration', + 'EdgeConfig', + 'EdgeHandler', + 'EdgeHandlerError', + 'ElicitAction', + 'ElicitRequest', + 'ElicitResponse', + 'ElicitationError', + 'EnvVar', + 'Error', + 'EventStream', + 'ExponentialBackoff', + 'FileEditorTool', + 'FileStorage', + 'GoogleModel', + 'GraphConfig', + 'GuardContentBlock', + 'GuardContentImage', + 'GuardContentText', + 'GuardQualifier', + 'HandlerState', + 'HandoffEvent', + 'HookRedaction', + 'HttpHeader', + 'HttpRequestTool', + 'HttpTransport', + 'ImageBlock', + 'ImageSource', + 'InputRedaction', + 'InputStream', + 'Instant', + 'Interrupt', + 'InterruptResponseBlock', + 'InvocationMetrics', + 'InvokeArgs', + 'JitterKind', + 'JsonBlock', + 'LinearBackoff', + 'ListSnapshotIdsArgs', + 'LoadSnapshotArgs', + 'LogEntry', + 'LogLevel', + 'ManifestArgs', + 'McpClientConfig', + 'McpConnectionState', + 'McpTransport', + 'Message', + 'MessageAddedData', + 'MessageMetadata', + 'MetadataEvent', + 'Metrics', + 'ModelConfig', + 'ModelError', + 'ModelEventStream', + 'ModelMessageData', + 'ModelParams', + 'ModelRetryStrategy', + 'ModelStopData', + 'ModelStreamOptions', + 'ModelStreamUpdateData', + 'MultiAgentInvokeArgs', + 'MultiAgentNode', + 'MultiAgentResult', + 'MultiAgentStreamEvent', + 'NodeConfig', + 'NodeError', + 'NodeEventData', + 'NodeKind', + 'NodeResult', + 'NodeStartData', + 'NotebookTool', + 'OpenaiModel', + 'OrchestrationStatus', + 'OutputRedaction', + 'OutputStream', + 'PluginStateEntry', + 'Pollable', + 'PromptInput', + 'ReasoningBlock', + 'RedactionEvent', + 'RespondArgs', + 'ResponseStream', + 'RetryConfig', + 'RetryStrategyState', + 'Role', + 'S3Location', + 'S3Storage', + 'SaveLatestPolicy', + 'SaveManifestArgs', + 'SaveSnapshotArgs', + 'SearchResultRange', + 'SessionManager', + 'SkillSource', + 'SlidingWindowConversationManager', + 'SlidingWindowState', + 'Snapshot', + 'SnapshotData', + 'SnapshotLocation', + 'SnapshotManifest', + 'SnapshotScope', + 'SseTransport', + 'StartStreamArgs', + 'StdioTransport', + 'StopEvent', + 'StopReason', + 'StorageConfig', + 'StorageError', + 'StreamError', + 'StreamEvent', + 'SummarizingConversationManager', + 'SummarizingState', + 'SwarmConfig', + 'TasksConfig', + 'TerminalStatus', + 'TextBlock', + 'ToolChoice', + 'ToolError', + 'ToolEventStream', + 'ToolExecutorStrategy', + 'ToolMetrics', + 'ToolResultBlock', + 'ToolResultContent', + 'ToolResultData', + 'ToolResultStatus', + 'ToolSpec', + 'ToolStreamEvent', + 'ToolStreamUpdateData', + 'ToolUseBlock', + 'ToolUseData', + 'ToolsBatchData', + 'TraceAttribute', + 'TraceContext', + 'TraceMetadataEntry', + 'TriggerError', + 'TriggerParams', + 'Usage', + 'VendedPlugin', + 'VendedTool', + 'VideoBlock', + 'VideoSource', + 'WebLocation', + 'err', + 'ok', +] diff --git a/strands-py-wasm/src/strands/_generated/strands_agent/__init__.py b/strands-py-wasm/src/strands/_generated/strands_agent/__init__.py new file mode 100644 index 0000000000..24687757ce --- /dev/null +++ b/strands-py-wasm/src/strands/_generated/strands_agent/__init__.py @@ -0,0 +1 @@ +"""Auto-generated by bindgen. Do not edit.""" diff --git a/strands-py-wasm/src/strands/_generated/strands_agent/api.py b/strands-py-wasm/src/strands/_generated/strands_agent/api.py new file mode 100644 index 0000000000..d5f03c03f1 --- /dev/null +++ b/strands-py-wasm/src/strands/_generated/strands_agent/api.py @@ -0,0 +1,221 @@ +"""Auto-generated by bindgen. Do not edit.""" +# ruff: noqa: E501, F401, I001 +# fmt: off + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Optional, Union + +from wasmtime.component import Variant as _WitVariant +from wasmtime.component import VariantCase as _WitVariantCase + +from .conversation import ConversationManagerConfig +from .mcp import McpClientConfig +from .messages import ContentBlock, Message, PromptInput +from .models import ModelConfig, ModelParams +from .retry import RetryConfig +from .sessions import SessionManager, StorageError +from .streaming import AgentMetrics, AgentTrace, StreamEvent +from .tools import AgentAsToolConfig, ToolChoice, ToolSpec +from .vended import VendedPlugin, VendedTool + + +@dataclass(kw_only=True) +class ConcurrentOptions: + """Concurrent-execution options.""" + max_concurrency: int | None + + +ToolExecutorStrategy = None | ConcurrentOptions +"""Strategy for executing tool calls emitted in a single assistant turn.""" + +AttributeValue = str | int | float | bool +"""Scalar attribute value attached to a trace.""" + +@dataclass(kw_only=True) +class TraceAttribute: + """Key-value pair attached to every OpenTelemetry span the agent emits. +Distinct from `streaming.trace-metadata-entry`, which is string-only.""" + key: str + value: AttributeValue + + +@dataclass(kw_only=True) +class TraceContext: + """W3C Trace Context headers linking the agent's spans to a caller's trace.""" + traceparent: str + tracestate: str | None + + +@dataclass(kw_only=True) +class AgentIdentity: + """Display-level identity of the agent; all fields default to sensible values.""" + name: str | None + id: str | None + description: str | None + + +@dataclass(kw_only=True) +class AgentConfig: + """Configuration passed to the `agent` constructor. +Invalid config surfaces on the first `generate` as `invalid-input`.""" + model: ModelConfig | None + model_params: ModelParams | None + messages: list[Message] | None + system_prompt: PromptInput | None + tools: list[ToolSpec] | None + agent_tools: list[AgentAsToolConfig] | None + vended_tools: list[VendedTool] | None + vended_plugins: list[VendedPlugin] | None + mcp_clients: list[McpClientConfig] | None + identity: AgentIdentity | None + tool_executor: ToolExecutorStrategy | None + display_output: bool | None + trace_attributes: list[TraceAttribute] | None + trace_context: TraceContext | None + session: SessionManager | None + conversation_manager: ConversationManagerConfig | None + retry: RetryConfig | None + structured_output_schema: str | None + app_state: str | None + model_state: str | None + + +@dataclass(kw_only=True) +class InvokeArgs: + """Arguments for `agent.generate`.""" + input: PromptInput + tools: list[ToolSpec] | None + tool_choice: ToolChoice | None + structured_output_schema: str | None + + +@dataclass(kw_only=True) +class RespondArgs: + """Payload supplied when resuming from a human-in-the-loop interrupt.""" + interrupt_id: str + response: str + + +class AgentError: + """Why an agent-resource call failed.""" + + class NoSessionConfigured(_WitVariantCase): + """The agent was constructed without a session config.""" + tag = 'no-session-configured' + + class Storage(_WitVariantCase): + """The storage backend rejected the operation.""" + tag = 'storage' + + class InvalidInput(_WitVariantCase): + """Supplied payload did not match the expected shape.""" + tag = 'invalid-input' + + class UnknownInterrupt(_WitVariantCase): + """Supplied `interrupt-id` does not match any live interrupt.""" + tag = 'unknown-interrupt' + + class Internal(_WitVariantCase): + """Catch-all for internal failures.""" + tag = 'internal' + + _CASES: dict[str, type] = { + 'no-session-configured': NoSessionConfigured, + 'storage': Storage, + 'invalid-input': InvalidInput, + 'unknown-interrupt': UnknownInterrupt, + 'internal': Internal, + } + + @staticmethod + def lift(raw: _WitVariant) -> AgentError: + cls = AgentError._CASES.get(raw.tag) + if cls is None: + raise ValueError(f'unknown AgentError arm: {raw.tag!r}') + return cls(raw.payload) + +class Agent: + """An agent instance. Persistent across `generate` calls.""" + # Wraps a wasmtime-py ResourceAny / ResourceHost handle. + # The runtime sets ._handle to the underlying resource and + # ._invoke to a callable that dispatches a method by WIT name. + + def __init__(self, handle: Any, invoke: Any = None) -> None: + self._handle = handle + self._invoke = invoke + + @staticmethod + def new(config: AgentConfig, *, invoke: Any) -> Agent: + return Agent(invoke('[constructor]agent', (config,)), invoke) + + def generate(self, args: InvokeArgs) -> Any: + return self._invoke('[method]agent.generate', (self._handle, args,)) + + def get_messages(self) -> list[Message]: + return self._invoke('[method]agent.get-messages', (self._handle,)) + + def set_messages(self, messages: list[Message]) -> Any: + return self._invoke('[method]agent.set-messages', (self._handle, messages,)) + + def get_app_state(self) -> str: + return self._invoke('[method]agent.get-app-state', (self._handle,)) + + def set_app_state(self, json: str) -> Any: + return self._invoke('[method]agent.set-app-state', (self._handle, json,)) + + def get_model_state(self) -> str: + return self._invoke('[method]agent.get-model-state', (self._handle,)) + + def set_model_state(self, json: str) -> Any: + return self._invoke('[method]agent.set-model-state', (self._handle, json,)) + + def get_traces(self) -> list[AgentTrace]: + return self._invoke('[method]agent.get-traces', (self._handle,)) + + def get_metrics(self) -> AgentMetrics: + return self._invoke('[method]agent.get-metrics', (self._handle,)) + + def save_session(self) -> Any: + return self._invoke('[method]agent.save-session', (self._handle,)) + + def list_snapshots(self) -> Any: + return self._invoke('[method]agent.list-snapshots', (self._handle,)) + + def delete_session(self) -> Any: + return self._invoke('[method]agent.delete-session', (self._handle,)) + + +class EventStream: + """Pull-based stream of agent events; sync-WIT placeholder for `stream`.""" + # Wraps a wasmtime-py ResourceAny / ResourceHost handle. + # The runtime sets ._handle to the underlying resource and + # ._invoke to a callable that dispatches a method by WIT name. + + def __init__(self, handle: Any, invoke: Any = None) -> None: + self._handle = handle + self._invoke = invoke + + def read(self) -> StreamEvent | None: + return self._invoke('[method]event-stream.read', (self._handle,)) + + +class ResponseStream: + """Handle to an in-flight `generate` invocation.""" + # Wraps a wasmtime-py ResourceAny / ResourceHost handle. + # The runtime sets ._handle to the underlying resource and + # ._invoke to a callable that dispatches a method by WIT name. + + def __init__(self, handle: Any, invoke: Any = None) -> None: + self._handle = handle + self._invoke = invoke + + def events(self) -> Any: + return self._invoke('[method]response-stream.events', (self._handle,)) + + def respond(self, args: RespondArgs) -> Any: + return self._invoke('[method]response-stream.respond', (self._handle, args,)) + + def cancel(self) -> None: + return self._invoke('[method]response-stream.cancel', (self._handle,)) diff --git a/strands-py-wasm/src/strands/_generated/strands_agent/conversation.py b/strands-py-wasm/src/strands/_generated/strands_agent/conversation.py new file mode 100644 index 0000000000..a3c793429a --- /dev/null +++ b/strands-py-wasm/src/strands/_generated/strands_agent/conversation.py @@ -0,0 +1,57 @@ +"""Auto-generated by bindgen. Do not edit.""" +# ruff: noqa: E501, F401, I001 +# fmt: off + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Optional, Union + +from wasmtime.component import Variant as _WitVariant +from wasmtime.component import VariantCase as _WitVariantCase + +from .models import ModelConfig + + +@dataclass(kw_only=True) +class SlidingWindowConversationManager: + """Sliding-window strategy: trim oldest messages once the conversation +exceeds `window-size`.""" + window_size: int + should_truncate_results: bool + + +@dataclass(kw_only=True) +class SummarizingConversationManager: + """Summarizing strategy: once the conversation grows, summarize older +messages into a single summary message and keep the rest verbatim.""" + summary_ratio: float + preserve_recent_messages: int + summarization_system_prompt: str | None + summarization_model: ModelConfig | None + + +class ConversationManagerConfig: + """Which conversation manager the agent uses. Wrapped in +``option<>`` at the call site; ``none`` means history grows without +bound and context-overflow errors propagate to the caller.""" + + class SlidingWindow(_WitVariantCase): + """Sliding-window trimming.""" + tag = 'sliding-window' + + class Summarizing(_WitVariantCase): + """Summarization of older messages.""" + tag = 'summarizing' + + _CASES: dict[str, type] = { + 'sliding-window': SlidingWindow, + 'summarizing': Summarizing, + } + + @staticmethod + def lift(raw: _WitVariant) -> ConversationManagerConfig: + cls = ConversationManagerConfig._CASES.get(raw.tag) + if cls is None: + raise ValueError(f'unknown ConversationManagerConfig arm: {raw.tag!r}') + return cls(raw.payload) diff --git a/strands-py-wasm/src/strands/_generated/strands_agent/edge_handler_registry.py b/strands-py-wasm/src/strands/_generated/strands_agent/edge_handler_registry.py new file mode 100644 index 0000000000..418a5cb16c --- /dev/null +++ b/strands-py-wasm/src/strands/_generated/strands_agent/edge_handler_registry.py @@ -0,0 +1,44 @@ +"""Auto-generated by bindgen. Do not edit.""" +# ruff: noqa: E501, F401, I001 +# fmt: off + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Optional, Union + +from wasmtime.component import Variant as _WitVariant +from wasmtime.component import VariantCase as _WitVariantCase + +from .multi_agent import NodeResult + + +class EdgeHandlerError: + """Why an edge evaluation failed.""" + + class Unknown(_WitVariantCase): + """No handler registered for the given id.""" + tag = 'unknown' + + class Failed(_WitVariantCase): + """Handler raised an exception.""" + tag = 'failed' + + _CASES: dict[str, type] = { + 'unknown': Unknown, + 'failed': Failed, + } + + @staticmethod + def lift(raw: _WitVariant) -> EdgeHandlerError: + cls = EdgeHandlerError._CASES.get(raw.tag) + if cls is None: + raise ValueError(f'unknown EdgeHandlerError arm: {raw.tag!r}') + return cls(raw.payload) + +@dataclass(kw_only=True) +class HandlerState: + """State snapshot passed to `evaluate` so the handler can branch on +prior node results.""" + results: list[NodeResult] + execution_count: int diff --git a/strands-py-wasm/src/strands/_generated/strands_agent/elicitation_handler.py b/strands-py-wasm/src/strands/_generated/strands_agent/elicitation_handler.py new file mode 100644 index 0000000000..8c6bad303f --- /dev/null +++ b/strands-py-wasm/src/strands/_generated/strands_agent/elicitation_handler.py @@ -0,0 +1,68 @@ +"""Auto-generated by bindgen. Do not edit.""" +# ruff: noqa: E501, F401, I001 +# fmt: off + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Optional, Union + +from wasmtime.component import Variant as _WitVariant +from wasmtime.component import VariantCase as _WitVariantCase + + +@dataclass(kw_only=True) +class ElicitRequest: + """Request for user input.""" + client_id: str + message: str + request: str + + +class ElicitAction(str): + """Outcome of an elicitation request.""" + __slots__ = () + + ACCEPT: ElicitAction + DECLINE: ElicitAction + CANCEL: ElicitAction + +ElicitAction.ACCEPT = ElicitAction('accept') # type: ignore[attr-defined] +ElicitAction.DECLINE = ElicitAction('decline') # type: ignore[attr-defined] +ElicitAction.CANCEL = ElicitAction('cancel') # type: ignore[attr-defined] + + +@dataclass(kw_only=True) +class ElicitResponse: + """Response to an elicitation request.""" + action: ElicitAction + content: str | None + + +class ElicitationError: + """Why an elicitation call failed.""" + + class UnknownClient(_WitVariantCase): + """No handler registered for the given `client-id`.""" + tag = 'unknown-client' + + class HandlerFailed(_WitVariantCase): + """Handler raised an exception.""" + tag = 'handler-failed' + + class TimedOut(_WitVariantCase): + """Request timed out waiting for a human response.""" + tag = 'timed-out' + + _CASES: dict[str, type] = { + 'unknown-client': UnknownClient, + 'handler-failed': HandlerFailed, + 'timed-out': TimedOut, + } + + @staticmethod + def lift(raw: _WitVariant) -> ElicitationError: + cls = ElicitationError._CASES.get(raw.tag) + if cls is None: + raise ValueError(f'unknown ElicitationError arm: {raw.tag!r}') + return cls(raw.payload) diff --git a/strands-py-wasm/src/strands/_generated/strands_agent/host_log.py b/strands-py-wasm/src/strands/_generated/strands_agent/host_log.py new file mode 100644 index 0000000000..712f78efe2 --- /dev/null +++ b/strands-py-wasm/src/strands/_generated/strands_agent/host_log.py @@ -0,0 +1,36 @@ +"""Auto-generated by bindgen. Do not edit.""" +# ruff: noqa: E501, F401, I001 +# fmt: off + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Optional, Union + +from wasmtime.component import Variant as _WitVariant +from wasmtime.component import VariantCase as _WitVariantCase + + +class LogLevel(str): + """Severity level of a log entry.""" + __slots__ = () + + TRACE: LogLevel + DEBUG: LogLevel + INFO: LogLevel + WARN: LogLevel + ERROR: LogLevel + +LogLevel.TRACE = LogLevel('trace') # type: ignore[attr-defined] +LogLevel.DEBUG = LogLevel('debug') # type: ignore[attr-defined] +LogLevel.INFO = LogLevel('info') # type: ignore[attr-defined] +LogLevel.WARN = LogLevel('warn') # type: ignore[attr-defined] +LogLevel.ERROR = LogLevel('error') # type: ignore[attr-defined] + + +@dataclass(kw_only=True) +class LogEntry: + """A single structured log entry.""" + level: LogLevel + message: str + context: str | None diff --git a/strands-py-wasm/src/strands/_generated/strands_agent/mcp.py b/strands-py-wasm/src/strands/_generated/strands_agent/mcp.py new file mode 100644 index 0000000000..fea54f3d37 --- /dev/null +++ b/strands-py-wasm/src/strands/_generated/strands_agent/mcp.py @@ -0,0 +1,114 @@ +"""Auto-generated by bindgen. Do not edit.""" +# ruff: noqa: E501, F401, I001 +# fmt: off + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Optional, Union + +from wasmtime.component import Variant as _WitVariant +from wasmtime.component import VariantCase as _WitVariantCase + +from .tools import ToolSpec + + +Duration = int + +class McpConnectionState(str): + """Connection state of an MCP client.""" + __slots__ = () + + DISCONNECTED: McpConnectionState + CONNECTED: McpConnectionState + FAILED: McpConnectionState + +McpConnectionState.DISCONNECTED = McpConnectionState('disconnected') # type: ignore[attr-defined] +McpConnectionState.CONNECTED = McpConnectionState('connected') # type: ignore[attr-defined] +McpConnectionState.FAILED = McpConnectionState('failed') # type: ignore[attr-defined] + + +@dataclass(kw_only=True) +class EnvVar: + """Single environment variable entry.""" + key: str + value: str + + +@dataclass(kw_only=True) +class StdioTransport: + """STDIO transport configuration.""" + command: str + args: list[str] + env: list[EnvVar] + cwd: str | None + + +@dataclass(kw_only=True) +class HttpHeader: + """Single HTTP header entry.""" + name: str + value: str + + +@dataclass(kw_only=True) +class HttpTransport: + """HTTP transport configuration.""" + url: str + headers: list[HttpHeader] + + +@dataclass(kw_only=True) +class SseTransport: + """SSE transport configuration.""" + url: str + headers: list[HttpHeader] + + +class McpTransport: + """How the client talks to the MCP server.""" + + class Stdio(_WitVariantCase): + """STDIO transport. Spawn a local process and talk via pipes.""" + tag = 'stdio' + + class StreamableHttp(_WitVariantCase): + """Streamable HTTP transport, per the current MCP specification.""" + tag = 'streamable-http' + + class Sse(_WitVariantCase): + """Legacy Server-Sent Events transport. Retained for older servers.""" + tag = 'sse' + + _CASES: dict[str, type] = { + 'stdio': Stdio, + 'streamable-http': StreamableHttp, + 'sse': Sse, + } + + @staticmethod + def lift(raw: _WitVariant) -> McpTransport: + cls = McpTransport._CASES.get(raw.tag) + if cls is None: + raise ValueError(f'unknown McpTransport arm: {raw.tag!r}') + return cls(raw.payload) + +@dataclass(kw_only=True) +class TasksConfig: + """Task-augmented tool execution. Enables long-running tools with +progress tracking. Experimental in the MCP specification.""" + ttl: int + poll_timeout: int + + +@dataclass(kw_only=True) +class McpClientConfig: + """MCP client configuration.""" + client_id: str + application_name: str | None + application_version: str | None + transport: McpTransport + tasks_config: TasksConfig | None + elicitation_enabled: bool + fail_open: bool + disable_instrumentation: bool diff --git a/strands-py-wasm/src/strands/_generated/strands_agent/messages.py b/strands-py-wasm/src/strands/_generated/strands_agent/messages.py new file mode 100644 index 0000000000..d8cb1651be --- /dev/null +++ b/strands-py-wasm/src/strands/_generated/strands_agent/messages.py @@ -0,0 +1,424 @@ +"""Auto-generated by bindgen. Do not edit.""" +# ruff: noqa: E501, F401, I001 +# fmt: off + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Optional, Union + +from wasmtime.component import Variant as _WitVariant +from wasmtime.component import VariantCase as _WitVariantCase + + +@dataclass(kw_only=True) +class TextBlock: + """Plain text.""" + text: str + + +@dataclass(kw_only=True) +class S3Location: + """Object stored in Amazon S3.""" + uri: str + bucket_owner: str | None + + +ImageSource = bytes | str | S3Location +"""Source of image bytes.""" + +@dataclass(kw_only=True) +class ImageBlock: + """Image attached to a message.""" + format: str + source: ImageSource + + +VideoSource = bytes | S3Location +"""Source of video bytes.""" + +@dataclass(kw_only=True) +class VideoBlock: + """Video attached to a message.""" + format: str + source: VideoSource + + +DocumentSource = bytes | str | list[TextBlock] | S3Location +"""Source of document bytes.""" + +@dataclass(kw_only=True) +class DocumentCitationsConfig: + """Citation configuration attached to a document.""" + enabled: bool + + +@dataclass(kw_only=True) +class DocumentBlock: + """Document attached to a message.""" + name: str + format: str + source: DocumentSource + citations: DocumentCitationsConfig | None + context: str | None + + +@dataclass(kw_only=True) +class ReasoningBlock: + """Model's thought process. Either plain reasoning (with an optional +signature) or an opaque redacted blob.""" + text: str | None + signature: str | None + redacted_content: bytes | None + + +class CacheKind(str): + """Prompt-caching kind. More arms will be added as providers surface +additional cache tiers (e.g. Anthropic's `ephemeral`).""" + __slots__ = () + + DEFAULT_CACHE: CacheKind + +CacheKind.DEFAULT_CACHE = CacheKind('default-cache') # type: ignore[attr-defined] + + +@dataclass(kw_only=True) +class CachePointBlock: + """Marks a caching boundary in the prompt.""" + kind: CacheKind + + +class GuardQualifier(str): + """How a piece of guard content should be evaluated.""" + __slots__ = () + + GROUNDING_SOURCE: GuardQualifier + QUERY: GuardQualifier + GUARD_CONTENT: GuardQualifier + +GuardQualifier.GROUNDING_SOURCE = GuardQualifier('grounding-source') # type: ignore[attr-defined] +GuardQualifier.QUERY = GuardQualifier('query') # type: ignore[attr-defined] +GuardQualifier.GUARD_CONTENT = GuardQualifier('guard-content') # type: ignore[attr-defined] + + +@dataclass(kw_only=True) +class GuardContentText: + """Text submitted to a guardrail for evaluation.""" + qualifiers: list[GuardQualifier] + text: str + + +@dataclass(kw_only=True) +class GuardContentImage: + """Image submitted to a guardrail for evaluation.""" + format: str + bytes: bytes + + +class GuardContentBlock: + """Content submitted to a guardrail for evaluation.""" + + class Text(_WitVariantCase): + """Text guard content.""" + tag = 'text' + + class Image(_WitVariantCase): + """Image guard content.""" + tag = 'image' + + _CASES: dict[str, type] = { + 'text': Text, + 'image': Image, + } + + @staticmethod + def lift(raw: _WitVariant) -> GuardContentBlock: + cls = GuardContentBlock._CASES.get(raw.tag) + if cls is None: + raise ValueError(f'unknown GuardContentBlock arm: {raw.tag!r}') + return cls(raw.payload) + +@dataclass(kw_only=True) +class DocumentRange: + """Range within a source document (characters, pages, or chunks).""" + document_index: int + start: int + end: int + + +@dataclass(kw_only=True) +class SearchResultRange: + """Range within a search result.""" + search_result_index: int + start: int + end: int + + +@dataclass(kw_only=True) +class WebLocation: + """Web citation target.""" + url: str + domain: str | None + + +class CitationLocation: + """Anchor a citation points to.""" + + class DocumentChar(_WitVariantCase): + """Character range within a document.""" + tag = 'document-char' + + class DocumentPage(_WitVariantCase): + """Page range within a document.""" + tag = 'document-page' + + class DocumentChunk(_WitVariantCase): + """Chunk range within a document.""" + tag = 'document-chunk' + + class SearchResult(_WitVariantCase): + """Range within a search result.""" + tag = 'search-result' + + class Web(_WitVariantCase): + """Web page.""" + tag = 'web' + + _CASES: dict[str, type] = { + 'document-char': DocumentChar, + 'document-page': DocumentPage, + 'document-chunk': DocumentChunk, + 'search-result': SearchResult, + 'web': Web, + } + + @staticmethod + def lift(raw: _WitVariant) -> CitationLocation: + cls = CitationLocation._CASES.get(raw.tag) + if cls is None: + raise ValueError(f'unknown CitationLocation arm: {raw.tag!r}') + return cls(raw.payload) + +@dataclass(kw_only=True) +class CitationText: + """Text fragment from a source or a generated answer.""" + text: str + + +@dataclass(kw_only=True) +class Citation: + """Link from generated content back to a source location.""" + location: CitationLocation + source: str + source_content: list[CitationText] + title: str + + +@dataclass(kw_only=True) +class CitationsBlock: + """Citations emitted by the model when citations are enabled.""" + citations: list[Citation] + content: list[CitationText] + + +@dataclass(kw_only=True) +class ToolUseBlock: + """Model's request to call a tool.""" + name: str + tool_use_id: str + input: str + reasoning_signature: str | None + + +class ToolResultStatus(str): + """Whether a tool invocation succeeded. Richer classification lives on `tools.tool-error`.""" + __slots__ = () + + SUCCESS: ToolResultStatus + ERROR: ToolResultStatus + +ToolResultStatus.SUCCESS = ToolResultStatus('success') # type: ignore[attr-defined] +ToolResultStatus.ERROR = ToolResultStatus('error') # type: ignore[attr-defined] + + +@dataclass(kw_only=True) +class JsonBlock: + """Structured JSON payload. Used for tool results and agent-as-tool +outputs that carry schema-validated data, not prose.""" + json: str + + +class ToolResultContent: + """Block valid inside `tool-result-block.content`. Narrower than `content-block`.""" + + class Text(_WitVariantCase): + """Text output.""" + tag = 'text' + + class Json(_WitVariantCase): + """Structured JSON output.""" + tag = 'json' + + class Image(_WitVariantCase): + """Image output.""" + tag = 'image' + + class Video(_WitVariantCase): + """Video output.""" + tag = 'video' + + class Document(_WitVariantCase): + """Document output.""" + tag = 'document' + + _CASES: dict[str, type] = { + 'text': Text, + 'json': Json, + 'image': Image, + 'video': Video, + 'document': Document, + } + + @staticmethod + def lift(raw: _WitVariant) -> ToolResultContent: + cls = ToolResultContent._CASES.get(raw.tag) + if cls is None: + raise ValueError(f'unknown ToolResultContent arm: {raw.tag!r}') + return cls(raw.payload) + +@dataclass(kw_only=True) +class ToolResultBlock: + """Outcome of a tool execution.""" + tool_use_id: str + status: ToolResultStatus + content: list[ToolResultContent] + + +@dataclass(kw_only=True) +class InterruptResponseBlock: + """User response to a previously-raised interrupt. Supplied on the +next invocation to resume the paused agent.""" + interrupt_id: str + response: str + + +class ContentBlock: + """Any block that can appear inside a message.""" + + class Text(_WitVariantCase): + """Plain text.""" + tag = 'text' + + class Json(_WitVariantCase): + """Structured JSON payload.""" + tag = 'json' + + class ToolUse(_WitVariantCase): + """Model requested a tool call.""" + tag = 'tool-use' + + class ToolResult(_WitVariantCase): + """Tool call completed.""" + tag = 'tool-result' + + class Reasoning(_WitVariantCase): + """Model reasoning.""" + tag = 'reasoning' + + class CachePoint(_WitVariantCase): + """Caching boundary marker.""" + tag = 'cache-point' + + class GuardContent(_WitVariantCase): + """Content submitted for guardrail evaluation.""" + tag = 'guard-content' + + class Image(_WitVariantCase): + """Image.""" + tag = 'image' + + class Video(_WitVariantCase): + """Video.""" + tag = 'video' + + class Document(_WitVariantCase): + """Document.""" + tag = 'document' + + class Citations(_WitVariantCase): + """Citations emitted by the model.""" + tag = 'citations' + + class InterruptResponse(_WitVariantCase): + """Response to a prior interrupt, supplied when resuming.""" + tag = 'interrupt-response' + + _CASES: dict[str, type] = { + 'text': Text, + 'json': Json, + 'tool-use': ToolUse, + 'tool-result': ToolResult, + 'reasoning': Reasoning, + 'cache-point': CachePoint, + 'guard-content': GuardContent, + 'image': Image, + 'video': Video, + 'document': Document, + 'citations': Citations, + 'interrupt-response': InterruptResponse, + } + + @staticmethod + def lift(raw: _WitVariant) -> ContentBlock: + cls = ContentBlock._CASES.get(raw.tag) + if cls is None: + raise ValueError(f'unknown ContentBlock arm: {raw.tag!r}') + return cls(raw.payload) + +class Role(str): + """Who a message is from.""" + __slots__ = () + + USER: Role + ASSISTANT: Role + +Role.USER = Role('user') # type: ignore[attr-defined] +Role.ASSISTANT = Role('assistant') # type: ignore[attr-defined] + + +@dataclass(kw_only=True) +class Usage: + """Token consumption for a model invocation.""" + input_tokens: int + output_tokens: int + total_tokens: int + cache_read_input_tokens: int | None + cache_write_input_tokens: int | None + + +@dataclass(kw_only=True) +class Metrics: + """Performance metrics for a model invocation.""" + latency_ms: float + + +@dataclass(kw_only=True) +class MessageMetadata: + """Metadata attached to a message. Not sent to model providers; persisted +alongside the message for bookkeeping.""" + usage: Usage | None + metrics: Metrics | None + custom: str | None + + +@dataclass(kw_only=True) +class Message: + """A complete message in a conversation.""" + role: Role + content: list[ContentBlock] + metadata: MessageMetadata | None + + +PromptInput = str | list[ContentBlock] +"""A prompt-style input: either prose or structured content. Used for +both system prompts and user input.""" diff --git a/strands-py-wasm/src/strands/_generated/strands_agent/model_provider.py b/strands-py-wasm/src/strands/_generated/strands_agent/model_provider.py new file mode 100644 index 0000000000..b9a974c5a4 --- /dev/null +++ b/strands-py-wasm/src/strands/_generated/strands_agent/model_provider.py @@ -0,0 +1,55 @@ +"""Auto-generated by bindgen. Do not edit.""" +# ruff: noqa: E501, F401, I001 +# fmt: off + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Optional, Union + +from wasmtime.component import Variant as _WitVariant +from wasmtime.component import VariantCase as _WitVariantCase + +from .messages import Message, PromptInput +from .models import ModelError +from .streaming import StreamEvent +from .tools import ToolChoice, ToolSpec + + +class ModelEventStream: + """Pull-based stream of model events from a custom provider; host produces, guest reads.""" + # Wraps a wasmtime-py ResourceAny / ResourceHost handle. + # The runtime sets ._handle to the underlying resource and + # ._invoke to a callable that dispatches a method by WIT name. + + def __init__(self, handle: Any, invoke: Any = None) -> None: + self._handle = handle + self._invoke = invoke + + def read(self) -> StreamEvent | None: + return self._invoke('[method]model-event-stream.read', (self._handle,)) + + +@dataclass(kw_only=True) +class ModelStreamOptions: + """Options passed alongside the messages on each streaming call.""" + system_prompt: PromptInput | None + tools: list[ToolSpec] | None + tool_choice: ToolChoice | None + + +@dataclass(kw_only=True) +class StartStreamArgs: + """Arguments for `start-stream`.""" + provider_id: str + messages: list[Message] + options: ModelStreamOptions + + +@dataclass(kw_only=True) +class CountTokensArgs: + """Arguments for `count-tokens`.""" + provider_id: str + messages: list[Message] + system_prompt: PromptInput | None + tools: list[ToolSpec] | None diff --git a/strands-py-wasm/src/strands/_generated/strands_agent/models.py b/strands-py-wasm/src/strands/_generated/strands_agent/models.py new file mode 100644 index 0000000000..f124c2e707 --- /dev/null +++ b/strands-py-wasm/src/strands/_generated/strands_agent/models.py @@ -0,0 +1,162 @@ +"""Auto-generated by bindgen. Do not edit.""" +# ruff: noqa: E501, F401, I001 +# fmt: off + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Optional, Union + +from wasmtime.component import Variant as _WitVariant +from wasmtime.component import VariantCase as _WitVariantCase + + +@dataclass(kw_only=True) +class AnthropicModel: + """Anthropic API model configuration.""" + model_id: str | None + api_key: str | None + additional_config: str | None + + +@dataclass(kw_only=True) +class BedrockModel: + """AWS Bedrock model configuration.""" + model_id: str + region: str | None + access_key_id: str | None + secret_access_key: str | None + session_token: str | None + additional_config: str | None + + +@dataclass(kw_only=True) +class OpenaiModel: + """OpenAI API model configuration.""" + model_id: str | None + api_key: str | None + additional_config: str | None + + +@dataclass(kw_only=True) +class GoogleModel: + """Google Gemini API model configuration.""" + model_id: str | None + api_key: str | None + additional_config: str | None + + +@dataclass(kw_only=True) +class CustomModel: + """Custom model provider supplied by your application.""" + provider_id: str + model_id: str | None + additional_config: str | None + stateful: bool + + +class ModelConfig: + """Which model provider the agent should use.""" + + class Anthropic(_WitVariantCase): + """Anthropic API.""" + tag = 'anthropic' + + class Bedrock(_WitVariantCase): + """AWS Bedrock.""" + tag = 'bedrock' + + class Openai(_WitVariantCase): + """OpenAI API.""" + tag = 'openai' + + class Gemini(_WitVariantCase): + """Google Gemini API.""" + tag = 'gemini' + + class Custom(_WitVariantCase): + """Custom provider supplied by your application. Implement the +`model-provider` interface to serve it.""" + tag = 'custom' + + _CASES: dict[str, type] = { + 'anthropic': Anthropic, + 'bedrock': Bedrock, + 'openai': Openai, + 'gemini': Gemini, + 'custom': Custom, + } + + @staticmethod + def lift(raw: _WitVariant) -> ModelConfig: + cls = ModelConfig._CASES.get(raw.tag) + if cls is None: + raise ValueError(f'unknown ModelConfig arm: {raw.tag!r}') + return cls(raw.payload) + +@dataclass(kw_only=True) +class ModelParams: + """Sampling parameters applied to every call on the chosen provider.""" + max_tokens: int | None + temperature: float | None + top_p: float | None + + +class ModelError: + """Why a model call failed. Retry logic keys off of which arm fires, so +implementations should pick the narrowest one that fits.""" + + class UnknownProvider(_WitVariantCase): + """No provider registered for the given `provider-id`.""" + tag = 'unknown-provider' + + class InvalidRequest(_WitVariantCase): + """Provider refused the request due to malformed input.""" + tag = 'invalid-request' + + class Unauthorized(_WitVariantCase): + """Caller lacks permission (missing or expired credentials).""" + tag = 'unauthorized' + + class Throttled(_WitVariantCase): + """Provider returned a rate-limit error. Retry after a backoff.""" + tag = 'throttled' + + class ServerError(_WitVariantCase): + """Provider returned a server-side error. Retry may succeed.""" + tag = 'server-error' + + class ContextWindowExceeded(_WitVariantCase): + """Request exceeded the model's context window.""" + tag = 'context-window-exceeded' + + class ContentFiltered(_WitVariantCase): + """Content was rejected by provider safety policy.""" + tag = 'content-filtered' + + class Transient(_WitVariantCase): + """Transient network or transport failure. Retry may succeed.""" + tag = 'transient' + + class Internal(_WitVariantCase): + """Catch-all for internal failures.""" + tag = 'internal' + + _CASES: dict[str, type] = { + 'unknown-provider': UnknownProvider, + 'invalid-request': InvalidRequest, + 'unauthorized': Unauthorized, + 'throttled': Throttled, + 'server-error': ServerError, + 'context-window-exceeded': ContextWindowExceeded, + 'content-filtered': ContentFiltered, + 'transient': Transient, + 'internal': Internal, + } + + @staticmethod + def lift(raw: _WitVariant) -> ModelError: + cls = ModelError._CASES.get(raw.tag) + if cls is None: + raise ValueError(f'unknown ModelError arm: {raw.tag!r}') + return cls(raw.payload) diff --git a/strands-py-wasm/src/strands/_generated/strands_agent/multi_agent.py b/strands-py-wasm/src/strands/_generated/strands_agent/multi_agent.py new file mode 100644 index 0000000000..0a78fdc7c3 --- /dev/null +++ b/strands-py-wasm/src/strands/_generated/strands_agent/multi_agent.py @@ -0,0 +1,269 @@ +"""Auto-generated by bindgen. Do not edit.""" +# ruff: noqa: E501, F401, I001 +# fmt: off + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Optional, Union + +from wasmtime.component import Variant as _WitVariant +from wasmtime.component import VariantCase as _WitVariantCase + +from .messages import ContentBlock, Metrics, PromptInput, Usage +from .streaming import StreamEvent + + +Duration = int + +class OrchestrationStatus(str): + """Lifecycle status of a node or overall run.""" + __slots__ = () + + PENDING: OrchestrationStatus + EXECUTING: OrchestrationStatus + COMPLETED: OrchestrationStatus + FAILED: OrchestrationStatus + CANCELLED: OrchestrationStatus + +OrchestrationStatus.PENDING = OrchestrationStatus('pending') # type: ignore[attr-defined] +OrchestrationStatus.EXECUTING = OrchestrationStatus('executing') # type: ignore[attr-defined] +OrchestrationStatus.COMPLETED = OrchestrationStatus('completed') # type: ignore[attr-defined] +OrchestrationStatus.FAILED = OrchestrationStatus('failed') # type: ignore[attr-defined] +OrchestrationStatus.CANCELLED = OrchestrationStatus('cancelled') # type: ignore[attr-defined] + + +class TerminalStatus(str): + """Terminal status of a node or run.""" + __slots__ = () + + COMPLETED: TerminalStatus + FAILED: TerminalStatus + CANCELLED: TerminalStatus + +TerminalStatus.COMPLETED = TerminalStatus('completed') # type: ignore[attr-defined] +TerminalStatus.FAILED = TerminalStatus('failed') # type: ignore[attr-defined] +TerminalStatus.CANCELLED = TerminalStatus('cancelled') # type: ignore[attr-defined] + + +class NodeKind(str): + """What a node is.""" + __slots__ = () + + AGENT: NodeKind + MULTI_AGENT: NodeKind + +NodeKind.AGENT = NodeKind('agent') # type: ignore[attr-defined] +NodeKind.MULTI_AGENT = NodeKind('multi-agent') # type: ignore[attr-defined] + + +@dataclass(kw_only=True) +class AgentNode: + """Definition of an agent-backed node.""" + id: str + description: str | None + timeout: int | None + agent_config: str + + +@dataclass(kw_only=True) +class MultiAgentNode: + """Definition of a node that wraps another orchestrator.""" + id: str + description: str | None + orchestrator: str + + +class NodeConfig: + """Any node a graph or swarm can execute.""" + + class Agent(_WitVariantCase): + """Wraps a single agent.""" + tag = 'agent' + + class MultiAgent(_WitVariantCase): + """Wraps a nested orchestrator.""" + tag = 'multi-agent' + + _CASES: dict[str, type] = { + 'agent': Agent, + 'multi-agent': MultiAgent, + } + + @staticmethod + def lift(raw: _WitVariant) -> NodeConfig: + cls = NodeConfig._CASES.get(raw.tag) + if cls is None: + raise ValueError(f'unknown NodeConfig arm: {raw.tag!r}') + return cls(raw.payload) + +@dataclass(kw_only=True) +class EdgeHandler: + """Condition attached to a graph edge.""" + handler_id: str + + +@dataclass(kw_only=True) +class EdgeConfig: + """Edge connecting two graph nodes.""" + source: str + target: str + handler: EdgeHandler | None + + +@dataclass(kw_only=True) +class GraphConfig: + """Runtime configuration for a Graph.""" + id: str + nodes: list[NodeConfig] + edges: list[EdgeConfig] + sources: list[str] + max_concurrency: int | None + max_steps: int | None + timeout: int | None + node_timeout: int | None + + +@dataclass(kw_only=True) +class SwarmConfig: + """Runtime configuration for a Swarm.""" + id: str + nodes: list[AgentNode] + start_node_id: str + max_steps: int | None + timeout: int | None + node_timeout: int | None + + +class NodeError: + """Why a node or run ended in `failed` status.""" + + class Execution(_WitVariantCase): + """An underlying agent or nested orchestrator failed.""" + tag = 'execution' + + class Timeout(_WitVariantCase): + """Wall-clock ceiling was exceeded.""" + tag = 'timeout' + + class LimitExceeded(_WitVariantCase): + """A declared runtime limit (max-steps, max-concurrency) was hit.""" + tag = 'limit-exceeded' + + class EdgeHandler(_WitVariantCase): + """Edge handler rejected the traversal with an error.""" + tag = 'edge-handler' + + class InvalidConfig(_WitVariantCase): + """Invalid configuration detected at run time.""" + tag = 'invalid-config' + + class Internal(_WitVariantCase): + """Catch-all for internal failures.""" + tag = 'internal' + + _CASES: dict[str, type] = { + 'execution': Execution, + 'timeout': Timeout, + 'limit-exceeded': LimitExceeded, + 'edge-handler': EdgeHandler, + 'invalid-config': InvalidConfig, + 'internal': Internal, + } + + @staticmethod + def lift(raw: _WitVariant) -> NodeError: + cls = NodeError._CASES.get(raw.tag) + if cls is None: + raise ValueError(f'unknown NodeError arm: {raw.tag!r}') + return cls(raw.payload) + +@dataclass(kw_only=True) +class NodeResult: + """Result of a single node execution.""" + node_id: str + status: TerminalStatus + duration: int + content: list[ContentBlock] + error: NodeError | None + structured_output: str | None + usage: Usage | None + metrics: Metrics | None + + +@dataclass(kw_only=True) +class MultiAgentResult: + """Final result of a graph or swarm run.""" + status: TerminalStatus + nodes: list[NodeResult] + duration: int + usage: Usage | None + metrics: Metrics | None + + +@dataclass(kw_only=True) +class MultiAgentInvokeArgs: + """Arguments for invoking a graph or swarm.""" + input: PromptInput + invocation_state: str | None + + +@dataclass(kw_only=True) +class NodeStartData: + """Payload for `node-start`.""" + node_id: str + kind: NodeKind + + +@dataclass(kw_only=True) +class NodeEventData: + """Payload for `node-event`. Carries a nested stream event from a +running node.""" + node_id: str + event: StreamEvent + + +@dataclass(kw_only=True) +class HandoffEvent: + """Payload for a handoff edge firing.""" + from_node_ids: list[str] + to_node_ids: list[str] + + +class MultiAgentStreamEvent: + """Events emitted while streaming a multi-agent run.""" + + class NodeStart(_WitVariantCase): + """A node began executing.""" + tag = 'node-start' + + class Nested(_WitVariantCase): + """A nested stream event from a running node.""" + tag = 'nested' + + class NodeStop(_WitVariantCase): + """A node finished executing.""" + tag = 'node-stop' + + class Handoff(_WitVariantCase): + """A handoff happened between nodes.""" + tag = 'handoff' + + class RunComplete(_WitVariantCase): + """Terminal result for the run.""" + tag = 'run-complete' + + _CASES: dict[str, type] = { + 'node-start': NodeStart, + 'nested': Nested, + 'node-stop': NodeStop, + 'handoff': Handoff, + 'run-complete': RunComplete, + } + + @staticmethod + def lift(raw: _WitVariant) -> MultiAgentStreamEvent: + cls = MultiAgentStreamEvent._CASES.get(raw.tag) + if cls is None: + raise ValueError(f'unknown MultiAgentStreamEvent arm: {raw.tag!r}') + return cls(raw.payload) diff --git a/strands-py-wasm/src/strands/_generated/strands_agent/retry.py b/strands-py-wasm/src/strands/_generated/strands_agent/retry.py new file mode 100644 index 0000000000..a223d48831 --- /dev/null +++ b/strands-py-wasm/src/strands/_generated/strands_agent/retry.py @@ -0,0 +1,97 @@ +"""Auto-generated by bindgen. Do not edit.""" +# ruff: noqa: E501, F401, I001 +# fmt: off + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Optional, Union + +from wasmtime.component import Variant as _WitVariant +from wasmtime.component import VariantCase as _WitVariantCase + + +Duration = int + +class JitterKind(str): + """How much random variation to apply to computed delays.""" + __slots__ = () + + NONE: JitterKind + FULL: JitterKind + EQUAL: JitterKind + DECORRELATED: JitterKind + +JitterKind.NONE = JitterKind('none') # type: ignore[attr-defined] +JitterKind.FULL = JitterKind('full') # type: ignore[attr-defined] +JitterKind.EQUAL = JitterKind('equal') # type: ignore[attr-defined] +JitterKind.DECORRELATED = JitterKind('decorrelated') # type: ignore[attr-defined] + + +@dataclass(kw_only=True) +class ConstantBackoff: + """Fixed delay between attempts.""" + delay: int + + +@dataclass(kw_only=True) +class LinearBackoff: + """Delay grows linearly with attempt number.""" + base: int + max: int + jitter: JitterKind + + +@dataclass(kw_only=True) +class ExponentialBackoff: + """Delay grows exponentially with attempt number.""" + base: int + max: int + factor: float + jitter: JitterKind + + +class BackoffStrategy: + """Backoff curve applied between attempts.""" + + class Constant(_WitVariantCase): + """Fixed delay.""" + tag = 'constant' + + class Linear(_WitVariantCase): + """Linear growth.""" + tag = 'linear' + + class Exponential(_WitVariantCase): + """Exponential growth.""" + tag = 'exponential' + + _CASES: dict[str, type] = { + 'constant': Constant, + 'linear': Linear, + 'exponential': Exponential, + } + + @staticmethod + def lift(raw: _WitVariant) -> BackoffStrategy: + cls = BackoffStrategy._CASES.get(raw.tag) + if cls is None: + raise ValueError(f'unknown BackoffStrategy arm: {raw.tag!r}') + return cls(raw.payload) + +@dataclass(kw_only=True) +class ModelRetryStrategy: + """A single retry strategy. Default is exponential backoff, full jitter, 6 attempts.""" + max_attempts: int + backoff: BackoffStrategy + total_budget: int | None + + +@dataclass(kw_only=True) +class RetryConfig: + """Retry configuration attached to an agent. +Every strategy observes every failure; the first to request a delay wins. +Empty list disables retries; omitting `agent-config.retry` applies a default +single exponential strategy. Two strategies with the same `backoff` arm +surface as `agent-error::invalid-input`.""" + strategies: list[ModelRetryStrategy] diff --git a/strands-py-wasm/src/strands/_generated/strands_agent/sessions.py b/strands-py-wasm/src/strands/_generated/strands_agent/sessions.py new file mode 100644 index 0000000000..b1e787f44e --- /dev/null +++ b/strands-py-wasm/src/strands/_generated/strands_agent/sessions.py @@ -0,0 +1,253 @@ +"""Auto-generated by bindgen. Do not edit.""" +# ruff: noqa: E501, F401, I001 +# fmt: off + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Optional, Union + +from wasmtime.component import Variant as _WitVariant +from wasmtime.component import VariantCase as _WitVariantCase + +from .messages import Message +from ..wasi_clocks.wall_clock import Datetime + + +@dataclass(kw_only=True) +class FileStorage: + """Local filesystem snapshot storage.""" + base_dir: str + + +@dataclass(kw_only=True) +class S3Storage: + """S3 snapshot storage.""" + bucket: str + region: str | None + prefix: str | None + + +@dataclass(kw_only=True) +class CustomStorage: + """Reference to an application-implemented storage backend.""" + backend_id: str + + +class StorageConfig: + """Where to persist session snapshots.""" + + class File(_WitVariantCase): + """Local filesystem.""" + tag = 'file' + + class S3(_WitVariantCase): + """Amazon S3.""" + tag = 's3' + + class Custom(_WitVariantCase): + """Application-implemented backend.""" + tag = 'custom' + + _CASES: dict[str, type] = { + 'file': File, + 's3': S3, + 'custom': Custom, + } + + @staticmethod + def lift(raw: _WitVariant) -> StorageConfig: + cls = StorageConfig._CASES.get(raw.tag) + if cls is None: + raise ValueError(f'unknown StorageConfig arm: {raw.tag!r}') + return cls(raw.payload) + +class SaveLatestPolicy: + """When to update the "latest" snapshot pointer. The `trigger` arm +carries the id of an application-supplied callback that decides +per-invocation.""" + + class Message(_WitVariantCase): + """After every message added to the conversation.""" + tag = 'message' + + class Invocation(_WitVariantCase): + """Once per invocation, after it completes.""" + tag = 'invocation' + + class Trigger(_WitVariantCase): + """Each invocation consults the named `snapshot-trigger-handler`. +The id identifies which handler to invoke.""" + tag = 'trigger' + + _CASES: dict[str, type] = { + 'message': Message, + 'invocation': Invocation, + 'trigger': Trigger, + } + + @staticmethod + def lift(raw: _WitVariant) -> SaveLatestPolicy: + cls = SaveLatestPolicy._CASES.get(raw.tag) + if cls is None: + raise ValueError(f'unknown SaveLatestPolicy arm: {raw.tag!r}') + return cls(raw.payload) + +@dataclass(kw_only=True) +class SessionManager: + """Session persistence configuration attached to an agent.""" + session_id: str + storage: StorageConfig + save_latest: SaveLatestPolicy | None + + +class SnapshotScope(str): + """Which kind of state a snapshot describes.""" + __slots__ = () + + AGENT: SnapshotScope + MULTI_AGENT: SnapshotScope + +SnapshotScope.AGENT = SnapshotScope('agent') # type: ignore[attr-defined] +SnapshotScope.MULTI_AGENT = SnapshotScope('multi-agent') # type: ignore[attr-defined] + + +@dataclass(kw_only=True) +class SnapshotLocation: + """Locator for a snapshot within the storage hierarchy.""" + session_id: str + scope: SnapshotScope + scope_id: str + + +@dataclass(kw_only=True) +class SlidingWindowState: + """Sliding-window conversation manager state at snapshot time.""" + removed_message_count: int + + +@dataclass(kw_only=True) +class SummarizingState: + """Summarizing conversation manager state at snapshot time.""" + summary_message: Message | None + removed_message_count: int + + +class ConversationManagerState: + """Conversation manager snapshot state. Wrapped in ``option<>`` at the +call site; ``none`` means there's no manager and nothing to persist.""" + + class SlidingWindow(_WitVariantCase): + """Sliding-window manager state.""" + tag = 'sliding-window' + + class Summarizing(_WitVariantCase): + """Summarizing manager state.""" + tag = 'summarizing' + + _CASES: dict[str, type] = { + 'sliding-window': SlidingWindow, + 'summarizing': Summarizing, + } + + @staticmethod + def lift(raw: _WitVariant) -> ConversationManagerState: + cls = ConversationManagerState._CASES.get(raw.tag) + if cls is None: + raise ValueError(f'unknown ConversationManagerState arm: {raw.tag!r}') + return cls(raw.payload) + +@dataclass(kw_only=True) +class RetryStrategyState: + """Retry-strategy state at snapshot time.""" + attempts_used: int + elapsed_ms: int + + +@dataclass(kw_only=True) +class PluginStateEntry: + """Named plugin state. `data` is an opaque JSON object owned by the plugin.""" + plugin_name: str + data: str + + +@dataclass(kw_only=True) +class SnapshotData: + """Framework-owned snapshot state. All fields are optional because an +agent may not exercise every subsystem in a given run.""" + messages: list[Message] + conversation_manager: ConversationManagerState | None + retry_strategy: RetryStrategyState | None + model_state: str | None + plugins: list[PluginStateEntry] + + +@dataclass(kw_only=True) +class Snapshot: + """Point-in-time capture of agent or orchestrator state.""" + scope: SnapshotScope + schema_version: str + created_at: Datetime + data: SnapshotData + app_data: str + + +@dataclass(kw_only=True) +class SnapshotManifest: + """Metadata describing the snapshot manifest file.""" + schema_version: str + updated_at: Datetime + + +class StorageError: + """Why a snapshot operation failed.""" + + class NotFound(_WitVariantCase): + """No snapshot or manifest at the requested location.""" + tag = 'not-found' + + class AccessDenied(_WitVariantCase): + """Caller lacks permission to read or write the storage.""" + tag = 'access-denied' + + class OutOfSpace(_WitVariantCase): + """Backing storage is full or over quota.""" + tag = 'out-of-space' + + class Corrupt(_WitVariantCase): + """Snapshot is malformed or cannot be deserialized.""" + tag = 'corrupt' + + class Conflict(_WitVariantCase): + """Concurrent writers collided; retrying may succeed.""" + tag = 'conflict' + + class Transient(_WitVariantCase): + """Transient I/O failure; retrying may succeed.""" + tag = 'transient' + + class Permanent(_WitVariantCase): + """Permanent backend failure.""" + tag = 'permanent' + + class UnknownBackend(_WitVariantCase): + """No custom backend registered for the given backend-id.""" + tag = 'unknown-backend' + + _CASES: dict[str, type] = { + 'not-found': NotFound, + 'access-denied': AccessDenied, + 'out-of-space': OutOfSpace, + 'corrupt': Corrupt, + 'conflict': Conflict, + 'transient': Transient, + 'permanent': Permanent, + 'unknown-backend': UnknownBackend, + } + + @staticmethod + def lift(raw: _WitVariant) -> StorageError: + cls = StorageError._CASES.get(raw.tag) + if cls is None: + raise ValueError(f'unknown StorageError arm: {raw.tag!r}') + return cls(raw.payload) diff --git a/strands-py-wasm/src/strands/_generated/strands_agent/snapshot_storage.py b/strands-py-wasm/src/strands/_generated/strands_agent/snapshot_storage.py new file mode 100644 index 0000000000..cb0a972af0 --- /dev/null +++ b/strands-py-wasm/src/strands/_generated/strands_agent/snapshot_storage.py @@ -0,0 +1,62 @@ +"""Auto-generated by bindgen. Do not edit.""" +# ruff: noqa: E501, F401, I001 +# fmt: off + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Optional, Union + +from wasmtime.component import Variant as _WitVariant +from wasmtime.component import VariantCase as _WitVariantCase + +from .sessions import Snapshot, SnapshotLocation, SnapshotManifest, StorageError + + +@dataclass(kw_only=True) +class SaveSnapshotArgs: + """Arguments for `save-snapshot`.""" + backend_id: str + location: SnapshotLocation + snapshot_id: str + is_latest: bool + snapshot: Snapshot + + +@dataclass(kw_only=True) +class LoadSnapshotArgs: + """Arguments for `load-snapshot`.""" + backend_id: str + location: SnapshotLocation + snapshot_id: str | None + + +@dataclass(kw_only=True) +class ListSnapshotIdsArgs: + """Arguments for `list-snapshot-ids`.""" + backend_id: str + location: SnapshotLocation + limit: int | None + start_after: str | None + + +@dataclass(kw_only=True) +class DeleteSessionArgs: + """Arguments for `delete-session`.""" + backend_id: str + session_id: str + + +@dataclass(kw_only=True) +class ManifestArgs: + """Arguments for `load-manifest` / `save-manifest`.""" + backend_id: str + location: SnapshotLocation + + +@dataclass(kw_only=True) +class SaveManifestArgs: + """Arguments for `save-manifest`.""" + backend_id: str + location: SnapshotLocation + manifest: SnapshotManifest diff --git a/strands-py-wasm/src/strands/_generated/strands_agent/snapshot_trigger_handler.py b/strands-py-wasm/src/strands/_generated/strands_agent/snapshot_trigger_handler.py new file mode 100644 index 0000000000..5bf52b2da2 --- /dev/null +++ b/strands-py-wasm/src/strands/_generated/strands_agent/snapshot_trigger_handler.py @@ -0,0 +1,45 @@ +"""Auto-generated by bindgen. Do not edit.""" +# ruff: noqa: E501, F401, I001 +# fmt: off + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Optional, Union + +from wasmtime.component import Variant as _WitVariant +from wasmtime.component import VariantCase as _WitVariantCase + +from .messages import Message + + +@dataclass(kw_only=True) +class TriggerParams: + """Context passed to the trigger on each call.""" + trigger_id: str + message_count: int + last_message: Message | None + + +class TriggerError: + """Why a trigger evaluation failed.""" + + class Unknown(_WitVariantCase): + """No trigger registered for the given id.""" + tag = 'unknown' + + class Failed(_WitVariantCase): + """Trigger raised an exception.""" + tag = 'failed' + + _CASES: dict[str, type] = { + 'unknown': Unknown, + 'failed': Failed, + } + + @staticmethod + def lift(raw: _WitVariant) -> TriggerError: + cls = TriggerError._CASES.get(raw.tag) + if cls is None: + raise ValueError(f'unknown TriggerError arm: {raw.tag!r}') + return cls(raw.payload) diff --git a/strands-py-wasm/src/strands/_generated/strands_agent/streaming.py b/strands-py-wasm/src/strands/_generated/strands_agent/streaming.py new file mode 100644 index 0000000000..fd4002874e --- /dev/null +++ b/strands-py-wasm/src/strands/_generated/strands_agent/streaming.py @@ -0,0 +1,443 @@ +"""Auto-generated by bindgen. Do not edit.""" +# ruff: noqa: E501, F401, I001 +# fmt: off + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Optional, Union + +from wasmtime.component import Variant as _WitVariant +from wasmtime.component import VariantCase as _WitVariantCase + +from .messages import ContentBlock, Message, Metrics, ToolResultBlock, ToolUseBlock, Usage +from .models import ModelError +from .tools import ToolError + + +@dataclass(kw_only=True) +class Interrupt: + """Human-in-the-loop interrupt raised by a tool or hook.""" + id: str + name: str + reason: str | None + + +class StopReason(str): + """Why the model stopped generating.""" + __slots__ = () + + END_TURN: StopReason + TOOL_USE: StopReason + MAX_TOKENS: StopReason + ERROR: StopReason + CONTENT_FILTERED: StopReason + GUARDRAIL_INTERVENED: StopReason + STOP_SEQUENCE: StopReason + MODEL_CONTEXT_WINDOW_EXCEEDED: StopReason + CANCELLED: StopReason + +StopReason.END_TURN = StopReason('end-turn') # type: ignore[attr-defined] +StopReason.TOOL_USE = StopReason('tool-use') # type: ignore[attr-defined] +StopReason.MAX_TOKENS = StopReason('max-tokens') # type: ignore[attr-defined] +StopReason.ERROR = StopReason('error') # type: ignore[attr-defined] +StopReason.CONTENT_FILTERED = StopReason('content-filtered') # type: ignore[attr-defined] +StopReason.GUARDRAIL_INTERVENED = StopReason('guardrail-intervened') # type: ignore[attr-defined] +StopReason.STOP_SEQUENCE = StopReason('stop-sequence') # type: ignore[attr-defined] +StopReason.MODEL_CONTEXT_WINDOW_EXCEEDED = StopReason('model-context-window-exceeded') # type: ignore[attr-defined] +StopReason.CANCELLED = StopReason('cancelled') # type: ignore[attr-defined] + + +@dataclass(kw_only=True) +class MetadataEvent: + """Usage and metrics accumulated so far.""" + usage: Usage | None + metrics: Metrics | None + + +@dataclass(kw_only=True) +class TraceMetadataEntry: + """Single key-value pair attached to a trace. Values are string-typed +to keep traces compact; structured payloads belong on `message`.""" + key: str + value: str + + +@dataclass(kw_only=True) +class AgentTrace: + """In-memory trace node. Returned flat; reconstruct the tree via `parent-id`.""" + id: str + name: str + parent_id: str | None + start_time_ms: int + end_time_ms: int | None + duration_ms: int + metadata: list[TraceMetadataEntry] + message: Message | None + + +@dataclass(kw_only=True) +class ToolMetrics: + """Per-tool execution metrics keyed by tool name in `agent-metrics`.""" + tool_name: str + call_count: int + success_count: int + error_count: int + total_time_ms: int + + +@dataclass(kw_only=True) +class InvocationMetrics: + """Per-invocation metrics. Cycles are flattened into `agent-metrics.cycles` +and linked back via `invocation-id`.""" + invocation_id: str + usage: Usage + + +@dataclass(kw_only=True) +class AgentLoopMetrics: + """Per-cycle usage tracking.""" + cycle_id: str + invocation_id: str + duration_ms: int + usage: Usage + + +@dataclass(kw_only=True) +class AgentMetrics: + """Snapshot of agent metrics. Returned by `agent.get-metrics`.""" + cycle_count: int + accumulated_usage: Usage + accumulated_metrics: Metrics + invocations: list[InvocationMetrics] + cycles: list[AgentLoopMetrics] + tool_metrics: list[ToolMetrics] + latest_context_size: int | None + projected_context_size: int | None + + +@dataclass(kw_only=True) +class ToolUseData: + """Mutable tool-use descriptor. `before-tool-call` hooks may rewrite fields.""" + name: str + tool_use_id: str + input: str + + +@dataclass(kw_only=True) +class HookRedaction: + """Redaction information when guardrails block content.""" + user_message: str + + +@dataclass(kw_only=True) +class ModelStopData: + """Model response surfaced on `after-model-call`.""" + message: Message + stop_reason: StopReason + redaction: HookRedaction | None + + +@dataclass(kw_only=True) +class BeforeInvocationData: + """Payload for `before-invocation`.""" + invocation_state: str + + +@dataclass(kw_only=True) +class AfterInvocationData: + """Payload for `after-invocation`.""" + invocation_state: str + + +@dataclass(kw_only=True) +class MessageAddedData: + """Payload for `message-added`.""" + message: Message + + +@dataclass(kw_only=True) +class BeforeModelCallData: + """Payload for `before-model-call`.""" + projected_input_tokens: int | None + + +@dataclass(kw_only=True) +class AfterModelCallData: + """Payload for `after-model-call`.""" + attempt_count: int + stop_data: ModelStopData | None + error: ModelError | None + + +@dataclass(kw_only=True) +class BeforeToolCallData: + """Payload for `before-tool-call`.""" + tool_use: ToolUseData + + +@dataclass(kw_only=True) +class AfterToolCallData: + """Payload for `after-tool-call`.""" + tool_use: ToolUseData + tool_result: ToolResultBlock + error: ToolError | None + + +@dataclass(kw_only=True) +class ToolsBatchData: + """Payload for `before-tools` / `after-tools`.""" + message: Message + + +@dataclass(kw_only=True) +class ContentBlockData: + """Payload for `content-block`.""" + content_block: ContentBlock + + +@dataclass(kw_only=True) +class ModelMessageData: + """Payload for `model-message`.""" + message: Message + stop_reason: StopReason + + +@dataclass(kw_only=True) +class ToolResultData: + """Payload for `tool-result-hook`.""" + tool_result: ToolResultBlock + + +@dataclass(kw_only=True) +class ToolStreamUpdateData: + """Payload for `tool-stream-update`.""" + data: str + + +@dataclass(kw_only=True) +class ModelStreamUpdateData: + """Payload for `model-stream-update`.""" + event: str + + +@dataclass(kw_only=True) +class InputRedaction: + """Input redaction emitted when a guardrail blocks input. Original is in history.""" + replace_content: str + + +@dataclass(kw_only=True) +class OutputRedaction: + """Output redaction emitted when a guardrail blocks output.""" + redacted_content: str | None + replace_content: str + + +@dataclass(kw_only=True) +class RedactionEvent: + """Redaction event. Input and output fields are independent; at least one is set.""" + input_redaction: InputRedaction | None + output_redaction: OutputRedaction | None + + +@dataclass(kw_only=True) +class StopEvent: + """Terminal event for a stream.""" + reason: StopReason + usage: Usage | None + metrics: Metrics | None + structured_output: str | None + + +@dataclass(kw_only=True) +class AgentResultData: + """Payload for `agent-result`.""" + stop: StopEvent + + +class StreamError: + """Why the agent loop surfaced an error mid-stream.""" + + class Model(_WitVariantCase): + """A model call failed.""" + tag = 'model' + + class Tool(_WitVariantCase): + """A tool call failed.""" + tag = 'tool' + + class ContextWindowExceeded(_WitVariantCase): + """Input exceeded the model's context window and no conversation +manager could recover.""" + tag = 'context-window-exceeded' + + class MaxTokensReached(_WitVariantCase): + """Exceeded the model's max-tokens budget mid-response.""" + tag = 'max-tokens-reached' + + class StructuredOutputUnavailable(_WitVariantCase): + """Structured output was requested but the model never called the +tool, even after being forced.""" + tag = 'structured-output-unavailable' + + class Internal(_WitVariantCase): + """Catch-all for internal failures.""" + tag = 'internal' + + _CASES: dict[str, type] = { + 'model': Model, + 'tool': Tool, + 'context-window-exceeded': ContextWindowExceeded, + 'max-tokens-reached': MaxTokensReached, + 'structured-output-unavailable': StructuredOutputUnavailable, + 'internal': Internal, + } + + @staticmethod + def lift(raw: _WitVariant) -> StreamError: + cls = StreamError._CASES.get(raw.tag) + if cls is None: + raise ValueError(f'unknown StreamError arm: {raw.tag!r}') + return cls(raw.payload) + +class StreamEvent: + """Events yielded during agent streaming. +Hot-path arms: `text-delta`, `tool-use`, `tool-result`. Other content +blocks flow through `content`. Lifecycle arms (`before-invocation` +through `agent-result`) mirror a hook system and can be filtered by tag.""" + + class TextDelta(_WitVariantCase): + """Incremental text from the model.""" + tag = 'text-delta' + + class ToolUse(_WitVariantCase): + """Model requested a tool call.""" + tag = 'tool-use' + + class ToolResult(_WitVariantCase): + """Tool call completed.""" + tag = 'tool-result' + + class Content(_WitVariantCase): + """Non-hot-path content block (image, reasoning, citations, etc).""" + tag = 'content' + + class Metadata(_WitVariantCase): + """Cumulative usage and metrics snapshot.""" + tag = 'metadata' + + class Stop(_WitVariantCase): + """Terminal event for the stream.""" + tag = 'stop' + + class Redaction(_WitVariantCase): + """Guardrail redaction fired.""" + tag = 'redaction' + + class Error(_WitVariantCase): + """Recoverable error surfaced mid-stream.""" + tag = 'error' + + class Interrupt(_WitVariantCase): + """Human-in-the-loop pause; resume via `response-stream.respond`.""" + tag = 'interrupt' + + class Initialized(_WitVariantCase): + """Agent finished construction.""" + tag = 'initialized' + + class BeforeInvocation(_WitVariantCase): + """About to process a user invocation.""" + tag = 'before-invocation' + + class AfterInvocation(_WitVariantCase): + """Finished processing a user invocation.""" + tag = 'after-invocation' + + class MessageAdded(_WitVariantCase): + """A message was appended to the conversation.""" + tag = 'message-added' + + class BeforeModelCall(_WitVariantCase): + """About to call the model.""" + tag = 'before-model-call' + + class AfterModelCall(_WitVariantCase): + """Model call returned.""" + tag = 'after-model-call' + + class BeforeTools(_WitVariantCase): + """About to run a batch of tool calls from one assistant turn.""" + tag = 'before-tools' + + class AfterTools(_WitVariantCase): + """Tool batch finished.""" + tag = 'after-tools' + + class BeforeToolCall(_WitVariantCase): + """About to call a single tool.""" + tag = 'before-tool-call' + + class AfterToolCall(_WitVariantCase): + """Tool call returned.""" + tag = 'after-tool-call' + + class ContentBlock(_WitVariantCase): + """A content block was assembled during streaming.""" + tag = 'content-block' + + class ModelMessage(_WitVariantCase): + """Model finished producing a full message.""" + tag = 'model-message' + + class ToolResultHook(_WitVariantCase): + """Tool finished execution (completion event, not streaming update).""" + tag = 'tool-result-hook' + + class ToolUpdate(_WitVariantCase): + """Streaming update from a tool.""" + tag = 'tool-update' + + class ModelUpdate(_WitVariantCase): + """Streaming update from the model.""" + tag = 'model-update' + + class AgentResult(_WitVariantCase): + """Final event for an invocation, carrying the terminal result.""" + tag = 'agent-result' + + _CASES: dict[str, type] = { + 'text-delta': TextDelta, + 'tool-use': ToolUse, + 'tool-result': ToolResult, + 'content': Content, + 'metadata': Metadata, + 'stop': Stop, + 'redaction': Redaction, + 'error': Error, + 'interrupt': Interrupt, + 'initialized': Initialized, + 'before-invocation': BeforeInvocation, + 'after-invocation': AfterInvocation, + 'message-added': MessageAdded, + 'before-model-call': BeforeModelCall, + 'after-model-call': AfterModelCall, + 'before-tools': BeforeTools, + 'after-tools': AfterTools, + 'before-tool-call': BeforeToolCall, + 'after-tool-call': AfterToolCall, + 'content-block': ContentBlock, + 'model-message': ModelMessage, + 'tool-result-hook': ToolResultHook, + 'tool-update': ToolUpdate, + 'model-update': ModelUpdate, + 'agent-result': AgentResult, + } + + @staticmethod + def lift(raw: _WitVariant) -> StreamEvent: + cls = StreamEvent._CASES.get(raw.tag) + if cls is None: + raise ValueError(f'unknown StreamEvent arm: {raw.tag!r}') + return cls(raw.payload) diff --git a/strands-py-wasm/src/strands/_generated/strands_agent/tool_provider.py b/strands-py-wasm/src/strands/_generated/strands_agent/tool_provider.py new file mode 100644 index 0000000000..5d9726870d --- /dev/null +++ b/strands-py-wasm/src/strands/_generated/strands_agent/tool_provider.py @@ -0,0 +1,13 @@ +"""Auto-generated by bindgen. Do not edit.""" +# ruff: noqa: E501, F401, I001 +# fmt: off + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Optional, Union + +from wasmtime.component import Variant as _WitVariant +from wasmtime.component import VariantCase as _WitVariantCase + +from .tools import CallToolArgs, ToolEventStream diff --git a/strands-py-wasm/src/strands/_generated/strands_agent/tools.py b/strands-py-wasm/src/strands/_generated/strands_agent/tools.py new file mode 100644 index 0000000000..125487ce63 --- /dev/null +++ b/strands-py-wasm/src/strands/_generated/strands_agent/tools.py @@ -0,0 +1,130 @@ +"""Auto-generated by bindgen. Do not edit.""" +# ruff: noqa: E501, F401, I001 +# fmt: off + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Optional, Union + +from wasmtime.component import Variant as _WitVariant +from wasmtime.component import VariantCase as _WitVariantCase + +from .messages import ToolResultContent + + +@dataclass(kw_only=True) +class ToolSpec: + """Declaration of a tool the model can call.""" + name: str + description: str + input_schema: str + + +@dataclass(kw_only=True) +class AgentAsToolConfig: + """Wrap a configured agent as a tool callable by the parent agent. The +child agent is instantiated at registration time.""" + name: str | None + description: str | None + preserve_context: bool + agent_config: str + + +@dataclass(kw_only=True) +class CallToolArgs: + """Arguments for a single tool call.""" + name: str + input: str + tool_use_id: str + + +class ToolChoice: + """Policy controlling whether and how the model calls tools on the next +generation step.""" + + class Auto(_WitVariantCase): + """Model decides whether to call a tool.""" + tag = 'auto' + + class Any(_WitVariantCase): + """Model must call at least one tool.""" + tag = 'any' + + class Named(_WitVariantCase): + """Model must call the tool with this name.""" + tag = 'named' + + _CASES: dict[str, type] = { + 'auto': Auto, + 'any': Any, + 'named': Named, + } + + @staticmethod + def lift(raw: _WitVariant) -> ToolChoice: + cls = ToolChoice._CASES.get(raw.tag) + if cls is None: + raise ValueError(f'unknown ToolChoice arm: {raw.tag!r}') + return cls(raw.payload) + +class ToolEventStream: + """Pull-based stream of tool events. Sync-WIT placeholder for +`stream`.""" + # Wraps a wasmtime-py ResourceAny / ResourceHost handle. + # The runtime sets ._handle to the underlying resource and + # ._invoke to a callable that dispatches a method by WIT name. + + def __init__(self, handle: Any, invoke: Any = None) -> None: + self._handle = handle + self._invoke = invoke + + def read(self) -> ToolStreamEvent | None: + return self._invoke('[method]tool-event-stream.read', (self._handle,)) + + +class ToolError: + """Why a tool call failed.""" + + class Unknown(_WitVariantCase): + """No tool registered under the given name.""" + tag = 'unknown' + + class InvalidInput(_WitVariantCase): + """Tool input didn't match the declared input schema.""" + tag = 'invalid-input' + + class ExecutionFailed(_WitVariantCase): + """Tool ran but returned an error result.""" + tag = 'execution-failed' + + class TimedOut(_WitVariantCase): + """Tool exceeded its time budget.""" + tag = 'timed-out' + + class Cancelled(_WitVariantCase): + """Tool was cancelled before completion.""" + tag = 'cancelled' + + class Internal(_WitVariantCase): + """Catch-all for internal failures.""" + tag = 'internal' + + _CASES: dict[str, type] = { + 'unknown': Unknown, + 'invalid-input': InvalidInput, + 'execution-failed': ExecutionFailed, + 'timed-out': TimedOut, + 'cancelled': Cancelled, + 'internal': Internal, + } + + @staticmethod + def lift(raw: _WitVariant) -> ToolError: + cls = ToolError._CASES.get(raw.tag) + if cls is None: + raise ValueError(f'unknown ToolError arm: {raw.tag!r}') + return cls(raw.payload) + +ToolStreamEvent = str | list[ToolResultContent] | ToolError +"""Incremental event emitted by a streaming tool while running.""" diff --git a/strands-py-wasm/src/strands/_generated/strands_agent/vended.py b/strands-py-wasm/src/strands/_generated/strands_agent/vended.py new file mode 100644 index 0000000000..df586e0812 --- /dev/null +++ b/strands-py-wasm/src/strands/_generated/strands_agent/vended.py @@ -0,0 +1,116 @@ +"""Auto-generated by bindgen. Do not edit.""" +# ruff: noqa: E501, F401, I001 +# fmt: off + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Optional, Union + +from wasmtime.component import Variant as _WitVariant +from wasmtime.component import VariantCase as _WitVariantCase + + +@dataclass(kw_only=True) +class BashTool: + """Bash tool configuration.""" + default_timeout_s: int | None + + +@dataclass(kw_only=True) +class FileEditorTool: + """File editor tool configuration.""" + workspace_root: str | None + + +@dataclass(kw_only=True) +class HttpRequestTool: + """HTTP request tool configuration.""" + allowed_hosts: list[str] + max_response_bytes: int + + +@dataclass(kw_only=True) +class NotebookTool: + """Notebook tool configuration.""" + workspace_root: str | None + + +class VendedTool: + """Built-in tools.""" + + class Bash(_WitVariantCase): + """Run shell commands in a persistent bash session.""" + tag = 'bash' + + class FileEditor(_WitVariantCase): + """Create, view, and edit files on disk.""" + tag = 'file-editor' + + class HttpRequest(_WitVariantCase): + """Make HTTP requests.""" + tag = 'http-request' + + class Notebook(_WitVariantCase): + """Read and execute Jupyter notebook cells.""" + tag = 'notebook' + + _CASES: dict[str, type] = { + 'bash': Bash, + 'file-editor': FileEditor, + 'http-request': HttpRequest, + 'notebook': Notebook, + } + + @staticmethod + def lift(raw: _WitVariant) -> VendedTool: + cls = VendedTool._CASES.get(raw.tag) + if cls is None: + raise ValueError(f'unknown VendedTool arm: {raw.tag!r}') + return cls(raw.payload) + +@dataclass(kw_only=True) +class SkillSource: + """Location of a skill definition on disk.""" + path: str + + +@dataclass(kw_only=True) +class AgentSkills: + """Skills plugin configuration.""" + skills: list[SkillSource] + strict: bool + max_resource_files: int | None + state_key: str | None + + +@dataclass(kw_only=True) +class ContextOffloader: + """Context offloader plugin configuration.""" + max_result_tokens: int | None + preview_tokens: int | None + include_retrieval_tool: bool + + +class VendedPlugin: + """Built-in plugins.""" + + class Skills(_WitVariantCase): + """Load and activate Anthropic-style skills from disk.""" + tag = 'skills' + + class ContextOffloader(_WitVariantCase): + """Offload large tool results to external storage.""" + tag = 'context-offloader' + + _CASES: dict[str, type] = { + 'skills': Skills, + 'context-offloader': ContextOffloader, + } + + @staticmethod + def lift(raw: _WitVariant) -> VendedPlugin: + cls = VendedPlugin._CASES.get(raw.tag) + if cls is None: + raise ValueError(f'unknown VendedPlugin arm: {raw.tag!r}') + return cls(raw.payload) diff --git a/strands-py-wasm/src/strands/_generated/wasi_clocks/__init__.py b/strands-py-wasm/src/strands/_generated/wasi_clocks/__init__.py new file mode 100644 index 0000000000..24687757ce --- /dev/null +++ b/strands-py-wasm/src/strands/_generated/wasi_clocks/__init__.py @@ -0,0 +1 @@ +"""Auto-generated by bindgen. Do not edit.""" diff --git a/strands-py-wasm/src/strands/_generated/wasi_clocks/monotonic_clock.py b/strands-py-wasm/src/strands/_generated/wasi_clocks/monotonic_clock.py new file mode 100644 index 0000000000..7fa4e8ddec --- /dev/null +++ b/strands-py-wasm/src/strands/_generated/wasi_clocks/monotonic_clock.py @@ -0,0 +1,18 @@ +"""Auto-generated by bindgen. Do not edit.""" +# ruff: noqa: E501, F401, I001 +# fmt: off + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Optional, Union + +from wasmtime.component import Variant as _WitVariant +from wasmtime.component import VariantCase as _WitVariantCase + +from ..wasi_io.poll import Pollable + + +Instant = int + +Duration = int diff --git a/strands-py-wasm/src/strands/_generated/wasi_clocks/wall_clock.py b/strands-py-wasm/src/strands/_generated/wasi_clocks/wall_clock.py new file mode 100644 index 0000000000..2211eb5ae9 --- /dev/null +++ b/strands-py-wasm/src/strands/_generated/wasi_clocks/wall_clock.py @@ -0,0 +1,18 @@ +"""Auto-generated by bindgen. Do not edit.""" +# ruff: noqa: E501, F401, I001 +# fmt: off + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Optional, Union + +from wasmtime.component import Variant as _WitVariant +from wasmtime.component import VariantCase as _WitVariantCase + + +@dataclass(kw_only=True) +class Datetime: + """A time and date in seconds plus nanoseconds.""" + seconds: int + nanoseconds: int diff --git a/strands-py-wasm/src/strands/_generated/wasi_io/__init__.py b/strands-py-wasm/src/strands/_generated/wasi_io/__init__.py new file mode 100644 index 0000000000..24687757ce --- /dev/null +++ b/strands-py-wasm/src/strands/_generated/wasi_io/__init__.py @@ -0,0 +1 @@ +"""Auto-generated by bindgen. Do not edit.""" diff --git a/strands-py-wasm/src/strands/_generated/wasi_io/error.py b/strands-py-wasm/src/strands/_generated/wasi_io/error.py new file mode 100644 index 0000000000..79b14b6d7f --- /dev/null +++ b/strands-py-wasm/src/strands/_generated/wasi_io/error.py @@ -0,0 +1,41 @@ +"""Auto-generated by bindgen. Do not edit.""" +# ruff: noqa: E501, F401, I001 +# fmt: off + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Optional, Union + +from wasmtime.component import Variant as _WitVariant +from wasmtime.component import VariantCase as _WitVariantCase + + +class Error: + """A resource which represents some error information. + +The only method provided by this resource is `to-debug-string`, +which provides some human-readable information about the error. + +In the `wasi:io` package, this resource is returned through the +`wasi:io/streams/stream-error` type. + +To provide more specific error information, other interfaces may +offer functions to "downcast" this error into more specific types. For example, +errors returned from streams derived from filesystem types can be described using +the filesystem's own error-code type. This is done using the function +`wasi:filesystem/types/filesystem-error-code`, which takes a `borrow` +parameter and returns an `option`. + +The set of functions which can "downcast" an `error` into a more +concrete type is open.""" + # Wraps a wasmtime-py ResourceAny / ResourceHost handle. + # The runtime sets ._handle to the underlying resource and + # ._invoke to a callable that dispatches a method by WIT name. + + def __init__(self, handle: Any, invoke: Any = None) -> None: + self._handle = handle + self._invoke = invoke + + def to_debug_string(self) -> str: + return self._invoke('[method]error.to-debug-string', (self._handle,)) diff --git a/strands-py-wasm/src/strands/_generated/wasi_io/poll.py b/strands-py-wasm/src/strands/_generated/wasi_io/poll.py new file mode 100644 index 0000000000..2c8ddbcbb5 --- /dev/null +++ b/strands-py-wasm/src/strands/_generated/wasi_io/poll.py @@ -0,0 +1,28 @@ +"""Auto-generated by bindgen. Do not edit.""" +# ruff: noqa: E501, F401, I001 +# fmt: off + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Optional, Union + +from wasmtime.component import Variant as _WitVariant +from wasmtime.component import VariantCase as _WitVariantCase + + +class Pollable: + """`pollable` represents a single I/O event which may be ready, or not.""" + # Wraps a wasmtime-py ResourceAny / ResourceHost handle. + # The runtime sets ._handle to the underlying resource and + # ._invoke to a callable that dispatches a method by WIT name. + + def __init__(self, handle: Any, invoke: Any = None) -> None: + self._handle = handle + self._invoke = invoke + + def ready(self) -> bool: + return self._invoke('[method]pollable.ready', (self._handle,)) + + def block(self) -> None: + return self._invoke('[method]pollable.block', (self._handle,)) diff --git a/strands-py-wasm/src/strands/_generated/wasi_io/streams.py b/strands-py-wasm/src/strands/_generated/wasi_io/streams.py new file mode 100644 index 0000000000..61c8506698 --- /dev/null +++ b/strands-py-wasm/src/strands/_generated/wasi_io/streams.py @@ -0,0 +1,102 @@ +"""Auto-generated by bindgen. Do not edit.""" +# ruff: noqa: E501, F401, I001 +# fmt: off + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Optional, Union + +from wasmtime.component import Variant as _WitVariant +from wasmtime.component import VariantCase as _WitVariantCase + +from .error import Error +from .poll import Pollable + + +StreamError = Any | None +"""An error for input-stream and output-stream operations.""" + +class InputStream: + """An input bytestream. + +`input-stream`s are *non-blocking* to the extent practical on underlying +platforms. I/O operations always return promptly; if fewer bytes are +promptly available than requested, they return the number of bytes promptly +available, which could even be zero. To wait for data to be available, +use the `subscribe` function to obtain a `pollable` which can be polled +for using `wasi:io/poll`.""" + # Wraps a wasmtime-py ResourceAny / ResourceHost handle. + # The runtime sets ._handle to the underlying resource and + # ._invoke to a callable that dispatches a method by WIT name. + + def __init__(self, handle: Any, invoke: Any = None) -> None: + self._handle = handle + self._invoke = invoke + + def read(self, len: int) -> Any: + return self._invoke('[method]input-stream.read', (self._handle, len,)) + + def blocking_read(self, len: int) -> Any: + return self._invoke('[method]input-stream.blocking-read', (self._handle, len,)) + + def skip(self, len: int) -> Any: + return self._invoke('[method]input-stream.skip', (self._handle, len,)) + + def blocking_skip(self, len: int) -> Any: + return self._invoke('[method]input-stream.blocking-skip', (self._handle, len,)) + + def subscribe(self) -> Any: + return self._invoke('[method]input-stream.subscribe', (self._handle,)) + + +class OutputStream: + """An output bytestream. + +`output-stream`s are *non-blocking* to the extent practical on +underlying platforms. Except where specified otherwise, I/O operations also +always return promptly, after the number of bytes that can be written +promptly, which could even be zero. To wait for the stream to be ready to +accept data, the `subscribe` function to obtain a `pollable` which can be +polled for using `wasi:io/poll`. + +Dropping an `output-stream` while there's still an active write in +progress may result in the data being lost. Before dropping the stream, +be sure to fully flush your writes.""" + # Wraps a wasmtime-py ResourceAny / ResourceHost handle. + # The runtime sets ._handle to the underlying resource and + # ._invoke to a callable that dispatches a method by WIT name. + + def __init__(self, handle: Any, invoke: Any = None) -> None: + self._handle = handle + self._invoke = invoke + + def check_write(self) -> Any: + return self._invoke('[method]output-stream.check-write', (self._handle,)) + + def write(self, contents: bytes) -> Any: + return self._invoke('[method]output-stream.write', (self._handle, contents,)) + + def blocking_write_and_flush(self, contents: bytes) -> Any: + return self._invoke('[method]output-stream.blocking-write-and-flush', (self._handle, contents,)) + + def flush(self) -> Any: + return self._invoke('[method]output-stream.flush', (self._handle,)) + + def blocking_flush(self) -> Any: + return self._invoke('[method]output-stream.blocking-flush', (self._handle,)) + + def subscribe(self) -> Any: + return self._invoke('[method]output-stream.subscribe', (self._handle,)) + + def write_zeroes(self, len: int) -> Any: + return self._invoke('[method]output-stream.write-zeroes', (self._handle, len,)) + + def blocking_write_zeroes_and_flush(self, len: int) -> Any: + return self._invoke('[method]output-stream.blocking-write-zeroes-and-flush', (self._handle, len,)) + + def splice(self, src: Any, len: int) -> Any: + return self._invoke('[method]output-stream.splice', (self._handle, src, len,)) + + def blocking_splice(self, src: Any, len: int) -> Any: + return self._invoke('[method]output-stream.blocking-splice', (self._handle, src, len,)) diff --git a/strands-py-wasm/src/strands/_runtime.py b/strands-py-wasm/src/strands/_runtime.py index bd9580dacc..54b629b71f 100644 --- a/strands-py-wasm/src/strands/_runtime.py +++ b/strands-py-wasm/src/strands/_runtime.py @@ -36,7 +36,7 @@ Variant, ) -from . import _generated as _t +from . import types as _t if TYPE_CHECKING: from . import Agent @@ -94,24 +94,23 @@ def _get_component() -> Component: class _HostToolEventStream: - """Host-side tool-event-stream backing a single ``call-tool`` invocation. + """Host-side queue backing a single ``call-tool`` invocation. - Holds a queue of WIT ``tool-stream-event`` Variants. ``read()`` returns - them one at a time, then ``None`` after the terminal ``complete`` / - ``error`` event has been delivered. + Holds a queue of ``tool-stream-event`` payloads. ``read()`` returns them + one at a time, then ``None`` after the terminal event is delivered. """ def __init__(self) -> None: - self._events: deque[Variant] = deque() + self._events: deque[Any] = deque() self._closed = False - def push(self, event: Variant) -> None: + def push(self, event: Any) -> None: self._events.append(event) def close(self) -> None: self._closed = True - def read(self) -> Variant | None: + def read(self) -> Any: if self._events: return self._events.popleft() return None @@ -165,11 +164,12 @@ def call_tool(store: Any, args: Any) -> ResourceHost: content_list = tool.invoke(raw_input) except Exception as exc: # noqa: BLE001 surface any tool exception as a tool-error logger.exception("tool %r raised; surfacing as tool-error to guest", name) - stream.push(_t.ToolStreamEvent.error(_t.ToolError.execution_failed(str(exc)))) + # tool-stream-event is untagged: push the bare arm payload. + stream.push(_t.ToolError.ExecutionFailed(str(exc))) stream.close() else: - # content_list items are already ToolResultContent wire variants. - stream.push(_t.ToolStreamEvent.complete(content_list)) + # content_list is already a list[ToolResultContent]. + stream.push(content_list) stream.close() # Register, then hand ownership to the guest. On failure, drop the rep @@ -212,12 +212,14 @@ def _tool_event_stream_read(store: Any, handle: ResourceAny) -> Variant | None: rep = host.rep stream = registry.lookup(rep) return stream.read() + return _tool_event_stream_read def _trap(name: str): def _f(*_a, **_k): raise RuntimeError(f"host import not implemented: {name}") + return _f @@ -248,8 +250,14 @@ def _register_imports(linker: Linker, agent: Agent, registry: _ToolStreamRegistr ns.add_func("count-tokens", _trap("model-provider.count-tokens")) with root.add_instance("strands:agent/snapshot-storage@0.1.0") as ns: - for fname in ("save-snapshot", "load-snapshot", "list-snapshot-ids", - "delete-session", "load-manifest", "save-manifest"): + for fname in ( + "save-snapshot", + "load-snapshot", + "list-snapshot-ids", + "delete-session", + "load-manifest", + "save-manifest", + ): ns.add_func(fname, _trap(f"snapshot-storage.{fname}")) with root.add_instance("strands:agent/snapshot-trigger-handler@0.1.0") as ns: @@ -264,9 +272,8 @@ def _register_imports(linker: Linker, agent: Agent, registry: _ToolStreamRegistr # --- Store + Linker ----------------------------------------------------- -def _make_store_and_linker( - agent: Agent, registry: _ToolStreamRegistry -) -> tuple[Store, Linker]: + +def _make_store_and_linker(agent: Agent, registry: _ToolStreamRegistry) -> tuple[Store, Linker]: engine = _get_engine() store = Store(engine) wasi = WasiConfig() diff --git a/strands-py-wasm/src/strands/types.py b/strands-py-wasm/src/strands/types.py new file mode 100644 index 0000000000..7239b83baf --- /dev/null +++ b/strands-py-wasm/src/strands/types.py @@ -0,0 +1,67 @@ +"""Public type surface for the Strands SDK. + +Re-exports the wire types from :mod:`strands._generated` (machine-written by +``wasmtime.component.bindgen``). Users import wire types from here, not from +``_generated`` directly. + +To remove a type from the public surface, add its name to ``_DROPPED``. The +underlying generated module stays untouched. + +The ``*Input`` aliases below name the unions accepted at the Agent +boundary, where we auto-coerce a bare payload into its variant arm +(e.g. ``BedrockModel(...)`` becomes ``ModelConfig.Bedrock(...)``). +Everywhere else the caller wraps explicitly. +""" + +from __future__ import annotations + +from strands._generated import * # noqa: F401, F403 +from strands._generated import __all__ as _generated_all +from strands._generated.strands_agent.conversation import ( + ConversationManagerConfig, + SlidingWindowConversationManager, + SummarizingConversationManager, +) +from strands._generated.strands_agent.models import ( + AnthropicModel, + BedrockModel, + CustomModel, + GoogleModel, + ModelConfig, + OpenaiModel, +) +from strands._generated.strands_agent.vended import ( + AgentSkills, + BashTool, + ContextOffloader, + FileEditorTool, + HttpRequestTool, + NotebookTool, + VendedPlugin, + VendedTool, +) + +ModelInput = ModelConfig | BedrockModel | AnthropicModel | OpenaiModel | GoogleModel | CustomModel +ConversationManagerInput = ConversationManagerConfig | SlidingWindowConversationManager | SummarizingConversationManager +VendedToolInput = VendedTool | BashTool | FileEditorTool | HttpRequestTool | NotebookTool +VendedPluginInput = VendedPlugin | AgentSkills | ContextOffloader + +_DROPPED: set[str] = { + "Datetime", + "Duration", + "Error", + "InputStream", + "Instant", + "OutputStream", + "Pollable", +} + +__all__ = sorted( # pyright: ignore[reportUnsupportedDunderAll] + (set(_generated_all) - _DROPPED) + | { + "ConversationManagerInput", + "ModelInput", + "VendedPluginInput", + "VendedToolInput", + } +) diff --git a/strands-wasm/entry.ts b/strands-wasm/entry.ts index 0cc9ea882b..5e1f5b6ec3 100644 --- a/strands-wasm/entry.ts +++ b/strands-wasm/entry.ts @@ -57,7 +57,6 @@ import type { } from '@strands-agents/sdk' import { ConversationManager, - NullConversationManager, SlidingWindowConversationManager, SummarizingConversationManager, } from '@strands-agents/sdk' @@ -447,8 +446,6 @@ function createConversationManager(config: AgentConfig): ConversationManager | u const cm = config.conversationManager if (!cm) return undefined switch (cm.tag) { - case 'none': - return new NullConversationManager() case 'sliding-window': return new SlidingWindowConversationManager({ windowSize: cm.val.windowSize, diff --git a/wit/agent.wit b/wit/agent.wit index d6d8b76ece..59193f7790 100644 --- a/wit/agent.wit +++ b/wit/agent.wit @@ -6,7 +6,7 @@ interface api { use messages.{message, content-block, prompt-input}; use models.{model-config, model-params}; use tools.{tool-spec, tool-choice, agent-as-tool-config}; - use sessions.{session-config, storage-error}; + use sessions.{session-manager, storage-error}; use conversation.{conversation-manager-config}; use retry.{retry-config}; use streaming.{stream-event, agent-trace, agent-metrics}; @@ -98,7 +98,7 @@ interface api { /// W3C Trace Context linking the agent's spans to a caller-supplied trace. trace-context: option, /// Session persistence. Absent means no persistence. - session: option, + session: option, /// Conversation history management. Defaults to a sliding window. conversation-manager: option, /// Retry policy for failed model calls. Defaults to exponential backoff capped at 6 attempts. @@ -199,9 +199,9 @@ world agent { import tool-provider; /// Receives structured log entries from the agent. import host-log; - /// Custom snapshot storage. Selected via `session-config.storage = custom`. + /// Custom snapshot storage. Selected via `session-manager.storage = custom`. import snapshot-storage; - /// Custom snapshot policy. Selected via `session-config.save-latest = trigger(id)`. + /// Custom snapshot policy. Selected via `session-manager.save-latest = trigger(id)`. import snapshot-trigger-handler; /// Custom model provider. Selected via `model-config.custom`. import model-provider; diff --git a/wit/conversation.wit b/wit/conversation.wit index 0c79228e04..1cacb18556 100644 --- a/wit/conversation.wit +++ b/wit/conversation.wit @@ -6,7 +6,7 @@ interface conversation { /// Sliding-window strategy: trim oldest messages once the conversation /// exceeds `window-size`. - record sliding-window-config { + record sliding-window-conversation-manager { /// Maximum number of messages retained. window-size: s32, /// Drop older tool results when trimming. @@ -15,7 +15,7 @@ interface conversation { /// Summarizing strategy: once the conversation grows, summarize older /// messages into a single summary message and keep the rest verbatim. - record summarizing-config { + record summarizing-conversation-manager { /// Fraction of messages to summarize. Must be between 0.1 and 0.8; /// out-of-range values surface on the first `generate` call as /// `agent-error::invalid-input`. @@ -28,14 +28,13 @@ interface conversation { summarization-model: option, } - /// Which conversation manager the agent uses. + /// Which conversation manager the agent uses. Wrapped in + /// ``option<>`` at the call site; ``none`` means history grows without + /// bound and context-overflow errors propagate to the caller. variant conversation-manager-config { - /// No conversation management. History grows without bound and - /// context-overflow errors propagate to the caller. - none, /// Sliding-window trimming. - sliding-window(sliding-window-config), + sliding-window(sliding-window-conversation-manager), /// Summarization of older messages. - summarizing(summarizing-config), + summarizing(summarizing-conversation-manager), } } diff --git a/wit/mcp.wit b/wit/mcp.wit index 28d469ed71..e2feecd0ae 100644 --- a/wit/mcp.wit +++ b/wit/mcp.wit @@ -18,15 +18,15 @@ interface mcp { /// How the client talks to the MCP server. variant mcp-transport { /// STDIO transport. Spawn a local process and talk via pipes. - stdio(stdio-transport-config), + stdio(stdio-transport), /// Streamable HTTP transport, per the current MCP specification. - streamable-http(http-transport-config), + streamable-http(http-transport), /// Legacy Server-Sent Events transport. Retained for older servers. - sse(sse-transport-config), + sse(sse-transport), } /// STDIO transport configuration. - record stdio-transport-config { + record stdio-transport { /// Command to execute. command: string, /// Arguments passed to the command. @@ -46,7 +46,7 @@ interface mcp { } /// HTTP transport configuration. - record http-transport-config { + record http-transport { /// Server endpoint URL. url: string, /// Extra HTTP headers. @@ -54,7 +54,7 @@ interface mcp { } /// SSE transport configuration. - record sse-transport-config { + record sse-transport { /// Server endpoint URL. url: string, /// Extra HTTP headers. diff --git a/wit/models.wit b/wit/models.wit index 9a6be6e83c..bf03eefa3b 100644 --- a/wit/models.wit +++ b/wit/models.wit @@ -3,7 +3,7 @@ package strands:agent@0.1.0; /// Model provider configuration and pluggable custom providers. interface models { /// Anthropic API model configuration. - record anthropic-config { + record anthropic-model { /// Model identifier, e.g. `claude-opus-4-7`. model-id: option, /// API key. Falls back to the `ANTHROPIC_API_KEY` environment variable. @@ -13,7 +13,7 @@ interface models { } /// AWS Bedrock model configuration. - record bedrock-config { + record bedrock-model { /// Bedrock model identifier, e.g. `us.anthropic.claude-opus-4-7-v1:0`. model-id: string, /// AWS region. Falls back to the default credential chain. @@ -29,7 +29,7 @@ interface models { } /// OpenAI API model configuration. - record openai-config { + record openai-model { /// Model identifier, e.g. `gpt-4o`. model-id: option, /// API key. Falls back to the `OPENAI_API_KEY` environment variable. @@ -39,7 +39,7 @@ interface models { } /// Google Gemini API model configuration. - record gemini-config { + record google-model { /// Model identifier, e.g. `gemini-2.0-flash`. model-id: option, /// API key. Falls back to the `GOOGLE_API_KEY` environment variable. @@ -49,7 +49,7 @@ interface models { } /// Custom model provider supplied by your application. - record custom-model-config { + record custom-model { /// Identifier routed back on each call. One application can register /// multiple providers under distinct ids. provider-id: string, @@ -64,16 +64,16 @@ interface models { /// Which model provider the agent should use. variant model-config { /// Anthropic API. - anthropic(anthropic-config), + anthropic(anthropic-model), /// AWS Bedrock. - bedrock(bedrock-config), + bedrock(bedrock-model), /// OpenAI API. - openai(openai-config), + openai(openai-model), /// Google Gemini API. - gemini(gemini-config), + gemini(google-model), /// Custom provider supplied by your application. Implement the /// `model-provider` interface to serve it. - custom(custom-model-config), + custom(custom-model), } /// Sampling parameters applied to every call on the chosen provider. diff --git a/wit/multiagent.wit b/wit/multiagent.wit index 84383e00c1..d892110eb4 100644 --- a/wit/multiagent.wit +++ b/wit/multiagent.wit @@ -39,7 +39,7 @@ interface multi-agent { } /// Definition of an agent-backed node. - record agent-node-config { + record agent-node { /// Node identifier, unique within its graph/swarm. id: string, /// Human-readable description. @@ -52,7 +52,7 @@ interface multi-agent { } /// Definition of a node that wraps another orchestrator. - record multi-agent-node-config { + record multi-agent-node { /// Node identifier, unique within its parent graph/swarm. id: string, /// Human-readable description. @@ -65,9 +65,9 @@ interface multi-agent { /// Any node a graph or swarm can execute. variant node-config { /// Wraps a single agent. - agent(agent-node-config), + agent(agent-node), /// Wraps a nested orchestrator. - multi-agent(multi-agent-node-config), + multi-agent(multi-agent-node), } /// Condition attached to a graph edge. @@ -112,7 +112,7 @@ interface multi-agent { /// Identifier of this swarm. id: string, /// Agent-backed nodes available for handoff. - nodes: list, + nodes: list, /// Agent that runs first. start-node-id: string, /// Max total agent executions. Absent means no limit. diff --git a/wit/retry.wit b/wit/retry.wit index aac70815cd..55617d0e6b 100644 --- a/wit/retry.wit +++ b/wit/retry.wit @@ -17,13 +17,13 @@ interface retry { } /// Fixed delay between attempts. - record constant-backoff-config { + record constant-backoff { /// Delay returned for every retry. delay: duration, } /// Delay grows linearly with attempt number. - record linear-backoff-config { + record linear-backoff { /// Base delay. Delay on attempt N is `base * N`. base: duration, /// Upper bound applied before jitter. @@ -33,7 +33,7 @@ interface retry { } /// Delay grows exponentially with attempt number. - record exponential-backoff-config { + record exponential-backoff { /// Base delay on the first retry. base: duration, /// Upper bound applied before jitter. @@ -47,11 +47,11 @@ interface retry { /// Backoff curve applied between attempts. variant backoff-strategy { /// Fixed delay. - constant(constant-backoff-config), + constant(constant-backoff), /// Linear growth. - linear(linear-backoff-config), + linear(linear-backoff), /// Exponential growth. - exponential(exponential-backoff-config), + exponential(exponential-backoff), } /// A single retry strategy. Default is exponential backoff, full jitter, 6 attempts. diff --git a/wit/sessions.wit b/wit/sessions.wit index 921f6115fc..b098c71cc8 100644 --- a/wit/sessions.wit +++ b/wit/sessions.wit @@ -6,13 +6,13 @@ interface sessions { use messages.{message}; /// Local filesystem snapshot storage. - record file-storage-config { + record file-storage { /// Directory under which snapshots are written. base-dir: string, } /// S3 snapshot storage. - record s3-storage-config { + record s3-storage { /// Target bucket. bucket: string, /// AWS region. Falls back to the default credential chain. @@ -22,7 +22,7 @@ interface sessions { } /// Reference to an application-implemented storage backend. - record custom-storage-config { + record custom-storage { /// Identifier routed back to the `snapshot-storage` handler on every /// call. One application can register multiple backends under /// distinct ids. @@ -32,11 +32,11 @@ interface sessions { /// Where to persist session snapshots. variant storage-config { /// Local filesystem. - file(file-storage-config), + file(file-storage), /// Amazon S3. - s3(s3-storage-config), + s3(s3-storage), /// Application-implemented backend. - custom(custom-storage-config), + custom(custom-storage), } /// When to update the "latest" snapshot pointer. The `trigger` arm @@ -53,7 +53,7 @@ interface sessions { } /// Session persistence configuration attached to an agent. - record session-config { + record session-manager { /// Identifier for this session's snapshots. session-id: string, /// Storage backend. @@ -96,11 +96,9 @@ interface sessions { removed-message-count: s32, } - /// Conversation manager snapshot state. Which arm is populated depends - /// on the conversation manager the agent was built with. + /// Conversation manager snapshot state. Wrapped in ``option<>`` at the + /// call site; ``none`` means there's no manager and nothing to persist. variant conversation-manager-state { - /// No conversation manager or null manager; nothing to persist. - none, /// Sliding-window manager state. sliding-window(sliding-window-state), /// Summarizing manager state. @@ -270,7 +268,7 @@ interface snapshot-storage { } /// Pluggable snapshot trigger called after each invocation. -/// Enabled via `session-config.save-latest = trigger(id)`. +/// Enabled via `session-manager.save-latest = trigger(id)`. interface snapshot-trigger-handler { use messages.{message}; diff --git a/wit/vended.wit b/wit/vended.wit index 7b7088e852..9fb0f81534 100644 --- a/wit/vended.wit +++ b/wit/vended.wit @@ -7,30 +7,30 @@ interface vended { /// Built-in tools. variant vended-tool { /// Run shell commands in a persistent bash session. - bash(bash-tool-config), + bash(bash-tool), /// Create, view, and edit files on disk. - file-editor(file-editor-tool-config), + file-editor(file-editor-tool), /// Make HTTP requests. - http-request(http-request-tool-config), + http-request(http-request-tool), /// Read and execute Jupyter notebook cells. - notebook(notebook-tool-config), + notebook(notebook-tool), } /// Bash tool configuration. - record bash-tool-config { + record bash-tool { /// Default timeout for `execute` calls, in seconds. default-timeout-s: option, } /// File editor tool configuration. - record file-editor-tool-config { + record file-editor-tool { /// Directory outside of which the tool refuses to operate. Absent /// permits any path the sandbox grants. workspace-root: option, } /// HTTP request tool configuration. - record http-request-tool-config { + record http-request-tool { /// Hosts the tool is allowed to reach. Empty permits any host the /// sandbox permits. allowed-hosts: list, @@ -39,7 +39,7 @@ interface vended { } /// Notebook tool configuration. - record notebook-tool-config { + record notebook-tool { /// Directory outside of which the tool refuses to operate. Absent /// permits any path the sandbox grants. workspace-root: option, @@ -52,7 +52,7 @@ interface vended { } /// Skills plugin configuration. - record skills-plugin-config { + record agent-skills { /// Skill sources to load. skills: list, /// Fail if a skill cannot be loaded. @@ -64,7 +64,7 @@ interface vended { } /// Context offloader plugin configuration. - record context-offloader-plugin-config { + record context-offloader { /// Token threshold at which tool results are offloaded. max-result-tokens: option, /// Tokens to keep inline when offloading (as a preview). @@ -77,8 +77,8 @@ interface vended { /// Built-in plugins. variant vended-plugin { /// Load and activate Anthropic-style skills from disk. - skills(skills-plugin-config), + skills(agent-skills), /// Offload large tool results to external storage. - context-offloader(context-offloader-plugin-config), + context-offloader(context-offloader), } } From a0fb25c7151987bbd8a1bfe9d876dbf839421c47 Mon Sep 17 00:00:00 2001 From: Gautam Sirdeshmukh <54588697+gautamsirdeshmukh@users.noreply.github.com> Date: Thu, 28 May 2026 14:04:48 -0400 Subject: [PATCH 3/4] fix: revert MCP from JSON to update node imports (#1113) Co-authored-by: Gautam Sirdeshmukh --- AGENTS.md | 1 - .../src/__tests__/mcp-config.test.node.ts | 404 ------------------ strands-ts/src/index.ts | 1 - strands-ts/src/mcp-config.ts | 188 -------- strands-ts/src/mcp.ts | 15 - 5 files changed, 609 deletions(-) delete mode 100644 strands-ts/src/__tests__/mcp-config.test.node.ts delete mode 100644 strands-ts/src/mcp-config.ts diff --git a/AGENTS.md b/AGENTS.md index 1c16778b78..8db2f7e06e 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -225,7 +225,6 @@ sdk-typescript/ │ │ ├── index.ts # Main SDK entry point │ │ ├── interrupt.ts # Interrupt handling │ │ ├── mcp.ts # MCP client implementation -│ │ ├── mcp-config.ts # MCP config file parsing │ │ ├── mime.ts # MIME type utilities │ │ └── state-store.ts # State store implementation │ │ diff --git a/strands-ts/src/__tests__/mcp-config.test.node.ts b/strands-ts/src/__tests__/mcp-config.test.node.ts deleted file mode 100644 index 095311dc59..0000000000 --- a/strands-ts/src/__tests__/mcp-config.test.node.ts +++ /dev/null @@ -1,404 +0,0 @@ -import { describe, it, expect, vi, beforeEach } from 'vitest' -import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js' -import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js' -import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js' -import { McpClient } from '../mcp.js' - -vi.mock('node:fs/promises', () => ({ - readFile: vi.fn(), -})) - -vi.mock('node:os', () => ({ - homedir: vi.fn(() => '/home/user'), -})) - -vi.mock('node:path', () => ({ - join: (...segments: string[]) => segments.join('/'), -})) - -vi.mock('@modelcontextprotocol/sdk/client/stdio.js', () => ({ - StdioClientTransport: vi.fn(function () {}), - getDefaultEnvironment: vi.fn(() => ({ PATH: '/usr/bin', HOME: '/home/user' })), -})) - -vi.mock('@modelcontextprotocol/sdk/client/streamableHttp.js', () => ({ - StreamableHTTPClientTransport: vi.fn(function () {}), -})) - -vi.mock('@modelcontextprotocol/sdk/client/sse.js', () => ({ - SSEClientTransport: vi.fn(function () {}), -})) - -vi.mock('@modelcontextprotocol/sdk/client/index.js', () => ({ - Client: vi.fn(function (this: Record) { - this.connect = vi.fn() - this.close = vi.fn() - this.listTools = vi.fn() - this.callTool = vi.fn() - this.setRequestHandler = vi.fn() - this.setNotificationHandler = vi.fn() - this.getServerCapabilities = vi.fn() - this.getServerVersion = vi.fn() - this.getInstructions = vi.fn() - this.experimental = { tasks: { callToolStream: vi.fn() } } - }), -})) - -describe('McpClient.loadServers', () => { - beforeEach(() => { - vi.clearAllMocks() - }) - - describe('transport detection', () => { - it('creates StdioClientTransport when command is present', async () => { - const clients = await McpClient.loadServers({ - 'my-server': { command: 'node', args: ['server.js'] }, - }) - - expect(clients).toHaveLength(1) - expect(StdioClientTransport).toHaveBeenCalledWith({ - command: 'node', - args: ['server.js'], - }) - }) - - it('creates McpClient with url when url is present', async () => { - const clients = await McpClient.loadServers({ - 'remote-server': { url: 'https://example.com/mcp' }, - }) - - expect(clients).toHaveLength(1) - expect(StdioClientTransport).not.toHaveBeenCalled() - expect(SSEClientTransport).not.toHaveBeenCalled() - }) - - it('creates SSEClientTransport when transport is "sse"', async () => { - const clients = await McpClient.loadServers({ - 'sse-server': { url: 'https://example.com/sse', transport: 'sse' }, - }) - - expect(clients).toHaveLength(1) - expect(SSEClientTransport).toHaveBeenCalledWith(new URL('https://example.com/sse'), undefined) - }) - - it('explicit transport overrides auto-detection', async () => { - const clients = await McpClient.loadServers({ - server: { url: 'https://example.com/mcp', transport: 'sse' }, - }) - - expect(clients).toHaveLength(1) - expect(SSEClientTransport).toHaveBeenCalled() - expect(StreamableHTTPClientTransport).not.toHaveBeenCalled() - }) - - it('interpolates auth credentials for streamable-http', async () => { - vi.stubEnv('CLIENT_ID', 'my-id') - vi.stubEnv('CLIENT_SECRET', 'my-secret') - - const clients = await McpClient.loadServers({ - server: { - url: 'https://example.com/mcp', - auth: { clientId: '${CLIENT_ID}', clientSecret: '${CLIENT_SECRET}' }, - }, - }) - - expect(clients).toHaveLength(1) - }) - - it('throws when auth credential references missing env var', async () => { - vi.unstubAllEnvs() - - await expect( - McpClient.loadServers({ - server: { - url: 'https://example.com/mcp', - auth: { clientId: '${MISSING_ID}', clientSecret: 'literal' }, - }, - }) - ).rejects.toThrow('Environment variable "MISSING_ID" is not set') - }) - }) - - describe('env interpolation', () => { - it('interpolates ${VAR} in env values', async () => { - vi.stubEnv('MY_SECRET', 'secret123') - - await McpClient.loadServers({ - server: { command: 'node', env: { SECRET: '${MY_SECRET}' } }, - }) - - expect(StdioClientTransport).toHaveBeenCalledWith({ - command: 'node', - env: { PATH: '/usr/bin', HOME: '/home/user', SECRET: 'secret123' }, - }) - }) - - it('interpolates ${VAR} in headers', async () => { - vi.stubEnv('TOKEN', 'abc') - - await McpClient.loadServers({ - server: { url: 'https://example.com/sse', transport: 'sse', headers: { Authorization: 'Bearer ${TOKEN}' } }, - }) - - expect(SSEClientTransport).toHaveBeenCalledWith(new URL('https://example.com/sse'), { - requestInit: { headers: { Authorization: 'Bearer abc' } }, - }) - }) - - it('interpolates ${VAR} in url', async () => { - vi.stubEnv('HOST', 'myhost.com') - - const clients = await McpClient.loadServers({ - server: { url: 'https://${HOST}/mcp', transport: 'sse' }, - }) - - expect(clients).toHaveLength(1) - expect(SSEClientTransport).toHaveBeenCalledWith(new URL('https://myhost.com/mcp'), undefined) - }) - - it('throws when env var is not set', async () => { - vi.unstubAllEnvs() - - await expect( - McpClient.loadServers({ - server: { command: 'node', env: { VAL: '${NONEXISTENT_VAR}' } }, - }) - ).rejects.toThrow('Environment variable "NONEXISTENT_VAR" is not set') - }) - - it('skips server with missing env var when continueOnError is true', async () => { - vi.unstubAllEnvs() - - const clients = await McpClient.loadServers({ - broken: { command: 'node', env: { VAL: '${NONEXISTENT_VAR}' }, continueOnError: true }, - working: { command: 'node' }, - }) - - expect(clients).toHaveLength(1) - expect(clients[0]!.clientName).toBe('working') - }) - - it('interpolates ${VAR} in cwd', async () => { - vi.stubEnv('PROJECT_DIR', '/home/user/projects') - - await McpClient.loadServers({ - server: { command: 'node', cwd: '${PROJECT_DIR}/my-server' }, - }) - - expect(StdioClientTransport).toHaveBeenCalledWith({ - command: 'node', - cwd: '/home/user/projects/my-server', - }) - }) - - it('merges env with default environment', async () => { - await McpClient.loadServers({ - server: { command: 'node', env: { CUSTOM: 'value' } }, - }) - - expect(StdioClientTransport).toHaveBeenCalledWith({ - command: 'node', - env: { PATH: '/usr/bin', HOME: '/home/user', CUSTOM: 'value' }, - }) - }) - }) - - describe('file config loading', () => { - it('reads and parses a JSON file', async () => { - const { readFile } = await import('node:fs/promises') - vi.mocked(readFile).mockResolvedValue( - JSON.stringify({ - 'my-server': { command: 'node', args: ['server.js'] }, - }) - ) - - const clients = await McpClient.loadServers('/path/to/config.json') - - expect(readFile).toHaveBeenCalledWith('/path/to/config.json', 'utf-8') - expect(clients).toHaveLength(1) - }) - - it('extracts mcpServers key when present', async () => { - const { readFile } = await import('node:fs/promises') - vi.mocked(readFile).mockResolvedValue( - JSON.stringify({ - mcpServers: { - 'server-a': { command: 'node' }, - 'server-b': { url: 'https://example.com' }, - }, - }) - ) - - const clients = await McpClient.loadServers('/path/to/config.json') - - expect(clients).toHaveLength(2) - }) - - it('uses whole object when mcpServers key is absent', async () => { - const { readFile } = await import('node:fs/promises') - vi.mocked(readFile).mockResolvedValue( - JSON.stringify({ - 'server-a': { command: 'node' }, - }) - ) - - const clients = await McpClient.loadServers('/path/to/config.json') - - expect(clients).toHaveLength(1) - }) - - it('expands ~ to home directory', async () => { - const { readFile } = await import('node:fs/promises') - vi.mocked(readFile).mockResolvedValue( - JSON.stringify({ - server: { command: 'node' }, - }) - ) - - await McpClient.loadServers('~/config/mcp.json') - - expect(readFile).toHaveBeenCalledWith('/home/user/config/mcp.json', 'utf-8') - }) - }) - - describe('defaults and per-server overrides', () => { - it('per-server continueOnError overrides defaults', async () => { - const clients = await McpClient.loadServers( - { - 'strict-server': { command: 'node', continueOnError: false }, - 'lenient-server': { command: 'node', continueOnError: true }, - }, - { continueOnError: true } - ) - - expect(clients[0]!.continueOnError).toBe(false) - expect(clients[1]!.continueOnError).toBe(true) - }) - - it('applies default continueOnError when server does not override', async () => { - const clients = await McpClient.loadServers({ server: { command: 'node' } }, { continueOnError: true }) - - expect(clients[0]!.continueOnError).toBe(true) - }) - - it('uses server name as applicationName when not in defaults', async () => { - const clients = await McpClient.loadServers({ - 'my-named-server': { command: 'node' }, - }) - - expect(clients[0]!.clientName).toBe('my-named-server') - }) - - it('uses defaults applicationName over server name', async () => { - const clients = await McpClient.loadServers({ server: { command: 'node' } }, { applicationName: 'my-app' }) - - expect(clients[0]!.clientName).toBe('my-app') - }) - }) - - describe('error cases', () => { - it('throws when server has neither command nor url', async () => { - await expect(McpClient.loadServers({ bad: {} })).rejects.toThrow( - 'Server config must include either "command" (stdio) or "url" (http)' - ) - }) - - it('throws when stdio transport specified without command', async () => { - await expect(McpClient.loadServers({ bad: { transport: 'stdio' } })).rejects.toThrow( - 'Stdio transport requires "command" field' - ) - }) - - it('throws when streamable-http transport specified without url', async () => { - await expect(McpClient.loadServers({ bad: { transport: 'streamable-http' } })).rejects.toThrow( - 'Streamable HTTP transport requires "url" field' - ) - }) - - it('throws when sse transport specified without url', async () => { - await expect(McpClient.loadServers({ bad: { transport: 'sse' } })).rejects.toThrow( - 'SSE transport requires "url" field' - ) - }) - - it('throws on invalid file path', async () => { - const { readFile } = await import('node:fs/promises') - vi.mocked(readFile).mockRejectedValue(new Error('ENOENT: no such file or directory')) - - await expect(McpClient.loadServers('/nonexistent/path.json')).rejects.toThrow('ENOENT') - }) - - it('throws on malformed JSON', async () => { - const { readFile } = await import('node:fs/promises') - vi.mocked(readFile).mockResolvedValue('not json{{{') - - await expect(McpClient.loadServers('/path/to/bad.json')).rejects.toThrow() - }) - - it('throws when auth is used with sse transport', async () => { - await expect( - McpClient.loadServers({ - server: { - url: 'https://example.com', - transport: 'sse', - auth: { clientId: 'id', clientSecret: 'secret' }, - }, - }) - ).rejects.toThrow('SSE transport does not support auth') - }) - - it('throws on invalid config shape', async () => { - const { readFile } = await import('node:fs/promises') - vi.mocked(readFile).mockResolvedValue(JSON.stringify([1, 2, 3])) - - await expect(McpClient.loadServers('/path/to/bad.json')).rejects.toThrow('MCP config must be a JSON object') - }) - - it('throws when server has both command and url without explicit transport', async () => { - await expect(McpClient.loadServers({ bad: { command: 'node', url: 'https://example.com' } })).rejects.toThrow( - 'Server config has both "command" and "url"' - ) - }) - }) - - describe('disabled', () => { - it('skips disabled servers', async () => { - const clients = await McpClient.loadServers({ - active: { command: 'node' }, - inactive: { command: 'node', disabled: true }, - }) - - expect(clients).toHaveLength(1) - expect(clients[0]!.clientName).toBe('active') - }) - }) - - describe('env interpolation syntax', () => { - it('supports ${env:VAR} namespaced syntax', async () => { - vi.stubEnv('MY_TOKEN', 'token123') - - await McpClient.loadServers({ - server: { command: 'node', env: { TOKEN: '${env:MY_TOKEN}' } }, - }) - - expect(StdioClientTransport).toHaveBeenCalledWith({ - command: 'node', - env: { PATH: '/usr/bin', HOME: '/home/user', TOKEN: 'token123' }, - }) - }) - - it('interpolates ${VAR} in command and args', async () => { - vi.stubEnv('MY_CMD', '/usr/local/bin/server') - vi.stubEnv('MY_ARG', '3000') - - await McpClient.loadServers({ - server: { command: '${MY_CMD}', args: ['--port=${MY_ARG}'] }, - }) - - expect(StdioClientTransport).toHaveBeenCalledWith({ - command: '/usr/local/bin/server', - args: ['--port=3000'], - }) - }) - }) -}) diff --git a/strands-ts/src/index.ts b/strands-ts/src/index.ts index 52316b23d0..46b1022c00 100644 --- a/strands-ts/src/index.ts +++ b/strands-ts/src/index.ts @@ -283,7 +283,6 @@ export { type McpConnectionState, McpClient, } from './mcp.js' -export { type McpServerConfig } from './mcp-config.js' export type { ElicitationCallback, ElicitationContext } from './types/elicitation.js' // Session management diff --git a/strands-ts/src/mcp-config.ts b/strands-ts/src/mcp-config.ts deleted file mode 100644 index daa42ddb44..0000000000 --- a/strands-ts/src/mcp-config.ts +++ /dev/null @@ -1,188 +0,0 @@ -import type { McpClientConfig, McpClientCredentials, McpClientOptions, McpTransport, TasksConfig } from './mcp.js' -import { logger } from './logging/index.js' - -/** - * Configuration for a single MCP server entry in a config file or object. - * - * Provide either `command` (stdio transport) or `url` (streamable-http/SSE), not both. - * When `transport` is omitted, it is auto-detected from the fields present. - */ -export interface McpServerConfig { - /** Command to spawn (stdio transport, supports `${VAR}` or `${env:VAR}` interpolation). */ - command?: string - /** Arguments passed to the command (supports `${VAR}` or `${env:VAR}` interpolation). */ - args?: string[] - /** Environment variables passed to the child process (supports `${VAR}` or `${env:VAR}` interpolation). */ - env?: Record - /** Working directory for the spawned process (supports `${VAR}` or `${env:VAR}` interpolation). */ - cwd?: string - /** Server endpoint URL (streamable-http or SSE transport, supports `${VAR}` or `${env:VAR}` interpolation). */ - url?: string - /** HTTP headers sent with every request (supports `${VAR}` or `${env:VAR}` interpolation). */ - headers?: Record - /** Explicit transport type. When omitted, auto-detected: `command` → stdio, `url` → streamable-http. */ - transport?: 'stdio' | 'sse' | 'streamable-http' - /** Client credentials for OAuth machine-to-machine auth (streamable-http only). */ - auth?: McpClientCredentials - /** When true, this server is skipped during loadServers. */ - disabled?: boolean - /** When true, config or connection failures skip this server instead of throwing. */ - continueOnError?: boolean - /** Task-augmented tool execution configuration (experimental). */ - tasksConfig?: TasksConfig -} - -/** - * Resolves an MCP servers config into an array of client configurations ready for instantiation. - * - * @param config - A file path to a JSON config, or a flat server map object. - * @param defaults - Options applied to all clients unless overridden per-server. - * @returns Resolved McpClientConfig array (one per enabled, successfully-resolved server). - */ -export async function resolveServerConfigs( - config: string | Record, - defaults?: McpClientOptions -): Promise { - const servers = await loadServersObject(config) - const results: McpClientConfig[] = [] - - for (const [name, server] of Object.entries(servers)) { - if (!server || typeof server !== 'object' || Array.isArray(server)) { - throw new Error(`Server "${name}" must be an object, got ${Array.isArray(server) ? 'array' : typeof server}`) - } - - if (server.disabled) continue - - const continueOnError = server.continueOnError ?? defaults?.continueOnError ?? false - - try { - if (server.command && server.url && !server.transport) { - throw new Error('Server config has both "command" and "url" — set "transport" explicitly or remove one') - } - - const type = server.transport ?? (server.command ? 'stdio' : server.url ? 'streamable-http' : undefined) - if (!type) throw new Error('Server config must include either "command" (stdio) or "url" (http)') - - let clientConfig: McpClientConfig - switch (type) { - case 'stdio': - clientConfig = await buildStdioConfig(server) - break - case 'streamable-http': - clientConfig = buildHttpConfig(server) - break - case 'sse': - clientConfig = await buildSseConfig(server) - break - default: { - const _exhaustive: never = type - throw new Error(`Unsupported transport type: ${_exhaustive}`) - } - } - - results.push({ ...baseOptions(name, server, defaults), ...clientConfig }) - } catch (error) { - if (!continueOnError) throw error - logger.warn(`server=<${name}>, error=<${error}> | MCP server config failed, skipping (continueOnError)`) - } - } - - return results -} - -async function buildStdioConfig(server: McpServerConfig): Promise { - if (!server.command) throw new Error('Stdio transport requires "command" field') - const { StdioClientTransport, getDefaultEnvironment } = await import('@modelcontextprotocol/sdk/client/stdio.js') - - const opts: ConstructorParameters[0] = { - command: interpolateEnv(server.command), - } - if (server.args) opts.args = server.args.map(interpolateEnv) - if (server.env) opts.env = { ...getDefaultEnvironment(), ...interpolateRecord(server.env) } - if (server.cwd) opts.cwd = interpolateEnv(server.cwd) - - return { transport: new StdioClientTransport(opts) as McpTransport } -} - -function buildHttpConfig(server: McpServerConfig): McpClientConfig { - if (!server.url) throw new Error('Streamable HTTP transport requires "url" field') - - const config: McpClientConfig = { url: interpolateEnv(server.url) } - if (server.headers) config.headers = interpolateRecord(server.headers) - if (server.auth) { - config.auth = { - clientId: interpolateEnv(server.auth.clientId), - clientSecret: interpolateEnv(server.auth.clientSecret), - ...(server.auth.scopes && { scopes: server.auth.scopes.map(interpolateEnv) }), - } - } - return config -} - -async function buildSseConfig(server: McpServerConfig): Promise { - if (!server.url) throw new Error('SSE transport requires "url" field') - if (server.auth) - throw new Error('SSE transport does not support auth — use streamable-http or provide a pre-configured transport') - - const { SSEClientTransport } = await import('@modelcontextprotocol/sdk/client/sse.js') - const headers = server.headers ? interpolateRecord(server.headers) : undefined - - return { - transport: new SSEClientTransport( - new URL(interpolateEnv(server.url)), - headers ? { requestInit: { headers } } : undefined - ) as McpTransport, - } -} - -function baseOptions(name: string, server: McpServerConfig, defaults?: McpClientOptions): McpClientOptions { - const opts: McpClientOptions = { ...defaults, applicationName: defaults?.applicationName ?? name } - if (server.continueOnError != null) opts.continueOnError = server.continueOnError - if (server.tasksConfig != null) opts.tasksConfig = server.tasksConfig - return opts -} - -/** - * Replaces `$\{VAR\}` and `$\{env:VAR\}` placeholders with their process.env values. - * Throws if a referenced variable is not set. - * - * @example - * ```typescript - * interpolateEnv('Bearer $\{TOKEN\}') // → 'Bearer ghp_abc123' - * interpolateEnv('$\{env:HOME\}/config') // → '/home/user/config' - * ``` - */ -function interpolateEnv(value: string): string { - return value.replace(/\$\{(?:env:)?([^}]+)\}/g, (_, key: string) => { - const resolved = process.env[key] - if (resolved === undefined) throw new Error(`Environment variable "${key}" is not set`) - return resolved - }) -} - -/** Applies {@link interpolateEnv} to every value in a string record. */ -function interpolateRecord(record: Record): Record { - return Object.fromEntries(Object.entries(record).map(([k, v]) => [k, interpolateEnv(v)])) -} - -async function loadServersObject( - config: string | Record -): Promise> { - if (typeof config !== 'string') return config - - const { readFile } = await import('node:fs/promises') - const { homedir } = await import('node:os') - const { join } = await import('node:path') - - const filePath = config.startsWith('~/') ? join(homedir(), config.slice(2)) : config - const parsed = JSON.parse(await readFile(filePath, 'utf-8')) - const servers = parsed.mcpServers ?? parsed - - if (!servers || typeof servers !== 'object' || Array.isArray(servers)) { - throw new Error( - 'MCP config must be a JSON object mapping server names to configs, e.g. { "my-server": { "command": "node" } }' - ) - } - - return servers -} diff --git a/strands-ts/src/mcp.ts b/strands-ts/src/mcp.ts index 436fdc7e95..32db080758 100644 --- a/strands-ts/src/mcp.ts +++ b/strands-ts/src/mcp.ts @@ -16,7 +16,6 @@ import type { JSONSchema, JSONValue } from './types/json.js' import type { ElicitationCallback } from './types/elicitation.js' import { McpTool } from './tools/mcp-tool.js' import { logger } from './logging/index.js' -import { type McpServerConfig, resolveServerConfigs } from './mcp-config.js' /** * Widened transport type that accepts MCP transport implementations without requiring explicit casts. @@ -127,20 +126,6 @@ export class McpClient { /** Default poll timeout for task completion in milliseconds (5 minutes). */ public static readonly DEFAULT_POLL_TIMEOUT = 300000 - /** - * Parses an MCP servers config (file path or object) and returns McpClient instances. - * - * @param config - A file path to a JSON config, or a flat server map object. - * @param defaults - Options applied to all clients unless overridden per-server. - * @returns An array of McpClient instances ready to be passed to an Agent. - */ - public static async loadServers( - config: string | Record, - defaults?: McpClientOptions - ): Promise { - return (await resolveServerConfigs(config, defaults)).map((c) => new McpClient(c)) - } - private _clientName: string private _clientVersion: string private _transport: Transport From 6285901988f554d81564ecec24efd0fdb9c5fb5f Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow Date: Fri, 29 May 2026 10:50:56 -0400 Subject: [PATCH 4/4] fix(ci): post-merge cleanup for typescript sync - Format strands-py-wasm/_runtime.py with ruff - Add missing || github.sha fallback for call-py-check ref --- .github/workflows/typescript-pr-and-push.yml | 2 +- strands-py-wasm/src/strands/_runtime.py | 8 ++------ 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/.github/workflows/typescript-pr-and-push.yml b/.github/workflows/typescript-pr-and-push.yml index 0d3a1adf60..718a18c2ed 100644 --- a/.github/workflows/typescript-pr-and-push.yml +++ b/.github/workflows/typescript-pr-and-push.yml @@ -36,7 +36,7 @@ jobs: permissions: contents: read with: - ref: ${{ github.event.pull_request.head.sha }} + ref: ${{ github.event.pull_request.head.sha || github.sha }} call-ts-test: uses: ./.github/workflows/typescript-ts-test.yml diff --git a/strands-py-wasm/src/strands/_runtime.py b/strands-py-wasm/src/strands/_runtime.py index ee2f765499..384cc44047 100644 --- a/strands-py-wasm/src/strands/_runtime.py +++ b/strands-py-wasm/src/strands/_runtime.py @@ -346,14 +346,10 @@ async def async_init(self) -> None: self._funcs.constructor.post_return(self._store) async def generate(self, args: _t.InvokeArgs) -> EventStream: - response_handle: ResourceAny = await self._funcs.generate.call_async( - self._store, self._handle, args - ) + response_handle: ResourceAny = await self._funcs.generate.call_async(self._store, self._handle, args) self._funcs.generate.post_return(self._store) self._current_response = response_handle - stream_handle: ResourceAny = await self._funcs.events.call_async( - self._store, response_handle - ) + stream_handle: ResourceAny = await self._funcs.events.call_async(self._store, response_handle) self._funcs.events.post_return(self._store) return EventStream(self, stream_handle)